Commit baae957c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

min duration in IdMapSet

parent d6d68feb
......@@ -552,6 +552,7 @@ class IdMapSet(Dataset):
data_root_path,
file_extension,
transform_pipeline=None,
frame_rate=100,
min_duration=0.165
):
"""
......@@ -569,7 +570,7 @@ class IdMapSet(Dataset):
self.len = self.idmap.leftids.shape[0]
self.transform_pipeline = transform_pipeline
self.min_duration = min_duration
self.sample_rate = sample_rate
self.sample_rate = frame_rate
_transform = []
if transform_pipeline is not None:
......@@ -601,10 +602,10 @@ class IdMapSet(Dataset):
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.sample_rate:
if stop - start <= self.min_duration * nfo.samplerate:
middle = (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.sample_rate / 2)))
stop = int(start + self.min_duration * nfo.sample_rate)
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop)
......
......@@ -557,10 +557,15 @@ class Xtractor(torch.nn.Module):
def context_size(self):
context = 1
for name, module in self.model.sequence_network.named_modules():
if name.startswith("conv"):
context += module.dilation[0] * (module.kernel_size[0] - 1)
if isinstance(self, Xtractor):
for name, module in self.sequence_network.named_modules():
if name.startswith("conv"):
context += module.dilation[0] * (module.kernel_size[0] - 1)
else:
for name, module in self.module.sequence_network.named_modules():
if name.startswith("conv"):
context += module.dilation[0] * (module.kernel_size[0] - 1)
return context
def xtrain(speaker_number,
dataset_yaml,
......@@ -946,7 +951,7 @@ def extract_embeddings(idmap_name,
data_root_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=frame_rate,
frame_rate=int(1 / frame_shift),
min_duration=model.context_size()
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment