Commit c324e7e1 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

AttentivePooling bugfix

parent 4d7d3827
...@@ -308,11 +308,11 @@ class AttentivePooling(torch.nn.Module): ...@@ -308,11 +308,11 @@ class AttentivePooling(torch.nn.Module):
# TODO Make convolution parameters configurable # TODO Make convolution parameters configurable
super(AttentivePooling, self).__init__() super(AttentivePooling, self).__init__()
self.attention = torch.nn.Sequential( self.attention = torch.nn.Sequential(
torch.nn.Conv1d(num_channels * (n_mels//8), num_channels//32, kernel_size=1), torch.nn.Conv1d(num_channels * (n_mels//8), num_channels//2, kernel_size=1),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.BatchNorm1d(num_channels//32), torch.nn.BatchNorm1d(num_channels//2),
torch.nn.Tanh(), torch.nn.Tanh(),
torch.nn.Conv1d(num_channels//32, num_channels * (n_mels//8), kernel_size=1), torch.nn.Conv1d(num_channels//2, num_channels * (n_mels//8), kernel_size=1),
torch.nn.Softmax(dim=2), torch.nn.Softmax(dim=2),
) )
#self.global_context = MeanStdPooling() #self.global_context = MeanStdPooling()
...@@ -516,15 +516,14 @@ class Xtractor(torch.nn.Module): ...@@ -516,15 +516,14 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.00 self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00 self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "halfresnet34": elif model_archi == "halfresnet34":
self.preprocessor = MelSpecFrontEnd(n_fft=512, win_length=400, hop_length=160, n_mels=64) self.preprocessor = MelSpecFrontEnd()
#self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34() self.sequence_network = PreHalfResNet34()
self.embedding_size = 512 self.embedding_size = 512
self.before_speaker_embedding = torch.nn.Linear(in_features = 4096, self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
out_features = self.embedding_size) out_features = self.embedding_size)
self.stat_pooling = AttentivePooling(256, 64) self.stat_pooling = AttentivePooling(256, 80)
self.stat_pooling_weight_decay = 0 self.stat_pooling_weight_decay = 0
self.loss = loss self.loss = loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment