Commit d426d9a2 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

déplacement sur gpu dans la boucle d'entrainement

parent f8978c95
...@@ -46,11 +46,9 @@ class HateSpeechDataset(Dataset): ...@@ -46,11 +46,9 @@ class HateSpeechDataset(Dataset):
class HateSpeechCollater: class HateSpeechCollater:
def __init__(self, tokenizer, device=None): def __init__(self, tokenizer):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.labels2id = {"Hate": 0, "Offensive": 1, "Other": 2} self.labels2id = {"Hate": 0, "Offensive": 1, "Other": 2}
self.device = torch.device(
"cpu") if device is None else torch.device(device)
def __call__(self, input): def __call__(self, input):
sentences = [] sentences = []
...@@ -58,22 +56,19 @@ class HateSpeechCollater: ...@@ -58,22 +56,19 @@ class HateSpeechCollater:
for sent, lab in input: for sent, lab in input:
sentences.append(sent) sentences.append(sent)
labels.append(self.labels2id[lab]) labels.append(self.labels2id[lab])
labels_tensor = torch.tensor( labels_tensor = torch.tensor(labels, dtype=torch.long)
labels, dtype=torch.long, device=self.device)
# labels_onehot = one_hot(labels_tensor, num_classes=3) # labels_onehot = one_hot(labels_tensor, num_classes=3)
try: try:
inputs = self.tokenizer( inputs = self.tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True) sentences, return_tensors="pt", padding=True, truncation=True)
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
except TypeError: except TypeError:
pad_id = self.tokenizer.token_to_id("<pad>") pad_id = self.tokenizer.token_to_id("<pad>")
self.tokenizer.enable_padding(pad_id=pad_id) self.tokenizer.enable_padding(pad_id=pad_id)
encoded = self.tokenizer.encode_batch(sentences) encoded = self.tokenizer.encode_batch(sentences)
inputs = { inputs = {
"input_ids": torch.tensor([x.ids for x in encoded], device=self.device), "input_ids": torch.tensor([x.ids for x in encoded]),
"attention_mask": torch.tensor([x.attention_mask for x in encoded], device=self.device) "attention_mask": torch.tensor([x.attention_mask for x in encoded])
} }
return inputs, labels_tensor return inputs, labels_tensor
......
...@@ -56,6 +56,11 @@ if __name__ == "__main__": ...@@ -56,6 +56,11 @@ if __name__ == "__main__":
for batch, (inputs, labels) in enumerate(train_loader, 1): for batch, (inputs, labels) in enumerate(train_loader, 1):
print(f"FOLD {fold} - BATCH {batch}") print(f"FOLD {fold} - BATCH {batch}")
inputs["input_ids"] = inputs["input_ids"].to(device)
inputs["attention_mask"] = inputs["attention_mask"].to(
device)
labels = labels.to(device)
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(**inputs, labels=labels) outputs = model(**inputs, labels=labels)
loss = outputs.loss loss = outputs.loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment