Commit f8978c95 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

ajout d'un paramètre device à HateSpeechCollater

parent f7b3e791
...@@ -46,9 +46,11 @@ class HateSpeechDataset(Dataset): ...@@ -46,9 +46,11 @@ class HateSpeechDataset(Dataset):
class HateSpeechCollater: class HateSpeechCollater:
def __init__(self, tokenizer): def __init__(self, tokenizer, device=None):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.labels2id = {"Hate": 0, "Offensive": 1, "Other": 2} self.labels2id = {"Hate": 0, "Offensive": 1, "Other": 2}
self.device = torch.device(
"cpu") if device is None else torch.device(device)
def __call__(self, input): def __call__(self, input):
sentences = [] sentences = []
...@@ -56,19 +58,22 @@ class HateSpeechCollater: ...@@ -56,19 +58,22 @@ class HateSpeechCollater:
for sent, lab in input: for sent, lab in input:
sentences.append(sent) sentences.append(sent)
labels.append(self.labels2id[lab]) labels.append(self.labels2id[lab])
labels_tensor = torch.tensor(labels, dtype=torch.long) labels_tensor = torch.tensor(
labels, dtype=torch.long, device=self.device)
# labels_onehot = one_hot(labels_tensor, num_classes=3) # labels_onehot = one_hot(labels_tensor, num_classes=3)
try: try:
inputs = self.tokenizer( inputs = self.tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True) sentences, return_tensors="pt", padding=True, truncation=True)
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
except TypeError: except TypeError:
pad_id = self.tokenizer.token_to_id("<pad>") pad_id = self.tokenizer.token_to_id("<pad>")
self.tokenizer.enable_padding(pad_id=pad_id) self.tokenizer.enable_padding(pad_id=pad_id)
encoded = self.tokenizer.encode_batch(sentences) encoded = self.tokenizer.encode_batch(sentences)
inputs = { inputs = {
"input_ids": torch.tensor([x.ids for x in encoded]), "input_ids": torch.tensor([x.ids for x in encoded], device=self.device),
"attention_mask": torch.tensor([x.attention_mask for x in encoded]) "attention_mask": torch.tensor([x.attention_mask for x in encoded], device=self.device)
} }
return inputs, labels_tensor return inputs, labels_tensor
......
...@@ -19,7 +19,7 @@ if __name__ == "__main__": ...@@ -19,7 +19,7 @@ if __name__ == "__main__":
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
dataset = HateSpeechDataset(args.input) dataset = HateSpeechDataset(args.input)
collater = HateSpeechCollater(tokenizer) collater = HateSpeechCollater(tokenizer, device)
class_names = collater.class_names() class_names = collater.class_names()
for fold, (train, test) in enumerate(dataset.iter_folds(args.folds, True), 1): for fold, (train, test) in enumerate(dataset.iter_folds(args.folds, True), 1):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment