Commit d7b14542 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 6c627622
......@@ -1006,7 +1006,7 @@ def update_training_dictionary(dataset_description,
return dataset_opts, model_opts, training_opts
def get_network(model_opts):
def get_network(model_opts, local_rank):
"""
:param model_opts:
......@@ -1041,18 +1041,19 @@ def get_network(model_opts):
if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False
logging.critical(model)
logging.critical("model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
if local_rank < 1:
logging.info(model)
logging.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
return model
......@@ -1306,13 +1307,17 @@ def new_xtrain(dataset_description,
# Initialize the model
model = get_network(model_opts)
speaker_number = model.speaker_number
#speaker_number = model.speaker_number
embedding_size = model.embedding_size
# Set the device and manage parallel processing
#device = torch.cuda.device(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device(local_rank)
if training_opts["multi_gpu"]:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment