Commit 827a0b75 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add noise in xvector idset

parent 63f9ccf6
......@@ -833,7 +833,8 @@ class FactorAnalyser:
scaling_factor=1.,
output_file_name=None,
save_partial=False,
save_final=True):
save_final=True,
num_thread=1):
"""
Train a simplified Probabilistic Linear Discriminant Analysis model (no within class covariance matrix
but full residual covariance matrix)
......@@ -904,7 +905,7 @@ class FactorAnalyser:
stat1=local_stat.stat1,
e_h=e_h,
e_hh=e_hh,
num_thread=1)
num_thread=num_thread)
# Accumulate for minimum divergence step
_R = numpy.sum(e_hh, axis=0) / session_per_model.shape[0]
......
......@@ -645,6 +645,13 @@ class IdMapSet(Dataset):
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(CMVN())
if 'add_noise' in t:
self.add_noise = numpy.ones(self.idmap.leftids.shape[0], dtype=bool)
numpy.random.shuffle(self.add_noise)
_transform.append(AddNoise(noise_db_csv="list/musan.csv",
snr_min_max=[5.0, 15.0],
noise_root_path="./data/musan/"))
self.transforms = transforms.Compose(_transform)
def __getitem__(self, index):
......
......@@ -929,7 +929,6 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
running_loss = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
data = data.squeeze().to(device)
print(f"Shape of data: {data.shape}")
target = target.squeeze()
target = target.to(device)
optimizer.zero_grad()
......
......@@ -36,6 +36,7 @@ import numpy
import os
import scipy
import sys
import tqdm
import warnings
from sidekit.bosaris import IdMap
......@@ -890,7 +891,7 @@ class StatServer:
unique_speaker = numpy.unique(self.modelset)
W = numpy.zeros((vect_size, vect_size))
for speakerID in unique_speaker:
for speakerID in tqdm.tqdm(unique_speaker):
spk_ctr_vec = self.get_model_stat1(speakerID) \
- numpy.mean(self.get_model_stat1(speakerID), axis=0)
W += numpy.dot(spk_ctr_vec.transpose(), spk_ctr_vec)
......@@ -1505,7 +1506,7 @@ class StatServer:
if save_partial:
sidekit.sidekit_io.write_fa_hdf5((mean, V, None, None, sigma),
save_partial + "_{}_between_class.h5".format(it))
"Partial_plda_{}_between_class.h5".format(it))
return V, sigma
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment