Commit 0fb39dc7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

test dataloaders

parent adb0f0c1
......@@ -338,7 +338,8 @@ class SideSet(Dataset):
chunk_per_segment=1,
overlap=0.,
dataset_df=None,
min_duration=0.165
min_duration=0.165,
output_format="pytorch"
):
"""
......@@ -355,7 +356,7 @@ class SideSet(Dataset):
self.data_file_extension = dataset["data_file_extension"]
self.transformation = ''
self.min_duration = min_duration
self.output_format = output_format
if set_type == "train":
self.duration = dataset["train"]["duration"]
......@@ -503,7 +504,10 @@ class SideSet(Dataset):
self.add_reverb[index]
))
return torch.from_numpy(sig).type(torch.FloatTensor), speaker_idx
if self.output_format == "pytorch":
return torch.from_numpy(sig).type(torch.FloatTensor), torch.from_numpy(speaker_idx).type(torch.LongTensor)
else:
return sig, speaker_idx
def __len__(self):
"""
......
......@@ -706,6 +706,7 @@ def xtrain(speaker_number,
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
#model = DDP(model)
else:
print("Train on a single GPU")
model.to(device)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment