Commit 19d860e6 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

clean

parent 3cf12ccc
...@@ -488,6 +488,7 @@ class FeaturesServer(object): ...@@ -488,6 +488,7 @@ class FeaturesServer(object):
feat, label = self.post_processing(feat, label, global_mean, global_std) feat, label = self.post_processing(feat, label, global_mean, global_std)
else: else:
feat, label = self.post_processing(feat, label) feat, label = self.post_processing(feat, label)
return feat, label return feat, label
def get_features_per_speaker(self, show, idmap, channel=0, input_feature_filename=None, label=None): def get_features_per_speaker(self, show, idmap, channel=0, input_feature_filename=None, label=None):
......
...@@ -1687,6 +1687,7 @@ def extract_embeddings(idmap_name, ...@@ -1687,6 +1687,7 @@ def extract_embeddings(idmap_name,
model_filename, model_filename,
data_root_name, data_root_name,
device, device,
batch_size=1,
file_extension="wav", file_extension="wav",
transform_pipeline={}, transform_pipeline={},
sliding_window=False, sliding_window=False,
...@@ -1712,6 +1713,10 @@ def extract_embeddings(idmap_name, ...@@ -1712,6 +1713,10 @@ def extract_embeddings(idmap_name,
:param mixed_precision: :param mixed_precision:
:return: :return:
""" """
if sliding_window:
batch_size = 1
# Load the model # Load the model
if isinstance(model_filename, str): if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device) checkpoint = torch.load(model_filename, map_location=device)
...@@ -1741,7 +1746,7 @@ def extract_embeddings(idmap_name, ...@@ -1741,7 +1746,7 @@ def extract_embeddings(idmap_name,
) )
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=1, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
pin_memory=True, pin_memory=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment