Commit 6e6c58ce authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

Check if position_type is valid

parent 4e6c135b
......@@ -47,12 +47,14 @@ class MiniBertEmbedding(nn.Module):
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
else:
elif position_type == PositionEmbeddingType.NONE:
positions = torch.zeros(
(position_count, embedding_dim), dtype=torch.float)
self.position_embeddings = nn.Embedding.from_pretrained(
positions, freeze=True
)
else:
raise Exception("Invalid position type")
self.register_buffer(
"position_ids", torch.arange(position_count).expand((1, -1)))
......
......@@ -33,7 +33,7 @@ if __name__ == "__main__":
help="If set, key and query share the same matrix")
parser.add_argument("--init-w2v", action="store_true",
help="Initialize first embedding layer with word2vec")
parser.add_argument("--position-type", default=PositionEmbeddingType.TRAINED,
parser.add_argument("--position-type", default=PositionEmbeddingType.TRAINED, type=int,
help="1 (train position embeddings), 2 (fixed position embeddings) or 3 (no position embeddings)")
parser.add_argument("--no-embedding-normalization", action="store_true",
help="Do not normalize embeddings")
......@@ -44,6 +44,7 @@ if __name__ == "__main__":
parser.add_argument("--save", type=str, required=False)
args = parser.parse_args()
args.position_type = PositionEmbeddingType(args.position_type)
if args.gpu and not torch.cuda.is_available():
print("--gpu flag set, but CUDA is not available. Unsetting flag.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment