Commit 4d6d2cc6 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

use cuda if available (in build_one_tensor_batch)

parent 0d37bf55
Pipeline #630 canceled with stages
...@@ -36,10 +36,14 @@ def build_batches(seqs, bs=5): ...@@ -36,10 +36,14 @@ def build_batches(seqs, bs=5):
def build_one_tensor_batch(b, voc2idx): def build_one_tensor_batch(b, voc2idx):
return torch.tensor([ batch = torch.tensor([
[voc2idx[x] for x in sent] for sent in b [voc2idx[x] for x in sent] for sent in b
], dtype=torch.long, requires_grad=False) ], dtype=torch.long, requires_grad=False)
if torch.cuda.is_available():
batch = batch.to("cuda")
return batch
def build_tensor_batches(batches, voc2idx): def build_tensor_batches(batches, voc2idx):
res = [] res = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment