Commit d89bf646 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent b7df1163
......@@ -31,7 +31,7 @@ Copyright 2014-2021 Anthony Larcher and Sylvain Meignier
from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import IdMapSet_per_speaker
from .xsets import IdMapSet_per_speaker, SpkSet
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling
from .res_net import ResBlock, ResNet18, PreResNet34
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
......
......@@ -237,40 +237,41 @@ class ArcMarginProduct(torch.nn.Module):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.s = torch.tensor(s)
self.m = torch.tensor(m)
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
torch.nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
self.easy_margin = torch.tensor(easy_margin)
self.cos_m = torch.tensor(math.cos(m))
self.sin_m = torch.tensor(math.sin(m))
self.th = torch.tensor(math.cos(math.pi - m))
self.mm = torch.tensor(math.sin(math.pi - m) * m)
def forward(self, input, target):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))
if target is None:
return cosine * self.s
else:
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
with torch.cuda.amp.autocast(enabled=False):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))
if target is None:
return cosine * self.s
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + (
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
# print(output)
return output
sine = torch.sqrt(torch.tensor(1.0) - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > torch.tensor(0), phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + (
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
output *= self.s
# print(output)
return output
......
......@@ -30,7 +30,6 @@ import copy
import glob
import h5py
import numpy
import multiprocessing
import pandas
import os
import pickle
......@@ -64,7 +63,6 @@ __docformat__ = 'reStructuredText'
wav_type = "float32" # can be int16, float64, int32 or float32
torch.multiprocessing.set_sharing_strategy('file_system')
def write_batch(batch_idx, data, label, batch_fn_format):
"""
......@@ -394,7 +392,6 @@ class SpkSet(Dataset):
_transform.append(PreEmphasis())
if 'add_noise' in t:
_transform.append(AddNoise(noise_db_csv=self.transformation["noise_db_csv"],
snr_min_max=self.transformation["noise_snr"],
noise_root_path=self.transformation["noise_root_db"]))
......@@ -406,9 +403,6 @@ class SpkSet(Dataset):
except ImportError:
has_pyroom = False
if has_pyroom:
#self.add_reverb[:int(self.len * self.transformation["reverb_file_ratio"])] = 1
#numpy.random.shuffle(self.add_reverb)
_transform.append(AddReverb(depth=self.transformation["reverb_depth"],
width=self.transformation["reverb_width"],
height=self.transformation["reverb_height"],
......@@ -574,7 +568,7 @@ class SideSet(Dataset):
weight_dict = dict()
# For each segment, get all possible segments with the current overlap
for idx in tqdm.trange(len(tmp_sessions)):
for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1):
current_session = tmp_sessions.iloc[idx]
# Compute possible starts
possible_starts = numpy.arange(0,
......@@ -837,26 +831,26 @@ class IdMapSet(Dataset):
_transform = []
if transform_pipeline is not None:
trans = transform_pipeline.split(",")
for t in trans:
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'MFCC' in t:
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(MFCC(lowfreq=transform_pipeline['MFCC']['lowfreq'],
maxfreq=transform_pipeline['MFCC']['maxfreq'],
nlogfilt=transform_pipeline['MFCC']['nb_filters'],
win_time=transform_pipeline['MFCC']['win_time'],
nceps=transform_pipeline['MFCC']['nb_ceps'],
shift=transform_pipeline['MFCC']['shift'],
n_fft=transform_pipeline['MFCC']['n_fft']))
if 'add_noise' in t:
self.add_noise = numpy.ones(self.idmap.leftids.shape[0], dtype=bool)
numpy.random.shuffle(self.add_noise)
_transform.append(AddNoise(noise_db_csv="list/musan.csv",
snr_min_max=[5.0, 15.0],
noise_root_path="./data/musan/"))
trans = transform_pipeline.keys()
if 'PreEmphasis' in transform_pipeline.keys():
_transform.append(PreEmphasis())
if 'MFCC' in transform_pipeline.keys():
_transform.append(MFCC(lowfreq=transform_pipeline['MFCC']['lowfreq'],
maxfreq=transform_pipeline['MFCC']['maxfreq'],
nlogfilt=transform_pipeline['MFCC']['nb_filters'],
win_time=transform_pipeline['MFCC']['win_time'],
nceps=transform_pipeline['MFCC']['nb_ceps'],
shift=transform_pipeline['MFCC']['shift'],
n_fft=transform_pipeline['MFCC']['n_fft']))
if 'CMVN' in transform_pipeline.keys():
_transform.append(CMVN())
if 'add_noise' in transform_pipeline.keys():
self.add_noise = numpy.ones(self.idmap.leftids.shape[0], dtype=bool)
numpy.random.shuffle(self.add_noise)
_transform.append(AddNoise(noise_db_csv="list/musan.csv",
snr_min_max=[5.0, 15.0],
noise_root_path="./data/musan/"))
self.transforms = transforms.Compose(_transform)
......
......@@ -65,8 +65,6 @@ from .loss import l2_norm
from .loss import ArcMarginProduct
os.environ['MKL_THREADING_LAYER'] = 'GNU'
__license__ = "LGPL"
......@@ -163,12 +161,9 @@ def plot_classes_preds(model, speech, labels):
return fig
def compute_metrics(model,
validation_loader,
device,
val_embs_shape,
speaker_number,
model_archi):
def test_metrics(model,
device,
speaker_number):
"""Compute model metrics
Args:
......@@ -185,25 +180,41 @@ def compute_metrics(model,
Returns:
[type]: [description]
"""
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, val_embs_shape)
#xv_stat = extract_embeddings(idmap_name='h5f/idmap_test.h5',
# speaker_number=speaker_number,
# model_filename=model,
# model_yaml=model_archi,
# data_root_name="data/vox1/wav/" ,
# device=device,
# transform_pipeline="MFCC,CMVN")
idmap_test_filename = 'h5f/idmap_test.h5'
ndx_test_filename = 'h5f/ndx_test.h5'
key_test_filename = 'h5f/key_test.h5'
data_root_name='/ssd/rsgb7088/vox1/test/wav'
transform_pipeline = dict()
mfcc_config = dict()
mfcc_config['nb_filters'] = 40
mfcc_config['nb_ceps'] = 30
mfcc_config['lowfreq'] = 133.333
mfcc_config['maxfreq'] = 6855.4976
mfcc_config['win_time'] = 0.025
mfcc_config['shift'] = 0.01
mfcc_config['n_fft'] = 512
transform_pipeline['MFCC'] = mfcc_config
transform_pipeline['CMVN'] = {}
xv_stat = extract_embeddings(idmap_name=idmap_test_filename,
speaker_number=speaker_number,
model_filename=model,
data_root_name=data_root_name,
device=device,
transform_pipeline=transform_pipeline)
scores = cosine_scoring(xv_stat,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True)
#scores = cosine_scoring(xv_stat, xv_stat,
# Ndx('h5f/ndx_test.h5'),
# wccn=None, check_missing=True)
tar, non = scores.get_tar_non(Key(key_test_filename))
#tar, non = scores.get_tar_non(Key('h5f/key_test.h5'))
#pmiss, pfa = rocch(numpy.array(tar).astype(numpy.double), numpy.array(non).astype(numpy.double))
#test_eer = rocch2eer(pmiss, pfa)
test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
test_eer = 0.0
return val_acc, val_loss, val_eer, test_eer
return test_eer
def get_lr(optimizer):
......@@ -674,9 +685,9 @@ class Xtractor(torch.nn.Module):
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(l2_norm(x), target=target), l2_norm(x)
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=target), torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(l2_norm(x), target=None), l2_norm(x)
x = self.after_speaker_embedding(torch.nn.functional.normalize(x, dim=1), target=None), torch.nn.functional.normalize(x, dim=1)
return x
......@@ -936,16 +947,16 @@ def xtrain(speaker_number,
curr_patience = patience
logging.critical("Compute EER before starting")
#val_acc, val_loss, val_eer, test_eer = compute_metrics(model,
# validation_loader,
# device,
# [validation_set.__len__(), embedding_size],
# speaker_number,
# model_archi)
val_acc, val_loss, val_eer = cross_validation(model,
validation_loader,
device,
[validation_set.__len__(),
embedding_size])
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
test_eer = test_metrics(model, device, speaker_number)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
......@@ -963,53 +974,52 @@ def xtrain(speaker_number,
tb_writer=writer)
# Add the cross validation here
val_acc, val_loss, val_eer, test_eer = compute_metrics(model,
validation_loader,
device,
[validation_set.__len__(), embedding_size],
speaker_number,
model_archi)
if math.fmod(epoch, 136) == 0:
val_acc, val_loss, val_eer = cross_validation(model, validation_loader, device, [validation_set.__len__(), embedding_size])
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
test_eer = test_metrics(model, device, speaker_number)
# Decrease learning rate according to the scheduler policy
#scheduler.step(val_loss)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %")
# remember best accuracy and save checkpoint
is_best = val_acc > best_accuracy
best_accuracy = max(val_acc, best_accuracy)
# Decrease learning rate according to the scheduler policy
#scheduler.step(val_loss)
if tmp_model_name is None:
tmp_model_name = "tmp_model"
if best_model_name is None:
best_model_name = "best_model"
# remember best accuracy and save checkpoint
is_best = val_acc > best_accuracy
best_accuracy = max(val_acc, best_accuracy)
if type(model) is Xtractor:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number' : speaker_number,
'model_archi': model_archi
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number': speaker_number,
'model_archi': model_archi
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
if tmp_model_name is None:
tmp_model_name = "tmp_model"
if best_model_name is None:
best_model_name = "best_model"
if type(model) is Xtractor:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number' : speaker_number,
'model_archi': model_archi
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler,
'speaker_number': speaker_number,
'model_archi': model_archi
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
#writer.close()
for ii in range(torch.cuda.device_count()):
......@@ -1017,7 +1027,7 @@ def xtrain(speaker_number,
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interval, device, clipping=False, tb_writer=None):
def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interval, device, clipping=False):
"""
:param model:
......@@ -1030,7 +1040,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
:return:
"""
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='none')
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
if isinstance(model, Xtractor):
loss_criteria = model.loss
......@@ -1039,11 +1049,10 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
accuracy = 0.0
running_loss = 0.0
for batch_idx, (data, target, weights) in enumerate(training_loader):
for batch_idx, (data, target) in enumerate(training_loader):
data = data.squeeze().to(device)
target = target.squeeze()
target = target.to(device)
weights = weights.squeeze().to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=False):
......@@ -1052,7 +1061,6 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
else:
output = model(data, target=None)
#with GuruMeditation():
loss = criterion(output, target)
if not torch.isnan(loss):
loss.backward()
......@@ -1142,7 +1150,6 @@ def cross_validation(model, validation_loader, device, validation_shape):
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
equal_error_rate = eer(negatives, positives)
#pmiss, pfa = rocch(numpy.array(positives).astype(numpy.double), numpy.array(negatives).astype(numpy.double))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size), equal_error_rate
......@@ -1494,10 +1501,10 @@ def extract_embeddings(idmap_name,
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in tqdm.tqdm(enumerate(dataloader)):
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
preds, vec = model(data.to(device), is_eval=True, extract_after_pooling=extract_after_pooling)
preds, vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......@@ -1588,10 +1595,10 @@ def extract_embeddings_per_speaker(idmap_name,
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in tqdm.tqdm(enumerate(dataloader)):
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
vec = model(data.to(device), is_eval=True, extract_after_pooling=extract_after_pooling)
vec = model(data.to(device), is_eval=True)
#if model.loss == "aam":
# vec = vec[1]
embeddings.stat1[idx, :] = vec.detach().cpu()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment