Commit 6d9bd67a authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

Attention

parents
__pycache__
\ No newline at end of file
from .modules import *
import torch
from torch import nn
from torch.nn import functional as F
from math import sqrt
__all__ = [
"Attention"
]
class Attention(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None):
super(Attention, self).__init__()
if hidden_dim is None:
hidden_dim = out_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.key = nn.Parameter(torch.rand((in_dim, hidden_dim)))
self.query = nn.Parameter(torch.rand((in_dim, hidden_dim)))
self.value = nn.Parameter(torch.rand((in_dim, out_dim)))
self._sqrt_hidden = sqrt(hidden_dim)
@classmethod
def from_weights(cls, key, query, value):
in_dim, hidden_dim = key.shape
out_dim = value.shape[1]
x = cls(in_dim, out_dim, hidden_dim)
with torch.no_grad():
x.key = nn.Parameter(key)
x.query = nn.Parameter(query)
x.value = nn.Parameter(value)
return x
def forward(self, input):
key = torch.matmul(input, self.key)
query = torch.matmul(input, self.query)
value = torch.matmul(input, self.value)
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_hidden
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
import unittest
import torch
from torch.nn import functional as F
from math import sqrt
from minibert import Attention
class TestAttention(unittest.TestCase):
def test_attention_given_matrix(self):
k = torch.tensor([[0, 0.5], [1, 0], [0.5, 0.5]], dtype=torch.float)
q = torch.tensor([[0, 0.5], [0, 0], [0.5, 0.5]], dtype=torch.float)
v = torch.tensor([[0.5, 0.5], [1, 0.5], [1, 1]], dtype=torch.float)
attention = Attention.from_weights(k, q, v)
x = torch.tensor(
[[1, 0, 1], [1, 1, 1], [0, 0, 1]], dtype=torch.float)
xk = torch.tensor([[0.5, 1], [1.5, 1], [0.5, 0.5]], dtype=torch.float)
xq = torch.tensor([[0.5, 1], [0.5, 1], [0.5, 0.5]], dtype=torch.float)
xv = torch.tensor([[1.5, 1.5], [2.5, 2], [1, 1]], dtype=torch.float)
x_qk = torch.matmul(xq, xk.t()) / sqrt(2)
expected = torch.matmul(F.softmax(x_qk), xv)
actual = attention(x)
self.assertTrue(torch.equal(expected, actual))
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment