Commit 41f51b95 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

MiniBert

parent 6d9bd67a
......@@ -4,7 +4,8 @@ from torch.nn import functional as F
from math import sqrt
__all__ = [
"Attention"
"Attention",
"MiniBert"
]
......@@ -42,3 +43,28 @@ class Attention(nn.Module):
qk = torch.matmul(query, key_t) / self._sqrt_hidden
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
class MiniBert(nn.Module):
def __init__(self, embedding_dim, voc_size, hidden_dim=None):
super(MiniBert, self).__init__()
if hidden_dim is None:
hidden_dim = embedding_dim
self.embedding = nn.Embedding(voc_size, embedding_dim)
self.attention = Attention(
embedding_dim, embedding_dim, hidden_dim=hidden_dim)
@classmethod
def from_weights(cls, embeddings, key, query, value, freeze=False):
voc_size, embedding_dim = embeddings.shape
hidden_dim = key.shape[1]
x = cls(embedding_dim, voc_size, hidden_dim=hidden_dim)
with torch.no_grad():
x.embedding = nn.Embedding.from_pretrained(
embeddings, freeze=freeze)
x.attention = Attention.from_weights(key, query, value)
return x
def forward(self, input):
x = self.embedding(input)
return self.attention(x)
import unittest
import torch
from minibert import MiniBert
class TestMiniBert(unittest.TestCase):
def test_minibert_not_fail(self):
minibert = MiniBert(10, 10)
x = torch.tensor([
[0, 1, 3, 4],
[0, 1, 3, 4],
[0, 1, 3, 4]
])
out = minibert(x)
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment