Commit 0a3f23b1 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

move batch data to device

parent 935c095e
......@@ -118,6 +118,10 @@ if __name__ == "__main__":
n_train = 0
for batch_id, (x, attention_mask, wids) in enumerate(train_loader):
x =
attention_mask =
wids =
output, loss = model(x, attention_mask, wids)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment