Commit 49ac37f2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debg

parent 8ccb5265
...@@ -412,7 +412,7 @@ class IdMapSet(Dataset): ...@@ -412,7 +412,7 @@ class IdMapSet(Dataset):
else: else:
duration = int(self.idmap.stop[index] * 0.01) * self.sample_rate - start duration = int(self.idmap.stop[index] * 0.01) * self.sample_rate - start
# add this in case the segment is too short # add this in case the segment is too short
if duration <= self.self.min_duration * self.sample_rate: if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2 middle = start + duration // 2
start = max(0, int(middle - (self.min_sample_nb / 2))) start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = int(self.min_sample_nb) duration = int(self.min_sample_nb)
......
...@@ -1585,8 +1585,8 @@ def extract_embeddings(idmap_name, ...@@ -1585,8 +1585,8 @@ def extract_embeddings(idmap_name,
if isinstance(model_filename, str): if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device) checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"] speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"] model_opts = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi, loss=checkpoint["loss"]) model = Xtractor(speaker_number, model_archi=model_opts["model_type"], loss=model_opts["loss"]["type"])
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
else: else:
model = model_filename model = model_filename
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment