Commit d7b14542 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 6c627622
...@@ -1006,7 +1006,7 @@ def update_training_dictionary(dataset_description, ...@@ -1006,7 +1006,7 @@ def update_training_dictionary(dataset_description,
return dataset_opts, model_opts, training_opts return dataset_opts, model_opts, training_opts
def get_network(model_opts): def get_network(model_opts, local_rank):
""" """
:param model_opts: :param model_opts:
...@@ -1041,18 +1041,19 @@ def get_network(model_opts): ...@@ -1041,18 +1041,19 @@ def get_network(model_opts):
if name.split(".")[0] in model_opts["reset_parts"]: if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False param.requires_grad = False
logging.critical(model) if local_rank < 1:
logging.info(model)
logging.critical("model_parameters_count: {:d}".format(
sum(p.numel() logging.info("Model_parameters_count: {:d}".format(
for p in model.sequence_network.parameters() sum(p.numel()
if p.requires_grad) + \ for p in model.sequence_network.parameters()
sum(p.numel() if p.requires_grad) + \
for p in model.before_speaker_embedding.parameters() sum(p.numel()
if p.requires_grad) + \ for p in model.before_speaker_embedding.parameters()
sum(p.numel() if p.requires_grad) + \
for p in model.stat_pooling.parameters() sum(p.numel()
if p.requires_grad))) for p in model.stat_pooling.parameters()
if p.requires_grad)))
return model return model
...@@ -1306,13 +1307,17 @@ def new_xtrain(dataset_description, ...@@ -1306,13 +1307,17 @@ def new_xtrain(dataset_description,
# Initialize the model # Initialize the model
model = get_network(model_opts) model = get_network(model_opts)
speaker_number = model.speaker_number #speaker_number = model.speaker_number
embedding_size = model.embedding_size embedding_size = model.embedding_size
# Set the device and manage parallel processing # Set the device and manage parallel processing
#device = torch.cuda.device(local_rank) #device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device(local_rank) device = torch.device(local_rank)
if training_opts["multi_gpu"]:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) model.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment