Commit 86a34b33 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 7e812325
...@@ -434,14 +434,16 @@ class IdMapSet(Dataset): ...@@ -434,14 +434,16 @@ class IdMapSet(Dataset):
if self.sliding_window: if self.sliding_window:
speech = speech.squeeze().unfold(0, self.window_len, self.window_shift) speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
middle_points = numpy.arange(start + self.window_len / 2, #middle_points = numpy.arange(start + self.window_len / 2,
start + duration - self.window_len / 2, # start + duration - self.window_len / 2,
self.window_shift) # self.window_shift)
starts = middle_points - self.window_shift / 2 #starts = middle_points - self.window_shift / 2
stops = middle_points + self.window_shift / 2 #stops = middle_points + self.window_shift / 2
starts[0] = start #starts[0] = start
start = starts #stops[-1] = start + duration
stops[-1] = start + duration #stop = stops
#start = starts
stop = start + duration
else: else:
stop = start + duration stop = start + duration
......
...@@ -1351,14 +1351,14 @@ def xtrain(dataset_description, ...@@ -1351,14 +1351,14 @@ def xtrain(dataset_description,
model = get_network(model_opts, local_rank) model = get_network(model_opts, local_rank)
embedding_size = model.embedding_size embedding_size = model.embedding_size
aam_scheduler = None aam_scheduler = None
if model.loss == "aam": #if model.loss == "aam":
aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"], # aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"],
final_margin=0.5, # final_margin=0.5,
final_steps_nb=120000, # final_steps_nb=120000,
update_frequency=25000, # update_frequency=25000,
mode='exp', # mode='exp',
Tau=1, # Tau=1,
verbose=True) # verbose=True)
# Set the device and manage parallel processing # Set the device and manage parallel processing
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
...@@ -1436,8 +1436,8 @@ def xtrain(dataset_description, ...@@ -1436,8 +1436,8 @@ def xtrain(dataset_description,
optimizer, optimizer,
scheduler, scheduler,
device, device,
scaler=scaler, scaler=scaler)
aam_scheduler=aam_scheduler) # aam_scheduler=aam_scheduler)
# Cross validation # Cross validation
if math.fmod(epoch, training_opts["validation_frequency"]) == 0: if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
...@@ -1726,6 +1726,9 @@ def extract_embeddings(idmap_name, ...@@ -1726,6 +1726,9 @@ def extract_embeddings(idmap_name,
modelset.extend(mod * data.shape[0]) modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0]) segset.extend(seg * data.shape[0])
if sliding_window: if sliding_window:
if vec.shape[0] > 1:
import ipdb
ipdb.set_trace()
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift, win_shift)) starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift, win_shift))
else: else:
starts.append(start) starts.append(start)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment