Commit bcfc2b77 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

fastresnet34

parent ddfd222e
......@@ -579,6 +579,59 @@ class PreResNet34(torch.nn.Module):
return out
class PreFastResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def __init__(self, block=BasicBlock, num_blocks=[3, 4, 6, 3], speaker_number=10):
super(PreFastResNet34, self).__init__()
self.in_planes = 16
self.speaker_number = speaker_number
# Feature extraction
n_fft = 2048
win_length = None
hop_length = 512
n_mels = 80
n_mfcc = 80
self.MFCC = torchaudio.transforms.MFCC(
sample_rate=16000,
n_mfcc=n_mfcc, melkwargs={'n_fft': n_fft, 'n_mels': n_mels, 'hop_length': hop_length})
self.CMVN = torch.nn.InstanceNorm1d(80)
self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=7,
stride=(2, 1), padding=3, bias=False)
self.bn1 = torch.nn.BatchNorm2d(16)
# With block = [3, 4, 6, 3]
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=(2, 2))
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=(2, 2))
self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=(1, 1))
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return torch.nn.Sequential(*layers)
def forward(self, x):
out = self.MFCC(x)
out = self.CMVN(out)
out = out.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
def ResNet34():
return ResNet(BasicBlock, [3, 1, 3, 1, 5, 1, 2])
......
......@@ -412,6 +412,31 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "fastresnet34":
self.input_nbdim = 2
self.preprocessor = None
self.sequence_network = PreFastResNet34()
self.before_speaker_embedding = torch.nn.Linear(in_features = 1280,
out_features = 256)
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
self.embedding_size = 256
self.loss = "aam"
self.after_speaker_embedding = ArcMarginProduct(256,
int(self.speaker_number),
s = 30.0,
m = 0.20,
easy_margin = True)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "rawnet2":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment