Commit 1873385c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

parallel

parent 4414cb29
......@@ -1311,7 +1311,6 @@ def new_xtrain(dataset_description,
# Set the device and manage parallel processing
device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank)
#device = torch.device("cuda")
model.to(device)
# If multi-gpu
......@@ -1333,18 +1332,17 @@ def new_xtrain(dataset_description,
if local_rank < 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
torch.distributed.init_process_group(backend='nccl', init_method='env://')
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank
)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
else:
print("Train on a single GPU")
# Initialise data loaders
training_loader, validation_loader, sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
training_loader, validation_loader, \
sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
if local_rank < 1:
monitor.logger.info(f"Start training process")
......@@ -1372,6 +1370,8 @@ def new_xtrain(dataset_description,
break
sampler.set_epoch(epoch)
if training_opts["multi_gpu"]:
torch.distributed.barrier()
model = new_train_epoch(model,
training_opts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment