Commit 2bb4cebc authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

AttentionEmbedding

parent 365934ae
......@@ -23,6 +23,8 @@ class MiniBertConfiguration:
self.normalize_embeddings = kwargs.get(
"normalize_embeddings", True)
self.merge_attention_and_embeddings = kwargs.get(
"merge_attention_and_embeddings", False)
class MiniBertForTrainingConfiguration(MiniBertConfiguration):
def __init__(self, **kwargs):
......
......@@ -21,6 +21,46 @@ class PositionEmbeddingType(Enum):
NONE = 3
class PositionnalEmbedding(nn.Module):
def __init__(self, embedding_dim, max_seq_len, position_type=PositionEmbeddingType.TRAINED):
super().__init__()
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
if position_type == PositionEmbeddingType.TRAINED:
self.position_embeddings = nn.Embedding(
max_seq_len, embedding_dim)
elif position_type == PositionEmbeddingType.FIXED:
# See Attention is all you need, section 3.5 (https://arxiv.org/pdf/1706.03762.pdf)
d = embedding_dim
positions = torch.zeros((max_seq_len, d), dtype=torch.float)
for pos in range(max_seq_len):
for i in range(d):
if i % 2 == 0:
positions[pos, i] = sin(pos / pow(10000, 2 * i / d))
else:
positions[pos, i] = cos(pos / pow(10000, 2 * i / d))
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
elif position_type == PositionEmbeddingType.NONE:
positions = torch.zeros(
(max_seq_len, embedding_dim), dtype=torch.float)
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
else:
raise Exception("Invalid position type")
self.register_buffer(
"position_ids", torch.arange(max_seq_len).expand((1, -1)))
def forward(self, input):
seq_len = input.shape[-1]
return self.position_embeddings(self.position_ids[:, :seq_len])
class MiniBertEmbedding(nn.Module):
def __init__(self, voc_size, embedding_dim, position_count, position_type, normalize_embeddings):
super().__init__()
......@@ -75,6 +115,46 @@ class MiniBertEmbedding(nn.Module):
return emb
class AttentionEmbedding(nn.Module):
def __init__(self, embedding_dim, voc_size, out_dim=None, position_type=PositionEmbeddingType.TRAINED, normalize_embeddings=True):
super(AttentionEmbedding, self).__init__()
if out_dim is None:
out_dim = embedding_dim
self.embedding_dim = embedding_dim
self.voc_size = voc_size
self.out_dim = out_dim
self.key = nn.Embedding(voc_size, embedding_dim)
self.query = nn.Embedding(voc_size, embedding_dim)
self.value = nn.Embedding(voc_size, out_dim)
self._sqrt_embedding = sqrt(embedding_dim)
self.position_embedding = PositionnalEmbedding(
embedding_dim, 1024, position_type=position_type)
self.norm = None
self.normalize_embeddings = normalize_embeddings
if normalize_embeddings:
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, input):
pos = self.position_embedding(input)
key = self.key(input) + pos
query = self.query(input) + pos
value = self.value(input)
if self.normalize_embeddings:
key = self.norm(key)
query = self.norm(query)
value = self.norm(value)
key_t = torch.transpose(key, -2, -1)
qk = torch.matmul(query, key_t) / self._sqrt_embedding
attention = F.softmax(qk, dim=-1)
return torch.matmul(attention, value)
class Attention(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=None, key_is_query=False):
super(Attention, self).__init__()
......@@ -124,24 +204,35 @@ class MiniBert(nn.Module):
self._voc_size = len(configuration.vocabulary)
self._embedding_dim = configuration.embedding_dim
self.embedding = MiniBertEmbedding(
self._voc_size,
self._embedding_dim,
position_count=configuration.position_embeddings_count,
position_type=configuration.position_type,
normalize_embeddings=configuration.normalize_embeddings
)
self.attention = Attention(
self._embedding_dim,
self._embedding_dim,
hidden_dim=configuration.hidden_dim,
key_is_query=configuration.key_is_query
)
if configuration.merge_attention_and_embeddings:
self.attention_embedding = AttentionEmbedding(
self._embedding_dim,
self._voc_size,
position_type=configuration.position_type,
normalize_embeddings=configuration.normalize_embeddings
)
else:
self.embedding = MiniBertEmbedding(
self._voc_size,
self._embedding_dim,
position_count=configuration.position_embeddings_count,
position_type=configuration.position_type,
normalize_embeddings=configuration.normalize_embeddings
)
self.attention = Attention(
self._embedding_dim,
self._embedding_dim,
hidden_dim=configuration.hidden_dim,
key_is_query=configuration.key_is_query
)
def forward(self, input):
x = self.embedding(input)
x = self.attention(x)
if self.configuration.merge_attention_and_embeddings:
x = self.attention_embedding(input)
else:
x = self.embedding(input)
x = self.attention(x)
return x
......
......@@ -37,6 +37,7 @@ if __name__ == "__main__":
help="1 (train position embeddings), 2 (fixed position embeddings) or 3 (no position embeddings)")
parser.add_argument("--no-embedding-normalization", action="store_true",
help="Do not normalize embeddings")
parser.add_argument("--merge-attention-embedding", action="store_true")
parser.add_argument("--gpu", action="store_true",
help="Use GPU acceleration")
parser.add_argument("--activation", default="gelu", type=str)
......@@ -97,7 +98,8 @@ if __name__ == "__main__":
reveal_mask_prob=0.1,
first_layer_output_size=args.d,
activation_fun=args.activation,
key_is_query=args.k_is_q
key_is_query=args.k_is_q,
merge_attention_and_embeddings=args.merge_attention_embedding
)
model = MiniBertForTraining(configuration)
......@@ -114,9 +116,17 @@ if __name__ == "__main__":
print(
f"No embedding for '{x}'. Generating random vector.", file=sys.stderr)
torch.rand((args.d, ), out=embs[i, :])
model.minibert.embedding.word_embeddings.weight = torch.nn.Parameter(
embs)
model.train()
if args.merge_attention_embedding:
model.minibert.attention_embedding.query.weight = torch.nn.Parameter(
embs)
model.minibert.attention_embedding.key.weight = torch.nn.Parameter(
embs.detach().clone())
model.minibert.attention_embedding.value.weight = torch.nn.Parameter(
embs.detach().clone())
else:
model.minibert.embedding.word_embeddings.weight = torch.nn.Parameter(
embs)
if args.gpu:
model = model.to("cuda")
......@@ -126,6 +136,7 @@ if __name__ == "__main__":
x_eval_test, y_eval_test = build_eval_batches(
sequences_test, voc2idx, mask_idx, args.gpu)
model.train()
writer = SummaryWriter(log_dir=args.logdir)
for epoch in range(1, args.epochs + 1):
cumloss = 0
......@@ -150,9 +161,9 @@ if __name__ == "__main__":
writer.add_scalar("True positives/train", nb_pos_train, epoch)
writer.add_scalar("Precision/test", precision_test, epoch)
writer.add_scalar("True positives/test", nb_pos_test, epoch)
if epoch % 1000 == 0:
writer.add_embedding(model.minibert.embedding.word_embeddings.weight,
metadata=voc, global_step=epoch, tag="Embeddings")
# if epoch % 1000 == 0:
# writer.add_embedding(model.minibert.embedding.word_embeddings.weight,
# metadata=voc, global_step=epoch, tag="Embeddings")
writer.flush()
writer.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment