Commit c85ffa25 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

MeanStdPooling

parent 04511e7f
......@@ -80,6 +80,26 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
if is_best:
shutil.copyfile(filename, best_filename)
class MeanStdPooling(torch.nn.Module):
"""
Mean and Standard deviation pooling
"""
def __init__(self):
"""
"""
super(MeanStdPooling, self).__init__()
pass
def forward(self, x):
"""
:param x:
:return:
"""
mean = torch.mean(x, dim=2)
std = torch.std(x, dim=2)
return torch.cat([mean, std], dim=1)
class GruPooling(torch.nn.Module):
"""
......@@ -226,7 +246,10 @@ class Xtractor(torch.nn.Module):
)
self.feature_size = self.preprocessor.dimension
elif cfg['preprocessor']["type"] == "rawnet2":
self.preprocessor =
self.preprocessor = RawPreprocessor(nb_samp=48000,
in_channels=1,
filts=128,
first_conv=3)
"""
Prepare sequence network
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment