Commit ebdf1b53 authored by Anthony Larcher's avatar Anthony Larcher
Browse files
parents 8e8fb525 9b6795f6
*.pyc
*.DS_Store
docs
.vscode
.gitignore
.vscode
.history
......@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "none" in augmentations:
pass
if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch()
rate = random.uniform(0.8,1.2)
......@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape = speech.shape[1]
configs = [
({"format": "wav", "encoding": 'ULAW', "bits_per_sample": 8}, "8 bit mu-law"),
({"format": "wav", "encoding": 'ALAW', "bits_per_sample": 8}, "8 bit a-law"),
({"format": "gsm"}, "GSM-FR"),
({"format": "mp3", "compression": -9}, "MP3"),
({"format": "vorbis", "compression": -1}, "Vorbis")
......
......@@ -55,14 +55,68 @@ class MeanStdPooling(torch.nn.Module):
def forward(self, x):
"""
:param x:
:param x: [B, C*F, T]
:return:
"""
if len(x.shape) == 4:
# [B, C, F, T]
x = x.permute(0, 1, 3, 2)
x = x.flatten(start_dim=1, end_dim=2)
# [B, C*F]
mean = torch.mean(x, dim=2)
# [B, C*F]
std = torch.std(x, dim=2)
# [B, 2*C*F]
return torch.cat([mean, std], dim=1)
class ChannelWiseCorrPooling(torch.nn.Module):
def __init__(self, in_channels=256, out_channels=64, in_freqs=10, channels_dropout=0.25):
super(ChannelWiseCorrPooling, self).__init__()
self.channels_dropout = channels_dropout
self.merge_freqs_count = 2
assert in_freqs % self.merge_freqs_count == 0
self.groups = in_freqs//self.merge_freqs_count
self.out_channels = out_channels
self.out_dim = int(self.out_channels*(self.out_channels-1)/2)*self.groups
self.L_proj = torch.nn.Conv2d(in_channels*self.groups, out_channels*self.groups, kernel_size=(1, 1), groups=self.groups)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.mask = torch.tril(torch.ones((out_channels, out_channels)), diagonal=-1).type(torch.BoolTensor)
def forward(self, x):
"""
:param x: [B, C, T, F]
:return:
"""
batch_size=x.shape[0]
num_locations = x.shape[-1]*x.shape[-2]/self.groups
self.mask = self.mask.to(x.device)
if self.training:
x *= torch.nn.functional.dropout(torch.ones((1, x.shape[1], 1, 1), device=x.device), p=self.channels_dropout)
#[B, T, C, F]
x = x.permute(0, 2, 1, 3)
#[B, T, C, Fr, f]
x = x.reshape(x.shape[0], x.shape[1], x.shape[-2], self.groups, self.merge_freqs_count)
#[B, T, f, Fr, C]
x = x.permute(0, 1, 4, 3, 2)
#[B, T, f, Fr*C]
x = x.flatten(start_dim=3, end_dim=4)
#[B, Fr*C, T, f]
x = x.permute(0, 3, 1, 2)
#[B, Fr*C', T, f]
x = self.L_proj(x)
#[B, Fr, C', Tr]
x = x.reshape(x.shape[0], self.groups, self.out_channels, -1)
x -= torch.mean(x, axis=-1, keepdims=True)
out = x/(torch.std(x, axis=-1, keepdims=True) + 1e-5)
#[B, C', C']
out = torch.einsum('abci,abdi->abcd', out, out)
#[B, C'*(C'-1)/2]
out = torch.masked_select(out, self.mask).reshape(batch_size, -1)
out = out/num_locations
return out
class AttentivePooling(torch.nn.Module):
"""
Mean and Standard deviation attentive pooling
......@@ -94,9 +148,14 @@ class AttentivePooling(torch.nn.Module):
def forward(self, x):
"""
:param x:
:param x: [B, C*F, T]
:return:
"""
if len(x.shape) == 4:
# [B, C, F, T]
x = x.permute(0, 1, 3, 2)
# [B, C*F, T]
x = x.flatten(start_dim=1, end_dim=2)
if self.global_context:
w = self.attention(torch.cat([x, self.gc(x).unsqueeze(2).repeat(1, 1, x.shape[-1])], dim=1))
else:
......
......@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module):
:param x:
:return:
"""
out = x.unsqueeze(1)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = self.layer6(out)
out = self.layer7(out)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
x = x.unsqueeze(1)
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
return x
class PreHalfResNet34(torch.nn.Module):
......@@ -535,16 +536,15 @@ class PreHalfResNet34(torch.nn.Module):
:param x:
:return:
"""
out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
x = x.unsqueeze(1)
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
class PreFastResNet34(torch.nn.Module):
......@@ -557,7 +557,7 @@ class PreFastResNet34(torch.nn.Module):
self.speaker_number = speaker_number
self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=7,
stride=(2, 1), padding=3, bias=False)
stride=(1, 2), padding=3, bias=False)
self.bn1 = torch.nn.BatchNorm2d(16)
# With block = [3, 4, 6, 3]
......@@ -589,16 +589,15 @@ class PreFastResNet34(torch.nn.Module):
:param x:
:return:
"""
out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2)
return out
x = x.unsqueeze(1)
x = x.permute(0, 1, 3, 2)
x = x.to(memory_format=torch.channels_last)
x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def ResNet34():
......
......@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler):
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self):
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
numpy.random.seed(self.seed + self.epoch)
......@@ -175,7 +175,7 @@ class SideSet(Dataset):
overlap=0.,
dataset_df=None,
min_duration=0.165,
output_format="pytorch",
output_format="pytorch"
):
"""
......@@ -269,6 +269,8 @@ class SideSet(Dataset):
self.transform["codec"] = []
if "phone_filtering" in transforms:
self.transform["phone_filtering"] = []
if "stretch" in transforms:
self.transform["stretch"] = []
self.noise_df = None
if "add_noise" in self.transform:
......@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
start = int(self.idmap.start[index] * 0.01 * self.sample_rate)
if self.idmap.stop[index] is None:
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
duration = int(speech.shape[1] - start)
else:
duration = int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
assert nfo.sample_rate == self.sample_rate
conversion_rate = nfo.sample_rate // self.sample_rate
duration = (int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start)
# add this in case the segment is too short
if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2
start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
duration = int(self.min_duration * self.sample_rate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=duration)
frame_offset=start * conversion_rate,
num_frames=duration * conversion_rate)
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
#speech += 10e-6 * torch.randn(speech.shape)
......
......@@ -36,12 +36,12 @@ import shutil
import torch
import tqdm
import yaml
#torch.autograd.set_detect_anomaly(True)
from collections import OrderedDict
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .pooling import MeanStdPooling
from .pooling import AttentivePooling
from .pooling import AttentivePooling, ChannelWiseCorrPooling
from .pooling import GruPooling
from .preprocessor import MfccFrontEnd
from .preprocessor import MelSpecFrontEnd
......@@ -522,6 +522,35 @@ class Xtractor(torch.nn.Module):
self.stat_pooling = AttentivePooling(256, 80, global_context=True)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number),
emb_dim=self.embedding_size)
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0.000
elif model_archi == "experimental":
self.preprocessor = MelSpecFrontEnd()
self.sequence_network = PreHalfResNet34()
self.embedding_size = embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self.before_speaker_embedding = torch.nn.Linear(in_features = int(64*63*5/2),
out_features = self.embedding_size)
self.stat_pooling = ChannelWiseCorrPooling(in_channels=256, out_channels=64)
self.loss = loss
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
......@@ -788,7 +817,6 @@ class Xtractor(torch.nn.Module):
# Mean and Standard deviation pooling
x = self.stat_pooling(x)
x = self.before_speaker_embedding(x)
if norm_embedding:
......@@ -1005,7 +1033,7 @@ def get_network(model_opts, local_rank):
:return:
"""
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]:
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34", "experimental"]:
model = Xtractor(model_opts["speaker_number"], model_opts["model_type"], loss=model_opts["loss"]["type"], embedding_size=model_opts["embedding_size"])
else:
# Custom type of model
......@@ -1035,24 +1063,9 @@ def get_network(model_opts, local_rank):
if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False
if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
if local_rank < 1:
logging.info(model)
logging.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
#if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
# model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
# print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
return model
......@@ -1080,7 +1093,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
training_df, validation_df = train_test_split(df,
test_size=dataset_opts["validation_ratio"],
stratify=stratify)
# TODO
torch.manual_seed(training_opts['torch_seed'] + local_rank)
torch.cuda.manual_seed(training_opts['torch_seed'] + local_rank)
......@@ -1090,7 +1104,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number=dataset_opts['train']['transform_number'],
overlap=dataset_opts['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
output_format="pytorch"
)
validation_set = SideSet(dataset_opts,
......@@ -1106,20 +1120,20 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"])
batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
batch_size=batch_size,
batch_size=batch_size*torch.cuda.device_count(),
seed=training_opts['torch_seed'],
rank=local_rank,
num_process=torch.cuda.device_count(),
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
else:
batch_size = dataset_opts["batch_size"] // dataset_opts["train"]["sampler"]["examples_per_speaker"]
batch_size = dataset_opts["batch_size"]
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
......@@ -1380,8 +1394,18 @@ def xtrain(dataset_description,
# Initialize the model
model = get_network(model_opts, local_rank)
if local_rank < 1:
if local_rank < 1:
monitor.logger.info(model)
monitor.logger.info("Model_parameters_count: {:d}".format(
sum(p.numel()
for p in model.sequence_network.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.before_speaker_embedding.parameters()
if p.requires_grad) + \
sum(p.numel()
for p in model.stat_pooling.parameters()
if p.requires_grad)))
embedding_size = model.embedding_size
aam_scheduler = None
......@@ -1569,7 +1593,7 @@ def train_epoch(model,
loss += criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
loss, output = output_tuple
loss, no_margin_output = output_tuple
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
......@@ -1603,7 +1627,7 @@ def train_epoch(model,
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))
training_acc=100.0 * accuracy / ((batch_idx + 1) * batch_size))
training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
training_monitor.current_epoch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment