Commit 4414cb29 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

parallel

parent 89b36652
...@@ -1296,12 +1296,12 @@ def new_xtrain(dataset_description, ...@@ -1296,12 +1296,12 @@ def new_xtrain(dataset_description,
# Display the entire configurations as YAML dictionnaries # Display the entire configurations as YAML dictionnaries
if local_rank < 1: if local_rank < 1:
monitor.logger.info("\n*********************************\nDataset options\n*********************************\n") monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False)) monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
monitor.logger.info("\n*********************************\nModel options\n*********************************\n") monitor.logger.info("\n*********************************\nModel options\n*********************************\n")
monitor.logger.info(yaml.dump(model_opts, default_flow_style=False)) monitor.logger.info(yaml.dump(model_opts, default_flow_style=False))
monitor.logger.info("\n*********************************\nTraining options\n*********************************\n") monitor.logger.info("\n*********************************\nTraining options\n*********************************\n")
monitor.logger.info(yaml.dump(training_opts, default_flow_style=False)) monitor.logger.info(yaml.dump(training_opts, default_flow_style=False))
# Initialize the model # Initialize the model
model = get_network(model_opts) model = get_network(model_opts)
...@@ -1309,8 +1309,9 @@ def new_xtrain(dataset_description, ...@@ -1309,8 +1309,9 @@ def new_xtrain(dataset_description,
embedding_size = model.embedding_size embedding_size = model.embedding_size
# Set the device and manage parallel processing # Set the device and manage parallel processing
device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device("cuda") #device = torch.device("cuda")
model.to(device) model.to(device)
# If multi-gpu # If multi-gpu
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment