Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Gaëtan Caillaut
MiniBert
Commits
88318f1b
Commit
88318f1b
authored
Nov 05, 2020
by
Gaëtan Caillaut
Browse files
MiniBertEmbedding + strategie de masquage
parent
53eb4577
Changes
3
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
88318f1b
__pycache__
\ No newline at end of file
__pycache__
runs
\ No newline at end of file
minibert/modules.py
View file @
88318f1b
...
...
@@ -5,10 +5,28 @@ from math import sqrt
__all__
=
[
"Attention"
,
"MiniBert"
"MiniBert"
,
"MiniBertForTraining"
,
"MiniBertEmbedding"
]
class
MiniBertEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
voc_size
,
embedding_dim
):
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
voc_size
,
embedding_dim
)
self
.
position_embeddings
=
nn
.
Embedding
(
1024
,
embedding_dim
)
self
.
norm
=
nn
.
LayerNorm
(
embedding_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
1024
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
input
):
seq_len
=
input
.
shape
[
-
1
]
emb
=
self
.
word_embeddings
(
input
)
pos
=
self
.
position_embeddings
(
self
.
position_ids
[:,
:
seq_len
])
return
self
.
norm
(
emb
+
pos
)
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
hidden_dim
=
None
):
super
(
Attention
,
self
).
__init__
()
...
...
@@ -46,11 +64,13 @@ class Attention(nn.Module):
class
MiniBert
(
nn
.
Module
):
def
__init__
(
self
,
embedding_dim
,
voc_size
,
hidden_dim
=
None
):
def
__init__
(
self
,
embedding_dim
,
voc_size
,
mask_idx
,
hidden_dim
=
None
):
super
(
MiniBert
,
self
).
__init__
()
if
hidden_dim
is
None
:
hidden_dim
=
embedding_dim
self
.
embedding
=
nn
.
Embedding
(
voc_size
,
embedding_dim
)
self
.
mask_idx
=
mask_idx
self
.
embedding
=
MiniBertEmbedding
(
voc_size
,
embedding_dim
)
self
.
attention
=
Attention
(
embedding_dim
,
embedding_dim
,
hidden_dim
=
hidden_dim
)
...
...
@@ -67,4 +87,65 @@ class MiniBert(nn.Module):
def
forward
(
self
,
input
):
x
=
self
.
embedding
(
input
)
return
self
.
attention
(
x
)
x
=
self
.
attention
(
x
)
return
x
class
MiniBertForTraining
(
nn
.
Module
):
def
__init__
(
self
,
embedding_dim
,
voc_size
,
mask_idx
,
hidden_dim
=
None
,
mask_prob
=
0.15
,
train
=
True
):
super
(
MiniBertForTraining
,
self
).
__init__
()
self
.
minibert
=
MiniBert
(
embedding_dim
,
voc_size
,
mask_idx
,
hidden_dim
=
hidden_dim
)
self
.
l1
=
nn
.
Linear
(
embedding_dim
,
embedding_dim
,
bias
=
False
)
self
.
l2
=
nn
.
Linear
(
embedding_dim
,
voc_size
,
bias
=
True
)
self
.
mask_idx
=
mask_idx
self
.
train
=
train
self
.
voc_size
=
voc_size
self
.
mask_prob
=
mask_prob
def
forward
(
self
,
input
):
prev_grad
=
torch
.
is_grad_enabled
()
torch
.
set_grad_enabled
(
self
.
train
)
if
self
.
train
:
# masked_input = input.detach().clone()
masked_input
=
input
.
clone
()
masked
=
torch
.
rand_like
(
input
,
dtype
=
torch
.
float
)
<=
self
.
mask_prob
masking_strategy
=
torch
.
rand_like
(
input
,
dtype
=
torch
.
float
)
# 80% des cas, on masque
# 10% des cas, on garde
# 10% des cas, on remplace
masking
=
masked
&
(
masking_strategy
<=
0.8
)
# On masque
corrupt
=
masked
&
(
0.9
<
masking_strategy
)
# On remplace
replacements
=
torch
.
randint
(
self
.
voc_size
,
(
torch
.
sum
(
corrupt
),
))
masked_input
[
masking
]
=
self
.
mask_idx
masked_input
[
corrupt
]
=
replacements
x
=
self
.
minibert
(
masked_input
)
else
:
x
=
self
.
minibert
(
input
)
x
=
self
.
l1
(
x
)
x
=
F
.
gelu
(
x
)
x
=
self
.
l2
(
x
)
if
self
.
train
:
# labels = input.detach().clone()
labels
=
input
.
clone
()
labels
[
~
masked
]
=
-
1
loss_fn
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fn
(
x
.
view
(
-
1
,
self
.
voc_size
),
labels
.
view
(
-
1
))
torch
.
set_grad_enabled
(
prev_grad
)
return
(
x
,
loss
)
else
:
torch
.
set_grad_enabled
(
prev_grad
)
return
x
def
set_train
(
self
,
value
):
self
.
train
=
value
test/test_attention.py
View file @
88318f1b
...
...
@@ -21,11 +21,32 @@ class TestAttention(unittest.TestCase):
xv
=
torch
.
tensor
([[
1.5
,
1.5
],
[
2.5
,
2
],
[
1
,
1
]],
dtype
=
torch
.
float
)
x_qk
=
torch
.
matmul
(
xq
,
xk
.
t
())
/
sqrt
(
2
)
expected
=
torch
.
matmul
(
F
.
softmax
(
x_qk
),
xv
)
expected
=
torch
.
matmul
(
F
.
softmax
(
x_qk
,
dim
=
1
),
xv
)
actual
=
attention
(
x
)
self
.
assertTrue
(
torch
.
equal
(
expected
,
actual
))
def
test_attention_given_batch
(
self
):
k
=
torch
.
tensor
([[
0
,
0.5
],
[
1
,
0
],
[
0.5
,
0.5
]],
dtype
=
torch
.
float
)
q
=
torch
.
tensor
([[
0
,
0.5
],
[
0
,
0
],
[
0.5
,
0.5
]],
dtype
=
torch
.
float
)
v
=
torch
.
tensor
([[
0.5
,
0.5
],
[
1
,
0.5
],
[
1
,
1
]],
dtype
=
torch
.
float
)
attention
=
Attention
.
from_weights
(
k
,
q
,
v
)
x
=
torch
.
tensor
(
[[
1
,
0
,
1
],
[
1
,
1
,
1
],
[
0
,
0
,
1
]],
dtype
=
torch
.
float
)
batch
=
torch
.
stack
([
x
,
x
,
x
])
xk
=
torch
.
tensor
([[
0.5
,
1
],
[
1.5
,
1
],
[
0.5
,
0.5
]],
dtype
=
torch
.
float
)
xq
=
torch
.
tensor
([[
0.5
,
1
],
[
0.5
,
1
],
[
0.5
,
0.5
]],
dtype
=
torch
.
float
)
xv
=
torch
.
tensor
([[
1.5
,
1.5
],
[
2.5
,
2
],
[
1
,
1
]],
dtype
=
torch
.
float
)
x_qk
=
torch
.
matmul
(
xq
,
xk
.
t
())
/
sqrt
(
2
)
expected
=
torch
.
matmul
(
F
.
softmax
(
x_qk
,
dim
=
1
),
xv
)
expected
=
torch
.
stack
([
expected
,
expected
,
expected
])
actual
=
attention
(
batch
)
self
.
assertTrue
(
torch
.
equal
(
expected
,
actual
))
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment