Commit 9e3b1262 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

chaque modèle a son dossier

parent 520f8890
......@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
from sklearn.metrics import pairwise
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
from pathlib import Path
from operator import is_not
......@@ -31,10 +32,10 @@ from eval import *
# "lemmatized": liste de chaines de caractères
# }
# }
def get_acr_def_dict(model, tokenizer, acronyms, device="cuda"):
definitions_file = "output/minibert/definitions.tar"
def get_acr_def_dict(model, tokenizer, acronyms, outdir, device="cuda"):
definitions_file = os.path.join(outdir, "definitions.tar")
if os.path.exists(definitions_file) and False:
if os.path.exists(definitions_file):
definitions = torch.load(definitions_file)
else:
definitions = {}
......@@ -59,7 +60,7 @@ def get_acr_def_dict(model, tokenizer, acronyms, device="cuda"):
return definitions
def minibert_wsd(model_path, glob):
def minibert_wsd(model_path, glob, outdir):
device = "cuda"
pin_memory = device != "cpu"
......@@ -83,7 +84,7 @@ def minibert_wsd(model_path, glob):
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
definitions = get_acr_def_dict(model, tokenizer, acronyms, device)
definitions = get_acr_def_dict(model, tokenizer, acronyms, outdir, device)
if glob:
all_defs = []
......@@ -119,34 +120,35 @@ def minibert_wsd(model_path, glob):
idef = torch.argmax(dists).item()
annotated[isent]["acronymes"][iacr]["prediction"] = definitions[acr]["lemmatized"][idef]
fname = os.path.basename(os.path.dirname(model_path))
if glob:
predictions_path = f"output/minibert/predictions_{fname}_glob.json"
predictions_path = os.path.join(outdir, "predictions_glob.json")
else:
predictions_path = f"output/minibert/predictions_{fname}.json"
predictions_path = os.path.join(outdir, "predictions.json")
with open(predictions_path, "w", encoding="UTF-8") as f:
f.write(json.dumps(annotated, indent=4, ensure_ascii=False))
def all_minibert_wsd(args):
for md in os.listdir(args.path):
cp_path = os.path.join(args.path, md, "checkpoint-00100.tar")
minibert_wsd(cp_path, args.glob)
outdir = os.path.join("output", "minibert", md)
Path(outdir).mkdir(exist_ok=True, parents=True)
minibert_wsd(cp_path, args.glob, outdir)
def eval_minibert(args):
pred_dir = "output/minibert"
pred_files = [f for f in os.listdir(pred_dir) if f.startswith("predictions_")]
pred_files = Path("output/minibert").glob("**/predictions*.json")
resd = {
"file": [],
"pos": []
"pos": [],
"glob": []
}
for f in pred_files:
pred_path = os.path.join(pred_dir, f)
annot = load_annot(pred_path)
annot = load_annot(str(f))
prec, rapp, prm = acc(count(annot))
resd["file"].append(f)
resd["file"].append(f.parent.name)
resd["pos"].append(prec)
resd["glob"].append(str(f).endswith("_glob.json"))
df = pd.DataFrame(resd)
df.sort_values("pos", inplace=True, ascending=False)
df.to_csv("output/minibert/scores.csv", index=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment