Commit 0d37bf55 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

use cuda if available

parent 8c8d7d84
Pipeline #629 canceled with stages
...@@ -97,6 +97,8 @@ if __name__ == "__main__": ...@@ -97,6 +97,8 @@ if __name__ == "__main__":
emb_dim = 64 emb_dim = 64
voc_size = len(voc) voc_size = len(voc)
model = MiniBertForTraining(emb_dim, voc_size, mask_idx, hidden_dim=64) model = MiniBertForTraining(emb_dim, voc_size, mask_idx, hidden_dim=64)
if torch.cuda.is_available():
model = model.to("cuda")
learning_rate = 1e-3 learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
......
...@@ -109,7 +109,8 @@ if __name__ == "__main__": ...@@ -109,7 +109,8 @@ if __name__ == "__main__":
model = MiniBertForTraining(emb_dim, voc_size, mask_idx, hidden_dim=64) model = MiniBertForTraining(emb_dim, voc_size, mask_idx, hidden_dim=64)
model.minibert.embedding.word_embeddings.weight = torch.nn.Parameter(embs) model.minibert.embedding.word_embeddings.weight = torch.nn.Parameter(embs)
if torch.cuda.is_available():
model = model.to("cuda")
learning_rate = 1e-3 learning_rate = 1e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment