Commit 536633fc authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

MiniBERT for sequence classification with attention

parent 2ddb43e5
......@@ -249,25 +249,34 @@ class MiniBertForSequenceClassificationWithAttention(nn.Module):
self.l2_activation_fun = parse_activation_function(
configuration.second_layer_activation_fun)
def forward(self, input, attention_mask=None):
x = self.minibert(input, attention_mask)
def forward(self, input, attention_mask=None, return_attention=False):
if return_attention:
x, att_layers = self.minibert(
input, attention_mask, return_attention=return_attention)
else:
x = self.minibert(input, attention_mask)
aw = torch.matmul(x, self.attention_vec)
if attention_mask is not None:
attention_modifier = (1 - attention_mask.float()) * (-1000)
print("attention_modifier:", attention_modifier.size())
aw = aw + attention_modifier
att = nn.functional.softmax(aw, dim=-1)
att = torch.unsqueeze(att, 1)
x = torch.matmul(att, x)
x = torch.squeeze(x)
#x = torch.squeeze(x)
x = x.view(x.size(0), -1)
x = self.l1(x)
x = self.l1_activation_fun(x)
x = self.l2(x)
x = self.l2_activation_fun(x, dim=1)
return x
if return_attention:
return (x, att_layers, torch.squeeze(att))
else:
return x
class MiniBertForMLMAndRegression(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment