Commit 94b6f695 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

AAM loss

parent e32dd028
......@@ -56,6 +56,7 @@ def crop(signal, duration):
return chunk
class AddNoise(object):
"""
......@@ -97,7 +98,7 @@ class AddNoise(object):
# Load noise from file
if not fs == self.sample_rate:
print("Problem") # todo
print("Problem") # todo
duration = noise_signal.shape[0]
......@@ -113,7 +114,7 @@ class AddNoise(object):
# Todo Downsample if needed
# if sample_rate > fs:
#
noise = normalize(noise)
noises.append(noise.squeeze())
......@@ -128,6 +129,76 @@ class AddNoise(object):
return data.squeeze(), sample[1], sample[2], sample[3], sample[4], sample[5]
class AddNoiseFromSilence(object):
"""
"""
def __init__(self, noise_db_csv, snr_min_max, noise_root_path, sample_rate=16000):
"""
"""
self.snr_min = snr_min_max[0]
self.snr_max = snr_min_max[1]
self.noise_root_path = noise_root_path
self.sample_rate = sample_rate
df = pandas.read_csv(noise_db_csv)
self.noises = []
for index, row in df.iterrows():
self.noises.append(Noise(type=row["type"], file_id=row["file_id"], duration=row["duration"]))
def __call__(self, sample):
"""
:param original:
:param sample_rate:
:return:
"""
data = sample[0]
if sample[4]:
original_duration = len(data)
# accumulate enough noise to cover duration of original waveform
noises = []
left = original_duration
while left > 0:
# select noise file at random
file = random.choice(self.noises)
noise_signal, fs = soundfile.read(self.noise_root_path + "/" + file.file_id + ".wav")
# Load noise from file
if not fs == self.sample_rate:
print("Problem") # todo
duration = noise_signal.shape[0]
# if noise file is longer than what is needed, crop it
if duration > left:
noise = crop(noise_signal, left)
left = 0
# otherwise, take the whole file
else:
noise = noise_signal
left -= duration
# Todo Downsample if needed
# if sample_rate > fs:
#
noise = normalize(noise)
noises.append(noise.squeeze())
# concatenate
noise = numpy.hstack(noises)
# select SNR at random
snr = (self.snr_max - self.snr_min) * numpy.random.random_sample() + self.snr_min
alpha = numpy.exp(-numpy.log(10) * snr / 20)
data = normalize(data) + alpha * noise
return data.squeeze(), sample[1], sample[2], sample[3], sample[4], sample[5]
class AddReverb(object):
......
......@@ -37,21 +37,20 @@ import shutil
import time
import torch
import torch.optim as optim
import torch.multiprocessing as mp
import yaml
from torchvision import transforms
from collections import OrderedDict
from .xsets import XvectorMultiDataset, StatDataset, VoxDataset, SideSet
from .xsets import SideSet
from .xsets import IdMapSet
from .xsets import FrequencyMask, CMVN, TemporalMask, MFCC
from .res_net import RawPreprocessor, ResBlockWFMS
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .sincnet import SincNet, SincConv1d
from .sincnet import SincNet
#from torch.utils.tensorboard import SummaryWriter
from .loss import ArcLinear
import tqdm
......@@ -241,7 +240,13 @@ class Xtractor(torch.nn.Module):
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, speaker_number, model_archi="xvector", norm_embedding=False):
def __init__(self,
speaker_number,
model_archi="xvector",
loss="cce",
norm_embedding=False,
aam_margin=0.5,
aam_s=0.5):
"""
If config is None, default architecture is created
:param model_archi:
......@@ -251,6 +256,10 @@ class Xtractor(torch.nn.Module):
self.feature_size = None
self.norm_embedding = norm_embedding
self.loss = loss
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
if model_archi == "xvector":
self.feature_size = 30
self.activation = torch.nn.LeakyReLU(0.2)
......@@ -281,15 +290,20 @@ class Xtractor(torch.nn.Module):
("linear6", torch.nn.Linear(3072, 512))
]))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("dropout6", torch.nn.Dropout(p=0.05)),
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
if self.loss == "aam":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("arclinear8", ArcLinear(512, int(self.speaker_number), margin=aam_margin, s=aam_s))
]))
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("dropout6", torch.nn.Dropout(p=0.05)),
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
......@@ -299,7 +313,7 @@ class Xtractor(torch.nn.Module):
filts = [128, [128, 128], [128, 256], [256, 256]]
self.norm_embedding = True
self.preprocessor = RawPreprocessor(nb_samp=32000,
self.preprocessor = RawPreprocessor(nb_samp=48000,
in_channels=1,
out_channels=filts[0],
kernel_size=3)
......@@ -320,9 +334,15 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding = torch.nn.Linear(in_features = 1024,
out_features = 1024)
self.after_speaker_embedding = torch.nn.Linear(in_features = 1024,
out_features = int(self.speaker_number),
bias = True)
if self.loss == "aam":
if loss == 'aam':
self.after_speaker_embedding = ArcLinear(1024,
int(self.speaker_number),
margin=aam_margin, s=aam_s)
elif self.loss == "cce"
self.after_speaker_embedding = torch.nn.Linear(in_features = 1024,
out_features = int(self.speaker_number),
bias = True)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
......@@ -387,7 +407,12 @@ class Xtractor(torch.nn.Module):
# Create sequential object for the first part of the network
segmental_layers = []
for k in cfg["segmental"].keys():
if k.startswith("conv"):
if k.startswith("lin"):
segmental_layers.append((k, torch.nn.Linear(input_size,
cfg["segmental"][k]["output"])))
input_size = cfg["segmental"][k]["output"]
elif k.startswith("conv"):
segmental_layers.append((k, torch.nn.Conv1d(input_size,
cfg["segmental"][k]["output_channels"],
kernel_size=cfg["segmental"][k]["kernel_size"],
......@@ -450,6 +475,18 @@ class Xtractor(torch.nn.Module):
cfg["after_embedding"][k]["output"])))
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith('arc'):
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append(
(k, ArcLinear(input_size, self.speaker_number, margin=aam_margin, s=aam_s)))
else:
after_embedding_layers.append(
(k, ArcLinear(input_size,
self.speaker_number,
margin=aam_margin,
s=aam_s)))
input_size = self.speaker_number
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
......@@ -463,7 +500,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
def forward(self, x, is_eval=False):
def forward(self, x, is_eval=False, target=None):
"""
:param x:
......@@ -477,9 +514,6 @@ class Xtractor(torch.nn.Module):
x = self.sequence_network(x)
# Mean and Standard deviation pooling
#mean = torch.mean(x, dim=2)
#std = torch.std(x, dim=2)
#x = torch.cat([mean, std], dim=1)
x = self.stat_pooling(x)
x = self.before_speaker_embedding(x)
......@@ -490,7 +524,14 @@ class Xtractor(torch.nn.Module):
x_norm = x.norm(p=2,dim=1, keepdim=True) / 10.
x = torch.div(x, x_norm)
x = self.after_speaker_embedding(x)
if self.loss == "cce":
x = self.after_speaker_embedding(x)
elif self.loss == "aam":
if not is_eval:
x = self.after_speaker_embedding(x,target=target)
else:
x = self.after_speaker_embedding(x, target=None)
return x
......@@ -500,6 +541,10 @@ def xtrain(speaker_number,
lr=0.01,
model_yaml=None,
model_name=None,
loss="cce",
aam_margin=0.5,
aam_s=30,
patience=10,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
......@@ -516,10 +561,17 @@ def xtrain(speaker_number,
:param lr:
:param model_yaml:
:param model_name:
:param loss:
:param aam_margin:
:param aam_s:
:param patience:
:param tmp_model_name:
:param best_model_name:
:param multi_gpu:
:param clipping:
:param opt:
:param reset_parts:
:param freeze_parts:
:param num_thread:
:return:
"""
......@@ -528,10 +580,10 @@ def xtrain(speaker_number,
#writer = SummaryWriter("runs/xvectors_experiments_2")
writer = None
t= time.localtime()
logging.critical(f"Start process at {time.strftime('%H:%M:%S', t)}")
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None:
# Initialize a first model
......@@ -613,15 +665,15 @@ def xtrain(speaker_number,
"""
Set the training options
"""
if opt == 'sgd':
_optimizer = torch.optim.SGD
_options = {'lr': lr, 'momentum': 0.9}
elif opt == 'adam':
if opt == 'adam':
_optimizer = torch.optim.Adam
_options = {'lr': lr}
elif opt == 'rmsprop':
_optimizer = torch.optim.RMSprop
_options = {'lr': lr}
else: # opt == 'sgd'
_optimizer = torch.optim.SGD
_options = {'lr': lr, 'momentum': 0.9}
params = [
{
......@@ -637,23 +689,47 @@ def xtrain(speaker_number,
},
]
#optimizer = torch.optim.Adam(params,
# lr=0.001,
# weight_decay=0.0001,
# amsgrad=1)
optimizer = torch.optim.SGD(params,
lr=lr,
momentum=0.9,
weight_decay=0.0005)
print(f"Learning rate = {lr}")
if type(model) is Xtractor:
optimizer = _optimizer([
{'params': model.preprocessor.parameters(),
'weight_decay': model.preprocessor_weight_decay},
{'params': model.sequence_network.parameters(),
'weight_decay': model.sequence_network_weight_decay},
{'params': model.stat_pooling.parameters(),
'weight_decay': model.stat_pooling_weight_decay},
{'params': model.before_speaker_embedding.parameters(),
'weight_decay': model.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(),
'weight_decay': model.after_speaker_embedding_weight_decay}],
**_options
)
else:
optimizer = _optimizer([
{'params': model.module.sequence_network.parameters(),
'weight_decay': model.module.sequence_network_weight_decay},
{'params': model.module.before_speaker_embedding.parameters(),
'weight_decay': model.module.before_speaker_embedding_weight_decay},
{'params': model.module.after_speaker_embedding.parameters(),
'weight_decay': model.module.after_speaker_embedding_weight_decay}],
**_options
)
#optimizer = torch.optim.SGD(params,
# lr=lr,
# momentum=0.9,
# weight_decay=0.0005)
#print(f"Learning rate = {lr}")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_accuracy = 0.0
best_accuracy_epoch = 1
curr_patience = patience
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
if curr_patience == 0:
print(f"Stopping at epoch {epoch} for cause of patience")
break
model = train_epoch(model,
epoch,
training_loader,
......@@ -665,8 +741,7 @@ def xtrain(speaker_number,
# Add the cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
t= time.localtime()
logging.critical(f"***{time.strftime('%H:%M:%S', t)} Cross validation accuracy = {accuracy} %")
logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Cross validation accuracy = {accuracy} %")
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
......@@ -694,6 +769,9 @@ def xtrain(speaker_number,
if is_best:
best_accuracy_epoch = epoch
curr_patience = patience
else:
curr_patience -= 1
#writer.close()
for ii in range(torch.cuda.device_count()):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment