Commit 6f9f9cb9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug extract_emebeddings

parent 002e973d
......@@ -419,7 +419,7 @@ class IdMapSet(Dataset):
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
duration = int(speech.shape[1] - start)
else:
duration = int(self.idmap.stop[index] * 0.01) * self.sample_rate - start
duration = int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start
# add this in case the segment is too short
if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2
......@@ -434,6 +434,16 @@ class IdMapSet(Dataset):
if self.sliding_window:
speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
middle_points = numpy.arange(start + self.window_len / 2,
start + duration - self.window_len / 2,
self.window_shift)
starts = middle_points - self.window_shift / 2
stops = middle_points + self.window_shift / 2
starts[0] = start
start = starts
stops[-1] = start + duration
else:
stop = start + duration
if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech,
......@@ -445,7 +455,7 @@ class IdMapSet(Dataset):
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, start + duration
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
def __len__(self):
"""
......
......@@ -1260,6 +1260,45 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
}, training_monitor.is_best, filename=training_opts["tmp_model_name"], best_filename=training_opts["best_model_name"])
class AAMScheduler():
"""
For now we only update margin
"""
def __init__(self, original_margin, final_margin, final_steps_nb, update_frequency, mode='lin', Tau=1, verbose=True):
"""
:param final_margin:
:param num_epochs:
:param mode: can be linear or exp
:param verbose:
"""
self.current_margin = original_margin
self.original_margin = original_margin
self.final_margin = final_margin
self.final_steps_nb = final_steps_nb
self.update_frequency = update_frequency
self.mode = mode
self.Tau = Tau
self.verbose = verbose
self._counter = 0
def __step__(self):
self._counter += 1
if self._counter % self.update_frequency == 0:
# update the parameters
if self.mode == "lin":
self.current_margin = self.original_margin + \
(self.final_margin - self.original_margin) * \
(self._counter / self.final_steps_nb)
else:
self.current_margin = self.original_margin + \
(self.final_margin - self.original_margin) * \
(1 - numpy.exp(-self._counter / (self.final_steps_nb/7)))
return self.current_margin
def xtrain(dataset_description,
model_description,
training_description,
......@@ -1311,6 +1350,15 @@ def xtrain(dataset_description,
# Initialize the model
model = get_network(model_opts, local_rank)
embedding_size = model.embedding_size
aam_scheduler = None
if model.loss == "aam":
aam_scheduler = AAMScheduler(model_opts["loss"]["aam_margin"],
final_margin=0.5,
final_steps_nb=120000,
update_frequency=25000,
mode='exp',
Tau=1,
verbose=True)
# Set the device and manage parallel processing
torch.cuda.set_device(local_rank)
......@@ -1347,7 +1395,7 @@ def xtrain(dataset_description,
print("Train on a single GPU")
# Initialise data loaders
training_loader, validation_loader, \
training_loader, validation_loader,\
sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
model_opts)
......@@ -1388,7 +1436,8 @@ def xtrain(dataset_description,
optimizer,
scheduler,
device,
scaler=scaler)
scaler=scaler,
aam_scheduler=aam_scheduler)
# Cross validation
if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
......@@ -1435,7 +1484,8 @@ def train_epoch(model,
scheduler,
device,
scaler=None,
clipping=False):
clipping=False,
aam_scheduler=None):
"""
:param model:
......@@ -1447,6 +1497,7 @@ def train_epoch(model,
:param device:
:param scaler:
:param clipping:
:param aam_scheduler:
:return:
"""
model.train()
......@@ -1526,6 +1577,8 @@ def train_epoch(model,
scheduler.step(training_monitor.best_eer)
else:
scheduler.step()
if aam_scheduler is not None:
model.after_speaker_embedding.margin = aam_scheduler.step()
return model
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
......@@ -1670,9 +1723,9 @@ def extract_embeddings(idmap_name,
_, vec = model(x=td.to(device), is_eval=True)
embed.append(vec.detach().cpu())
modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0])
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift , win_shift))
modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0])
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift, win_shift))
embeddings = StatServer()
embeddings.modelset = numpy.array(modelset).astype('>U')
......@@ -1680,7 +1733,6 @@ def extract_embeddings(idmap_name,
embeddings.start = numpy.array(starts)
embeddings.stop = numpy.array(starts) + win_duration
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
print(f"type = {type(embed)}, {type(embed[0])}")
embeddings.stat1 = numpy.concatenate(embed)
return embeddings
......@@ -1854,8 +1906,8 @@ def extract_sliding_embedding(idmap_name,
for td in tmp_data:
vec = model(x=td.to(device), is_eval=True)
embeddings.append(vec.detach().cpu())
modelset += [mod,] * data.shape[0]
segset += [seg,] * data.shape[0]
modelset += [mod, ] * data.shape[0]
segset += [seg, ] * data.shape[0]
starts += [numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift),]
#REPRENDRE ICI
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment