Commit f391adb6 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parents 54484b38 86a34b33
...@@ -434,14 +434,16 @@ class IdMapSet(Dataset): ...@@ -434,14 +434,16 @@ class IdMapSet(Dataset):
if self.sliding_window: if self.sliding_window:
speech = speech.squeeze().unfold(0, self.window_len, self.window_shift) speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
middle_points = numpy.arange(start + self.window_len / 2, #middle_points = numpy.arange(start + self.window_len / 2,
start + duration - self.window_len / 2, # start + duration - self.window_len / 2,
self.window_shift) # self.window_shift)
starts = middle_points - self.window_shift / 2 #starts = middle_points - self.window_shift / 2
stops = middle_points + self.window_shift / 2 #stops = middle_points + self.window_shift / 2
starts[0] = start #starts[0] = start
start = starts #stops[-1] = start + duration
stops[-1] = start + duration #stop = stops
#start = starts
stop = start + duration
else: else:
stop = start + duration stop = start + duration
......
...@@ -1354,14 +1354,14 @@ def xtrain(dataset_description, ...@@ -1354,14 +1354,14 @@ def xtrain(dataset_description,
model = get_network(model_opts, local_rank) model = get_network(model_opts, local_rank)
embedding_size = model.embedding_size embedding_size = model.embedding_size
aam_scheduler = None aam_scheduler = None
if model.loss == "aam": #if model.loss == "aam":
aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"], # aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"],
final_margin=0.5, # final_margin=0.5,
final_steps_nb=120000, # final_steps_nb=120000,
update_frequency=25000, # update_frequency=25000,
mode='exp', # mode='exp',
Tau=1, # Tau=1,
verbose=True) # verbose=True)
# Set the device and manage parallel processing # Set the device and manage parallel processing
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
...@@ -1439,8 +1439,8 @@ def xtrain(dataset_description, ...@@ -1439,8 +1439,8 @@ def xtrain(dataset_description,
optimizer, optimizer,
scheduler, scheduler,
device, device,
scaler=scaler, scaler=scaler)
aam_scheduler=aam_scheduler) # aam_scheduler=aam_scheduler)
# Cross validation # Cross validation
if math.fmod(epoch, training_opts["validation_frequency"]) == 0: if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment