Commit 9ff94c59 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

parallel

parent 1873385c
......@@ -1067,7 +1067,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, speaker_number):
Set the dataloaders according to the dataset_yaml
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
Then we provide those two
"""
df = pandas.read_csv(dataset_opts["dataset_csv"])
......@@ -1294,7 +1294,7 @@ def new_xtrain(dataset_description,
torch.manual_seed(training_opts["seed"])
torch.cuda.manual_seed(training_opts["seed"])
# Display the entire configurations as YAML dictionnaries
# Display the entire configurations as YAML dictionaries
if local_rank < 1:
monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
......@@ -1309,11 +1309,12 @@ def new_xtrain(dataset_description,
embedding_size = model.embedding_size
# Set the device and manage parallel processing
device = torch.cuda.device(local_rank)
#device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device(local_rank)
model.to(device)
# If multi-gpu
""" [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
- Add the following line right after "if __name__ == '__main__':" in your main script :
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment