Commit 9d1a30bb authored by Le Lan Gaël's avatar Le Lan Gaël
Browse files

Channel Wise Correlation Pooling

parent 88f4d2b9
......@@ -63,6 +63,49 @@ class MeanStdPooling(torch.nn.Module):
return torch.cat([mean, std], dim=1)
class ChannelWiseCorrPooling(torch.nn.Module):
def __init__(self, in_channels=256, out_channels=64, in_freqs=10, channels_dropout=0.25):
super(ChannelWiseCorrPooling, self).__init__()
self.channels_dropout = channels_dropout
self.merge_freqs_count = 2
assert in_freqs % self.merge_freqs_count == 0
self.groups = in_freqs//self.merge_freqs_count
self.out_channels = out_channels
self.out_dim = int(self.out_channels*(self.out_channels-1)/2)*self.groups
self.L_proj = torch.nn.Conv2d(in_channels*self.groups, out_channels*self.groups, kernel_size=(1, 1), groups=self.groups)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.mask = torch.tril(torch.ones((out_channels, out_channels)), diagonal=-1).type(torch.BoolTensor)
def forward(self, x):
"""
:param x: [B, C, T, F]
:return:
"""
batch_size=x.shape[0]
num_locations = x.shape[-1]*x.shape[-2]/self.groups
self.mask = self.mask.to(x.device)
if self.training:
x *= torch.nn.functional.dropout(torch.ones((1, x.shape[1], 1, 1), device=x.device), p=self.channels_dropout)
#[B, C, Fr, T, m]
x = x.reshape(x.shape[0], x.shape[1], self.groups, x.shape[-2], self.merge_freqs_count)
#[B, Fr, C, T, m]
x = x.permute(0, 2, 1, 3, 4)
#[B, Fr*C, T, m]
x = x.flatten(start_dim=1, end_dim=2)
x = self.L_proj(x)
#[B, Fr*C', T, m]
x = x.reshape(x.shape[0], self.groups, self.out_channels, -1)
#[B, Fr, C', Tr]
x -= torch.mean(x, axis=-1, keepdims=True)
out = x/(torch.std(x, axis=-1, keepdims=True) + 1e-5)
#[B, C', C']
out = torch.einsum('abci,abdi->abcd', out, out)
#[B, C'*(C'-1)/2]
out = torch.masked_select(out, self.mask).reshape(batch_size, -1)
out = out/num_locations
return out
class AttentivePooling(torch.nn.Module):
"""
Mean and Standard deviation attentive pooling
......@@ -94,7 +137,7 @@ class AttentivePooling(torch.nn.Module):
def forward(self, x):
"""
:param x:
:param x: [B, C*F, T] Tensor
:return:
"""
if self.global_context:
......
......@@ -535,16 +535,12 @@ class PreHalfResNet34(torch.nn.Module):
:param x:
:return:
"""
out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
class PreFastResNet34(torch.nn.Module):
......@@ -589,15 +585,11 @@ class PreFastResNet34(torch.nn.Module):
:param x:
:return:
"""
out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
......
......@@ -36,12 +36,12 @@ import shutil
import torch
import tqdm
import yaml
#torch.autograd.set_detect_anomaly(True)
from collections import OrderedDict
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .pooling import MeanStdPooling
from .pooling import AttentivePooling
from .pooling import AttentivePooling, ChannelWiseCorrPooling
from .pooling import GruPooling
from .preprocessor import MfccFrontEnd
from .preprocessor import MelSpecFrontEnd
......@@ -537,6 +537,34 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.000
elif model_archi == "experimental":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34()
self.embedding_size = embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self.before_speaker_embedding = torch.nn.Linear(in_features = int(48*47*5/2),
out_features = self.embedding_size)
self.stat_pooling = ChannelWiseCorrPooling(in_channels=256, out_channels=48)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "rawnet2":
if loss not in ["cce", 'aam']:
......@@ -784,11 +812,17 @@ class Xtractor(torch.nn.Module):
if self.preprocessor is not None:
x = self.preprocessor(x, is_eval)
x = x.unsqueeze(1)
# Does not work for FastResNet34 !
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = self.sequence_network(x)
#x = x.to(memory_format=torch.contiguous_format)
#print(x.shape)
#x = torch.flatten(x, start_dim=1, end_dim=2)
#print(x.shape)
# Mean and Standard deviation pooling
x = self.stat_pooling(x)
x = self.before_speaker_embedding(x)
if norm_embedding:
......@@ -1003,7 +1037,7 @@ def get_network(model_opts, local_rank):
:return:
"""
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]:
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34", "experimental"]:
model = Xtractor(model_opts["speaker_number"], model_opts["model_type"], loss=model_opts["loss"]["type"], embedding_size=model_opts["embedding_size"])
else:
# Custom type of model
......@@ -1031,24 +1065,9 @@ def get_network(model_opts, local_rank):
if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False
if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
if local_rank < 1:
logging.info(model)
logging.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
#if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
# model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
# print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
return model
......@@ -1108,7 +1127,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
batch_size=batch_size,
batch_size=batch_size*torch.cuda.device_count(),
seed=training_opts['torch_seed'],
rank=local_rank,
num_process=torch.cuda.device_count(),
......@@ -1376,8 +1395,18 @@ def xtrain(dataset_description,
# Initialize the model
model = get_network(model_opts, local_rank)
if local_rank < 1:
if local_rank < 1:
monitor.logger.info(model)
monitor.logger.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
embedding_size = model.embedding_size
aam_scheduler = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment