Commit 9ff94c59 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

parallel

parent 1873385c
......@@ -1294,7 +1294,7 @@ def new_xtrain(dataset_description,
torch.manual_seed(training_opts["seed"])
torch.cuda.manual_seed(training_opts["seed"])
# Display the entire configurations as YAML dictionnaries
# Display the entire configurations as YAML dictionaries
if local_rank < 1:
monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
......@@ -1309,11 +1309,12 @@ def new_xtrain(dataset_description,
embedding_size = model.embedding_size
# Set the device and manage parallel processing
device = torch.cuda.device(local_rank)
#device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device(local_rank)
model.to(device)
# If multi-gpu
""" [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
- Add the following line right after "if __name__ == '__main__':" in your main script :
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment