Commit 3de11dde authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

bugfixes

parent 497fe561
...@@ -1124,12 +1124,11 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1124,12 +1124,11 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"] num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
) )
else: else:
batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'], side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"], spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"], examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"], samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
batch_size=batch_size, batch_size=dataset_opts["batch_size"],
seed=training_opts['torch_seed'], seed=training_opts['torch_seed'],
rank=0, rank=0,
num_process=torch.cuda.device_count(), num_process=torch.cuda.device_count(),
...@@ -1137,7 +1136,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1137,7 +1136,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
) )
training_loader = DataLoader(training_set, training_loader = DataLoader(training_set,
batch_size=batch_size * dataset_opts["train"]["sampler"]["augmentation_replica"], batch_size=dataset_opts["batch_size"] * dataset_opts["train"]["sampler"]["augmentation_replica"],
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
pin_memory=True, pin_memory=True,
...@@ -1147,7 +1146,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1147,7 +1146,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
worker_init_fn=seed_worker) worker_init_fn=seed_worker)
validation_loader = DataLoader(validation_set, validation_loader = DataLoader(validation_set,
batch_size=batch_size, batch_size=dataset_opts["batch_size"],
drop_last=False, drop_last=False,
pin_memory=True, pin_memory=True,
num_workers=training_opts["num_cpu"], num_workers=training_opts["num_cpu"],
...@@ -1552,8 +1551,8 @@ def train_epoch(model, ...@@ -1552,8 +1551,8 @@ def train_epoch(model,
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
if loss_criteria == 'aam': if loss_criteria == 'aam':
output_tuple, _ = model(data, target=target) output_tuple, _ = model(data, target=target)
output, no_margin_output = output_tuple margin_output, output = output_tuple
loss = criterion(output, target) loss = criterion(margin_output, target)
elif loss_criteria == 'smn': elif loss_criteria == 'smn':
output_tuple, _ = model(data, target=target) output_tuple, _ = model(data, target=target)
loss, output = output_tuple loss, output = output_tuple
...@@ -1588,7 +1587,7 @@ def train_epoch(model, ...@@ -1588,7 +1587,7 @@ def train_epoch(model,
optimizer.step() optimizer.step()
running_loss += loss.item() running_loss += loss.item()
accuracy += (torch.argmax(no_margin_output.data, 1) == target).sum().cpu() accuracy += (torch.argmax(output.data, 1) == target).sum().cpu()
batch_count += 1 batch_count += 1
if math.fmod(batch_idx, training_opts["log_interval"]) == 0: if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment