Commit 7ce0edcc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

resnet34

parent 5c37ba16
......@@ -165,9 +165,6 @@ if SIDEKIT_CONFIG["cuda"]:
if CUDA:
from .nnet import FForwardNetwork
from .nnet import kaldi_to_hdf5
from .nnet import XvectorMultiDataset
from .nnet import XvectorDataset
from .nnet import StatDataset
from .nnet import Xtractor
from .nnet import xtrain
from .nnet import extract_embeddings
......
......@@ -31,7 +31,7 @@ Copyright 2014-2021 Anthony Larcher and Sylvain Meignier
from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset, IdMapSet_per_speaker
from .xsets import IdMapSet_per_speaker
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling
from .res_net import ResBlock, ResNet18, PreResNet34
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
......
......@@ -35,7 +35,6 @@ import torch
import torch.optim as optim
import torch.multiprocessing as mp
from collections import OrderedDict
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.nn import Parameter
......
......@@ -35,7 +35,6 @@ import torch.optim as optim
import torch.multiprocessing as mp
from torchvision import transforms
from collections import OrderedDict
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset, VoxDataset
from .xsets import FrequencyMask, CMVN, TemporalMask
from .sincnet import SincNet, SincConv1d
from ..bosaris import IdMap
......@@ -534,6 +533,7 @@ class PreResNet34(torch.nn.Module):
return torch.nn.Sequential(*layers)
def forward(self, x):
x = x.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
......
......@@ -565,8 +565,9 @@ class StatServerSet(Dataset):
:param input_filename:
"""
self.len = 0
self.Statserver = StatServer(input_filename)
self.data =
#self.data =
......@@ -576,7 +577,7 @@ class StatServerSet(Dataset):
:param index:
:return:
"""
return
return 0
def __len__(self):
"""
......@@ -696,7 +697,12 @@ class IdMapSet_per_speaker(Dataset):
file_extension,
transform_pipeline=None,
frame_rate=100,
min_duration=0.165
min_duration=0.165,
nb_filters=30,
nb_ceps=30,
lowfreq=133.333,
maxfreq=6855.4976,
n_fft=512
):
"""
......@@ -720,7 +726,11 @@ class IdMapSet_per_speaker(Dataset):
self.output_im.rightids = self.output_im.leftids
self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.mfcc_nbfilter = nb_filters
self.mfcc_nceps = nb_ceps
self.lowfreq = lowfreq
self.maxfreq = maxfreq
self.n_fft = n_fft
_transform = []
......@@ -730,7 +740,11 @@ class IdMapSet_per_speaker(Dataset):
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'MFCC' in t:
_transform.append(MFCC())
_transform.append(MFCC(lowfreq=self.lowfreq,
maxfreq=self.maxfreq,
nlogfilt=self.mfcc_nbfilter,
nceps=self.mfcc_nceps,
n_fft=self.n_fft))
if "CMVN" in t:
_transform.append(CMVN())
if 'add_noise' in t:
......
......@@ -381,9 +381,11 @@ class Xtractor(torch.nn.Module):
self.embedding_size = 256
self.loss = "aam"
self.after_speaker_embedding = ArcLinear(256,
int(self.speaker_number),
margin=aam_margin, s=aam_s)
self.after_speaker_embedding = ArcMarginProduct(256,
int(self.speaker_number),
m=aam_margin,
s=aam_s,
easy_margin=True)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
......@@ -930,31 +932,27 @@ def xtrain(speaker_number,
param_list.append({'params': model.module.before_speaker_embedding.parameters(), 'weight_decay': model.module.before_speaker_embedding_weight_decay})
param_list.append({'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay})
optimizer = _optimizer(param_list, **_options)
#optimizer = torch.optim.SGD(params,
# lr=lr,
# momentum=0.9,
# weight_decay=0.0005)
#print(f"Learning rate = {lr}")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[50000, 60000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000, 150000],
gamma=0.1,
last_epoch=-1,
verbose=True)
best_accuracy = 0.0
best_accuracy_epoch = 1
curr_patience = patience
logging.critical("Compute EER before starting")
val_acc, val_loss, val_eer, test_eer = compute_metrics(model,
validation_loader,
device,
[validation_set.__len__(), embedding_size],
speaker_number,
model_archi)
#val_acc, val_loss, val_eer, test_eer = compute_metrics(model,
# validation_loader,
# device,
# [validation_set.__len__(), embedding_size],
# speaker_number,
# model_archi)
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
#logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Initial metrics - Test EER = {test_eer * 100} %")
for epoch in range(1, epochs + 1):
......@@ -966,6 +964,7 @@ def xtrain(speaker_number,
epoch,
training_loader,
optimizer,
scheduler,
dataset_params["log_interval"],
device=device,
clipping=clipping,
......@@ -982,7 +981,7 @@ def xtrain(speaker_number,
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %")
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
#scheduler.step(val_loss)
# remember best accuracy and save checkpoint
is_best = val_acc > best_accuracy
......@@ -1026,7 +1025,7 @@ def xtrain(speaker_number,
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device, clipping=False, tb_writer=None):
def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interval, device, clipping=False, tb_writer=None):
"""
:param model:
......@@ -1061,7 +1060,6 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
else:
output = model(data, target=None)
#with GuruMeditation():
loss = torch.sum(weights * criterion(output, target))
if not torch.isnan(loss):
loss.backward()
......@@ -1093,6 +1091,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
import sys
sys.exit()
running_loss = 0.0
scheduler.step()
return model
......@@ -1282,8 +1281,7 @@ class XtractorTop(torch.nn.Module):
def xtraintop(model,
dataset,
batch_size,
epochs=None,
opt,
epochs,
lr=None,
patience=None,
tmp_model_name=None,
......@@ -1317,9 +1315,9 @@ def xtraintop(model,
embedding_size = model.embedding_size
# Create datasets
training_set =
training_set = None
validation_set =
validation_set = None
training_loader = DataLoader(training_set,
batch_size=batch_size,
......@@ -1415,11 +1413,12 @@ def xtraintop(model,
best_accuracy = max(val_acc, best_accuracy)
if type(model) is Xtractor:
pass
# TODO faire une fonction de sauvegarde
else:
# TODO faire une fonction de sauvegarde
pass
if is_best:
best_accuracy_epoch = epoch
......@@ -1522,6 +1521,11 @@ def extract_embeddings_per_speaker(idmap_name,
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
nb_filters=30,
nb_ceps=30,
lowfreq=133.333,
maxfreq=6855.4976,
n_fft=512,
num_thread=1):
# Load the model
if isinstance(model_filename, str):
......@@ -1530,11 +1534,16 @@ def extract_embeddings_per_speaker(idmap_name,
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
model_archi = checkpoint["model_archi"]
else:
model_archi = model_yaml
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
print(model)
if isinstance(idmap_name, IdMap):
idmap = idmap_name
else:
......@@ -1548,7 +1557,12 @@ def extract_embeddings_per_speaker(idmap_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
min_duration=(model.context_size() + 2) * frame_shift * 2
min_duration=(model.context_size() + 2) * frame_shift * 2,
nb_filters=nb_filters,
nb_ceps=nb_ceps,
lowfreq=lowfreq,
maxfreq=maxfreq,
n_fft=n_fft
)
dataloader = DataLoader(dataset,
......@@ -1565,7 +1579,10 @@ def extract_embeddings_per_speaker(idmap_name,
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
if extract_after_pooling:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
else:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = StatServer()
......@@ -1582,8 +1599,8 @@ def extract_embeddings_per_speaker(idmap_name,
if data.shape[1] > 20000000:
data = data[..., :20000000]
vec = model(data.to(device), is_eval=True, extract_after_pooling=extract_after_pooling)
if model.loss == "aam":
vec = vec[1]
#if model.loss == "aam":
# vec = vec[1]
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment