Commit d59dbae1 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

learning strategy

parent 2d5520b9
......@@ -389,7 +389,7 @@ class SpkSet(Dataset):
self._spk_dict = spk_dict
self._spk_index = list(spk_dict.keys())
self.len = len(self._spk_index)
self.len = 10 * len(self._spk_index)
for idx, speaker in enumerate(self._spk_index):
self._spk_dict[speaker]['num_segs'] = len(self._spk_dict[speaker]['segments'])
......@@ -451,7 +451,7 @@ class SpkSet(Dataset):
:return:
"""
current_speaker = self._spk_index[index]
current_speaker = self._spk_index[index % len(self._spk_index)]
segment_index = numpy.random.choice(self._spk_dict[current_speaker]['num_segs'], p=self._spk_dict[current_speaker]['p'])
self._spk_dict[current_speaker]['p'][segment_index] /= 2
current_segment = self._spk_dict[current_speaker]['segments'][segment_index]
......
......@@ -908,19 +908,19 @@ def xtrain(speaker_number,
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
#training_set = SpkSet(dataset_yaml,
# set_type="train",
# dataset_df=training_df,
# overlap=dataset_params['train']['overlap'],
# output_format="pytorch",
# windowed=True)
training_set = SpkSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
overlap=dataset_params['train']['overlap'],
output_format="pytorch",
windowed=True)
training_set = SideSet(dataset_yaml,
set_type="train",
overlap=dataset_params['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
)
#training_set = SideSet(dataset_yaml,
# set_type="train",
# overlap=dataset_params['train']['overlap'],
# dataset_df=training_df,
# output_format="pytorch",
# )
validation_set = SideSet(dataset_yaml,
......@@ -991,8 +991,8 @@ def xtrain(speaker_number,
optimizer = _optimizer(param_list, **_options)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=numpy.arange(50000,160000,10000),
gamma=0.1,
milestones=numpy.arange(50,10000,10),
gamma=0.95,
last_epoch=-1,
verbose=False)
......@@ -1037,7 +1037,7 @@ def xtrain(speaker_number,
print("end of train epoch")
# Add the cross validation here
if math.fmod(epoch, 5) == 0:
if math.fmod(epoch, 100) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size], mixed_precision)
test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment