Commit 31c365cb authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

camembert v2

parent cb65454a
......@@ -17,10 +17,10 @@ class MyCamembertForSequenceClassification(torch.nn.Module):
super(MyCamembertForSequenceClassification, self).__init__()
self.camembert = CamembertModel.from_pretrained("camembert-base")
self.l1 = nn.Linear(768, 768 / 2, bias=True)
self.l1 = torch.nn.Linear(768, 768 / 2, bias=True)
self.l1_activation_fun = parse_activation_function("gelu")
self.l2 = nn.Linear(768/2, num_labels, bias=True)
self.l2 = torch.nn.Linear(768/2, num_labels, bias=True)
self.l2_activation_fun = parse_activation_function("none")
def forward(self, input, attention_mask=None):
......@@ -1768,7 +1768,7 @@ if __name__ == "__main__":
cam_t1_parser.add_argument("--epochs-between-save", default=10, type=int)
cam_t1_parser.add_argument("--show-progress", default=50, type=int)
cam_t1_parser.add_argument("--sample", action="store_true")
cam_t1_parser.set_defaults(func=finetune_t1_camembert)
cam_t1_parser.set_defaults(func=finetune_t1_camembert_v2)
cam_t2_parser = subparsers.add_parser("camembert-t2-v2")
cam_t2_parser.add_argument("train")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment