Commit 6a35ac35 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents d89161f4 0c053d7f
...@@ -478,6 +478,7 @@ class IdMapSetPerSpeaker(Dataset): ...@@ -478,6 +478,7 @@ class IdMapSetPerSpeaker(Dataset):
data_root_path, data_root_path,
file_extension, file_extension,
transform_pipeline={}, transform_pipeline={},
transform_number=1,
frame_rate=100, frame_rate=100,
min_duration=0.165 min_duration=0.165
): ):
...@@ -494,7 +495,6 @@ class IdMapSetPerSpeaker(Dataset): ...@@ -494,7 +495,6 @@ class IdMapSetPerSpeaker(Dataset):
self.data_root_path = data_root_path self.data_root_path = data_root_path
self.file_extension = file_extension self.file_extension = file_extension
self.len = len(set(self.idmap.leftids)) self.len = len(set(self.idmap.leftids))
self.transformation = transform_pipeline
self.min_duration = min_duration self.min_duration = min_duration
self.sample_rate = frame_rate self.sample_rate = frame_rate
self.speaker_list = list(set(self.idmap.leftids)) self.speaker_list = list(set(self.idmap.leftids))
...@@ -503,13 +503,8 @@ class IdMapSetPerSpeaker(Dataset): ...@@ -503,13 +503,8 @@ class IdMapSetPerSpeaker(Dataset):
self.output_im.rightids = self.output_im.leftids self.output_im.rightids = self.output_im.leftids
self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O") self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O") self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.transformation = transform_pipeline
self.transform = [] self.transform_number = transform_number
#if (len(self.transformation) > 0):
# if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
# self.transform_list = self.transformation["pipeline"].split(',')
if self.transformation is not None:
self.transform_list = self.transformation.split(",")
self.noise_df = None self.noise_df = None
if "add_noise" in self.transform: if "add_noise" in self.transform:
...@@ -555,10 +550,10 @@ class IdMapSetPerSpeaker(Dataset): ...@@ -555,10 +550,10 @@ class IdMapSetPerSpeaker(Dataset):
speech = torch.cat(tmp_data, dim=1) speech = torch.cat(tmp_data, dim=1)
speech += 10e-6 * torch.randn(speech.shape) speech += 10e-6 * torch.randn(speech.shape)
if len(self.transform) > 0: if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech, speech = data_augmentation(speech,
speech_fs, speech_fs,
self.transform, self.transformation,
self.transform_number, self.transform_number,
noise_df=self.noise_df, noise_df=self.noise_df,
rir_df=self.rir_df) rir_df=self.rir_df)
......
...@@ -1787,10 +1787,7 @@ def extract_embeddings_per_speaker(idmap_name, ...@@ -1787,10 +1787,7 @@ def extract_embeddings_per_speaker(idmap_name,
# Get the size of embeddings to extract # Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight' name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
if extract_after_pooling: emb_size = model.embedding_size
emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
else:
emb_size = model.embedding_size
# Create the StatServer # Create the StatServer
embeddings = StatServer() embeddings = StatServer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment