Commit 86a34b33 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 7e812325
......@@ -434,14 +434,16 @@ class IdMapSet(Dataset):
if self.sliding_window:
speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
middle_points = numpy.arange(start + self.window_len / 2,
start + duration - self.window_len / 2,
self.window_shift)
starts = middle_points - self.window_shift / 2
stops = middle_points + self.window_shift / 2
starts[0] = start
start = starts
stops[-1] = start + duration
#middle_points = numpy.arange(start + self.window_len / 2,
# start + duration - self.window_len / 2,
# self.window_shift)
#starts = middle_points - self.window_shift / 2
#stops = middle_points + self.window_shift / 2
#starts[0] = start
#stops[-1] = start + duration
#stop = stops
#start = starts
stop = start + duration
else:
stop = start + duration
......
......@@ -1351,14 +1351,14 @@ def xtrain(dataset_description,
model = get_network(model_opts, local_rank)
embedding_size = model.embedding_size
aam_scheduler = None
if model.loss == "aam":
aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"],
final_margin=0.5,
final_steps_nb=120000,
update_frequency=25000,
mode='exp',
Tau=1,
verbose=True)
#if model.loss == "aam":
# aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"],
# final_margin=0.5,
# final_steps_nb=120000,
# update_frequency=25000,
# mode='exp',
# Tau=1,
# verbose=True)
# Set the device and manage parallel processing
torch.cuda.set_device(local_rank)
......@@ -1436,8 +1436,8 @@ def xtrain(dataset_description,
optimizer,
scheduler,
device,
scaler=scaler,
aam_scheduler=aam_scheduler)
scaler=scaler)
# aam_scheduler=aam_scheduler)
# Cross validation
if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
......@@ -1726,6 +1726,9 @@ def extract_embeddings(idmap_name,
modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0])
if sliding_window:
if vec.shape[0] > 1:
import ipdb
ipdb.set_trace()
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift, win_shift))
else:
starts.append(start)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment