Commit f2160a93 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent a24435dc
......@@ -471,7 +471,7 @@ class ResNet(torch.nn.Module):
n_mels = 80
n_mfcc = 80
self.mfcc_transform = torchaudio.transforms.MFCC(
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
......@@ -501,7 +501,7 @@ class ResNet(torch.nn.Module):
return torch.nn.Sequential(*layers)
def forward(self, x):
out = self.mfcc_transform(x)
out = self.MFCC(x)
out = self.CMVN(out)
out = out.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
......@@ -528,17 +528,17 @@ class PreResNet34(torch.nn.Module):
self.speaker_number = speaker_number
# Feature extraction
n_fft = 2048
win_length = None
hop_length = 512
n_mels = 80
n_mfcc = 80
#n_fft = 2048
#win_length = None
#hop_length = 512
#n_mels = 80
#n_mfcc = 80
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
#self.MFCC = torchaudio.transforms.MFCC(
# sample_rate=16000,
# n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
self.CMVN = torch.nn.InstanceNorm1d(80)
#self.CMVN = torch.nn.InstanceNorm1d(80)
self.conv1 = torch.nn.Conv2d(1, 128, kernel_size=3,
stride=1, padding=1, bias=False)
......@@ -593,13 +593,13 @@ class PreFastResNet34(torch.nn.Module):
win_length = None
hop_length = 512
n_mels = 80
n_mfcc = 80
n_mfcc = 60
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
self.CMVN = torch.nn.InstanceNorm1d(80)
self.CMVN = torch.nn.InstanceNorm1d(60)
self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=7,
stride=(2, 1), padding=3, bias=False)
......@@ -621,9 +621,7 @@ class PreFastResNet34(torch.nn.Module):
return torch.nn.Sequential(*layers)
def forward(self, x):
out = self.MFCC(x)
out = self.CMVN(out)
out = out.unsqueeze(1)
out = x.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
......
......@@ -401,7 +401,7 @@ class SpkSet(Dataset):
self._spk_dict[speaker]['p'] = numpy.ones((self._spk_dict[speaker]['num_segs'],))/self._spk_dict[speaker]['num_segs']
_transform = []
self.transform = None
self.transform = []
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform = self.transformation["pipeline"].split(',')
......@@ -489,34 +489,35 @@ class SpkSet(Dataset):
frame_offset=start_frame,
num_frames=self.sample_number)
# Select the data augmentation randomly
aug_idx = numpy.random.randint(0,len(self.transform), self.transform_number)
augmentations = list(numpy.array(self.transform)[aug_idx])
if len(self.transform) > 0:
# Select the data augmentation randomly
aug_idx = numpy.random.randint(0,len(self.transform), self.transform_number)
augmentations = list(numpy.array(self.transform)[aug_idx])
if "add_noise" in augmentations:
# Pick a SNR level
snr_db = random.choice(self.transformation["noise_snr"])
if "add_noise" in augmentations:
# Pick a SNR level
snr_db = random.choice(self.transformation["noise_snr"])
# Pick a file name from the noise_df
noise_fn = self.noise_root_db + "/" + random.choice(self.noise_df) + ".wav"
noise, noise_fs = torchaudio.load(noise_fn,
frame_offset=0,
num_frames=speech.shape[1])
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
# Pick a file name from the noise_df
noise_fn = self.noise_root_db + "/" + random.choice(self.noise_df) + ".wav"
noise, noise_fs = torchaudio.load(noise_fn,
frame_offset=0,
num_frames=speech.shape[1])
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
if "add_reverb" in augmentations:
pass
if "add_reverb" in augmentations:
pass
if "codec" in augmentations:
pass
if "codec" in augmentations:
pass
if "filter" in augmentations:
pass
if "filter" in augmentations:
pass
speaker_idx = self._spk_dict[current_speaker]["speaker_idx"]
......
......@@ -38,6 +38,7 @@ import shutil
import sys
import time
import torch
import torchaudio
import tqdm
import yaml
......@@ -311,8 +312,8 @@ class Xtractor(torch.nn.Module):
model_archi="xvector",
loss=None,
norm_embedding=False,
aam_margin=0.5,
aam_s=0.5):
aam_margin=0.2,
aam_s=30):
"""
If config is None, default architecture is created
:param model_archi:
......@@ -333,9 +334,22 @@ class Xtractor(torch.nn.Module):
else:
self.loss = loss
self.feature_size = 30
self.feature_size = 80
self.activation = torch.nn.LeakyReLU(0.2)
# Feature extraction
n_fft = 2048
win_length = None
hop_length = 128
n_mels = 80
n_mfcc = 80
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
self.CMVN = torch.nn.InstanceNorm1d(80)
self.preprocessor = None
self.sequence_network = torch.nn.Sequential(OrderedDict([
......@@ -388,6 +402,20 @@ class Xtractor(torch.nn.Module):
elif model_archi == "resnet34":
self.input_nbdim = 2
# Feature extraction
n_fft = 2048
win_length = None
hop_length = 128
n_mels = 80
n_mfcc = 80
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
self.CMVN = torch.nn.InstanceNorm1d(80)
self.preprocessor = None
self.sequence_network = PreResNet34()
......@@ -414,6 +442,20 @@ class Xtractor(torch.nn.Module):
elif model_archi == "fastresnet34":
self.input_nbdim = 2
# Feature extraction
n_fft = 2048
win_length = None
hop_length = 128
n_mels = 80
n_mfcc = 80
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
self.CMVN = torch.nn.InstanceNorm1d(80)
self.preprocessor = None
self.sequence_network = PreFastResNet34()
......@@ -430,7 +472,7 @@ class Xtractor(torch.nn.Module):
int(self.speaker_number),
s = 30.0,
m = 0.20,
easy_margin = True)
easy_margin = False)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
......@@ -700,6 +742,11 @@ class Xtractor(torch.nn.Module):
if self.preprocessor is not None:
x = self.preprocessor(x)
else:
x = self.MFCC(x)
x = self.CMVN(x)
#x = x.unsqueeze(1)
x = self.sequence_network(x)
# Mean and Standard deviation pooling
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment