Commit ca227311 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

correct valid and aam change param

parent 0077125d
......@@ -1316,6 +1316,7 @@ def xtrain(dataset_description,
local_rank = int(os.environ['RANK'])
# Test to optimize
torch.backends.cudnn.benchmark = True
torch.autograd.profiler.emit_nvtx(enabled=False)
dataset_opts, model_opts, training_opts = update_training_dictionary(dataset_description,
......@@ -1365,7 +1366,10 @@ def xtrain(dataset_description,
# Set the device and manage parallel processing
torch.cuda.set_device(local_rank)
device = torch.device(local_rank)
if local_rank >= 0:
device = torch.device(local_rank)
else:
device = torch.device("cuda")
if training_opts["multi_gpu"]:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -1523,7 +1527,8 @@ def train_epoch(model,
if scaler is not None:
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
output_tuple, _ = model(data, target=target)
output, no_margin_output = output_tuple
loss = criterion(output, target)
elif loss_criteria == 'smn':
output_tuple, _ = model(data, target=target)
......@@ -1559,7 +1564,8 @@ def train_epoch(model,
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
accuracy += (torch.argmax(no_margin_output.data, 1) == target).sum().cpu()
batch_count += 1
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
......@@ -1571,17 +1577,18 @@ def train_epoch(model,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(),
loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
running_loss = 0.0
running_loss / batch_count,
100.0 * accuracy / (batch_count*target.shape[0])))
running_loss = 0.0
accuracy = 0.0
batch_count = 0
running_loss = 0.0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(training_monitor.best_eer)
else:
scheduler.step()
if aam_scheduler is not None:
model.after_speaker_embedding.margin = aam_scheduler.step()
model.after_speaker_embedding.margin = aam_scheduler.__step__()
return model
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment