Commit 2acf04a7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent a681c617
......@@ -85,7 +85,8 @@ class SideSampler(torch.utils.data.Sampler):
self.num_process = num_process
self.num_replicas = num_replicas
assert batch_size % (examples_per_speaker * self.num_replicas) == 0
assert batch_size % examples_per_speaker == 0
#assert batch_size % (examples_per_speaker * self.num_replicas) == 0
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
......@@ -106,10 +107,6 @@ class SideSampler(torch.utils.data.Sampler):
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
numpy.random.seed(self.seed + self.epoch)
# Generate batches per speaker
straight = numpy.arange(self.spk_count)
indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
......@@ -140,6 +137,9 @@ class SideSampler(torch.utils.data.Sampler):
# we want to convert the speaker indexes into segment indexes
self.index_iterator = numpy.zeros_like(batch_matrix)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# keep track of next segment index to sample for each speaker
for idx, value in enumerate(batch_matrix):
if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
......@@ -148,14 +148,15 @@ class SideSampler(torch.utils.data.Sampler):
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1
self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
#self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
#self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
return iter(self.index_iterator)
def __len__(self) -> int:
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process
#return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_process
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment