Commit 8e62d9aa authored by Anthony Larcher's avatar Anthony Larcher
Browse files

merge

parent 013117aa
......@@ -189,15 +189,16 @@ class Ndx:
"""
with h5py.File(input_file_name, "r") as f:
ndx = Ndx()
ndx.modelset = f.get("modelset")[()]
ndx.segset = f.get("segset")[()]
ndx.modelset = f["modelset"][()]
ndx.segset = f["segset"][()]
# if running python 3, need a conversion to unicode
if sys.version_info[0] == 3:
ndx.modelset = ndx.modelset.astype('U100', copy=False)
ndx.segset = ndx.segset.astype('U100', copy=False)
ndx.modelset = ndx.modelset.astype('U100')
ndx.segset = ndx.segset.astype('U100')
ndx.trialmask = f.get("trial_mask")[()].astype('bool')
ndx.trialmask = numpy.zeros((ndx.modelset.shape[0], ndx.segset.shape[0]), dtype=numpy.bool)
f["trial_mask"].read_direct(ndx.trialmask)
assert ndx.validate(), "Error: wrong Ndx format"
return ndx
......
......@@ -163,12 +163,18 @@ class Scores:
:return: a vector of target scores.
:return: a vector of non-target scores.
"""
new_score = self.align_with_ndx(key)
tarndx = key.tar & new_score.scoremask
nonndx = key.non & new_score.scoremask
tar = new_score.scoremat[tarndx]
non = new_score.scoremat[nonndx]
return tar, non
if (key.modelset == self.modelset).all() \
and (key.segset == self.segset).all() \
and self.scoremask.shape[0] == key.tar.shape[0] \
and self.scoremask.shape[1] == key.tar.shape[1]:
return self.scoremat[key.tar & self.scoremask], self.scoremat[key.non & self.scoremask]
else:
new_score = self.align_with_ndx(key)
tarndx = key.tar & new_score.scoremask
nonndx = key.non & new_score.scoremask
tar = new_score.scoremat[tarndx]
non = new_score.scoremat[nonndx]
return tar, non
def align_with_ndx(self, ndx):
"""The ordering in the output Scores object corresponds to ndx, so
......
......@@ -451,8 +451,8 @@ class Xtractor(torch.nn.Module):
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 20,
m = 0.3,
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
......@@ -747,7 +747,7 @@ class Xtractor(torch.nn.Module):
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
x = self.after_speaker_embedding(x, target=target), torch.nn.functional.normalize(x, dim=1)
return x
......@@ -938,7 +938,7 @@ def xtrain(speaker_number,
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"], stratify=df["speaker_idx"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"]) #, stratify=df["speaker_idx"])
torch.manual_seed(dataset_params['seed'])
......@@ -1108,7 +1108,8 @@ def xtrain(speaker_number,
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number' : speaker_number,
'model_archi': model_archi
'model_archi': model_archi,
'loss': loss
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
......@@ -1118,7 +1119,8 @@ def xtrain(speaker_number,
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number': speaker_number,
'model_archi': model_archi
'model_archi': model_archi,
'loss': loss
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
......@@ -1317,7 +1319,7 @@ def extract_embeddings(idmap_name,
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi, loss=loss)
model = Xtractor(speaker_number, model_archi=model_archi, loss=checkpoint["loss"])
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment