Commit 0a3f23b1 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

move batch data to device

parent 935c095e
......@@ -118,6 +118,10 @@ if __name__ == "__main__":
n_train = 0
for batch_id, (x, attention_mask, wids) in enumerate(train_loader):
x = x.to(device)
attention_mask = attention_mask.to(device)
wids = wids.to(device)
optimizer.zero_grad()
output, loss = model(x, attention_mask, wids)
loss.backward()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment