Commit a6a7725f authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

augmentation

parent 295d9fa5
......@@ -459,16 +459,6 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "phone_filtering" in augmentations:
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
speech,
sample_rate,
effects=[
["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"],
["rate", "16000"],
])
if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch()
rate = random.uniform(0.8,1.2)
......@@ -514,15 +504,29 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
if "phone_filtering" in augmentations:
final_shape = speech.shape[1]
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
speech,
sample_rate,
effects=[
["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"],
["rate", "16000"],
])
speech = speech[:, :final_shape]
if "codec" in augmentations:
final_shape = speech.shape[1]
configs = [
{"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8},
{"format": "gsm"},
{"format": "mp3", "compression": -9},
{"format": "vorbis", "compression": -1}
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis")
]
param, title = random.choice(configs)
speech = torchaudio.functional.apply_codec(speech, sample_rate, **param)
speech = speech[:, :final_shape]
return speech
......
......@@ -204,6 +204,8 @@ class MelSpecFrontEnd(torch.nn.Module):
n_mels=self.melkwargs['n_mels'])
self.CMVN = torch.nn.InstanceNorm1d(self.n_mels)
self.time_masking = torchaudio.transforms.TimeMasking(time_mask_param=5)
self.freq_masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=10)
def forward(self, x):
"""
......@@ -219,10 +221,10 @@ class MelSpecFrontEnd(torch.nn.Module):
out = self.MelSpec(out)+1e-6
out = torch.log(out)
out = self.CMVN(out)
out = self.freq_masking(out)
out = self.time_masking(out)
return out
class RawPreprocessor(torch.nn.Module):
"""
......
......@@ -232,6 +232,10 @@ class SideSet(Dataset):
self.transform["add_noise"] = self.transformation["add_noise"]
if "add_reverb" in transforms:
self.transform["add_reverb"] = self.transformation["add_reverb"]
if "codec" in transforms:
self.transform["codec"] = []
if "phone_filtering" in transforms:
self.transform["phone_filtering"] = []
self.noise_df = None
if "add_noise" in self.transform:
......
......@@ -533,12 +533,15 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.0000
self.after_speaker_embedding_weight_decay = 0.0000
elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd()
self.preprocessor = MelSpecFrontEnd(n_fft=1024,
win_length=1024,
hop_length=256)
self.sequence_network = PreHalfResNet34()
self.embedding_size = 256
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256, 80)
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
......@@ -1038,6 +1041,7 @@ def xtrain(speaker_number,
training_set = SideSet(dataset_yaml,
set_type="train",
chunk_per_segment=-1,
transform_number=2,
overlap=dataset_params['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
......@@ -1051,8 +1055,8 @@ def xtrain(speaker_number,
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
2,
128,
1,
256,
dataset_params["batch_size"])
training_loader = DataLoader(training_set,
......@@ -1106,7 +1110,7 @@ def xtrain(speaker_number,
optimizer = _optimizer(param_list, **_options)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=1 * training_loader.__len__(),
gamma=0.95)
gamma=0.90)
if mixed_precision:
scaler = torch.cuda.amp.GradScaler()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment