Commit 47510e91 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents e1f5d369 0df15d91
...@@ -221,7 +221,7 @@ class FeaturesServer(object): ...@@ -221,7 +221,7 @@ class FeaturesServer(object):
feat = pca_dct(feat, self.dct_pca_config[0], self.dct_pca_config[1], self.dct_pca_config[2]) feat = pca_dct(feat, self.dct_pca_config[0], self.dct_pca_config[1], self.dct_pca_config[2])
elif self.sdc: elif self.sdc:
feat = shifted_delta_cepstral(feat, d=self.sdc_config[0], p=self.sdc_config[1], k=self.sdc_config[2]) feat = shifted_delta_cepstral(feat, d=self.sdc_config[0], p=self.sdc_config[1], k=self.sdc_config[2])
# Apply a mask on the features # Apply a mask on the features
if self.mask is not None: if self.mask is not None:
feat = self._mask(feat) feat = self._mask(feat)
...@@ -488,6 +488,7 @@ class FeaturesServer(object): ...@@ -488,6 +488,7 @@ class FeaturesServer(object):
feat, label = self.post_processing(feat, label, global_mean, global_std) feat, label = self.post_processing(feat, label, global_mean, global_std)
else: else:
feat, label = self.post_processing(feat, label) feat, label = self.post_processing(feat, label)
return feat, label return feat, label
def get_features_per_speaker(self, show, idmap, channel=0, input_feature_filename=None, label=None): def get_features_per_speaker(self, show, idmap, channel=0, input_feature_filename=None, label=None):
......
...@@ -48,14 +48,7 @@ from .preprocessor import RawPreprocessor ...@@ -48,14 +48,7 @@ from .preprocessor import RawPreprocessor
from .preprocessor import MfccFrontEnd from .preprocessor import MfccFrontEnd
from .preprocessor import MelSpecFrontEnd from .preprocessor import MelSpecFrontEnd
has_pyroom = True
try:
import pyroomacoustics
except ImportError:
has_pyroom = False
if has_pyroom:
from .augmentation import AddReverb
__author__ = "Anthony Larcher and Sylvain Meignier" __author__ = "Anthony Larcher and Sylvain Meignier"
......
...@@ -1480,17 +1480,17 @@ def xtrain(dataset_description, ...@@ -1480,17 +1480,17 @@ def xtrain(dataset_description,
validation_non_indices, validation_non_indices,
training_opts["mixed_precision"]) training_opts["mixed_precision"])
test_eer = None #test_eer = None
if training_opts["compute_test_eer"] and local_rank < 1: #if training_opts["compute_test_eer"] and local_rank < 1:
test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts) # test_eer = test_metrics(model, device, model_opts, dataset_opts, training_opts)
monitor.update(test_eer=test_eer, monitor.update(test_eer=test_eer,
val_eer=val_eer, val_eer=val_eer,
val_loss=val_loss, val_loss=val_loss,
val_acc=val_acc) val_acc=val_acc)
if local_rank < 1: #if local_rank < 1:
monitor.display() # monitor.display()
# Save the current model and if needed update the best one # Save the current model and if needed update the best one
# TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR # TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
...@@ -1675,6 +1675,7 @@ def extract_embeddings(idmap_name, ...@@ -1675,6 +1675,7 @@ def extract_embeddings(idmap_name,
model_filename, model_filename,
data_root_name, data_root_name,
device, device,
batch_size=1,
file_extension="wav", file_extension="wav",
transform_pipeline={}, transform_pipeline={},
sliding_window=False, sliding_window=False,
...@@ -1700,6 +1701,10 @@ def extract_embeddings(idmap_name, ...@@ -1700,6 +1701,10 @@ def extract_embeddings(idmap_name,
:param mixed_precision: :param mixed_precision:
:return: :return:
""" """
if sliding_window:
batch_size = 1
# Load the model # Load the model
if isinstance(model_filename, str): if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device) checkpoint = torch.load(model_filename, map_location=device)
...@@ -1729,7 +1734,7 @@ def extract_embeddings(idmap_name, ...@@ -1729,7 +1734,7 @@ def extract_embeddings(idmap_name,
) )
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=1, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
pin_memory=True, pin_memory=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment