Commit b6038a2b authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning

parent 01af21ba
......@@ -43,6 +43,7 @@ from .pooling import GruPooling
from .res_net import ResBlock
from .res_net import PreResNet34
from .res_net import PreFastResNet34
from .res_net import PreHalfResNet34
from .sincnet import SincNet
from .preprocessor import RawPreprocessor
from .preprocessor import MfccFrontEnd
......
......@@ -457,6 +457,47 @@ class PreResNet34(torch.nn.Module):
return out
class PreHalfResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def __init__(self, block=BasicBlock, num_blocks=[3, 4, 6, 3], speaker_number=10):
super(PreHalfResNet34, self).__init__()
self.in_planes = 32
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3,
stride=(1, 1), padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(32)
# With block = [3, 4, 6, 3]
self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=(1, 1))
self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=(2, 2))
self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=(2, 2))
self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=(2, 2))
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return torch.nn.Sequential(*layers)
def forward(self, x):
out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
class PreFastResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment