Commit d89161f4 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent f391adb6
...@@ -167,7 +167,6 @@ if CUDA: ...@@ -167,7 +167,6 @@ if CUDA:
from .nnet import Xtractor from .nnet import Xtractor
from .nnet import xtrain from .nnet import xtrain
from .nnet import extract_embeddings from .nnet import extract_embeddings
from .nnet import extract_sliding_embedding
from .nnet import ResBlock from .nnet import ResBlock
from .nnet import SincNet from .nnet import SincNet
......
...@@ -36,7 +36,6 @@ from .xsets import SideSampler ...@@ -36,7 +36,6 @@ from .xsets import SideSampler
from .xvector import Xtractor from .xvector import Xtractor
from .xvector import xtrain from .xvector import xtrain
from .xvector import extract_embeddings from .xvector import extract_embeddings
from .xvector import extract_sliding_embedding
from .pooling import MeanStdPooling from .pooling import MeanStdPooling
from .pooling import AttentivePooling from .pooling import AttentivePooling
from .pooling import GruPooling from .pooling import GruPooling
......
...@@ -162,9 +162,9 @@ class MelSpecFrontEnd(torch.nn.Module): ...@@ -162,9 +162,9 @@ class MelSpecFrontEnd(torch.nn.Module):
n_fft=1024, n_fft=1024,
f_min=90, f_min=90,
f_max=7600, f_max=7600,
win_length=1024, win_length=400,
window_fn=torch.hann_window, window_fn=torch.hann_window,
hop_length=256, hop_length=160,
power=2.0, power=2.0,
n_mels=80): n_mels=80):
......
...@@ -71,6 +71,8 @@ from .loss import AngularProximityMagnet ...@@ -71,6 +71,8 @@ from .loss import AngularProximityMagnet
os.environ['MKL_THREADING_LAYER'] = 'GNU' os.environ['MKL_THREADING_LAYER'] = 'GNU'
#torch.backends.cudnn.benchmark = True
__license__ = "LGPL" __license__ = "LGPL"
__author__ = "Anthony Larcher" __author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2021 Anthony Larcher" __copyright__ = "Copyright 2015-2021 Anthony Larcher"
...@@ -457,8 +459,7 @@ class Xtractor(torch.nn.Module): ...@@ -457,8 +459,7 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding = torch.nn.Linear(in_features=5120, self.before_speaker_embedding = torch.nn.Linear(in_features=5120,
out_features=self.embedding_size) out_features=self.embedding_size)
self.stat_pooling = MeanStdPooling() self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.stat_pooling_weight_decay = 0
self.loss = "aam" self.loss = "aam"
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size, self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
...@@ -467,11 +468,11 @@ class Xtractor(torch.nn.Module): ...@@ -467,11 +468,11 @@ class Xtractor(torch.nn.Module):
m = 0.20, m = 0.20,
easy_margin = False) easy_margin = False)
self.preprocessor_weight_decay = 0.000 self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.000 self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.000 self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00 self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.00 self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "fastresnet34": elif model_archi == "fastresnet34":
self.preprocessor = MelSpecFrontEnd() self.preprocessor = MelSpecFrontEnd()
...@@ -505,8 +506,8 @@ class Xtractor(torch.nn.Module): ...@@ -505,8 +506,8 @@ class Xtractor(torch.nn.Module):
elif model_archi == "halfresnet34": elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd(n_fft=1024, self.preprocessor = MelSpecFrontEnd(n_fft=1024,
win_length=1024, win_length=400,
hop_length=256, hop_length=160,
n_mels=80) n_mels=80)
self.sequence_network = PreHalfResNet34() self.sequence_network = PreHalfResNet34()
self.embedding_size = 256 self.embedding_size = 256
...@@ -766,7 +767,7 @@ class Xtractor(torch.nn.Module): ...@@ -766,7 +767,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"] self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
def forward(self, x, is_eval=False, target=None, extract_after_pooling=False): def forward(self, x, is_eval=False, target=None, norm_embedding=True):
""" """
:param x: :param x:
...@@ -781,12 +782,9 @@ class Xtractor(torch.nn.Module): ...@@ -781,12 +782,9 @@ class Xtractor(torch.nn.Module):
# Mean and Standard deviation pooling # Mean and Standard deviation pooling
x = self.stat_pooling(x) x = self.stat_pooling(x)
if extract_after_pooling:
return x
x = self.before_speaker_embedding(x) x = self.before_speaker_embedding(x)
if self.norm_embedding: if norm_embedding:
x = l2_norm(x) x = l2_norm(x)
if self.loss == "cce": if self.loss == "cce":
...@@ -1199,7 +1197,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader): ...@@ -1199,7 +1197,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer, scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8, base_lr=1e-8,
max_lr=train_opts["lr"], max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 2, step_size_up=model_opts["speaker_number"] * 8,
step_size_down=None, step_size_down=None,
cycle_momentum=cycle_momentum, cycle_momentum=cycle_momentum,
mode="triangular2") mode="triangular2")
...@@ -1649,7 +1647,8 @@ def extract_embeddings(idmap_name, ...@@ -1649,7 +1647,8 @@ def extract_embeddings(idmap_name,
win_shift=1.5, win_shift=1.5,
num_thread=1, num_thread=1,
sample_rate=16000, sample_rate=16000,
mixed_precision=False): mixed_precision=False,
norm_embeddings=True):
""" """
:param idmap_name: :param idmap_name:
...@@ -1721,7 +1720,7 @@ def extract_embeddings(idmap_name, ...@@ -1721,7 +1720,7 @@ def extract_embeddings(idmap_name,
with torch.cuda.amp.autocast(enabled=mixed_precision): with torch.cuda.amp.autocast(enabled=mixed_precision):
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100))) tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
for td in tmp_data: for td in tmp_data:
_, vec = model(x=td.to(device), is_eval=True) _, vec = model(x=td.to(device), is_eval=True, norm_embedding=norm_embeddings)
embed.append(vec.detach().cpu()) embed.append(vec.detach().cpu())
modelset.extend(mod * data.shape[0]) modelset.extend(mod * data.shape[0])
...@@ -1773,7 +1772,7 @@ def extract_embeddings_per_speaker(idmap_name, ...@@ -1773,7 +1772,7 @@ def extract_embeddings_per_speaker(idmap_name,
file_extension=file_extension, file_extension=file_extension,
transform_pipeline=transform_pipeline, transform_pipeline=transform_pipeline,
frame_rate=sample_rate, frame_rate=sample_rate,
min_duration=(model.context_size() + 2) * frame_shift * 2) min_duration=1.)
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=1, batch_size=1,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment