Commit 3de11dde authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

bugfixes

parent 497fe561
......@@ -1124,12 +1124,11 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
else:
batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
batch_size=batch_size,
batch_size=dataset_opts["batch_size"],
seed=training_opts['torch_seed'],
rank=0,
num_process=torch.cuda.device_count(),
......@@ -1137,7 +1136,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
)
training_loader = DataLoader(training_set,
batch_size=batch_size * dataset_opts["train"]["sampler"]["augmentation_replica"],
batch_size=dataset_opts["batch_size"] * dataset_opts["train"]["sampler"]["augmentation_replica"],
shuffle=False,
drop_last=True,
pin_memory=True,
......@@ -1147,7 +1146,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
worker_init_fn=seed_worker)
validation_loader = DataLoader(validation_set,
batch_size=batch_size,
batch_size=dataset_opts["batch_size"],
drop_last=False,
pin_memory=True,
num_workers=training_opts["num_cpu"],
......@@ -1552,8 +1551,8 @@ def train_epoch(model,
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output_tuple, _ = model(data, target=target)
output, no_margin_output = output_tuple
loss = criterion(output, target)
margin_output, output = output_tuple
loss = criterion(margin_output, target)
elif loss_criteria == 'smn':
output_tuple, _ = model(data, target=target)
loss, output = output_tuple
......@@ -1588,7 +1587,7 @@ def train_epoch(model,
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(no_margin_output.data, 1) == target).sum().cpu()
accuracy += (torch.argmax(output.data, 1) == target).sum().cpu()
batch_count += 1
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment