Commit acb426e0 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sidesampler

parent 8acddae7
...@@ -54,13 +54,15 @@ class SideSampler(torch.utils.data.Sampler): ...@@ -54,13 +54,15 @@ class SideSampler(torch.utils.data.Sampler):
Data Sampler used to generate uniformly distributed batches Data Sampler used to generate uniformly distributed batches
""" """
def __init__(self, data_source, def __init__(self,
data_source,
spk_count, spk_count,
examples_per_speaker, examples_per_speaker,
samples_per_speaker, samples_per_speaker,
batch_size, batch_size,
seed=0, seed=0,
rank=0, rank=0,
num_process=1,
num_replicas=1): num_replicas=1):
"""[summary] """[summary]
...@@ -70,6 +72,7 @@ class SideSampler(torch.utils.data.Sampler): ...@@ -70,6 +72,7 @@ class SideSampler(torch.utils.data.Sampler):
examples_per_speaker ([type]): [description] examples_per_speaker ([type]): [description]
samples_per_speaker ([type]): [description] samples_per_speaker ([type]): [description]
batch_size ([type]): [description] batch_size ([type]): [description]
num_replicas: number of GPUs for parallel computing
""" """
self.train_sessions = data_source self.train_sessions = data_source
self.labels_to_indices = dict() self.labels_to_indices = dict()
...@@ -79,29 +82,33 @@ class SideSampler(torch.utils.data.Sampler): ...@@ -79,29 +82,33 @@ class SideSampler(torch.utils.data.Sampler):
self.epoch = 0 self.epoch = 0
self.seed = seed self.seed = seed
self.rank = rank self.rank = rank
self.num_process = num_process
self.num_replicas = num_replicas self.num_replicas = num_replicas
assert batch_size % examples_per_speaker == 0 assert batch_size % examples_per_speaker == 0
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_replicas == 0 assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
self.batch_size = batch_size // examples_per_speaker self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
# reference all segment indexes per speaker # reference all segment indexes per speaker
for idx in range(self.spk_count): for idx in range(self.spk_count):
self.labels_to_indices[idx] = list() self.labels_to_indices[idx] = list()
for idx, value in enumerate(self.train_sessions): for idx, value in enumerate(self.train_sessions):
self.labels_to_indices[value].append(idx) self.labels_to_indices[value].append(idx)
# suffle segments per speaker # shuffle segments per speaker
g = torch.Generator() g = torch.Generator()
g.manual_seed(self.seed + self.epoch) g.manual_seed(self.seed + self.epoch)
for idx, ldlist in enumerate(self.labels_to_indices.values()): for idx, ldlist in enumerate(self.labels_to_indices.values()):
ldlist = numpy.array(ldlist) ldlist = numpy.array(ldlist)
self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0]).numpy()] self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0], generator=g).numpy()]
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int) self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self): def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
numpy.random.seed(self.seed + self.epoch)
# Generate batches per speaker # Generate batches per speaker
straight = numpy.arange(self.spk_count) straight = numpy.arange(self.spk_count)
indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
...@@ -132,20 +139,19 @@ class SideSampler(torch.utils.data.Sampler): ...@@ -132,20 +139,19 @@ class SideSampler(torch.utils.data.Sampler):
# we want to convert the speaker indexes into segment indexes # we want to convert the speaker indexes into segment indexes
self.index_iterator = numpy.zeros_like(batch_matrix) self.index_iterator = numpy.zeros_like(batch_matrix)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# keep track of next segment index to sample for each speaker # keep track of next segment index to sample for each speaker
for idx, value in enumerate(batch_matrix): for idx, value in enumerate(batch_matrix):
if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1: if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
self.labels_to_indices[value] = self.labels_to_indices[value][torch.randperm(self.labels_to_indices[value].shape[0])] self.labels_to_indices[value] = self.labels_to_indices[value][torch.randperm(self.labels_to_indices[value].shape[0], generator=g)]
self.segment_cursors[value] = 0 self.segment_cursors[value] = 0
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]] self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1 self.segment_cursors[value] += 1
self.index_iterator = self.index_iterator.reshape(-1, self.num_replicas * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
return iter(self.index_iterator) self.index_iterator = torch.repeat_interleave(self.index_iterator, self.num_replicas)
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
return iter(self.index_iterator)
def __len__(self) -> int: def __len__(self) -> int:
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_replicas return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_replicas
...@@ -353,8 +359,8 @@ class IdMapSet(Dataset): ...@@ -353,8 +359,8 @@ class IdMapSet(Dataset):
transform_pipeline={}, transform_pipeline={},
transform_number=1, transform_number=1,
sliding_window=False, sliding_window=False,
window_len=24000, window_len=3.,
window_shift=8000, window_shift=1.,
sample_rate=16000, sample_rate=16000,
min_duration=0.165 min_duration=0.165
): ):
...@@ -378,10 +384,6 @@ class IdMapSet(Dataset): ...@@ -378,10 +384,6 @@ class IdMapSet(Dataset):
self.window_len = window_len self.window_len = window_len
self.window_shift = window_shift self.window_shift = window_shift
self.transform_number = transform_number self.transform_number = transform_number
#if self.transformation is not None:
# self.transform_list = self.transformation.split(",")
self.noise_df = None self.noise_df = None
if "add_noise" in self.transformation: if "add_noise" in self.transformation:
......
...@@ -861,6 +861,7 @@ def update_training_dictionary(dataset_description, ...@@ -861,6 +861,7 @@ def update_training_dictionary(dataset_description,
dataset_opts["train"]["sampler"] = dict() dataset_opts["train"]["sampler"] = dict()
dataset_opts["train"]["sampler"]["examples_per_speaker"] = 1 dataset_opts["train"]["sampler"]["examples_per_speaker"] = 1
dataset_opts["train"]["sampler"]["samples_per_speaker"] = 100 dataset_opts["train"]["sampler"]["samples_per_speaker"] = 100
dataset_opts["train"]["sampler"]["augmentation_replicas"] = 1
dataset_opts["train"]["transform_number"] = 2 dataset_opts["train"]["transform_number"] = 2
dataset_opts["train"]["transformation"] = dict() dataset_opts["train"]["transformation"] = dict()
dataset_opts["train"]["transformation"]["pipeline"] = "" dataset_opts["train"]["transformation"]["pipeline"] = ""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment