Commit 8ccb5265 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

track bug

parent 80656610
......@@ -62,7 +62,6 @@ class SideSampler(torch.utils.data.Sampler):
batch_size,
seed=0,
rank=0,
num_process=1,
num_replicas=1):
"""[summary]
......@@ -82,33 +81,29 @@ class SideSampler(torch.utils.data.Sampler):
self.epoch = 0
self.seed = seed
self.rank = rank
self.num_process = num_process
self.num_replicas = num_replicas
assert batch_size % examples_per_speaker == 0
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_replicas == 0
self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
self.batch_size = batch_size // examples_per_speaker
# reference all segment indexes per speaker
for idx in range(self.spk_count):
self.labels_to_indices[idx] = list()
for idx, value in enumerate(self.train_sessions):
self.labels_to_indices[value].append(idx)
# shuffle segments per speaker
# suffle segments per speaker
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
for idx, ldlist in enumerate(self.labels_to_indices.values()):
ldlist = numpy.array(ldlist)
self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0], generator=g).numpy()]
self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0]).numpy()]
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
numpy.random.seed(self.seed + self.epoch)
def __iter__(self):
# Generate batches per speaker
straight = numpy.arange(self.spk_count)
indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
......@@ -139,6 +134,9 @@ class SideSampler(torch.utils.data.Sampler):
# we want to convert the speaker indexes into segment indexes
self.index_iterator = numpy.zeros_like(batch_matrix)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# keep track of next segment index to sample for each speaker
for idx, value in enumerate(batch_matrix):
if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
......@@ -146,15 +144,13 @@ class SideSampler(torch.utils.data.Sampler):
self.segment_cursors[value] = 0
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1
self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
self.index_iterator = self.index_iterator.reshape(-1, self.num_replicas * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
return iter(self.index_iterator)
def __len__(self) -> int:
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_replicas
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
......
......@@ -1559,9 +1559,11 @@ def extract_embeddings(idmap_name,
device,
file_extension="wav",
transform_pipeline="",
frame_shift=1.5,
frame_duration=3.,
sliding_window=False,
win_duration=3.,
win_shift=1.5,
num_thread=1,
sample_rate=16000,
mixed_precision=False):
"""
......@@ -1569,14 +1571,13 @@ def extract_embeddings(idmap_name,
:param model_filename:
:param data_root_name:
:param device:
:param model_yaml:
:param speaker_number:
:param file_extension:
:param transform_pipeline:
:param frame_shift:
:param frame_duration:
:param extract_after_pooling:
:param sliding_window:
:param win_duration:
:param win_shift:
:param num_thread:
:param sample_rate:
:param mixed_precision:
:return:
"""
......@@ -1595,21 +1596,27 @@ def extract_embeddings(idmap_name,
else:
idmap = IdMap(idmap_name)
if type(model) is Xtractor:
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
model_cs = model.context_size()
else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
#if type(model) is Xtractor:
# min_duration = (model.context_size() - 1) * win_shift + win_duration
# model_cs = model.context_size()
#else:
# min_duration = (model.module.context_size() - 1) * win_shift + win_duration
# model_cs = model.module.context_size()
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
min_duration=1.5
transform_number=0,
sliding_window=sliding_window,
window_len=win_duration,
window_shift=win_shift,
sample_rate=sample_rate,
min_duration=0.165
)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment