Commit 1873385c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

parallel

parent 4414cb29
...@@ -1311,7 +1311,6 @@ def new_xtrain(dataset_description, ...@@ -1311,7 +1311,6 @@ def new_xtrain(dataset_description,
# Set the device and manage parallel processing # Set the device and manage parallel processing
device = torch.cuda.device(local_rank) device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
#device = torch.device("cuda")
model.to(device) model.to(device)
# If multi-gpu # If multi-gpu
...@@ -1333,18 +1332,17 @@ def new_xtrain(dataset_description, ...@@ -1333,18 +1332,17 @@ def new_xtrain(dataset_description,
if local_rank < 1: if local_rank < 1:
print("Let's use", torch.cuda.device_count(), "GPUs!") print("Let's use", torch.cuda.device_count(), "GPUs!")
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
model = torch.nn.parallel.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(model,
model, device_ids=[local_rank],
device_ids=[local_rank], output_device=local_rank)
output_device=local_rank
)
else: else:
print("Train on a single GPU") print("Train on a single GPU")
# Initialise data loaders # Initialise data loaders
training_loader, validation_loader, sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts, training_loader, validation_loader, \
training_opts, sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
speaker_number) training_opts,
speaker_number)
if local_rank < 1: if local_rank < 1:
monitor.logger.info(f"Start training process") monitor.logger.info(f"Start training process")
...@@ -1372,6 +1370,8 @@ def new_xtrain(dataset_description, ...@@ -1372,6 +1370,8 @@ def new_xtrain(dataset_description,
break break
sampler.set_epoch(epoch) sampler.set_epoch(epoch)
if training_opts["multi_gpu"]:
torch.distributed.barrier()
model = new_train_epoch(model, model = new_train_epoch(model,
training_opts, training_opts,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment