Commit 748a7a3e authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent ae14ad83
......@@ -120,7 +120,6 @@ class SideSampler(torch.utils.data.Sampler):
if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
random.shuffle(self.labels_to_indices[value])
self.segment_cursors[value] = 0
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1
return iter(self.index_iterator)
......
......@@ -223,10 +223,10 @@ def test_metrics(model,
[type]: [description]
"""
idmap_test_filename = 'h5f/vox1_test_cleaned_idmap.h5'
ndx_test_filename = 'h5f/vox1_test_cleaned_ndx.h5'
key_test_filename = 'h5f/vox1_test_cleaned_key.h5'
data_root_name='/hdd/data/vox1/test/wav'
idmap_test_filename = '/lium/raid01_c/larcher/data/allies_dev_verif_idmap.h5'
ndx_test_filename = '/lium/raid01_c/larcher/data/allies_dev_verif_ndx.h5'
key_test_filename = '/lium/raid01_c/larcher/data/allies_dev_verif_key.h5'
data_root_name='/lium/corpus/base/ALLIES/wav'
transform_pipeline = dict()
......@@ -234,10 +234,10 @@ def test_metrics(model,
model_filename=model,
data_root_name=data_root_name,
device=device,
loss="aam",
transform_pipeline=transform_pipeline,
num_thread=num_thread,
mixed_precision=mixed_precision,
backward=False)
mixed_precision=mixed_precision)
tar, non = cosine_scoring(xv_stat,
xv_stat,
......@@ -891,7 +891,7 @@ class Xtractor(torch.nn.Module):
def update_training_dictionary(dataset_description,
model_description,
kwargs)
kwargs):
"""
speaker_number,
dataset_yaml,
......@@ -1562,7 +1562,7 @@ def xtrain(speaker_number,
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"]) #, stratify=df["speaker_idx"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"] , stratify=df["speaker_idx"])
torch.manual_seed(dataset_params['seed'])
......@@ -1602,6 +1602,7 @@ def xtrain(speaker_number,
num_workers=num_thread,
persistent_workers=False)
"""
Set the training options
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment