Commit d426d9a2 authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

déplacement sur gpu dans la boucle d'entrainement

parent f8978c95
......@@ -46,11 +46,9 @@ class HateSpeechDataset(Dataset):
class HateSpeechCollater:
def __init__(self, tokenizer, device=None):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.labels2id = {"Hate": 0, "Offensive": 1, "Other": 2}
self.device = torch.device(
"cpu") if device is None else torch.device(device)
def __call__(self, input):
sentences = []
......@@ -58,22 +56,19 @@ class HateSpeechCollater:
for sent, lab in input:
sentences.append(sent)
labels.append(self.labels2id[lab])
labels_tensor = torch.tensor(
labels, dtype=torch.long, device=self.device)
labels_tensor = torch.tensor(labels, dtype=torch.long)
# labels_onehot = one_hot(labels_tensor, num_classes=3)
try:
inputs = self.tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True)
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
except TypeError:
pad_id = self.tokenizer.token_to_id("<pad>")
self.tokenizer.enable_padding(pad_id=pad_id)
encoded = self.tokenizer.encode_batch(sentences)
inputs = {
"input_ids": torch.tensor([x.ids for x in encoded], device=self.device),
"attention_mask": torch.tensor([x.attention_mask for x in encoded], device=self.device)
"input_ids": torch.tensor([x.ids for x in encoded]),
"attention_mask": torch.tensor([x.attention_mask for x in encoded])
}
return inputs, labels_tensor
......
......@@ -56,6 +56,11 @@ if __name__ == "__main__":
for batch, (inputs, labels) in enumerate(train_loader, 1):
print(f"FOLD {fold} - BATCH {batch}")
inputs["input_ids"] = inputs["input_ids"].to(device)
inputs["attention_mask"] = inputs["attention_mask"].to(
device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(**inputs, labels=labels)
loss = outputs.loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment