Commit add67511 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

minor modification

parent baae957c
......@@ -337,19 +337,20 @@ class SincNet(torch.nn.Module):
bias=True,
)
else:
print(f"smaple_rate ={self.sample_rate} ")
conv1d = SincConv1d(
1,
out_channels,
kernel_size,
sample_rate=self.sample_rate,
min_low_hz=self.min_low_hz,
min_band_hz=self.min_band_hz,
stride=stride,
padding=0,
dilation=1,
bias=False,
groups=1,
)
out_channels,
kernel_size,
sample_rate=self.sample_rate,
in_channels=1,
min_low_hz=self.min_low_hz,
min_band_hz=self.min_band_hz,
stride=stride,
padding=0,
dilation=1,
bias=False,
groups=1,
)
self.conv1d_.append(conv1d)
# 1D max-pooling
......
......@@ -375,6 +375,7 @@ class Xtractor(torch.nn.Module):
self.preprocessor = None
if "preprocessor" in cfg:
if cfg['preprocessor']["type"] == "sincnet":
print(f"sample_rate = {cfg['preprocessor']['sample_rate']}")
self.preprocessor = SincNet(
waveform_normalize=cfg['preprocessor']["waveform_normalize"],
sample_rate=cfg['preprocessor']["sample_rate"],
......@@ -398,6 +399,7 @@ class Xtractor(torch.nn.Module):
padding=cfg['preprocessor']["padding"],
dilation=cfg['preprocessor']["dilation"])
self.feature_size = cfg["feature_size"]
self.preprocessor_weight_decay = 0.000
"""
Prepare sequence network
......@@ -446,10 +448,12 @@ class Xtractor(torch.nn.Module):
Pooling
"""
self.stat_pooling = MeanStdPooling()
tmp_input_size = input_size * 2
if cfg["stat_pooling"]["type"] == "GRU":
self.stat_pooling = GruPooling(input_size=cfg["stat_pooling"]["input_size"],
gru_node=cfg["stat_pooling"]["gru_node"],
nb_gru_layer=cfg["stat_pooling"]["nb_gru_layer"])
tmp_input_size = gru_node=cfg["stat_pooling"]["gru_node"]
self.stat_pooling_weight_decay = cfg["stat_pooling"]["weight_decay"]
......@@ -457,7 +461,7 @@ class Xtractor(torch.nn.Module):
Prepare last part of the network (after pooling)
"""
# Create sequential object for the second part of the network
input_size = input_size * 2
input_size = tmp_input_size
before_embedding_layers = []
for k in cfg["before_embedding"].keys():
if k.startswith("lin"):
......@@ -649,6 +653,7 @@ def xtrain(speaker_number,
if name.split(".")[0] in freeze_parts:
param.requires_grad = False
print(model)
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment