Commit 520f8890 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

eval minibert

parent 780eb86d
......@@ -20,6 +20,7 @@ except:
import sys
sys.path.append("../minibert")
from minibert import *
from eval import *
# Retourne les embeddings de definitions
......@@ -33,7 +34,7 @@ except:
def get_acr_def_dict(model, tokenizer, acronyms, device="cuda"):
definitions_file = "output/minibert/definitions.tar"
if os.path.exists(definitions_file):
if os.path.exists(definitions_file) and False:
definitions = torch.load(definitions_file)
else:
definitions = {}
......@@ -43,8 +44,8 @@ def get_acr_def_dict(model, tokenizer, acronyms, device="cuda"):
attention_mask = torch.tensor([encoded.attention_mask], device=device)
# wids = torch.tensor([encoded.word_ids], device=device)
output = model.minibert(x, attention_mask)
mean_vec = torch.mean(output, dim=1)
output = torch.squeeze(model.minibert(x, attention_mask))
mean_vec = torch.mean(output, dim=0, keepdim=False)
if acr not in definitions:
definitions[acr] = { "vectors": [], "definitions": [], "lemmatized": [] }
......@@ -58,7 +59,7 @@ def get_acr_def_dict(model, tokenizer, acronyms, device="cuda"):
return definitions
def minibert_wsd(args):
def minibert_wsd(model_path, glob):
device = "cuda"
pin_memory = device != "cpu"
......@@ -74,7 +75,7 @@ def minibert_wsd(args):
tokenizer = Tokenizer.from_file("../minibert-sncf/data/tokenizer.json")
collater = SncfCollater(tokenizer, pad_token)
checkpoint = torch.load(args.model, map_location=torch.device(device))
checkpoint = torch.load(model_path, map_location=torch.device(device))
configuration_dict = checkpoint["configuration"]
device = checkpoint["device"]
configuration = MiniBertForMLMConfiguration(**configuration_dict)
......@@ -84,6 +85,14 @@ def minibert_wsd(args):
definitions = get_acr_def_dict(model, tokenizer, acronyms, device)
if glob:
all_defs = []
all_defs_vecs = []
for d in definitions.values():
all_defs.extend(d["lemmatized"])
all_defs_vecs.append(d["vectors"])
all_defs_vecs = torch.vstack(all_defs_vecs)
json_path = "data/annotation.json"
with open(json_path, "r", encoding="UTF-8") as f:
json_data = f.read()
......@@ -101,14 +110,47 @@ def minibert_wsd(args):
acr = tok["token"]
v = embeddings[0, i, :].view(-1, 1)
dists = torch.matmul(definitions[acr]["vectors"], v).view(-1)
idef = torch.argmax(dists).item()
annotated[isent]["acronymes"][iacr]["prediction"] = definitions[acr]["lemmatized"][idef]
predictions_path = "output/minibert/predictions.json"
if glob:
dists = torch.matmul(all_defs_vecs, v).view(-1)
idef = torch.argmax(dists).item()
annotated[isent]["acronymes"][iacr]["prediction"] = all_defs[idef]
else:
dists = torch.matmul(definitions[acr]["vectors"], v).view(-1)
idef = torch.argmax(dists).item()
annotated[isent]["acronymes"][iacr]["prediction"] = definitions[acr]["lemmatized"][idef]
fname = os.path.basename(os.path.dirname(model_path))
if glob:
predictions_path = f"output/minibert/predictions_{fname}_glob.json"
else:
predictions_path = f"output/minibert/predictions_{fname}.json"
with open(predictions_path, "w", encoding="UTF-8") as f:
f.write(json.dumps(annotated, indent=4, ensure_ascii=False))
def all_minibert_wsd(args):
for md in os.listdir(args.path):
cp_path = os.path.join(args.path, md, "checkpoint-00100.tar")
minibert_wsd(cp_path, args.glob)
def eval_minibert(args):
pred_dir = "output/minibert"
pred_files = [f for f in os.listdir(pred_dir) if f.startswith("predictions_")]
resd = {
"file": [],
"pos": []
}
for f in pred_files:
pred_path = os.path.join(pred_dir, f)
annot = load_annot(pred_path)
prec, rapp, prm = acc(count(annot))
resd["file"].append(f)
resd["pos"].append(prec)
df = pd.DataFrame(resd)
df.sort_values("pos", inplace=True, ascending=False)
df.to_csv("output/minibert/scores.csv", index=False)
if __name__ == "__main__":
import argparse
......@@ -117,8 +159,13 @@ if __name__ == "__main__":
subparsers = parser.add_subparsers()
wsd_parser = subparsers.add_parser("wsd")
wsd_parser.add_argument("-m", "--model", default="../minibert-sncf/models/d64_self-attention_fixed_gelu_norm/checkpoint-00100.tar")
wsd_parser.set_defaults(func=minibert_wsd)
# wsd_parser.add_argument("-m", "--model", default="../minibert-sncf/models/d64_self-attention_fixed_gelu_norm/checkpoint-00100.tar")
wsd_parser.add_argument("-p", "--path", default="../minibert-sncf/models")
wsd_parser.add_argument("-g", "--glob", action="store_true")
wsd_parser.set_defaults(func=all_minibert_wsd)
eval_parser = subparsers.add_parser("eval")
eval_parser.set_defaults(func=eval_minibert)
args = parser.parse_args()
args.func(args)
args.func(args)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment