Commit 82aac4e3 authored by Caillaut Gaetan's avatar Caillaut Gaetan
Browse files

fix MyCamembert

parent 42094045
......@@ -18,13 +18,13 @@ class MyCamembertForSequenceClassification(torch.nn.Module):
self.camembert = CamembertModel.from_pretrained("camembert-base")
self.l1 = torch.nn.Linear(768, 768 // 2, bias=True)
self.l1_activation_fun = parse_activation_function("gelu")
self.l1_activation_fun = torch.nn.functional.gelu
self.l2 = torch.nn.Linear(768//2, num_labels, bias=True)
self.l2_activation_fun = parse_activation_function("none")
def forward(self, input, attention_mask=None):
x = self.camembert(input_ids=input, attention_mask=attention_mask)
outputs = self.camembert(input_ids=input, attention_mask=attention_mask)
x = outputs.last_hidden_state
# Average tokens for sentence classification
if attention_mask is None:
......@@ -40,7 +40,6 @@ class MyCamembertForSequenceClassification(torch.nn.Module):
x = self.l1(x)
x = self.l1_activation_fun(x)
x = self.l2(x)
x = self.l2_activation_fun(x, dim=1)
return x
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment