Commit 0ff214ca authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new resnet34 in xtractor

parent eb10ef30
......@@ -32,8 +32,8 @@ from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset, IdMapSet_per_speaker
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding
from .res_net import ResBlock, ResNet18
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling
from .res_net import ResBlock, ResNet18, PreResNet34
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
from .sincnet import SincNet
......
......@@ -245,7 +245,7 @@ class ResBlock(torch.nn.Module):
"""
"""
def __init__(self, in_channels, out_channels, is_first=False):
def __init__(self, in_channels, out_channels, stride, is_first=False):
"""
:param filter_size:
......@@ -275,11 +275,13 @@ class ResBlock(torch.nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(3,3),
stride=stride,
padding=1,
padding_mode='zeros',
dilation=1)
self.conv2= torch.nn.Conv2d(in_channels=self.out_channels,
out_channels=self.out_channels,
stride=stride,
kernel_size=(3,3),
padding=1,
padding_mode='zeros',
......@@ -322,7 +324,7 @@ class ResNet18(torch.nn.Module):
def __init__(self, spk_number,
entry_conv_kernel_size=(3,3),
entry_conv_out_channels=128,
megablock_out_channels=(128, 128, 256, 512),
megablock_out_channels=(64, 128, 256, 512),
megablock_size=(2, 2, 2, 2),
block_type = ResBlock
):
......@@ -344,7 +346,7 @@ class ResNet18(torch.nn.Module):
self.entry_conv = torch.nn.Conv2d(in_channels=1,
out_channels=entry_conv_out_channels,
kernel_size=entry_conv_kernel_size,
padding=1,
padding=(1),
stride=1)
self.entry_batch_norm = torch.nn.BatchNorm2d(entry_conv_out_channels)
self.top_channel_number = entry_conv_out_channels
......@@ -356,8 +358,8 @@ class ResNet18(torch.nn.Module):
self.top_channel_number = mb_out
# Top layers for classification and embeddings extraction
self.top_lin1 = torch.nn.Linear(204800, 2560) # a modifier pour voir la taille
self.top_batch_norm1 = torch.nn.BatchNorm1d(2560)
self.top_lin1 = torch.nn.Linear(megablock_out_channels[-1] * 2 * 40, 512) # a modifier pour voir la taille
self.top_batch_norm1 = torch.nn.BatchNorm1d(512)
self.top_lin2 = torch.nn.Linear(512, spk_number)
def forward(self, x):
......@@ -372,13 +374,10 @@ class ResNet18(torch.nn.Module):
for layer in self.mega_blocks:
x = layer(x)
print(f"x.shape = {x.shape}")
# Pooling done as for x-vectors
mean = torch.mean(x, dim=2)
print(f"shape of mean: {mean.shape}")
mean = torch.flatten(mean, 1)
print(f"shape of mean: {mean.shape}")
std = torch.std(x, dim=2)
std = torch.flatten(std, 1)
x = torch.cat([mean, std], dim=1)
......@@ -399,6 +398,159 @@ class ResNet18(torch.nn.Module):
return torch.nn.Sequential(*rblocks)
class BasicBlock(torch.nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(planes)
self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = torch.nn.BatchNorm2d(planes)
self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = torch.nn.Sequential(
torch.nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
torch.nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(torch.nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(planes)
self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = torch.nn.BatchNorm2d(planes)
self.conv3 = torch.nn.Conv2d(planes, self.expansion *
planes, kernel_size=1, bias=False)
self.bn3 = torch.nn.BatchNorm2d(self.expansion*planes)
self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = torch.nn.Sequential(
torch.nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
torch.nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = torch.nn.functional.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = torch.nn.functional.relu(out)
return out
class ResNet(torch.nn.Module):
def __init__(self, block, num_blocks, speaker_number=10):
super(ResNet, self).__init__()
self.in_planes = 128
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 128, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(128)
# With block = [3, 1, 3, 1, 5, 1, 2]
self.layer1 = self._make_layer(block, 128, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=1)
self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
self.layer5 = self._make_layer(block, 256, num_blocks[4], stride=1)
self.layer6 = self._make_layer(block, 256, num_blocks[5], stride=2)
self.layer7 = self._make_layer(block, 256, num_blocks[5], stride=1)
self.stat_pooling = MeanStdPooling()
self.before_embedding = torch.nn.Linear(5120, 256)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return torch.nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = self.layer6(out)
out = self.layer7(out)
out = torch.flatten(out, start_dim=1, end_dim=2)
out = self.stat_pooling(out)
out = self.before_embedding(out)
return out
class PreResNet34(torch.nn.Module):
"""
Networks that contains only the ResNet part until pooling, with NO classification layers
"""
def __init__(self, block=BasicBlock, num_blocks=[3, 1, 3, 1, 5, 1, 2], speaker_number=10):
super(PreResNet34, self).__init__()
self.in_planes = 128
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 128, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(128)
# With block = [3, 1, 3, 1, 5, 1, 2]
self.layer1 = self._make_layer(block, 128, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=1)
self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2)
self.layer5 = self._make_layer(block, 256, num_blocks[4], stride=1)
self.layer6 = self._make_layer(block, 256, num_blocks[5], stride=2)
self.layer7 = self._make_layer(block, 256, num_blocks[5], stride=1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return torch.nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = self.layer6(out)
out = self.layer7(out)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
def ResNet34():
return ResNet(BasicBlock, [3, 1, 3, 1, 5, 1, 2])
def restrain(args):
"""
Initialize and train an ResNet for Speaker Recognition
......
......@@ -48,7 +48,7 @@ from .xsets import SideSet
from .xsets import FileSet
from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .res_net import RawPreprocessor, ResBlockWFMS, ResBlock
from .res_net import RawPreprocessor, ResBlockWFMS, ResBlock, PreResNet34
from ..bosaris import IdMap
from ..bosaris import Key
from ..bosaris import Ndx
......@@ -365,6 +365,38 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
elif model_archi == "resnet34":
self.preprocessor = None
self.sequence_network = PreResNet34()
self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
out_features = 256)
self.stat_pooling = MeanStdPooling()
self.stat_pooling_weight_decay = 0
if loss not in ["cce", 'aam']:
raise NotImplementedError(f"The valid loss are for now cce and aam ")
else:
self.loss = loss
if self.loss == "aam":
if loss == 'aam':
self.after_speaker_embedding = ArcLinear(256,
int(self.speaker_number),
margin=aam_margin, s=aam_s)
elif self.loss == "cce":
self.after_speaker_embedding = torch.nn.Linear(in_features = 256,
out_features = int(self.speaker_number),
bias = True)
self.preprocessor_weight_decay = 0.000
self.sequence_network_weight_decay = 0.000
self.stat_pooling_weight_decay = 0.000
self.before_speaker_embedding_weight_decay = 0.00
self.after_speaker_embedding_weight_decay = 0.00
elif model_archi == "rawnet2":
if loss not in ["cce", 'aam']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment