Commit c4dad041 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

MiniBertForSequenceClassification

parent 11ed5db1
......@@ -6,6 +6,8 @@ __all__ = [
"MiniBertForMLMConfiguration",
"MiniBertForRegressionConfiguration",
"MiniBertForMLMAndRegressionConfiguration",
"MiniBertForSequenceClassificationConfiguration",
"MiniBertForMLMAndSequenceClassificationConfiguration",
]
......@@ -64,6 +66,18 @@ class MiniBertForRegressionConfiguration(MiniBertConfiguration):
self.output_size = kwargs["output_size"]
class MiniBertForSequenceClassificationConfiguration(MiniBertConfiguration):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Prediction layers
self.first_layer_output_size = kwargs.get(
"first_layer_output_size", self.embedding_dim)
self.first_layer_activation_fun = kwargs.get("activation_fun", "gelu")
self.second_layer_activation_fun = kwargs.get("activation_fun", "none")
self.output_size = kwargs["output_size"]
class MiniBertForMLMAndRegressionConfiguration(MiniBertConfiguration):
def __init__(self, **kwargs):
super().__init__(**kwargs)
......@@ -98,3 +112,39 @@ class MiniBertForMLMAndRegressionConfiguration(MiniBertConfiguration):
self.reg_second_layer_activation_fun = kwargs.get(
"activation_fun", "none")
self.reg_output_size = kwargs["reg_output_size"]
class MiniBertForMLMAndSequenceClassificationConfiguration(MiniBertConfiguration):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mask_idx = kwargs["mask_idx"]
self.mask_token = kwargs.get("mask_token", "<mask>")
# Masking strategy
self.mask_prob = kwargs.get("mask_prob", 0.15)
self.keep_mask_prob = kwargs.get("keep_mask_prob", 0.8)
self.corrupt_mask_prob = kwargs.get("corrupt_mask_prob", 0.1)
self.reveal_mask_prob = kwargs.get("reveal_mask_prob", 0.1)
if self.keep_mask_prob + self.corrupt_mask_prob + self.reveal_mask_prob != 1:
raise ValueError("Sum of masking strategies is not 1")
# Prediction layers
# MLM
self.mlm_first_layer_output_size = kwargs.get(
"mlm_first_layer_output_size", self.embedding_dim)
self.mlm_first_layer_activation_fun = kwargs.get(
"activation_fun", "gelu")
self.mlm_second_layer_activation_fun = kwargs.get(
"activation_fun", "none")
# Sequence classification
self.cls_first_layer_output_size = kwargs.get(
"cls_first_layer_output_size", self.embedding_dim)
self.cls_first_layer_activation_fun = kwargs.get(
"activation_fun", "relu")
self.cls_second_layer_activation_fun = kwargs.get(
"activation_fun", "sigmoid")
self.cls_output_size = kwargs["cls_output_size"]
......@@ -13,6 +13,8 @@ __all__ = [
"MiniBertForMLM",
"MiniBertForRegression",
"MiniBertForMLMAndRegression",
"MiniBertForSequenceClassification",
"MiniBertForMLMAndSequenceClassification",
]
......@@ -154,6 +156,39 @@ class MiniBertForRegression(nn.Module):
return x
class MiniBertForSequenceClassification(nn.Module):
def __init__(self, configuration):
super().__init__()
self.minibert = MiniBert(configuration)
self.configuration = configuration
self._voc_size = len(configuration.vocabulary)
self._embedding_dim = configuration.embedding_dim
l1_in_size = self._embedding_dim * configuration.position_embeddings_count
self.l1 = nn.Linear(
l1_in_size, configuration.first_layer_output_size, bias=True)
self.l1_activation_fun = parse_activation_function(
configuration.first_layer_activation_fun)
self.l2 = nn.Linear(configuration.first_layer_output_size,
configuration.output_size, bias=True)
self.l2_activation_fun = parse_activation_function(
configuration.second_layer_activation_fun)
def forward(self, input):
x = self.minibert(input)
d_batch = x.size(0)
x = x.view(d_batch, -1)
x = self.l1(x)
x = self.l1_activation_fun(x)
x = self.l2(x)
x = self.l2_activation_fun(x)
return x
class MiniBertForMLMAndRegression(nn.Module):
def __init__(self, configuration):
super().__init__()
......@@ -243,3 +278,100 @@ class MiniBertForMLMAndRegression(nn.Module):
x = self.reg_l2(x)
x = self.reg_l2_activation_fun(x)
return x
class MiniBertForMLMAndSequenceClassification(nn.Module):
def __init__(self, configuration):
super().__init__()
self.minibert = MiniBert(configuration)
self.configuration = configuration
self._voc_size = len(configuration.vocabulary)
self._embedding_dim = configuration.embedding_dim
self.mlm_l1 = nn.Linear(self._embedding_dim,
configuration.mlm_first_layer_output_size, bias=False)
self.mlm_l2 = nn.Linear(
configuration.mlm_first_layer_output_size, self._voc_size, bias=True)
self._in_cls_size = self._embedding_dim * \
configuration.position_embeddings_count
self.cls_l1 = nn.Linear(
self._in_cls_size, configuration.cls_first_layer_output_size, bias=False)
self.cls_l2 = nn.Linear(
configuration.cls_first_layer_output_size, configuration.cls_output_size, bias=True)
self.mask_idx = configuration.mask_idx
self.mlm_l1_activation_fun = parse_activation_function(
configuration.mlm_first_layer_activation_fun)
self.cls_l1_activation_fun = parse_activation_function(
configuration.cls_first_layer_activation_fun)
self.mlm_l2_activation_fun = parse_activation_function(
configuration.mlm_second_layer_activation_fun)
self.cls_l2_activation_fun = parse_activation_function(
configuration.cls_second_layer_activation_fun)
self._mask_prob = configuration.mask_prob
self._keep_mask_prob = configuration.keep_mask_prob
self._inv_corrupt_mask_prob = 1 - configuration.corrupt_mask_prob
# task == 0 -> MLM
# task == 1 -> Sequence classification
def forward(self, input, task):
if task == 0:
return self.forward_mlm(input)
elif task == 1:
return self.forward_cls(input)
else:
raise Exception(
f"`task` parameter must be either 0 or 1, not {task}.")
def forward_mlm(self, input):
if self.training:
masked_input = input.clone().detach()
masked = torch.rand_like(
input, dtype=torch.float) <= self._mask_prob
masking_strategy = torch.rand_like(input, dtype=torch.float)
masking = masked & (masking_strategy <=
self._keep_mask_prob) # Keep masks
corrupt = masked & (self._inv_corrupt_mask_prob <
masking_strategy) # Corrupt masks
replacements = torch.randint(
self._voc_size, (torch.sum(corrupt), ), device=input.device)
masked_input[masking] = self.mask_idx
masked_input[corrupt] = replacements
x = self.minibert(masked_input)
else:
x = self.minibert(input)
x = self.mlm_l1(x)
x = self.mlm_l1_activation_fun(x)
x = self.mlm_l2(x)
x = self.mlm_l2_activation_fun(x)
if self.training:
labels = input.clone().detach()
labels[~masked] = -1
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fn(x.view(-1, self._voc_size), labels.view(-1))
return (x, loss)
else:
return x
def forward_cls(self, input):
x = self.minibert(input)
d_batch = x.size(0)
x = x.view(d_batch, -1)
x = self.cls_l1(x)
x = self.cls_l1_activation_fun(x)
x = self.cls_l2(x)
x = self.cls_l2_activation_fun(x)
return x
......@@ -74,6 +74,49 @@ class TestMiniBert(unittest.TestCase):
torch.Size((x_acr.size(0), x_acr.size(1), len(vocabulary)))
)
def test_minibert_mlm_and_sequence_classification(self):
vocabulary = ["<mask>", "a", "b", "c", "d", "e", "<pad>"]
sentences = torch.tensor([
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
])
x_cls = torch.tensor([
[3, 4, 5, 6, 6],
[1, 4, 2, 6, 6],
])
y_cls = torch.tensor([1, 2])
cls_out_size = 2
for pos in PositionalEmbeddingType:
for att in AttentionType:
config = MiniBertForMLMAndSequenceClassificationConfiguration(
vocabulary=vocabulary,
embedding_dim=4,
position_embeddings_count=5,
mask_idx=0,
cls_output_size=cls_out_size,
position_type=pos,
attention_type=att
)
model = MiniBertForMLMAndSequenceClassification(config)
mlm_output = model.forward(sentences, 0)
cls_output = model.forward(x_cls, 1)
self.assertIsInstance(mlm_output, tuple)
self.assertEqual(
mlm_output[0].size(),
torch.Size(
(sentences.size(0), sentences.size(1), len(vocabulary)))
)
self.assertEqual(
cls_output.size(),
torch.Size((x_cls.size(0), cls_out_size))
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment