Commit 795e4559 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

debug

parent e193510d
......@@ -485,15 +485,41 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
speech = torch.nn.functional.conv1d(speech_[None, ...], rir[None, ...])[0]
if "add_noise" in augmentations:
# Pick a noise sample from the noise_df
noise_row = noise_df.sample()
noise_type = noise_row['type']
noise_start = noise_row['start']
noise_duration = noise_row['duration']
assert noise_type in ['music', 'noise', 'speech']
# Pick a SNR level
snr_db = random.choice(transform_dict["add_noise"]["snr"])
if noise_type == 'music':
snr_db = random.choice(transform_dict["add_noise"]["snr"])
elif noise_type == 'noise':
snr_db = numpy.random.randint(0, 16)
else:
snr_db = numpy.random.randint(13, 20)
if noise_duration * sample_rate > speech.shape[1]:
frame_offset = 0
else:
frame_offset = numpy.random.randint(noise_start * sample_rate, speech.shape[1] - noise_duration * sample_rate)
# Pick a file name from the noise_df
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + random.choice(noise_df) + ".wav"
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=0, num_frames=speech.shape[1])
noise_fn = transform_dict["add_noise"]["data_path"] + "/" + noise_row['file_id'] + ".wav"
noise, noise_fs = torchaudio.load(noise_fn, frame_offset=frame_offset, num_frames=frame_offset + noise_duration * sample_rate)
speech_power = speech.norm(p=2)
noise_power = noise.norm(p=2)
if numpy.random.randint(0, 2) == 1:
noise = torch.flip(noise, dims=[0, 1])
if noise.shape[1] < speech.shape[1]:
noise = torch.tensor(numpy.resize(noise.numpy(), speech.shape))
snr = math.exp(snr_db / 10)
scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2
......
......@@ -287,6 +287,24 @@ class ResBlock(torch.nn.Module):
return out
class SELayer(torch.nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
self.fc = torch.nn.Sequential(
torch.nn.Linear(channel, channel // reduction, bias=False),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(channel // reduction, channel, bias=False),
torch.nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class BasicBlock(torch.nn.Module):
expansion = 1
......@@ -299,6 +317,8 @@ class BasicBlock(torch.nn.Module):
stride=1, padding=1, bias=False)
self.bn2 = torch.nn.BatchNorm2d(planes)
self.se = SELayer(planes)
self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = torch.nn.Sequential(
......@@ -310,6 +330,7 @@ class BasicBlock(torch.nn.Module):
def forward(self, x):
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.se(out)
out += self.shortcut(x)
out = torch.nn.functional.relu(out)
return out
......@@ -489,6 +510,7 @@ class PreFastResNet34(torch.nn.Module):
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
def ResNet34():
......
......@@ -227,9 +227,7 @@ class SideSet(Dataset):
self.noise_df = None
if "add_noise" in self.transform:
# Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
self.noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
self.rir_df = None
if "add_reverb" in self.transform:
......@@ -291,7 +289,8 @@ class IdMapSet(Dataset):
file_extension,
transform_pipeline={},
frame_rate=100,
min_duration=0.165
min_duration=0.165,
backward=False
):
"""
......@@ -309,7 +308,7 @@ class IdMapSet(Dataset):
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.backward = backward
self.transform = []
if (len(self.transformation) > 0):
......@@ -368,7 +367,10 @@ class IdMapSet(Dataset):
noise_df=self.noise_df,
rir_df=self.rir_df)
speech = speech.squeeze()
if self.backward:
speech = torch.flip(speech, [0, 1]).squeeze()
else:
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment