Commit 2413e9e8 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

set model_type in eval

parent 304c66f9
......@@ -52,7 +52,8 @@ def parse_run_path(s):
res["position"] = x[2]
res["architecture"] = x[5]
res["attention-scaling"] = x[6]
res["frozen"] = run_name.endswith("_frozen")
res["frozen"] = "_frozen" in run_name
res["wa"] = "_wa" in run_name
return res
......@@ -98,6 +99,10 @@ def eval_t1(run, e, dev_loader, test_loader):
scal = run["attention-scaling"]
frozen = run["frozen"]
crps = run["corpus"]
if run["wa"]:
model_type = MiniBertForSequenceClassificationWithAttention
else:
model_type = MiniBertForSequenceClassification
model, _, _, _ = t1_model_from_params(
mlm_path(f"models/{crps}", d, att, pos, arch, scal),
......@@ -112,7 +117,8 @@ def eval_t1(run, e, dev_loader, test_loader):
checkpoint_path=str(Path(run["path"], f"checkpoint-{e:05}.tar")),
height=height,
depth=depth,
attention_scaling=parse_attention_scaling(scal)
attention_scaling=parse_attention_scaling(scal),
model_type=model_type
)
model.eval()
dev_tp, dev_fp, dev_fn, dev_r, dev_p, dev_f = fmeasure_deft2018_t1(
......@@ -149,6 +155,10 @@ def eval_t2(run, e, dev_loader, test_loader):
scal = run["attention-scaling"]
frozen = run["frozen"]
crps = run["corpus"]
if run["wa"]:
model_type = MiniBertForSequenceClassificationWithAttention
else:
model_type = MiniBertForSequenceClassification
model, _, _, _ = t2_model_from_params(
mlm_path(f"models/{crps}", d, att, pos, arch, scal),
......@@ -163,7 +173,8 @@ def eval_t2(run, e, dev_loader, test_loader):
checkpoint_path=str(Path(run["path"], f"checkpoint-{e:05}.tar")),
height=height,
depth=depth,
attention_scaling=parse_attention_scaling(scal)
attention_scaling=parse_attention_scaling(scal),
model_type=model_type
)
model.eval()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment