Commit b4fc174c authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

to_tensor respects max_seq_size

parent 0a3f23b1
......@@ -72,8 +72,10 @@ class TrainData:
(len(encoded), self.max_seq_size), -1, dtype=torch.long)
for i, encoded in enumerate(self.tokenizer.encode_batch(sentences)):
sequence_tensor[i, :] = torch.tensor(encoded.ids)
attention_mask_tensor[i, :] = torch.tensor(encoded.attention_mask)
sequence_tensor[i, :] = torch.tensor(
encoded.ids[:self.max_seq_size])
attention_mask_tensor[i, :] = torch.tensor(
encoded.attention_mask[:self.max_seq_size])
for j, wid in enumerate(encoded.word_ids):
if wid is None:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment