Commit 7ce25f50 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents e27bcc5c 0946c781
......@@ -250,7 +250,7 @@ class ArcMarginProduct(torch.nn.Module):
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
def __change_params(self, s=None, m=None):
def change_params(self, s=None, m=None):
if s is None:
s = self.s
if m is None:
......
......@@ -1021,6 +1021,10 @@ def get_network(model_opts, local_rank):
if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False
if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
if local_rank < 1:
logging.info(model)
logging.info("Model_parameters_count: {:d}".format(
......@@ -1314,6 +1318,7 @@ def xtrain(dataset_description,
local_rank = int(os.environ['RANK'])
# Test to optimize
torch.backends.cudnn.benchmark = True
torch.autograd.profiler.emit_nvtx(enabled=False)
dataset_opts, model_opts, training_opts = update_training_dictionary(dataset_description,
......@@ -1363,7 +1368,10 @@ def xtrain(dataset_description,
# Set the device and manage parallel processing
torch.cuda.set_device(local_rank)
device = torch.device(local_rank)
if local_rank >= 0:
device = torch.device(local_rank)
else:
device = torch.device("cuda")
if training_opts["multi_gpu"]:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -1521,7 +1529,8 @@ def train_epoch(model,
if scaler is not None:
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
output_tuple, _ = model(data, target=target)
output, no_margin_output = output_tuple
loss = criterion(output, target)
elif loss_criteria == 'smn':
output_tuple, _ = model(data, target=target)
......@@ -1557,7 +1566,8 @@ def train_epoch(model,
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
accuracy += (torch.argmax(no_margin_output.data, 1) == target).sum().cpu()
batch_count += 1
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
......@@ -1569,17 +1579,18 @@ def train_epoch(model,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(),
loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
running_loss = 0.0
running_loss / batch_count,
100.0 * accuracy / (batch_count*target.shape[0])))
running_loss = 0.0
accuracy = 0.0
batch_count = 0
running_loss = 0.0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(training_monitor.best_eer)
else:
scheduler.step()
if aam_scheduler is not None:
model.after_speaker_embedding.margin = aam_scheduler.step()
model.after_speaker_embedding.margin = aam_scheduler.__step__()
return model
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment