Commit ba6d19e2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

check min duration for idmapset

parent 9777d7dc
......@@ -335,7 +335,8 @@ class SideSet(Dataset):
set_type="train",
chunk_per_segment=1,
overlap=0.,
dataset_df=None
dataset_df=None,
min_duration=0.165
):
"""
......@@ -351,6 +352,7 @@ class SideSet(Dataset):
self.sample_rate = int(dataset["sample_rate"])
self.data_file_extension = dataset["data_file_extension"]
self.transformation = ''
self.min_duration = min_duration
if set_type == "train":
......@@ -545,7 +547,14 @@ class IdMapSet(Dataset):
DataSet that provide data according to a sidekit.IdMap object
"""
def __init__(self, idmap_name, data_root_path, file_extension, transform_pipeline=None):
def __init__(self,
idmap_name,
data_root_path,
file_extension,
transform_pipeline=None,
sample_rate=16000,
min_duration=0.165
):
"""
:param data_root_name:
......@@ -560,6 +569,8 @@ class IdMapSet(Dataset):
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
self.transform_pipeline = transform_pipeline
self.min_duration = min_duration
self.sample_rate = sample_rate
_transform = []
if transform_pipeline is not None:
......@@ -583,17 +594,18 @@ class IdMapSet(Dataset):
start = 0.0
if self.idmap.start[index] is None and self.idmap.stop[index] is None:
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
sig, sample_rate = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = 0
stop = len(sig)
else:
nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
# add this in case the segment is too short
if stop - start < 14400:
middle = (stop - start)//2
start = max(0, int(middle - 0.45 * 16000))
stop = int(start + 0.9 * 16000)
if stop - start <= self.min_duration * nfo.sample_rate:
middle = (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.sample_rate / 2)))
stop = int(start + self.min_duration * nfo.sample_rate)
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop)
......
......@@ -555,6 +555,12 @@ class Xtractor(torch.nn.Module):
return x
def context_size(self):
context = 1
for name, module in self.model.sequence_network.named_modules():
if name.startswith("conv"):
context += module.dilation[0] * (module.kernel_size[0] - 1)
def xtrain(speaker_number,
dataset_yaml,
......@@ -917,18 +923,32 @@ def extract_embeddings(idmap_name,
device,
file_extension="wav",
transform_pipeline=None,
frame_shift=0.01,
frame_duration=0.025,
num_thread=1):
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
model = Xtractor(speaker_number, model_archi=model_yaml)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
if isinstance(idmap_name, IdMap):
idmap = idmap_name
else:
idmap = IdMap(idmap_name)
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_root_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline)
transform_pipeline=transform_pipeline,
frame_rate=frame_rate,
min_duration=model.context_size()
)
dataloader = DataLoader(dataset,
......@@ -939,13 +959,6 @@ def extract_embeddings(idmap_name,
num_workers=num_thread)
with torch.no_grad():
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
model = Xtractor(speaker_number, model_archi=model_yaml)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
model.eval()
model.to(device)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment