diff --git a/nnet/augmentation.py b/nnet/augmentation.py index 34ffcc8ddd7e85ce3e946c76786fedd27a10ae37..f17eabf899eff65cec65c71814eb77876829db5c 100644 --- a/nnet/augmentation.py +++ b/nnet/augmentation.py @@ -169,6 +169,7 @@ class AddNoise(object): noises = [] left = original_duration + while left > 0: # select noise file at random file = random.choice(self.noises) @@ -433,7 +434,12 @@ if has_pyroom: return data, sample[1], sample[2], sample[3] , sample[4], sample[5] -def data_augmentation(speech, sample_rate, transform_dict, transform_number, noise_df=None, rir_df=None): +def data_augmentation(speech, + sample_rate, + transform_dict, + transform_number, + noise_df=None, + rir_df=None): """ :param speech: @@ -469,6 +475,16 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi ["rate", "16000"], ]) + if "filtering" in augmentations: + effects = [ + ["bandpass","2000","3500"], + ["bandstop","200","500"]] + speech,sample_rate = torchaudio.sox_eefects.apply_effects_tensor( + speech, + sample_rate, + effects = [effects[random.randint(0,1)]], + ) + if "stretch" in augmentations: strech = torchaudio.functional.TimeStretch() rate = random.uniform(0.8,1.2) @@ -476,52 +492,54 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi if "add_reverb" in augmentations: rir_nfo = random.randrange(len(rir_df)) - rir_fn = transform_dict["add_noise"]["data_path"] + "/" + rir_nfo + ".wav" + rir_fn = transform_dict["add_reverb"]["data_path"] + "/" + rir_nfo + ".wav" rir, rir_fs = torchaudio.load(rir_fn) rir = rir[rir_nfo[1], :] #keep selected channel speech_ = torch.nn.functional.pad(speech, (rir.shape[1]-1, 0)) speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0] if "add_noise" in augmentations: - # Pick a noise sample from the noise_df - noise_row = noise_df.iloc[random.randrange(noise_df.shape[0])] - - noise_type = noise_row['type'] - noise_start = noise_row['start'] - noise_duration = noise_row['duration'] - noise_file_id = noise_row['file_id'] - - # Pick a SNR level - # TODO make SNRs configurable by noise type - if noise_type == 'music': + # Pick a noise type + noise = torch.zeros_like(speech) + noise_idx = random.randrange(4) + + # speech + if noise_idx == 0: + # Pick a SNR level + # TODO make SNRs configurable by noise type + snr_db = random.randint(13, 20) + pick_count = random.randint(3, 7) + index_list = random.choices(range(noise_df.loc['speech'].shape[0]), k=pick_count) + for idx in index_list: + noise_row = noise_df.loc['speech'].iloc[idx] + noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"]) + noise /= pick_count + # music + elif noise_idx == 1: snr_db = random.randint(5, 15) - elif noise_type == 'noise': + noise_row = noise_df.loc['music'].iloc[random.randrange(noise_df.loc['music'].shape[0])] + noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"]) + # noise + elif noise_idx == 2: snr_db = random.randint(0, 15) - else: - snr_db = random.randint(13, 20) - - if noise_duration * sample_rate > speech.shape[1]: - # We force frame_offset to stay in the 20 first seconds of the file, otherwise it takes too long to load - frame_offset = random.randrange(noise_start * sample_rate, min(int(20*sample_rate), int((noise_start + noise_duration) * sample_rate - speech.shape[1]))) - else: - frame_offset = noise_start * sample_rate + noise_row = noise_df.loc['noise'].iloc[random.randrange(noise_df.loc['noise'].shape[0])] + noise += load_noise_seg(noise_row, speech.shape, sample_rate, transform_dict["add_noise"]["data_path"]) + # babble noise with different volume + elif noise_idx == 3: + snr_db = random.randint(13,20) + ns = random.randint(5,10) # Randomly select 5 to 10 speakers + noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_df[noise_df["type"] == "speech"].sample(ns,replace=False)["file_id"].values + ".wav" + noise = torch.zeros(1,speech.shape[1]) + for idx in range(ns): + noise_,noise_fs = torchaudio.load(noise_fn[idx],frame_offset=0,num_frames=speech.shape[1]) + transform = torchaudio.transforms.Vol(gain=random.randint(5,15),gain_type='db') # Randomly select volume level (5-15d) + noise += transform(noise_) + noise /= ns - noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_file_id + ".wav" - if noise_duration * sample_rate > speech.shape[1]: - noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech.shape[1])) - else: - noise, noise_fs = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate)) speech_power = speech.norm(p=2) noise_power = noise.norm(p=2) - - #if numpy.random.randint(0, 2) == 1: - # noise = torch.flip(noise, dims=[0, 1]) - - if noise.shape[1] < speech.shape[1]: - noise = torch.tensor(numpy.resize(noise.numpy(), speech.shape)) - - snr = math.exp(snr_db / 10) + snr = 10 ** (snr_db / 20) scale = snr * noise_power / speech_power speech = (scale * speech + noise) / 2 @@ -537,6 +555,31 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi return speech +def load_noise_seg(noise_row, speech_shape, sample_rate, data_path): + noise_start = noise_row['start'] + noise_duration = noise_row['duration'] + noise_file_id = noise_row['file_id'] + + if noise_duration * sample_rate > speech_shape[1]: + # It is recommended to split noise files (especially speech noise type) in shorter subfiles + # When frame_offset is too high, loading the segment can take much longer + frame_offset = random.randrange(noise_start * sample_rate, int((noise_start + noise_duration) * sample_rate - speech_shape[1])) + else: + frame_offset = noise_start * sample_rate + + noise_fn = data_path + "/" + noise_file_id + ".wav" + if noise_duration * sample_rate > speech_shape[1]: + noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1])) + else: + noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate)) + + #if numpy.random.randint(0, 2) == 1: + # noise = torch.flip(noise, dims=[0, 1]) + + if noise_seg.shape[1] < speech_shape[1]: + noise_seg = torch.tensor(numpy.resize(noise_seg.numpy(), speech_shape)) + return noise_seg + """ It might not be 100% on topic, but maybe this is interesting for you anyway. If you do not need to do real time processing, things can be made more easy. Limiting and dynamic compression can be seen as applying a dynamic transfer function. This function just maps input to output values. A linear function then returns the original audio and a "curved" function does compression or expansion. Applying a transfer function is as simple as diff --git a/nnet/xsets.py b/nnet/xsets.py index 891cf549d3762e423588166cb8fcfb9efb69c619..c9d3d29750bcd7816ee136da0b36a901952c8699 100644 --- a/nnet/xsets.py +++ b/nnet/xsets.py @@ -168,6 +168,7 @@ class SideSet(Dataset): self.transformation = dataset["eval"]["transformation"] self.sample_number = int(self.duration * self.sample_rate) + self.overlap = int(overlap * self.sample_rate) # Load the dataset description as pandas.dataframe if dataset_df is None: @@ -188,16 +189,18 @@ class SideSet(Dataset): # Create lists for each column of the dataframe df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []])) + df_dict["file_start"] = list() + df_dict["file_duration"] = list() # For each segment, get all possible segments with the current overlap - for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1): + for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None): current_session = tmp_sessions.iloc[idx] # Compute possible starts possible_starts = numpy.arange(0, int(self.sample_rate * (current_session.duration - self.duration)), - self.sample_number - int(self.sample_rate * overlap) - ) + self.sample_number + ) + int(self.sample_rate * (current_session.duration % self.duration / 2)) possible_starts += int(self.sample_rate * current_session.start) # Select max(seg_nb, possible_segments) segments @@ -206,7 +209,7 @@ class SideSet(Dataset): chunk_nb = len(possible_starts) else: chunk_nb = min(len(possible_starts), chunk_per_segment) - starts = numpy.random.permutation(possible_starts)[:chunk_nb] / self.sample_rate + starts = numpy.random.permutation(possible_starts)[:chunk_nb] # Once we know how many segments are selected, create the other fields to fill the DataFrame for ii in range(chunk_nb): @@ -215,6 +218,8 @@ class SideSet(Dataset): df_dict["file_id"].append(current_session.file_id) df_dict["start"].append(starts[ii]) df_dict["duration"].append(self.duration) + df_dict["file_start"].append(current_session.start) + df_dict["file_duration"].append(current_session.duration) df_dict["speaker_idx"].append(current_session.speaker_idx) df_dict["gender"].append(current_session.gender) @@ -231,8 +236,9 @@ class SideSet(Dataset): self.noise_df = None if "add_noise" in self.transform: - # Load the noise dataset, filter according to the duration - self.noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"]) + noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"]) + noise_df = noise_df.loc[noise_df.duration > self.duration] + self.noise_df = noise_df.set_index(noise_df.type) self.rir_df = None if "add_reverb" in self.transform: @@ -249,7 +255,18 @@ class SideSet(Dataset): current_session = self.sessions.iloc[index] nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}") - start_frame = int(current_session['start']) + original_start = int(current_session['start']) + if self.overlap > 0: + lowest_shift = self.overlap/2 + highest_shift = self.overlap/2 + if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2): + lowest_shift = int(original_start - current_session['file_start']*self.sample_rate) + if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2: + highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number)) + start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift)) + else: + start_frame = original_start + if start_frame + self.sample_number >= nfo.frames: start_frame = numpy.min(nfo.frames - self.sample_number - 1) @@ -292,7 +309,8 @@ class IdMapSet(Dataset): idmap_name, data_path, file_extension, - transform_pipeline="", + transform_pipeline={}, + transform_number=1, sliding_window=False, window_len=24000, window_shift=8000, @@ -318,20 +336,21 @@ class IdMapSet(Dataset): self.sliding_window = sliding_window self.window_len = window_len self.window_shift = window_shift + self.transform_number = transform_number + - self.transform = [] - if self.transformation is not None: - self.transform_list = self.transformation.split(",") + #if self.transformation is not None: + # self.transform_list = self.transformation.split(",") self.noise_df = None - if "add_noise" in self.transform: + if "add_noise" in self.transformation: # Load the noise dataset, filter according to the duration noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"]) - tmp_df = noise_df.loc[noise_df['duration'] > self.duration] - self.noise_df = tmp_df['file_id'].tolist() + #tmp_df = noise_df.loc[noise_df['duration'] > self.duration] + self.noise_df = noise_df.set_index(noise_df.type) self.rir_df = None - if "add_reverb" in self.transform: + if "add_reverb" in self.transformation: # load the RIR database tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"]) self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist()) @@ -344,18 +363,19 @@ class IdMapSet(Dataset): """ if self.idmap.start[index] is None: start = 0 + else: + start = int(self.idmap.start[index]) * 160 if self.idmap.stop[index] is None: speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}") - duration = speech.shape[1] - start + duration = int(speech.shape[1] - start) else: - start = int(self.idmap.start[index]) - duration = int(self.idmap.stop[index]) - start + duration = int(self.idmap.stop[index]) * 160 - start # add this in case the segment is too short if duration <= self.min_sample_nb: middle = start + duration // 2 start = max(0, int(middle - (self.min_sample_nb / 2))) - duration = self.min_sample_nb + duration = int(self.min_sample_nb) speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}", frame_offset=start, @@ -366,10 +386,10 @@ class IdMapSet(Dataset): if self.sliding_window: speech = speech.squeeze().unfold(0,self.window_len,self.window_shift) - if len(self.transform) > 0: + if len(self.transformation.keys()) > 0: speech = data_augmentation(speech, speech_fs, - self.transform, + self.transformation, self.transform_number, noise_df=self.noise_df, rir_df=self.rir_df) diff --git a/nnet/xvector.py b/nnet/xvector.py index abc5f55921a91af1576a4f9502136f65ad549425..7e6897579df450aff20d4dfd479953661f8f42bb 100755 --- a/nnet/xvector.py +++ b/nnet/xvector.py @@ -202,7 +202,10 @@ def eer(negatives, positives): def test_metrics(model, device, - speaker_number, + idmap_test_filename, + ndx_test_filename, + key_test_filename, + data_root_name, num_thread, mixed_precision): """Compute model metrics @@ -221,10 +224,10 @@ def test_metrics(model, Returns: [type]: [description] """ - idmap_test_filename = 'h5f/idmap_test.h5' - ndx_test_filename = 'h5f/ndx_test.h5' - key_test_filename = 'h5f/key_test.h5' - data_root_name='/lium/scratch/larcher/voxceleb1/test/wav' + #idmap_test_filename = 'h5f/idmap_test.h5' + #ndx_test_filename = 'h5f/ndx_test.h5' + #key_test_filename = 'h5f/key_test.h5' + #data_root_name='/lium/scratch/larcher/voxceleb1/test/wav' transform_pipeline = dict() @@ -409,6 +412,7 @@ class Xtractor(torch.nn.Module): self.preprocessor = MelSpecFrontEnd(n_mels=80) self.sequence_network = PreResNet34() + self.embedding_size = 256 self.before_speaker_embedding = torch.nn.Linear(in_features=5120, out_features=self.embedding_size) @@ -445,8 +449,8 @@ class Xtractor(torch.nn.Module): if self.loss == "aam": self.after_speaker_embedding = ArcMarginProduct(self.embedding_size, int(self.speaker_number), - s = 30, - m = 0.2, + s = 20, + m = 0.3, easy_margin = False) elif self.loss == 'aps': @@ -822,7 +826,7 @@ def xtrain(speaker_number, else: logging.critical(f"*** Load model from = {model_name}") checkpoint = torch.load(model_name) - model = Xtractor(speaker_number, model_yaml) + model = Xtractor(speaker_number, model_yaml, loss=loss) """ Here we remove all layers that we don't want to reload @@ -836,7 +840,6 @@ def xtrain(speaker_number, new_model_dict.update(pretrained_dict) model.load_state_dict(new_model_dict) - # Freeze required layers for name, param in model.named_parameters(): if name.split(".")[0] in freeze_parts: @@ -979,7 +982,7 @@ def xtrain(speaker_number, elif opt == 'rmsprop': _optimizer = torch.optim.RMSprop _options = {'lr': lr} - else: # opt == 'sgd' + else: # opt == 'sgd' _optimizer = torch.optim.SGD _options = {'lr': lr, 'momentum': 0.9} @@ -1001,9 +1004,9 @@ def xtrain(speaker_number, param_list.append({'params': model.module.after_speaker_embedding.parameters(), 'weight_decay': model.module.after_speaker_embedding_weight_decay}) optimizer = _optimizer(param_list, **_options) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, - step_size=20 * training_loader.__len__(), - gamma=0.5) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=[10000,50000,100000], + gamma=0.5) if mixed_precision: scaler = torch.cuda.amp.GradScaler() @@ -1046,7 +1049,15 @@ def xtrain(speaker_number, logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Cross validation accuracy = {val_acc} %, EER = {val_eer * 100} %") if compute_test_eer: - test_eer = test_metrics(model, device, speaker_number, num_thread, mixed_precision) + test_eer = test_metrics(model, + device, + idmap_test_filename=dataset_params["test_set"]["idmap_test_filename"], + ndx_test_filename=dataset_params["test_set"]["ndx_test_filename"], + key_test_filename=dataset_params["test_set"]["key_test_filename"], + data_root_name=dataset_params["test_set"]["data_root_name"], + num_thread=num_thread, + mixed_precision=mixed_precision) + logging.critical(f"***{time.strftime('%H:%M:%S', time.localtime())} Training metrics - Test EER = {test_eer * 100} %") # remember best accuracy and save checkpoint @@ -1215,7 +1226,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m data = data.squeeze().to(device) with torch.cuda.amp.autocast(enabled=mixed_precision): if loss_criteria == 'aam': - batch_predictions, batch_embeddings = model(data, target=None, is_eval=False) + batch_predictions, batch_embeddings = model(data, target=target, is_eval=False) elif loss_criteria == 'aps': batch_predictions, batch_embeddings = model(data, target=None, is_eval=False) else: @@ -1225,7 +1236,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum() loss += criterion(batch_predictions, target) embeddings[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0],:] = batch_embeddings.detach().cpu() - #classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu() + classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu() #print(classes.shape[0]) local_device = "cpu" if embeddings.shape[0] > 3e4 else device @@ -1246,6 +1257,7 @@ def extract_embeddings(idmap_name, model_filename, data_root_name, device, + loss, file_extension="wav", transform_pipeline="", frame_shift=0.01, @@ -1275,7 +1287,7 @@ def extract_embeddings(idmap_name, checkpoint = torch.load(model_filename, map_location=device) speaker_number = checkpoint["speaker_number"] model_archi = checkpoint["model_archi"] - model = Xtractor(speaker_number, model_archi=model_archi) + model = Xtractor(speaker_number, model_archi=model_archi, loss=loss) model.load_state_dict(checkpoint["model_state_dict"]) else: model = model_filename @@ -1341,6 +1353,7 @@ def extract_embeddings(idmap_name, for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)): + if data.shape[1] > 20000000: data = data[...,:20000000] with torch.cuda.amp.autocast(enabled=mixed_precision):