Commit 3a2b9935 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

SideSet refactor

- snr fix
- random offset for noise segment selection
- MelSpecFrontEnd defaults
- device parameter for xtrain()
- reduce tar/non count for validation (WIP)
parent 0d43b570
......@@ -488,12 +488,13 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
# Pick a noise type
noise = torch.zeros_like(speech)
noise_idx = random.randrange(3)
# speech
if noise_idx == 0:
# Pick a SNR level
# TODO make SNRs configurable by noise type
snr_db = random.randint(13, 20)
pick_count = random.randint(5, 10)
pick_count = random.randint(3, 7)
index_list = random.choices(range(noise_df.loc['speech'].shape[0]), k=pick_count)
for idx in index_list:
noise_row = noise_df.loc['speech'].iloc[idx]
......@@ -512,7 +513,7 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
snr = math.exp(snr_db / 10)
snr = 10 ** (snr_db / 20)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
......@@ -533,7 +534,12 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
noise_duration = noise_row['duration']
noise_file_id = noise_row['file_id']
frame_offset = noise_start * sample_rate
if noise_duration * sample_rate > speech_shape[1]:
# It is recommended to split noise files (especially speech noise type) in shorter subfiles
# When frame_offset is too high, loading the segment can take much longer
frame_offset = random.randrange(noise_start * sample_rate, int((noise_start + noise_duration) * sample_rate - speech_shape[1]))
else:
frame_offset = noise_start * sample_rate
noise_fn = data_path + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech_shape[1]:
......
......@@ -167,6 +167,7 @@ class SideSet(Dataset):
self.transformation = dataset["eval"]["transformation"]
self.sample_number = int(self.duration * self.sample_rate)
self.overlap = int(overlap * self.sample_rate)
# Load the dataset description as pandas.dataframe
if dataset_df is None:
......@@ -187,6 +188,8 @@ class SideSet(Dataset):
# Create lists for each column of the dataframe
df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
df_dict["file_start"] = list()
df_dict["file_duration"] = list()
# For each segment, get all possible segments with the current overlap
for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
......@@ -195,8 +198,8 @@ class SideSet(Dataset):
# Compute possible starts
possible_starts = numpy.arange(0,
int(self.sample_rate * (current_session.duration - self.duration)),
self.sample_number - int(self.sample_rate * overlap)
)
self.sample_number
) + int(self.sample_rate * (current_session.duration % self.duration / 2))
possible_starts += int(self.sample_rate * current_session.start)
# Select max(seg_nb, possible_segments) segments
......@@ -214,6 +217,8 @@ class SideSet(Dataset):
df_dict["file_id"].append(current_session.file_id)
df_dict["start"].append(starts[ii])
df_dict["duration"].append(self.duration)
df_dict["file_start"].append(current_session.start)
df_dict["file_duration"].append(current_session.duration)
df_dict["speaker_idx"].append(current_session.speaker_idx)
df_dict["gender"].append(current_session.gender)
......@@ -223,14 +228,15 @@ class SideSet(Dataset):
self.transform = dict()
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
transforms = self.transformation["pipeline"].split(',')
if "add_noise" in transforms:
if "add_noise" in transforms:
self.transform["add_noise"] = self.transformation["add_noise"]
if "add_reverb" in transforms:
if "add_reverb" in transforms:
self.transform["add_reverb"] = self.transformation["add_reverb"]
self.noise_df = None
if "add_noise" in self.transform:
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
noise_df = noise_df.loc[noise_df.duration > self.duration]
self.noise_df = noise_df.set_index(noise_df.type)
self.rir_df = None
......@@ -249,7 +255,15 @@ class SideSet(Dataset):
# TODO is this required ?
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
start_frame = int(current_session['start'])
original_start = int(current_session['start'])
lowest_shift = self.overlap/2
highest_shift = self.overlap/2
if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......
......@@ -300,7 +300,7 @@ class AttentivePooling(torch.nn.Module):
"""
Mean and Standard deviation attentive pooling
"""
def __init__(self, num_channels):
def __init__(self, num_channels, n_mels):
"""
"""
......@@ -308,11 +308,11 @@ class AttentivePooling(torch.nn.Module):
# TODO Make convolution parameters configurable
super(AttentivePooling, self).__init__()
self.attention = torch.nn.Sequential(
torch.nn.Conv1d(num_channels * 10, 128, kernel_size=1),
torch.nn.Conv1d(num_channels * (n_mels//8), num_channels//32, kernel_size=1),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(128),
torch.nn.BatchNorm1d(num_channels//32),
torch.nn.Tanh(),
torch.nn.Conv1d(128, num_channels * 10, kernel_size=1),
torch.nn.Conv1d(num_channels//32, num_channels * (n_mels//8), kernel_size=1),
torch.nn.Softmax(dim=2),
)
#self.global_context = MeanStdPooling()
......@@ -516,14 +516,15 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd()
self.preprocessor = MelSpecFrontEnd(n_fft=512, win_length=400, hop_length=160, n_mels=64)
#self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34()
self.embedding_size = 512
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
self.before_speaker_embedding = torch.nn.Linear(in_features = 4096,
out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256)
self.stat_pooling = AttentivePooling(256, 64)
self.stat_pooling_weight_decay = 0
self.loss = loss
......@@ -855,6 +856,7 @@ def xtrain(speaker_number,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
device=None,
mixed_precision=False,
clipping=False,
opt=None,
......@@ -896,7 +898,8 @@ def xtrain(speaker_number,
logging.critical(f"Use {num_thread} cpus")
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device == None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use a predefined architecture
if model_yaml in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]:
......@@ -1015,7 +1018,7 @@ def xtrain(speaker_number,
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"], stratify=df["speaker_idx"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"]) #, stratify=df["speaker_idx"])
torch.manual_seed(dataset_params['seed'])
......@@ -1029,6 +1032,7 @@ def xtrain(speaker_number,
validation_set = SideSet(dataset_yaml,
set_type="validation",
chunk_per_segment=2,
dataset_df=validation_df,
output_format="pytorch")
......@@ -1086,7 +1090,7 @@ def xtrain(speaker_number,
optimizer = _optimizer(param_list, **_options)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=12 * training_loader.__len__(),
step_size=5 * training_loader.__len__(),
gamma=0.75)
if mixed_precision:
......@@ -1107,6 +1111,10 @@ def xtrain(speaker_number,
non_indices = torch.tril(~mask, -1).numpy()
tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[1-tar_non_ratio, tar_non_ratio])
tar_indices *= numpy.random.choice([False, True], size=tar_indices.shape, p=[0.9, 0.1])
non_indices *= numpy.random.choice([False, True], size=non_indices.shape, p=[0.9, 0.1])
logging.critical("val tar count : {:d}, non count : {:d}".format(numpy.sum(tar_indices), numpy.sum(non_indices)))
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
......@@ -1321,10 +1329,10 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
positives = scores[tar_indices]
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
equal_error_rate = eer(negatives, positives)
#equal_error_rate = eer(negatives, positives)
#pmiss, pfa = rocch(positives, negatives)
#equal_error_rate = rocch2eer(pmiss, pfa)
pmiss, pfa = rocch(positives, negatives)
equal_error_rate = rocch2eer(pmiss, pfa)
return (100. * accuracy.cpu().numpy() / validation_shape[0],
loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment