Commit d23472b2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Add SideSampler

parent f5175cba
......@@ -518,12 +518,89 @@ class SpkSet(Dataset):
write_batch(batch_idx, data, target, batch_fn_format)
class SideSampler(torch.utils.data.Sampler):
def __init__(self, data_source, spk_count, examples_per_speaker, samples_per_speaker, batch_size):
"""[summary]
Args:
data_source ([type]): [description]
spk_count ([type]): [description]
examples_per_speaker ([type]): [description]
samples_per_speaker ([type]): [description]
batch_size ([type]): [description]
"""
self.train_sessions = data_source
self.labels_to_indices = dict()
self.spk_count = spk_count
self.examples_per_speaker = examples_per_speaker
self.samples_per_speaker = samples_per_speaker
self.batch_size = batch_size
# reference all segment indexes per speaker
for idx in range(self.spk_count):
self.labels_to_indices[idx] = list()
for idx, value in enumerate(self.train_sessions):
self.labels_to_indices[value].append(idx)
# suffle segments per speaker
for ldlist in self.labels_to_indices.values():
random.shuffle(ldlist)
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self):
# Generate batches per speaker
straight = numpy.arange(self.spk_count)
indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
batch_cursor = 0
# each line of "indices" represents all speaker indexes (shuffled in a different way)
for idx in range(self.samples_per_speaker):
if batch_cursor == 0:
indices[idx, :] = numpy.random.permutation(straight)
else:
# if one batch is split between the end of previous line and the beginning of current line
# we make sure no speaker is present twice in this batch
probs = numpy.ones_like(straight)
probs[indices[idx-1, -batch_cursor:]] = 0
probs = probs/numpy.sum(probs)
indices[idx, :self.batch_size - batch_cursor] = numpy.random.choice(self.spk_count, self.batch_size - batch_cursor, replace=False, p=probs)
probs = numpy.ones_like(straight)
probs[indices[idx, :self.batch_size - batch_cursor]] = 0
to_pick = numpy.sum(probs).astype(numpy.int)
probs = probs/numpy.sum(probs)
indices[idx, self.batch_size - batch_cursor:] = numpy.random.choice(self.spk_count, to_pick, replace=False, p=probs)
assert numpy.sum(indices[idx, :]) == numpy.sum(straight)
batch_cursor = (batch_cursor + indices.shape[1]) % self.batch_size
# now we have the speaker indexes to sample in batches
batch_matrix = numpy.repeat(indices, self.examples_per_speaker, axis=1).flatten()
# we want to convert the speaker indexes into segment indexes
self.index_iterator = numpy.zeros_like(batch_matrix)
# keep track of next segment index to sample for each speaker
for idx, value in enumerate(batch_matrix):
if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
random.shuffle(self.labels_to_indices[value])
self.segment_cursors[value] = 0
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1
return iter(self.index_iterator)
def __len__(self) -> int:
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker)
class SideSet(Dataset):
def __init__(self,
data_set_yaml,
set_type="train",
chunk_per_segment=1,
transform_number=1,
overlap=0.,
dataset_df=None,
min_duration=0.165,
......@@ -546,6 +623,9 @@ class SideSet(Dataset):
self.min_duration = min_duration
self.output_format = output_format
self.transform_number = transform_number
self.noise_root_db = dataset["train"]["transformation"]["noise_root_db"]
if set_type == "train":
self.duration = dataset["train"]["duration"]
self.transformation = dataset["train"]["transformation"]
......@@ -607,69 +687,19 @@ class SideSet(Dataset):
self.len = len(self.sessions)
_transform = []
self.transform = []
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
trans = self.transformation["pipeline"].split(',')
self.add_noise = numpy.zeros(self.len, dtype=bool)
self.add_reverb = numpy.zeros(self.len, dtype=bool)
self.spec_aug = numpy.zeros(self.len, dtype=bool)
self.temp_aug = numpy.zeros(self.len, dtype=bool)
for t in trans:
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'add_noise' in t:
self.add_noise[:int(self.len * self.transformation["noise_file_ratio"])] = 1
numpy.random.shuffle(self.add_noise)
_transform.append(AddNoise(noise_db_csv=self.transformation["noise_db_csv"],
snr_min_max=self.transformation["noise_snr"],
noise_root_path=self.transformation["noise_root_db"]))
if 'add_reverb' in t:
has_pyroom = True
try:
import pyroomacoustics
except ImportError:
has_pyroom = False
if has_pyroom:
self.add_reverb[:int(self.len * self.transformation["reverb_file_ratio"])] = 1
numpy.random.shuffle(self.add_reverb)
_transform.append(AddReverb(depth=self.transformation["reverb_depth"],
width=self.transformation["reverb_width"],
height=self.transformation["reverb_height"],
absorption=self.transformation["reverb_absorption"],
noise=None,
snr=self.transformation["reverb_snr"]))
if 'MFCC' in t:
_transform.append(MFCC(lowfreq=self.lowfreq,
maxfreq=self.maxfreq,
nlogfilt=self.mfcc_nbfilter,
nceps=self.mfcc_nceps,
n_fft=self.n_fft))
if "CMVN" in t:
_transform.append(CMVN())
if "FrequencyMask" in t:
# Setup temporal and spectral augmentation if any
self.spec_aug[:int(self.len * self.transformation["spec_aug"])] = 1
numpy.random.shuffle(self.spec_aug)
a = int(t.split('-')[0].split('(')[1])
b = int(t.split('-')[1].split(')')[0])
_transform.append(FrequencyMask(a, b))
if "TemporalMask" in t:
self.temp_aug[:int(self.len * self.transformation["temp_aug"])] = 1
numpy.random.shuffle(self.temp_aug)
self.transform = self.transformation["pipeline"].split(',')
a = int(t.split("(")[1].split(")")[0])
_transform.append(TemporalMask(a))
if "add_noise" in self.transform:
# Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["noise_db_csv"])
tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
self.transforms = transforms.Compose(_transform)
if "add_reverb" in self.transform:
# load the RIR database
pass
def __getitem__(self, index):
"""
......@@ -685,29 +715,49 @@ class SideSet(Dataset):
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
stop_frame = start_frame + self.sample_number
sig, _ = soundfile.read(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
start=start_frame,
stop=stop_frame,
dtype=wav_type
)
sig = sig.astype(numpy.float32)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
frame_offset=start_frame,
num_frames=self.sample_number)
speaker_idx = current_session["speaker_idx"]
speech += 10e-6 * (torch.rand(speech.shape[0]) - 0.5)
if self.transformation["pipeline"]:
sig, speaker_idx, _, __, _t, _s = self.transforms((sig,
speaker_idx,
self.spec_aug[index],
self.temp_aug[index],
self.add_noise[index],
self.add_reverb[index]
))
if len(self.transform) > 0:
# Select the data augmentation randomly
aug_idx = numpy.random.randint(0, len(self.transform), self.transform_number)
augmentations = list(numpy.array(self.transform)[aug_idx])
if "add_noise" in augmentations:
# Pick a SNR level
snr_db = random.choice(self.transformation["noise_snr"])
# Pick a file name from the noise_df
noise_fn = self.noise_root_db + "/" + random.choice(self.noise_df) + ".wav"
noise, noise_fs = torchaudio.load(noise_fn,
frame_offset=0,
num_frames=speech.shape[1])
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
if "add_reverb" in augmentations:
pass
if "codec" in augmentations:
pass
if "filter" in augmentations:
pass
speaker_idx = current_session["speaker_idx"]
if self.output_format == "pytorch":
return torch.tensor(sig).type(torch.FloatTensor), torch.tensor(speaker_idx)
return speech, torch.tensor(speaker_idx)
else:
return sig.astype(numpy.float32), speaker_idx
return speech, speaker_idx
def __len__(self):
"""
......
......@@ -48,6 +48,7 @@ from .xsets import FileSet
from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .xsets import SpkSet
from .xsets import SideSampler
from .res_net import RawPreprocessor, ResBlockWFMS, ResBlock, PreResNet34, PreFastResNet34
from ..bosaris import IdMap
from ..bosaris import Key
......@@ -182,7 +183,7 @@ def test_metrics(model,
idmap_test_filename = 'h5f/idmap_test.h5'
ndx_test_filename = 'h5f/ndx_test.h5'
key_test_filename = 'h5f/key_test.h5'
data_root_name='/lium/corpus/base/voxceleb1/test/wav'
data_root_name='/data/larcher/voxceleb1/test/wav'
transform_pipeline = dict()
#mfcc_config = dict()
......@@ -501,7 +502,7 @@ class Xtractor(torch.nn.Module):
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
out_features = 256)
out_features = self.embedding_size)
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
......@@ -987,12 +988,21 @@ def xtrain(speaker_number,
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SpkSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
overlap=dataset_params['train']['overlap'],
output_format=output_format,
windowed=True)
#training_set = SpkSet(dataset_yaml,
# set_type="train",
# dataset_df=training_df,
# overlap=dataset_params['train']['overlap'],
# output_format=output_format,
# windowed=True)
training_set = SideSet(dataset_yaml,
set_type="train",
chunk_per_segment=-1,
overlap=dataset_params['train']['overlap'],
dataset_df=training_df,
output_format=output_format,
)
validation_set = SideSet(dataset_yaml,
set_type="validation",
......@@ -1014,11 +1024,19 @@ def xtrain(speaker_number,
batch_size = dataset_params["batch_size"]
print(f"Size of batches = {batch_size}")
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
1,
100,
batch_size)
training_loader = DataLoader(training_set,
batch_size=batch_size,
shuffle=True,
shuffle=False,
drop_last=True,
pin_memory=True,
sampler=side_sampler,
num_workers=num_thread,
persistent_workers=True)
......@@ -1201,7 +1219,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
accuracy += (torch.argmax(output.data, 1) == target).sum()
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
logging.critical('{}, Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.7f}'.format(
logging.critical('{}, Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
time.strftime('%H:%M:%S', time.localtime()),
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment