Commit def54548 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

back

parent 2237df98
......@@ -173,6 +173,8 @@ def data_augmentation(speech,
rir_nfo = rir_df.iloc[random.randrange(rir_df.shape[0])].file_id
rir_fn = transform_dict["add_reverb"]["data_path"] + rir_nfo # TODO harmonize with noise
rir, rir_fs = torchaudio.load(rir_fn)
assert rir_fs == sample_rate
#rir = rir[rir_nfo[1], :] #keep selected channel
speech = torch.tensor(signal.convolve(speech, rir, mode='full')[:, :speech.shape[1]])
if "add_noise" in augmentations:
......
......@@ -85,12 +85,12 @@ class SideSampler(torch.utils.data.Sampler):
self.num_process = num_process
self.num_replicas = num_replicas
#assert batch_size % (examples_per_speaker * self.num_replicas) == 0
assert batch_size % examples_per_speaker == 0
assert batch_size % (examples_per_speaker * self.num_replicas) == 0
#assert batch_size % examples_per_speaker == 0
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
#self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
self.batch_size = batch_size // self.examples_per_speaker
self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
#self.batch_size = batch_size // self.examples_per_speaker
# reference all segment indexes per speaker
for idx in range(self.spk_count):
......@@ -149,16 +149,17 @@ class SideSampler(torch.utils.data.Sampler):
self.segment_cursors[value] = 0
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
#self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
#self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
#self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
return iter(self.index_iterator)
def __len__(self) -> int:
#return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_process
#return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_process
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process
def set_epoch(self, epoch: int) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment