Commit 9ff94c59 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

parallel

parent 1873385c
...@@ -1067,7 +1067,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, speaker_number): ...@@ -1067,7 +1067,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, speaker_number):
Set the dataloaders according to the dataset_yaml Set the dataloaders according to the dataset_yaml
First we load the dataframe from CSV file in order to split it for training and validation purpose First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two Then we provide those two
""" """
df = pandas.read_csv(dataset_opts["dataset_csv"]) df = pandas.read_csv(dataset_opts["dataset_csv"])
...@@ -1294,7 +1294,7 @@ def new_xtrain(dataset_description, ...@@ -1294,7 +1294,7 @@ def new_xtrain(dataset_description,
torch.manual_seed(training_opts["seed"]) torch.manual_seed(training_opts["seed"])
torch.cuda.manual_seed(training_opts["seed"]) torch.cuda.manual_seed(training_opts["seed"])
# Display the entire configurations as YAML dictionnaries # Display the entire configurations as YAML dictionaries
if local_rank < 1: if local_rank < 1:
monitor.logger.info("\n*********************************\nDataset options\n*********************************\n") monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False)) monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
...@@ -1309,11 +1309,12 @@ def new_xtrain(dataset_description, ...@@ -1309,11 +1309,12 @@ def new_xtrain(dataset_description,
embedding_size = model.embedding_size embedding_size = model.embedding_size
# Set the device and manage parallel processing # Set the device and manage parallel processing
device = torch.cuda.device(local_rank) #device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device(local_rank)
model.to(device) model.to(device)
# If multi-gpu
""" [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c """ [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
- Add the following line right after "if __name__ == '__main__':" in your main script : - Add the following line right after "if __name__ == '__main__':" in your main script :
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment