Commit 50d0d1f3 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

merge

parents 0fa85281 754a4f9c
......@@ -170,7 +170,6 @@ if CUDA:
from .nnet import extract_embeddings
from .nnet import extract_sliding_embedding
from .nnet import ResBlock
from .nnet import ResNet18
from .nnet import SincNet
else:
......
......@@ -31,11 +31,14 @@ Copyright 2014-2021 Anthony Larcher and Sylvain Meignier
from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import IdMapSet_per_speaker, SpkSet
from .xsets import IdMapSetPerSpeaker
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling
from .res_net import ResBlock, ResNet18, PreResNet34
from .res_net import ResBlock, PreResNet34
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
from .sincnet import SincNet
from .preprocessor import RawPreprocessor
from .preprocessor import MfccFrontEnd
from .preprocessor import MelSpecFrontEnd
has_pyroom = True
try:
......
......@@ -272,3 +272,38 @@ class ArcMarginProduct(torch.nn.Module):
output = output * self.s
return output
class SoftmaxAngularProto(torch.nn.Module):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def __init__(self, spk_count, init_w=10.0, init_b=-5.0, **kwargs):
super(SoftmaxAngularProto, self).__init__()
self.test_normalize = True
self.w = torch.nn.Parameter(torch.tensor(init_w))
self.b = torch.nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
self.cce_backend = torch.nn.Sequential(OrderedDict([
("linear8", torch.nn.Linear(256, spk_count))
]))
def forward(self, x, target=None):
assert x.size()[1] >= 2
cce_prediction = self.cce_backend(x)
if target==None:
return cce_prediction
x = x.reshape(-1,2,x.size()[-1]).squeeze(1)
out_anchor = torch.mean(x[:,1:,:],1)
out_positive = x[:,0,:]
cos_sim_matrix = torch.nn.functional.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
return cos_sim_matrix, cce_prediction
......@@ -42,12 +42,11 @@ import yaml
from collections import OrderedDict
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .augmentation import PreEmphasis
from .xsets import SideSet
from .xsets import FileSet
from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .xsets import IdMapSetPerSpeaker
from .xsets import SideSampler
from .res_net import RawPreprocessor
from .res_net import ResBlockWFMS
from .res_net import ResBlock
from .res_net import PreResNet34
......@@ -214,6 +213,8 @@ class MelSpecFrontEnd(torch.nn.Module):
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if x.dim() == 1:
x = x.unsqueeze(0)
out = self.PreEmphasis(x)
out = self.MelSpec(out)+1e-6
out = torch.log(out)
......@@ -222,7 +223,7 @@ class MelSpecFrontEnd(torch.nn.Module):
class (torch.nn.Module):
class RawPreprocessor(torch.nn.Module):
"""
"""
......
......@@ -36,8 +36,6 @@ import torch.optim as optim
import torch.multiprocessing as mp
from torchvision import transforms
from collections import OrderedDict
from .xsets import FrequencyMask, CMVN, TemporalMask
from .sincnet import SincNet, SincConv1d
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.utils.data import DataLoader
......
......@@ -69,7 +69,8 @@ class SideSampler(torch.utils.data.Sampler):
self.spk_count = spk_count
self.examples_per_speaker = examples_per_speaker
self.samples_per_speaker = samples_per_speaker
self.batch_size = batch_size
assert batch_size % examples_per_speaker == 0
self.batch_size = batch_size//examples_per_speaker
# reference all segment indexes per speaker
for idx in range(self.spk_count):
......@@ -220,16 +221,18 @@ class SideSet(Dataset):
self.sessions = pandas.DataFrame.from_dict(df_dict)
self.len = len(self.sessions)
self.transform = []
self.transform = dict()
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
transforms = self.transformation["pipeline"].split(',')
if "add_noise" in transforms:
self.transform["add_noise"] = self.transformation["add_noise"]
if "add_reverb" in transforms:
self.transform["add_reverb"] = self.transformation["add_reverb"]
self.noise_df = None
if "add_noise" in self.transform:
# Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
self.noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
self.rir_df = None
if "add_reverb" in self.transform:
......@@ -246,7 +249,7 @@ class SideSet(Dataset):
current_session = self.sessions.iloc[index]
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
start_frame = int(current_session['start'] * self.sample_rate)
start_frame = int(current_session['start'])
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......@@ -306,13 +309,15 @@ class IdMapSet(Dataset):
self.data_path = data_path
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
self.transform_pipeline = transform_pipeline
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.transform = []
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
if (len(self.transformation) > 0):
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
self.noise_df = None
if "add_noise" in self.transform:
......@@ -342,14 +347,16 @@ class IdMapSet(Dataset):
stop = len(speech)
else:
nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
conversion_rate = nfo.samplerate // self.sample_rate
start = int(self.idmap.start[index]) * conversion_rate
stop = int(self.idmap.stop[index]) * conversion_rate
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=stop - start)
......@@ -364,6 +371,8 @@ class IdMapSet(Dataset):
noise_df=self.noise_df,
rir_df=self.rir_df)
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
......@@ -403,7 +412,7 @@ class IdMapSetPerSpeaker(Dataset):
self.data_root_path = data_root_path
self.file_extension = file_extension
self.len = len(set(self.idmap.leftids))
self.transform_pipeline = transform_pipeline
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.speaker_list = list(set(self.idmap.leftids))
......@@ -414,8 +423,9 @@ class IdMapSetPerSpeaker(Dataset):
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.transform = []
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
if (len(self.transformation) > 0):
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
self.noise_df = None
if "add_noise" in self.transform:
......@@ -458,7 +468,7 @@ class IdMapSetPerSpeaker(Dataset):
tmp_data.append(speech)
speech = torch.cat(tmp_data, dim=0)
speech = torch.cat(tmp_data, dim=1)
speech += 10e-6 * torch.randn(speech.shape)
if len(self.transform) > 0:
......@@ -469,6 +479,8 @@ class IdMapSetPerSpeaker(Dataset):
noise_df=self.noise_df,
rir_df=self.rir_df)
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
def __len__(self):
......
......@@ -59,7 +59,7 @@ from ..bosaris import Ndx
from ..statserver import StatServer
from ..iv_scoring import cosine_scoring
from .sincnet import SincNet
from .loss import ArcLinear
from .loss import SoftmaxAngularProto, ArcLinear
from .loss import l2_norm
from .loss import ArcMarginProduct
......@@ -241,7 +241,8 @@ def test_metrics(model,
xv_stat,
Ndx(ndx_test_filename),
wccn=None,
check_missing=True)
check_missing=True,
device=device)
tar, non = scores.get_tar_non(Key(key_test_filename))
test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
......@@ -442,12 +443,16 @@ class Xtractor(torch.nn.Module):
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
self.loss = "aam"
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
......@@ -734,7 +739,7 @@ class Xtractor(torch.nn.Module):
else:
return self.after_speaker_embedding(x), x
elif self.loss == "aam":
elif self.loss in ['aam', 'aps']:
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
......@@ -814,7 +819,7 @@ def xtrain(speaker_number,
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
if model_name is None:
model = Xtractor(speaker_number, model_yaml)
model = Xtractor(speaker_number, model_yaml, loss=loss)
else:
logging.critical(f"*** Load model from = {model_name}")
......@@ -874,8 +879,8 @@ def xtrain(speaker_number,
else:
# Load the model
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, model_yaml)
checkpoint = torch.load(model_name, map_location=device)
model = Xtractor(speaker_number, model_yaml, loss=loss)
"""
Here we remove all layers that we don't want to reload
......@@ -902,6 +907,9 @@ def xtrain(speaker_number,
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
embedding_size = model.embedding_size
......@@ -1115,15 +1123,25 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
loss = criterion(output, target)
else:
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
loss = criterion(output, target)
if not torch.isnan(loss):
if scaler is not None:
scaler.scale(loss).backward()
......@@ -1189,8 +1207,10 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
target = target.squeeze().to(device)
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == "aam":
if loss_criteria == 'aam':
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
elif loss_criteria == 'aps':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
else:
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_embeddings = l2_norm(batch_embeddings)
......@@ -1200,15 +1220,9 @@ def cross_validation(model, validation_loader, device, validation_shape, mixed_p
embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu()
classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
if classes.shape[0] > 2e4:
local_device = "cpu"
else:
local_device = device
mask = ((torch.ger(classes.to(local_device).float() + 1,
(1 / (classes.to(local_device).float() + 1))) == 1).long() * 2 - 1).float().cpu()
mask = mask.numpy()
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu()
scores = scores.numpy()
#print(classes.shape[0])
local_device = "cpu" if embeddings.shape[0] > 3e4 else device
scores = torch.mm(embeddings.to(local_device), embeddings.to(local_device).T).cpu().numpy()
scores = scores[numpy.tril_indices(scores.shape[0], -1)]
mask = mask[numpy.tril_indices(mask.shape[0], -1)]
negatives = scores[numpy.argwhere(mask == -1)][:, 0].astype(float)
......@@ -1226,9 +1240,8 @@ def extract_embeddings(idmap_name,
model_filename,
data_root_name,
device,
model_yaml=None,
file_extension="wav",
transform_pipeline=None,
transform_pipeline={},
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
......@@ -1255,10 +1268,7 @@ def extract_embeddings(idmap_name,
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
model_archi = checkpoint["model_archi"]
else:
model_archi = model_yaml
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
......@@ -1269,18 +1279,18 @@ def extract_embeddings(idmap_name,
else:
idmap = IdMap(idmap_name)
if type(model) is Xtractor:
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
model_cs = model.context_size()
else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
if type(model) is Xtractor:
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
model_cs = model.context_size()
else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_root_path=data_root_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
......@@ -1339,7 +1349,6 @@ def extract_embeddings_per_speaker(idmap_name,
model_filename,
data_root_name,
device,
model_yaml=None,
file_extension="wav",
transform_pipeline=None,
frame_shift=0.01,
......@@ -1350,10 +1359,7 @@ def extract_embeddings_per_speaker(idmap_name,
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
if model_yaml is None:
model_archi = checkpoint["model_archi"]
else:
model_archi = model_yaml
model_archi = checkpoint["model_archi"]
model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
......@@ -1388,7 +1394,7 @@ def extract_embeddings_per_speaker(idmap_name,
if extract_after_pooling:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
else:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
emb_size = model.embedding_size
# Create the StatServer
embeddings = StatServer()
......@@ -1404,6 +1410,7 @@ def extract_embeddings_per_speaker(idmap_name,
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
print(f"Shape of data: {data.shape}")
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
......@@ -1413,13 +1420,12 @@ def extract_sliding_embedding(idmap_name,
window_length,
sample_rate,
overlap,
speaker_number,
model_filename,
model_yaml,
data_root_name ,
device,
file_extension="wav",
transform_pipeline=None):
transform_pipeline=None,
num_thread=1):
"""
:param idmap_name:
......@@ -1470,12 +1476,11 @@ def extract_sliding_embedding(idmap_name,
assert sliding_idmap.validate()
embeddings = extract_embeddings(sliding_idmap,
speaker_number,
model_filename,
model_yaml,
data_root_name,
device,
file_extension=file_extension,
transform_pipeline=transform_pipeline)
transform_pipeline=transform_pipeline,
num_thread=num_thread)
return embeddings
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment