Commit 5a09cbe3 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent c11e4144
...@@ -147,7 +147,7 @@ class SideSampler(torch.utils.data.Sampler): ...@@ -147,7 +147,7 @@ class SideSampler(torch.utils.data.Sampler):
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]] self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1 self.segment_cursors[value] += 1
self.index_iterator = torch.repeat_interleave(self.index_iterator, self.num_replicas) self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten() self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
......
...@@ -79,7 +79,7 @@ __status__ = "Production" ...@@ -79,7 +79,7 @@ __status__ = "Production"
__docformat__ = 'reS' __docformat__ = 'reS'
def seed_worker(): def seed_worker(seed_val):
""" """
:param worker_id: :param worker_id:
...@@ -869,7 +869,7 @@ def update_training_dictionary(dataset_description, ...@@ -869,7 +869,7 @@ def update_training_dictionary(dataset_description,
dataset_opts["train"]["sampler"] = dict() dataset_opts["train"]["sampler"] = dict()
dataset_opts["train"]["sampler"]["examples_per_speaker"] = 1 dataset_opts["train"]["sampler"]["examples_per_speaker"] = 1
dataset_opts["train"]["sampler"]["samples_per_speaker"] = 100 dataset_opts["train"]["sampler"]["samples_per_speaker"] = 100
dataset_opts["train"]["sampler"]["augmentation_replicas"] = 1 dataset_opts["train"]["sampler"]["augmentation_replica"] = 1
dataset_opts["train"]["transform_number"] = 2 dataset_opts["train"]["transform_number"] = 2
dataset_opts["train"]["transformation"] = dict() dataset_opts["train"]["transformation"] = dict()
dataset_opts["train"]["transformation"]["pipeline"] = "" dataset_opts["train"]["transformation"]["pipeline"] = ""
...@@ -1072,7 +1072,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1072,7 +1072,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
seed=training_opts['torch_seed'], seed=training_opts['torch_seed'],
rank=local_rank, rank=local_rank,
num_process=torch.cuda.device_count(), num_process=torch.cuda.device_count(),
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replicas"] num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
) )
else: else:
batch_size = dataset_opts["batch_size"] batch_size = dataset_opts["batch_size"]
...@@ -1084,11 +1084,11 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1084,11 +1084,11 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
seed=training_opts['torch_seed'], seed=training_opts['torch_seed'],
rank=0, rank=0,
num_process=torch.cuda.device_count(), num_process=torch.cuda.device_count(),
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replicas"] num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
) )
training_loader = DataLoader(training_set, training_loader = DataLoader(training_set,
batch_size=batch_size, batch_size=batch_size * dataset_opts["train"]["sampler"]["augmentation_replica"],
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
pin_memory=True, pin_memory=True,
...@@ -1118,7 +1118,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0): ...@@ -1118,7 +1118,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
return training_loader, validation_loader, side_sampler, tar_indices, non_indices return training_loader, validation_loader, side_sampler, tar_indices, non_indices
def get_optimizer(model, model_opts, train_opts): def get_optimizer(model, model_opts, train_opts, training_loader):
""" """
:param model: :param model:
...@@ -1336,7 +1336,7 @@ def xtrain(dataset_description, ...@@ -1336,7 +1336,7 @@ def xtrain(dataset_description,
monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials") monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
# Create optimizer and scheduler # Create optimizer and scheduler
optimizer, scheduler = get_optimizer(model, model_opts, training_opts) optimizer, scheduler = get_optimizer(model, model_opts, training_opts, training_loader)
scaler = None scaler = None
if training_opts["mixed_precision"]: if training_opts["mixed_precision"]:
...@@ -1435,6 +1435,7 @@ def train_epoch(model, ...@@ -1435,6 +1435,7 @@ def train_epoch(model,
running_loss = 0.0 running_loss = 0.0
for batch_idx, (data, target) in enumerate(training_loader): for batch_idx, (data, target) in enumerate(training_loader):
data = data.squeeze().to(device) data = data.squeeze().to(device)
target = target.squeeze() target = target.squeeze()
target = target.to(device) target = target.to(device)
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment