Commit 35bd6c7f authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

eval_model peut recevoir des modèles huggingface

parent 6f34a7a8
...@@ -13,7 +13,11 @@ def eval_model(model, dataloader, device): ...@@ -13,7 +13,11 @@ def eval_model(model, dataloader, device):
labels = labels.to(device) labels = labels.to(device)
logits = model(x, attention_mask) logits = model(x, attention_mask)
predicted = torch.argmax(logits, dim=-1).tolist() try:
predicted = torch.argmax(logits, dim=-1).tolist()
except TypeError:
# Si le modèle provient de huggingface
predicted = torch.argmax(logits["logits"], dim=-1).tolist()
for pred, gold in zip(predicted, labels.tolist()): for pred, gold in zip(predicted, labels.tolist()):
if gold not in confusion: if gold not in confusion:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment