Commit 514a1c0f authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

WSD avec minibert

parent 94ca918c
......@@ -24,6 +24,15 @@ except:
from eval import *
def load_minibert_model(path, device):
checkpoint = torch.load(path, map_location=torch.device(device))
configuration_dict = checkpoint["configuration"]
device = checkpoint["device"]
configuration = MiniBertForMLMConfiguration(**configuration_dict)
model = MiniBertForMLM(configuration).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
return model
# Retourne les embeddings de definitions
# {
# acr_token: {
......@@ -33,7 +42,7 @@ from eval import *
# }
# }
def get_acr_def_dict(model, tokenizer, acronyms, outdir, device="cuda"):
definitions_file = os.path.join(outdir, "definitions.tar")
definitions_file = os.path.join(outdir, "definitions_acr.tar")
if os.path.exists(definitions_file):
definitions = torch.load(definitions_file)
......@@ -45,7 +54,8 @@ def get_acr_def_dict(model, tokenizer, acronyms, outdir, device="cuda"):
attention_mask = torch.tensor([encoded.attention_mask], device=device)
# wids = torch.tensor([encoded.word_ids], device=device)
output = torch.squeeze(model.minibert(x, attention_mask))
output = model.minibert(x, attention_mask)
output = output.view(len(encoded), -1)
mean_vec = torch.mean(output, dim=0, keepdim=False)
if acr not in definitions:
......@@ -76,12 +86,7 @@ def minibert_wsd(model_path, glob, outdir):
tokenizer = Tokenizer.from_file("../minibert-sncf/data/tokenizer.json")
collater = SncfCollater(tokenizer, pad_token)
checkpoint = torch.load(model_path, map_location=torch.device(device))
configuration_dict = checkpoint["configuration"]
device = checkpoint["device"]
configuration = MiniBertForMLMConfiguration(**configuration_dict)
model = MiniBertForMLM(configuration).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model = load_minibert_model(model_path)
model.eval()
definitions = get_acr_def_dict(model, tokenizer, acronyms, outdir, device)
......@@ -153,6 +158,248 @@ def eval_minibert(args):
df.sort_values("pos", inplace=True, ascending=False)
df.to_csv("output/minibert/scores.csv", index=False)
def wsd_sncf(model, outdir):
device = "cuda"
pin_memory = device != "cpu"
mask_token = "<mask>"
pad_token = "<pad>"
max_seq_size = 32
seq_stride = 32
bs = 512
tokenizer = Tokenizer.from_file("../minibert-sncf/data/tokenizer.json")
dataset = SncfDataset("../minibert-sncf/data", max_seq_size, seq_stride)
collater = SncfCollater(tokenizer, pad_token)
loader = DataLoader(dataset, collate_fn=collater, batch_size=bs, pin_memory=pin_memory)
acronyms = pd.read_csv("data/acr_eval.csv")
acronyms.dropna(inplace=True)
definitions = get_acr_def_dict(model, tokenizer, acronyms, outdir, device="cuda")
# tenseur contenant les id des acronymes
acr_ids = torch.tensor(list(filter(partial(is_not, None), (tokenizer.token_to_id(a) for a in acronyms["Token"]))), device=device)
contextualized_path = str(Path(outdir, "contextualized_acr.tar"))
if os.path.exists(contextualized_path):
contextualized = torch.load(contextualized_path, map_location=torch.device(device))
else:
contextualized = defaultdict(list)
for x, att_mask, wid in loader:
x = x.to(device)
att_mask = att_mask.to(device)
y = model.minibert(x, att_mask)
is_acr = torch.tensor([i in acr_ids for i in x.view(-1)], device=device).view(x.size())
acr_pos = torch.nonzero(is_acr)
for pos in acr_pos:
i, j = pos
acr = tokenizer.id_to_token(x[i, j].item())
v = y[i, j, :]
contextualized[acr].append(v.tolist())
contextualized = {k: torch.tensor(v, device=device) for k, v in contextualized.items()}
torch.save(contextualized, contextualized_path)
## wsd
disambiguated_path = str(Path(outdir, "disambiguated_acr.tar"))
if os.path.exists(disambiguated_path):
disambiguated = torch.load(disambiguated_path, map_location=torch.device(device))
else:
disambiguated = {}
for acr, vecs in contextualized.items():
if acr not in disambiguated:
disambiguated[acr] = []
defdict = definitions[acr]
dists = torch.matmul(vecs, torch.transpose(defdict["vectors"], 0, 1))
senses = torch.argmax(dists, dim=1)
for i in range(len(defdict["definitions"])):
vs = vecs[senses == i, :]
if vs.size(0) == 0:
mean_vec = torch.zeros((1, vecs.size(1)), device=device)
else:
mean_vec = torch.mean(vecs[senses == i, :], dim=0)
wsddict = {
"acr_id": tokenizer.token_to_id(acr),
"vector": mean_vec,
"definition": defdict["definitions"][i],
"lemmatized": defdict["lemmatized"][i]
}
disambiguated[acr].append(wsddict)
torch.save(disambiguated, disambiguated_path)
## neighbors
ignored = torch.tensor([id for w, id in tokenizer.get_vocab().items() if w in ("avoir", "etre") or len(w) < 3], device=device)
neighbors_path = str(Path(outdir, "neighbors_acr.tar"))
if os.path.exists(neighbors_path):
neighbors = torch.load(neighbors_path)
else:
neighbors = {}
for acr, senses in disambiguated.items():
if acr not in neighbors:
neighbors[acr] = []
for i in range(len(senses)):
y = torch.squeeze(model.mlm_head(senses[i]["vector"].view(1, -1)))
# Si l'acronyme apparait dans ses voisins, il dégage
predicted = torch.argsort(y, descending=True).tolist()
neighbors_id = []
for pid in predicted:
if pid == senses[i]["acr_id"] or pid in ignored:
continue
else:
neighbors_id.append(pid)
if len(neighbors_id) == 4:
break
neighbors_tokens = [tokenizer.id_to_token(id) for id in neighbors_id]
neighbors[acr].append(neighbors_tokens)
torch.save(neighbors, neighbors_path)
## to csv
cols = {
"acronyme": [],
"definition": [],
"lemmatized": [],
"voisin1": [],
"voisin2": [],
"voisin3": [],
"voisin4": []
}
for acr in disambiguated.keys():
neighs = neighbors[acr]
defs = disambiguated[acr]
for i in range(len(neighs)):
cols["acronyme"].append(acr)
cols["definition"].append(defs[i]["definition"])
cols["lemmatized"].append(defs[i]["lemmatized"])
cols["voisin1"].append(neighs[i][0])
cols["voisin2"].append(neighs[i][1])
cols["voisin3"].append(neighs[i][2])
cols["voisin4"].append(neighs[i][3])
csv_path = str(Path(outdir, "neighbors_acr.csv"))
df = pd.DataFrame(cols)
df.sort_values("acronyme")
df.to_csv(csv_path, index=False)
def lexicon_neighbors(model, outdir):
device = "cuda"
pin_memory = device != "cpu"
mask_token = "<mask>"
pad_token = "<pad>"
max_seq_size = 32
seq_stride = 32
bs = 512
lexicon = pd.read_csv("data/lex_eval.csv")
lexicon.dropna(inplace=True)
tokenizer = Tokenizer.from_file("../minibert-sncf/data/tokenizer.json")
dataset = SncfDataset("../minibert-sncf/data", max_seq_size, seq_stride)
collater = SncfCollater(tokenizer, pad_token)
loader = DataLoader(dataset, collate_fn=collater, batch_size=bs, pin_memory=pin_memory)
# tenseur contenant les id des mots
lex_ids = torch.tensor(list(filter(partial(is_not, None), (tokenizer.token_to_id(l) for l in lexicon["Token"]))), device=device)
contextualized_path = str(Path(outdir, "contextualized_lex.tar"))
if os.path.exists(contextualized_path):
contextualized = torch.load(contextualized_path, map_location=torch.device(device))
else:
contextualized = defaultdict(list)
for x, att_mask, wid in loader:
x = x.to(device)
att_mask = att_mask.to(device)
y = model.minibert(x, att_mask)
is_lex = torch.tensor([i in lex_ids for i in x.view(-1)], device=device).view(x.size())
lex_pos = torch.nonzero(is_lex)
for pos in lex_pos:
i, j = pos
lex = tokenizer.id_to_token(x[i, j].item())
v = y[i, j, :]
contextualized[lex].append(v.tolist())
ctx_mean = {}
for k, lexdef, lemm in zip(lexicon["Token"], lexicon["Definition"], lexicon["Lemmatized"]):
vs = contextualized[k]
if len(vs) == 0:
vs = torch.zeros((0, model.configuration.embedding_dim), device=device)
vs_mean = torch.zeros(model.configuration.embedding_dim, device=device)
else:
vs = torch.tensor(vs, device=device)
vs_mean = torch.mean(vs, dim=0)
ctx_mean[k] = {
"lex_id": tokenizer.token_to_id(k),
"vectors": vs,
"mean": vs_mean,
"definition": lexdef,
"lemmatized": lemm
}
torch.save(ctx_mean, contextualized_path)
contextualized = ctx_mean
## neighbors
ignored = torch.tensor([id for w, id in tokenizer.get_vocab().items() if w in ("avoir", "etre") or len(w) < 3], device=device)
neighbors_path = str(Path(outdir, "neighbors_lex.tar"))
if os.path.exists(neighbors_path):
neighbors = torch.load(neighbors_path)
else:
neighbors = {}
for lex, d in contextualized.items():
y = torch.squeeze(model.mlm_head(d["mean"].view(1, -1)))
# Si le mot apparait dans ses voisins, il dégage
predicted = torch.argsort(y, descending=True).tolist()
neighbors_id = []
for pid in predicted:
if pid == d["lex_id"] or pid in ignored:
continue
else:
neighbors_id.append(pid)
if len(neighbors_id) == 4:
break
neighbors_tokens = [tokenizer.id_to_token(id) for id in neighbors_id]
neighbors[lex] = neighbors_tokens
torch.save(neighbors, neighbors_path)
## to csv
cols = {
"term": [],
"definition": [],
"lemmatized": [],
"voisin1": [],
"voisin2": [],
"voisin3": [],
"voisin4": []
}
for lex in neighbors.keys():
neighs = neighbors[lex]
defs = contextualized[lex]
cols["term"].append(lex)
cols["definition"].append(defs["definition"])
cols["lemmatized"].append(defs["lemmatized"])
cols["voisin1"].append(neighs[0])
cols["voisin2"].append(neighs[1])
cols["voisin3"].append(neighs[2])
cols["voisin4"].append(neighs[3])
csv_path = str(Path(outdir, "neighbors_lex.csv"))
df = pd.DataFrame(cols)
df.sort_values("term")
df.to_csv(csv_path, index=False)
def compute_neighbors_for_sncf(args):
for md in os.listdir(args.path):
cp_path = os.path.join(args.path, md, "checkpoint-00100.tar")
outdir = os.path.join("output", "minibert", md)
model = load_minibert_model(cp_path, "cuda")
model.eval()
# wsd_sncf(model, outdir)
lexicon_neighbors(model, outdir)
if __name__ == "__main__":
import argparse
......@@ -169,5 +416,9 @@ if __name__ == "__main__":
eval_parser = subparsers.add_parser("eval")
eval_parser.set_defaults(func=eval_minibert)
sncf_parser = subparsers.add_parser("sncf")
sncf_parser.add_argument("-p", "--path", default="../minibert-sncf/models")
sncf_parser.set_defaults(func=compute_neighbors_for_sncf)
args = parser.parse_args()
args.func(args)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment