Commit 6a35ac35 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents d89161f4 0c053d7f
......@@ -478,6 +478,7 @@ class IdMapSetPerSpeaker(Dataset):
data_root_path,
file_extension,
transform_pipeline={},
transform_number=1,
frame_rate=100,
min_duration=0.165
):
......@@ -494,7 +495,6 @@ class IdMapSetPerSpeaker(Dataset):
self.data_root_path = data_root_path
self.file_extension = file_extension
self.len = len(set(self.idmap.leftids))
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.speaker_list = list(set(self.idmap.leftids))
......@@ -503,13 +503,8 @@ class IdMapSetPerSpeaker(Dataset):
self.output_im.rightids = self.output_im.leftids
self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.transform = []
#if (len(self.transformation) > 0):
# if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
# self.transform_list = self.transformation["pipeline"].split(',')
if self.transformation is not None:
self.transform_list = self.transformation.split(",")
self.transformation = transform_pipeline
self.transform_number = transform_number
self.noise_df = None
if "add_noise" in self.transform:
......@@ -555,10 +550,10 @@ class IdMapSetPerSpeaker(Dataset):
speech = torch.cat(tmp_data, dim=1)
speech += 10e-6 * torch.randn(speech.shape)
if len(self.transform) > 0:
if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech,
speech_fs,
self.transform,
self.transformation,
self.transform_number,
noise_df=self.noise_df,
rir_df=self.rir_df)
......
......@@ -1787,10 +1787,7 @@ def extract_embeddings_per_speaker(idmap_name,
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
if extract_after_pooling:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
else:
emb_size = model.embedding_size
emb_size = model.embedding_size
# Create the StatServer
embeddings = StatServer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment