Commit cd2f701c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

refactoring

parent 0f31dc10
......@@ -188,7 +188,7 @@ class RawPreprocessor(torch.nn.Module):
"""
"""
def __init__(self, nb_samp, in_channels, filts, first_conv):
def __init__(self, nb_samp, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50, sample_rate=16000):
"""
:param nb_samp:
......@@ -199,10 +199,18 @@ class RawPreprocessor(torch.nn.Module):
super(RawPreprocessor, self).__init__()
self.ln = LayerNorm(nb_samp)
self.first_conv = SincConv1d(in_channels = in_channels,
out_channels = filts,
kernel_size = first_conv
out_channels = out_channels,
kernel_size = kernel_size,
sample_rate = sample_rate,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
min_low_hz=min_low_hz,
min_band_hz=min_band_hz
)
self.first_bn = torch.nn.BatchNorm1d(num_features = filts)
self.first_bn = torch.nn.BatchNorm1d(num_features = out_channels)
self.lrelu = torch.nn.LeakyReLU()
self.lrelu_keras = torch.nn.LeakyReLU(negative_slope = 0.3)
......@@ -356,16 +364,13 @@ class ResNet18(torch.nn.Module):
:param x:
:return:
"""
print("entree: {}".format(x.shape))
x = self.entry_conv(x)
x = self.entry_batch_norm(x)
x = self.activation(x)
print("Avant resblocks: {}".format(x.shape))
for layer in self.mega_blocks:
x = layer(x)
print("Avant pooling: {}".format(x.shape))
# Pooling done as for x-vectors
mean = torch.mean(x, dim=2)
mean = torch.flatten(mean, 1)
......
......@@ -413,7 +413,7 @@ class SideSet(Dataset):
self.len = len(self.sessions)
_transform = []
if not self.transformation["pipeline"] == '':
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
trans = self.transformation["pipeline"].split(',')
self.add_noise = numpy.zeros(self.len, dtype=bool)
......@@ -473,7 +473,7 @@ class SideSet(Dataset):
:return:
"""
# Check the size of the file
nfo = soundfile.info("{self.data_path}/{self.sessions.iloc[index]['file_id']}{self.data_file_extension}")
nfo = soundfile.info(f"{self.data_path}/{self.sessions.iloc[index]['file_id']}{self.data_file_extension}")
start_frame = int(self.sessions.iloc[index]['start'] * self.sample_rate)
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
......
......@@ -202,10 +202,10 @@ class Xtractor(torch.nn.Module):
filts = [128, [128, 128], [128, 256], [256, 256]]
self.norm_embedding = True
self.preprocessor = RawPreprocessor(nb_samp=48000,
self.preprocessor = RawPreprocessor(nb_samp=32000,
in_channels=1,
filts=filts[0],
first_conv=3)
out_channels=filts[0],
kernel_size=3)
self.sequence_network = torch.nn.Sequential(OrderedDict([
("block0", ResBlockWFMS(nb_filts=filts[1], first=True)),
......@@ -259,11 +259,14 @@ class Xtractor(torch.nn.Module):
)
self.feature_size = self.preprocessor.dimension
elif cfg['preprocessor']["type"] == "rawnet2":
self.preprocessor = RawPreprocessor(nb_samp=48000,
self.preprocessor = RawPreprocessor(nb_samp=int(cfg['preprocessor']["sampling_frequency"] * cfg['preprocessor']["duration"]),
in_channels=1,
filts=128,
first_conv=3)
self.feature_size = 128
out_channels=cfg["feature_size"],
kernel_size=cfg['preprocessor']["kernel_size"],
stride=cfg['preprocessor']["stride"],
padding=cfg['preprocessor']["padding"],
dilation=cfg['preprocessor']["dilation"])
self.feature_size = cfg["feature_size"]
"""
Prepare sequence network
......@@ -362,6 +365,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
def forward(self, x, is_eval=False):
"""
......@@ -421,7 +425,7 @@ def xtrain(speaker_number,
:param num_thread:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
if model_name is None:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment