Commit d89161f4 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent f391adb6
......@@ -167,7 +167,6 @@ if CUDA:
from .nnet import Xtractor
from .nnet import xtrain
from .nnet import extract_embeddings
from .nnet import extract_sliding_embedding
from .nnet import ResBlock
from .nnet import SincNet
......
......@@ -36,7 +36,6 @@ from .xsets import SideSampler
from .xvector import Xtractor
from .xvector import xtrain
from .xvector import extract_embeddings
from .xvector import extract_sliding_embedding
from .pooling import MeanStdPooling
from .pooling import AttentivePooling
from .pooling import GruPooling
......
......@@ -162,9 +162,9 @@ class MelSpecFrontEnd(torch.nn.Module):
n_fft=1024,
f_min=90,
f_max=7600,
win_length=1024,
win_length=400,
window_fn=torch.hann_window,
hop_length=256,
hop_length=160,
power=2.0,
n_mels=80):
......
......@@ -71,6 +71,8 @@ from .loss import AngularProximityMagnet
os.environ['MKL_THREADING_LAYER'] = 'GNU'
#torch.backends.cudnn.benchmark = True
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2021 Anthony Larcher"
......@@ -457,8 +459,7 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding = torch.nn.Linear(in_features=5120,
out_features=self.embedding_size)
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = "aam"
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
......@@ -467,11 +468,11 @@ class Xtractor(torch.nn.Module):
m = 0.20,
easy_margin = False)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "fastresnet34":
self.preprocessor = MelSpecFrontEnd()
......@@ -505,8 +506,8 @@ class Xtractor(torch.nn.Module):
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd(n_fft=1024,
win_length=1024,
hop_length=256,
win_length=400,
hop_length=160,
n_mels=80)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
......@@ -766,7 +767,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
def forward(self, x, is_eval=False, target=None, extract_after_pooling=False):
def forward(self, x, is_eval=False, target=None, norm_embedding=True):
"""
:param x:
......@@ -781,12 +782,9 @@ class Xtractor(torch.nn.Module):
# Mean and Standard deviation pooling
x = self.stat_pooling(x)
if extract_after_pooling:
return x
x = self.before_speaker_embedding(x)
if self.norm_embedding:
if norm_embedding:
x = l2_norm(x)
if self.loss == "cce":
......@@ -1199,7 +1197,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 2,
step_size_up=model_opts["speaker_number"] * 8,
step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular2")
......@@ -1649,7 +1647,8 @@ def extract_embeddings(idmap_name,
win_shift=1.5,
num_thread=1,
sample_rate=16000,
mixed_precision=False):
mixed_precision=False,
norm_embeddings=True):
"""
:param idmap_name:
......@@ -1721,7 +1720,7 @@ def extract_embeddings(idmap_name,
with torch.cuda.amp.autocast(enabled=mixed_precision):
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
for td in tmp_data:
_, vec = model(x=td.to(device), is_eval=True)
_, vec = model(x=td.to(device), is_eval=True, norm_embedding=norm_embeddings)
embed.append(vec.detach().cpu())
modelset.extend(mod * data.shape[0])
......@@ -1773,7 +1772,7 @@ def extract_embeddings_per_speaker(idmap_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=sample_rate,
min_duration=(model.context_size() + 2) * frame_shift * 2)
min_duration=1.)
dataloader = DataLoader(dataset,
batch_size=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment