Commit 69f39238 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents d7f0153d 9ff94c59
......@@ -1067,7 +1067,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, speaker_number):
Set the dataloaders according to the dataset_yaml
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
Then we provide those two
"""
df = pandas.read_csv(dataset_opts["dataset_csv"])
......@@ -1294,7 +1294,7 @@ def new_xtrain(dataset_description,
torch.manual_seed(training_opts["seed"])
torch.cuda.manual_seed(training_opts["seed"])
# Display the entire configurations as YAML dictionnaries
# Display the entire configurations as YAML dictionaries
if local_rank < 1:
monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
......@@ -1309,12 +1309,12 @@ def new_xtrain(dataset_description,
embedding_size = model.embedding_size
# Set the device and manage parallel processing
device = torch.cuda.device(local_rank)
#device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank)
#device = torch.device("cuda")
device = torch.device(local_rank)
model.to(device)
# If multi-gpu
""" [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
- Add the following line right after "if __name__ == '__main__':" in your main script :
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
......@@ -1333,18 +1333,17 @@ def new_xtrain(dataset_description,
if local_rank < 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
torch.distributed.init_process_group(backend='nccl', init_method='env://')
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank
)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
else:
print("Train on a single GPU")
# Initialise data loaders
training_loader, validation_loader, sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
training_loader, validation_loader, \
sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
if local_rank < 1:
monitor.logger.info(f"Start training process")
......@@ -1372,6 +1371,8 @@ def new_xtrain(dataset_description,
break
sampler.set_epoch(epoch)
if training_opts["multi_gpu"]:
torch.distributed.barrier()
model = new_train_epoch(model,
training_opts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment