Commit ce5dd07b authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

Merge branch 'refactoring' into 'master'

Refactoring

See merge request !2
parents e5603fb6 784bf267
__pycache__
runs
dist
minibert_pkg_gcaillaut.egg-info
\ No newline at end of file
minibert.egg-info
\ No newline at end of file
- Utiliser des nn.Linear pour les matrices key, query et value ?
- utiliser load_state_dict dans les méthodes from_weights
\ No newline at end of file
from .attention import *
from .embeddings import *
from .modules import *
from .configuration import *
from enum import Enum
from math import sqrt
from .embeddings import PositionalEmbeddingType, PositionalEmbedding
import torch
from torch import nn
from torch.nn import functional as F
__all__ = [
"AttentionType",
"AttentionEmbedding",
"Attention",
"NonTransformingAttention",
]
class AttentionType(Enum):
SelfAttention = 1
AttentionEmbedding = 2
NonTransformingAttention = 3
class AttentionEmbedding(nn.Module):
def __init__(self, embedding_dim, voc_size, out_dim=None, max_seq_len=1024, position_type=PositionalEmbeddingType.TRAINED, normalize_embeddings=True):
super(AttentionEmbedding, self).__init__()
if out_dim is None:
out_dim = embedding_dim
self.embedding_dim = embedding_dim
self.voc_size = voc_size
self.out_dim = out_dim
self.max_seq_len = max_seq_len
self.key = nn.Embedding(voc_size, embedding_dim)
self.query = nn.Embedding(voc_size, embedding_dim)
self.value = nn.Embedding(voc_size, out_dim)
self._sqrt_embedding = sqrt(embedding_dim)
self.position_embedding = PositionalEmbedding(
embedding_dim, max_seq_len, ptype=position_type)
self.norm = None
self.normalize_embeddings = normalize_embeddings
if normalize_embeddings:
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, input):
pos = self.position_embedding(input)
key = self.key(input) + pos
query = self.query(input) + pos
value = self.value(input)
if self.normalize_embeddings:
key = self.norm(key)
query = self.norm(query)
value = self.norm(value)
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_embedding
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
class Attention(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None, key_is_query=False):
super(Attention, self).__init__()
if hidden_dim is None:
hidden_dim = out_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.key_is_query = key_is_query
self.hidden_dim = hidden_dim
self.key = nn.Parameter(torch.rand((in_dim, hidden_dim)))
if key_is_query:
self.query = self.key
else:
self.query = nn.Parameter(torch.rand((in_dim, hidden_dim)))
self.value = nn.Parameter(torch.rand((in_dim, out_dim)))
self._sqrt_hidden = sqrt(hidden_dim)
@classmethod
def from_weights(cls, key, query, value):
in_dim, hidden_dim = key.shape
out_dim = value.shape[1]
x = cls(in_dim, out_dim, hidden_dim)
with torch.no_grad():
x.key = nn.Parameter(key)
x.query = nn.Parameter(query)
x.value = nn.Parameter(value)
return x
def forward(self, input):
key = torch.matmul(input, self.key)
query = torch.matmul(input, self.query)
value = torch.matmul(input, self.value)
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_hidden
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
class NonTransformingAttention(nn.Module):
def __init__(self, dim):
super(NonTransformingAttention, self).__init__()
self.dim = dim
self._sqrt_dim = sqrt(dim)
def forward(self, input):
query = input
key = input
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_dim
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, input)
from minibert.modules import AttentionType, PositionEmbeddingType
from minibert.attention import AttentionType
from .embeddings import PositionalEmbeddingType
__all__ = [
"MiniBertConfiguration",
"MiniBertForMLMConfiguration",
"MiniBertForRegressionConfiguration",
"MiniBertForMLMAndRegressionConfiguration",
]
......@@ -22,7 +24,7 @@ class MiniBertConfiguration:
self.position_embeddings_count = kwargs.get(
"position_embeddings_count", 1024)
self.position_type = kwargs.get(
"position_type", PositionEmbeddingType.TRAINED)
"position_type", PositionalEmbeddingType.TRAINED)
self.normalize_embeddings = kwargs.get(
"normalize_embeddings", True)
......@@ -58,3 +60,32 @@ class MiniBertForRegressionConfiguration(MiniBertConfiguration):
"first_layer_output_size", self.embedding_dim)
self.activation_fun = kwargs.get("activation_fun", "gelu")
self.output_size = kwargs["output_size"]
class MiniBertForMLMAndRegressionConfiguration(MiniBertConfiguration):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mask_idx = kwargs["mask_idx"]
self.mask_token = kwargs.get("mask_token", "<mask>")
# Masking strategy
self.mask_prob = kwargs.get("mask_prob", 0.15)
self.keep_mask_prob = kwargs.get("keep_mask_prob", 0.8)
self.corrupt_mask_prob = kwargs.get("corrupt_mask_prob", 0.1)
self.reveal_mask_prob = kwargs.get("reveal_mask_prob", 0.1)
if self.keep_mask_prob + self.corrupt_mask_prob + self.reveal_mask_prob != 1:
raise ValueError("Sum of masking strategies is not 1")
# Prediction layers
## MLM
self.mlm_first_layer_output_size = kwargs.get(
"mlm_first_layer_output_size", self.embedding_dim)
self.mlm_activation_fun = kwargs.get("mlm_activation_fun", "gelu")
## Regression
self.reg_first_layer_output_size = kwargs.get(
"reg_first_layer_output_size", self.embedding_dim)
self.reg_activation_fun = kwargs.get("reg_activation_fun", "gelu")
self.reg_output_size = kwargs["reg_output_size"]
\ No newline at end of file
from enum import Enum
from math import sin, cos
import torch
from torch import nn
__all__ = [
"PositionalEmbeddingType",
"PositionalEmbedding",
"MiniBertEmbedding",
]
class PositionalEmbeddingType(Enum):
TRAINED = 1
FIXED = 2
NONE = 3
class PositionalEmbedding(nn.Module):
def __init__(self, embedding_dim, max_seq_len, ptype=PositionalEmbeddingType.TRAINED):
super().__init__()
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
if ptype == PositionalEmbeddingType.TRAINED:
self.embeddings = nn.Embedding(max_seq_len, embedding_dim)
elif ptype == PositionalEmbeddingType.FIXED:
# See Attention is all you need, section 3.5 (https://arxiv.org/pdf/1706.03762.pdf)
positions = torch.zeros(
(max_seq_len, embedding_dim), dtype=torch.float)
for pos in range(max_seq_len):
for i in range(embedding_dim):
if i % 2 == 0:
positions[pos, i] = sin(
pos / pow(10000, 2 * i / embedding_dim))
else:
positions[pos, i] = cos(
pos / pow(10000, 2 * i / embedding_dim))
self.embeddings = nn.Embedding.from_pretrained(
positions, freeze=True)
elif ptype == PositionalEmbeddingType.NONE:
positions = torch.zeros(
(max_seq_len, embedding_dim), dtype=torch.float)
self.embeddings = nn.Embedding.from_pretrained(
positions, freeze=True)
else:
raise Exception("Invalid position type")
self.register_buffer(
"position_ids", torch.arange(max_seq_len).expand((1, -1)))
def forward(self, input):
seq_len = input.shape[-1]
return self.embeddings(self.position_ids[:, :seq_len])
class MiniBertEmbedding(nn.Module):
def __init__(self, voc_size, embedding_dim, max_seq_len, position_type, normalize_embeddings):
super().__init__()
self.max_seq_len = max_seq_len
self.position_type = position_type
self.normalize_embeddings = normalize_embeddings
self.word_embeddings = nn.Embedding(voc_size, embedding_dim)
self.position_embeddings = PositionalEmbedding(
embedding_dim, max_seq_len, ptype=position_type)
self.norm = None
if normalize_embeddings:
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, input):
word_emb = self.word_embeddings(input)
pos_emb = self.position_embeddings(input)
emb = word_emb + pos_emb
if self.normalize_embeddings:
return self.norm(emb)
else:
return emb
......@@ -5,220 +5,17 @@ from math import sqrt, sin, cos, pow
from enum import Enum
from .activations import parse_activation_function
from .attention import Attention, AttentionType, AttentionEmbedding, NonTransformingAttention
from .embeddings import MiniBertEmbedding
__all__ = [
"Attention",
"NonTransformingAttention",
"MiniBert",
"MiniBertForMLM",
"MiniBertForRegression",
"MiniBertEmbedding",
"PositionEmbeddingType",
"AttentionType",
"MiniBertForMLMAndRegression",
]
class PositionEmbeddingType(Enum):
TRAINED = 1
FIXED = 2
NONE = 3
class AttentionType(Enum):
SelfAttention = 1
AttentionEmbedding = 2
NonTransformingAttention = 3
class PositionnalEmbedding(nn.Module):
def __init__(self, embedding_dim, max_seq_len, position_type=PositionEmbeddingType.TRAINED):
super().__init__()
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
if position_type == PositionEmbeddingType.TRAINED:
self.position_embeddings = nn.Embedding(
max_seq_len, embedding_dim)
elif position_type == PositionEmbeddingType.FIXED:
# See Attention is all you need, section 3.5 (https://arxiv.org/pdf/1706.03762.pdf)
d = embedding_dim
positions = torch.zeros((max_seq_len, d), dtype=torch.float)
for pos in range(max_seq_len):
for i in range(d):
if i % 2 == 0:
positions[pos, i] = sin(pos / pow(10000, 2 * i / d))
else:
positions[pos, i] = cos(pos / pow(10000, 2 * i / d))
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
elif position_type == PositionEmbeddingType.NONE:
positions = torch.zeros(
(max_seq_len, embedding_dim), dtype=torch.float)
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
else:
raise Exception("Invalid position type")
self.register_buffer(
"position_ids", torch.arange(max_seq_len).expand((1, -1)))
def forward(self, input):
seq_len = input.shape[-1]
return self.position_embeddings(self.position_ids[:, :seq_len])
class MiniBertEmbedding(nn.Module):
def __init__(self, voc_size, embedding_dim, position_count, position_type, normalize_embeddings):
super().__init__()
self.position_count = position_count
self.position_type = position_type
self.normalize_embeddings = normalize_embeddings
self.word_embeddings = nn.Embedding(voc_size, embedding_dim)
if position_type == PositionEmbeddingType.TRAINED:
self.position_embeddings = nn.Embedding(
position_count, embedding_dim)
elif position_type == PositionEmbeddingType.FIXED:
# See Attention is all you need, section 3.5 (https://arxiv.org/pdf/1706.03762.pdf)
d = embedding_dim
positions = torch.zeros((position_count, d), dtype=torch.float)
for pos in range(position_count):
for i in range(d):
if i % 2 == 0:
positions[pos, i] = sin(pos / pow(10000, 2 * i / d))
else:
positions[pos, i] = cos(pos / pow(10000, 2 * i / d))
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
elif position_type == PositionEmbeddingType.NONE:
positions = torch.zeros(
(position_count, embedding_dim), dtype=torch.float)
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
else:
raise Exception("Invalid position type")
self.register_buffer(
"position_ids", torch.arange(position_count).expand((1, -1)))
self.norm = None
if normalize_embeddings:
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, input):
seq_len = input.shape[-1]
word_emb = self.word_embeddings(input)
pos_emb = self.position_embeddings(self.position_ids[:, :seq_len])
emb = word_emb + pos_emb
if self.normalize_embeddings:
return self.norm(emb)
else:
return emb
class AttentionEmbedding(nn.Module):
def __init__(self, embedding_dim, voc_size, out_dim=None, position_type=PositionEmbeddingType.TRAINED, normalize_embeddings=True):
super(AttentionEmbedding, self).__init__()
if out_dim is None:
out_dim = embedding_dim
self.embedding_dim = embedding_dim
self.voc_size = voc_size
self.out_dim = out_dim
self.key = nn.Embedding(voc_size, embedding_dim)
self.query = nn.Embedding(voc_size, embedding_dim)
self.value = nn.Embedding(voc_size, out_dim)
self._sqrt_embedding = sqrt(embedding_dim)
self.position_embedding = PositionnalEmbedding(
embedding_dim, 1024, position_type=position_type)
self.norm = None
self.normalize_embeddings = normalize_embeddings
if normalize_embeddings:
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, input):
pos = self.position_embedding(input)
key = self.key(input) + pos
query = self.query(input) + pos
value = self.value(input)
if self.normalize_embeddings:
key = self.norm(key)
query = self.norm(query)
value = self.norm(value)
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_embedding
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
class Attention(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None, key_is_query=False):
super(Attention, self).__init__()
if hidden_dim is None:
hidden_dim = out_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.key_is_query = key_is_query
self.hidden_dim = hidden_dim
self.key = nn.Parameter(torch.rand((in_dim, hidden_dim)))
if key_is_query:
self.query = self.key
else:
self.query = nn.Parameter(torch.rand((in_dim, hidden_dim)))
self.value = nn.Parameter(torch.rand((in_dim, out_dim)))
self._sqrt_hidden = sqrt(hidden_dim)
@classmethod
def from_weights(cls, key, query, value):
in_dim, hidden_dim = key.shape
out_dim = value.shape[1]
x = cls(in_dim, out_dim, hidden_dim)
with torch.no_grad():
x.key = nn.Parameter(key)
x.query = nn.Parameter(query)
x.value = nn.Parameter(value)
return x
def forward(self, input):
key = torch.matmul(input, self.key)
query = torch.matmul(input, self.query)
value = torch.matmul(input, self.value)
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_hidden
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
class NonTransformingAttention(nn.Module):
def __init__(self, dim):
super(NonTransformingAttention, self).__init__()
self.dim = dim
self._sqrt_dim = sqrt(dim)
def forward(self, input):
query = input
key = input
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_dim
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, input)
class MiniBert(nn.Module):
def __init__(self, configuration):
super(MiniBert, self).__init__()
......@@ -239,7 +36,7 @@ class MiniBert(nn.Module):
self.embedding = MiniBertEmbedding(
self._voc_size,
self._embedding_dim,
position_count=configuration.position_embeddings_count,
max_seq_len=configuration.position_embeddings_count,
position_type=configuration.position_type,
normalize_embeddings=configuration.normalize_embeddings
)
......@@ -349,3 +146,87 @@ class MiniBertForRegression(nn.Module):
x = self.activation_fun(x)
x = self.l2(x)
return x
class MiniBertForMLMAndRegression(nn.Module):
def __init__(self, configuration):
super().__init__()
self.minibert = MiniBert(configuration)
self.configuration = configuration
self._voc_size = len(configuration.vocabulary)
self._embedding_dim = configuration.embedding_dim
self.mlm_l1 = nn.Linear(self._embedding_dim,
configuration.mlm_first_layer_output_size, bias=False)
self.mlm_l2 = nn.Linear(
configuration.mlm_first_layer_output_size, self._voc_size, bias=True)
self.reg_l1 = nn.Linear(self._embedding_dim,
configuration.reg_first_layer_output_size, bias=False)
self.reg_l2 = nn.Linear(
configuration.reg_first_layer_output_size, configuration.reg_output_size, bias=True)
self.mask_idx = configuration.mask_idx
self.mlm_activation_fun = parse_activation_function(
configuration.mlm_activation_fun)
self.reg_activation_fun = parse_activation_function(
configuration.reg_activation_fun)
self._mask_prob = configuration.mask_prob
self._keep_mask_prob = configuration.keep_mask_prob
self._inv_corrupt_mask_prob = 1 - configuration.corrupt_mask_prob
# task == 0 -> MLM
# task == 1 -> Regression
def forward(self, input, task):
if task == 0:
return self.forward_mlm(input)
elif task == 1:
return self.forward_reg(input)
else:
raise Exception(
f"`task` parameter must be either 0 or 1, not {task}.")
def forward_mlm(self, input):
if self.training:
masked_input = input.clone().detach()
masked = torch.rand_like(
input, dtype=torch.float) <= self._mask_prob
masking_strategy = torch.rand_like(input, dtype=torch.float)
masking = masked & (masking_strategy <=
self._keep_mask_prob) # Keep masks
corrupt = masked & (self._inv_corrupt_mask_prob <
masking_strategy) # Corrupt masks
replacements = torch.randint(
self._voc_size, (torch.sum(corrupt), ), device=input.device)
masked_input[masking] = self.mask_idx
masked_input[corrupt] = replacements
x = self.minibert(masked_input)
else:
x = self.minibert(input)
x = self.mlm_l1(x)
x = self.mlm_activation_fun(x)
x = self.mlm_l2(x)
if self.training:
labels = input.clone().detach()
labels[~masked] = -1
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fn(x.view(-1, self._voc_size), labels.view(-1))
return (x, loss)
else:
return x
def forward_reg(self, input):
x = self.minibert(input)
x = self.reg_l1(x)
x = self.reg_activation_fun(x)
x = self.reg_l2(x)
return x
......@@ -5,7 +5,7 @@ with open("README.md", "rt", encoding="utf-8") as fh:
setuptools.setup(
name="minibert",
version="0.1.0",
version="0.2.0",
author="Gaëtan Caillaut",
author_email="gaetan.caillaut@univ-lemans.fr",
description="A simplified implementation of BERT",
......
from minibert.embeddings import PositionalEmbeddingType
import unittest
imp