Commit 88318f1b authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

MiniBertEmbedding + strategie de masquage

parent 53eb4577
__pycache__
\ No newline at end of file
__pycache__
runs
\ No newline at end of file
......@@ -5,10 +5,28 @@ from math import sqrt
__all__ = [
"Attention",
"MiniBert"
"MiniBert",
"MiniBertForTraining",
"MiniBertEmbedding"
]
class MiniBertEmbedding(nn.Module):
def __init__(self, voc_size, embedding_dim):
super().__init__()
self.word_embeddings = nn.Embedding(voc_size, embedding_dim)
self.position_embeddings = nn.Embedding(1024, embedding_dim)
self.norm = nn.LayerNorm(embedding_dim)
self.register_buffer(
"position_ids", torch.arange(1024).expand((1, -1)))
def forward(self, input):
seq_len = input.shape[-1]
emb = self.word_embeddings(input)
pos = self.position_embeddings(self.position_ids[:, :seq_len])
return self.norm(emb + pos)
class Attention(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None):
super(Attention, self).__init__()
......@@ -46,11 +64,13 @@ class Attention(nn.Module):
class MiniBert(nn.Module):
def __init__(self, embedding_dim, voc_size, hidden_dim=None):
def __init__(self, embedding_dim, voc_size, mask_idx, hidden_dim=None):
super(MiniBert, self).__init__()
if hidden_dim is None:
hidden_dim = embedding_dim
self.embedding = nn.Embedding(voc_size, embedding_dim)
self.mask_idx = mask_idx
self.embedding = MiniBertEmbedding(voc_size, embedding_dim)
self.attention = Attention(
embedding_dim, embedding_dim, hidden_dim=hidden_dim)
......@@ -67,4 +87,65 @@ class MiniBert(nn.Module):
def forward(self, input):
x = self.embedding(input)
return self.attention(x)
x = self.attention(x)
return x
class MiniBertForTraining(nn.Module):
def __init__(self, embedding_dim, voc_size, mask_idx, hidden_dim=None, mask_prob=0.15, train=True):
super(MiniBertForTraining, self).__init__()
self.minibert = MiniBert(
embedding_dim, voc_size, mask_idx, hidden_dim=hidden_dim)
self.l1 = nn.Linear(embedding_dim, embedding_dim, bias=False)
self.l2 = nn.Linear(embedding_dim, voc_size, bias=True)
self.mask_idx = mask_idx
self.train = train
self.voc_size = voc_size
self.mask_prob = mask_prob
def forward(self, input):
prev_grad = torch.is_grad_enabled()
torch.set_grad_enabled(self.train)
if self.train:
# masked_input = input.detach().clone()
masked_input = input.clone()
masked = torch.rand_like(
input, dtype=torch.float) <= self.mask_prob
masking_strategy = torch.rand_like(input, dtype=torch.float)
# 80% des cas, on masque
# 10% des cas, on garde
# 10% des cas, on remplace
masking = masked & (masking_strategy <= 0.8) # On masque
corrupt = masked & (0.9 < masking_strategy) # On remplace
replacements = torch.randint(self.voc_size, (torch.sum(corrupt), ))
masked_input[masking] = self.mask_idx
masked_input[corrupt] = replacements
x = self.minibert(masked_input)
else:
x = self.minibert(input)
x = self.l1(x)
x = F.gelu(x)
x = self.l2(x)
if self.train:
# labels = input.detach().clone()
labels = input.clone()
labels[~masked] = -1
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fn(x.view(-1, self.voc_size), labels.view(-1))
torch.set_grad_enabled(prev_grad)
return (x, loss)
else:
torch.set_grad_enabled(prev_grad)
return x
def set_train(self, value):
self.train = value
......@@ -21,11 +21,32 @@ class TestAttention(unittest.TestCase):
xv = torch.tensor([[1.5, 1.5], [2.5, 2], [1, 1]], dtype=torch.float)
x_qk = torch.matmul(xq, xk.t()) / sqrt(2)
expected = torch.matmul(F.softmax(x_qk), xv)
expected = torch.matmul(F.softmax(x_qk, dim=1), xv)
actual = attention(x)
self.assertTrue(torch.equal(expected, actual))
def test_attention_given_batch(self):
k = torch.tensor([[0, 0.5], [1, 0], [0.5, 0.5]], dtype=torch.float)
q = torch.tensor([[0, 0.5], [0, 0], [0.5, 0.5]], dtype=torch.float)
v = torch.tensor([[0.5, 0.5], [1, 0.5], [1, 1]], dtype=torch.float)
attention = Attention.from_weights(k, q, v)
x = torch.tensor(
[[1, 0, 1], [1, 1, 1], [0, 0, 1]], dtype=torch.float)
batch = torch.stack([x, x, x])
xk = torch.tensor([[0.5, 1], [1.5, 1], [0.5, 0.5]], dtype=torch.float)
xq = torch.tensor([[0.5, 1], [0.5, 1], [0.5, 0.5]], dtype=torch.float)
xv = torch.tensor([[1.5, 1.5], [2.5, 2], [1, 1]], dtype=torch.float)
x_qk = torch.matmul(xq, xk.t()) / sqrt(2)
expected = torch.matmul(F.softmax(x_qk, dim=1), xv)
expected = torch.stack([expected, expected, expected])
actual = attention(batch)
self.assertTrue(torch.equal(expected, actual))
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment