Commit f8978c95 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

ajout d'un paramètre device à HateSpeechCollater

parent f7b3e791
......@@ -46,9 +46,11 @@ class HateSpeechDataset(Dataset):
class HateSpeechCollater:
def __init__(self, tokenizer):
def __init__(self, tokenizer, device=None):
self.tokenizer = tokenizer
self.labels2id = {"Hate": 0, "Offensive": 1, "Other": 2}
self.device = torch.device(
"cpu") if device is None else torch.device(device)
def __call__(self, input):
sentences = []
......@@ -56,19 +58,22 @@ class HateSpeechCollater:
for sent, lab in input:
sentences.append(sent)
labels.append(self.labels2id[lab])
labels_tensor = torch.tensor(labels, dtype=torch.long)
labels_tensor = torch.tensor(
labels, dtype=torch.long, device=self.device)
# labels_onehot = one_hot(labels_tensor, num_classes=3)
try:
inputs = self.tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True)
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
except TypeError:
pad_id = self.tokenizer.token_to_id("<pad>")
self.tokenizer.enable_padding(pad_id=pad_id)
encoded = self.tokenizer.encode_batch(sentences)
inputs = {
"input_ids": torch.tensor([x.ids for x in encoded]),
"attention_mask": torch.tensor([x.attention_mask for x in encoded])
"input_ids": torch.tensor([x.ids for x in encoded], device=self.device),
"attention_mask": torch.tensor([x.attention_mask for x in encoded], device=self.device)
}
return inputs, labels_tensor
......
......@@ -19,7 +19,7 @@ if __name__ == "__main__":
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
dataset = HateSpeechDataset(args.input)
collater = HateSpeechCollater(tokenizer)
collater = HateSpeechCollater(tokenizer, device)
class_names = collater.class_names()
for fold, (train, test) in enumerate(dataset.iter_folds(args.folds, True), 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment