Commit 690ba2cc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xvectors

parent c9f843ae
......@@ -31,6 +31,8 @@ import pickle
import torch
import torch.optim as optim
import torch.multiprocessing as mp
import yaml
from torchvision import transforms
from collections import OrderedDict
from .xsets import XvectorMultiDataset, StatDataset, VoxDataset
......@@ -75,11 +77,15 @@ class Xtractor(torch.nn.Module):
If config is None, default architecture is created
:param config:
"""
super(Xtractor, self).__init__()
self.speaker_number = speaker_number
self.activation = torch.nn.ReLU()
self.feature_size = 24
if config is None:
self.activation = torch.nn.ReLU()
self.sequence_network = torch.nn.Sequential(OrderedDict([
("conv1", torch.nn.Conv1d(30, 512, 5, dilation=1)),
("conv1", torch.nn.Conv1d(self.feature_size, 512, 5, dilation=1)),
("activation1", torch.nn.LeakyReLU(0.2)),
("norm1", torch.nn.BatchNorm1d(512)),
("conv2", torch.nn.Conv1d(512, 512, 3, dilation=2)),
......@@ -88,25 +94,25 @@ class Xtractor(torch.nn.Module):
("conv3", torch.nn.Conv1d(512, 512, 3, dilation=3)),
("activation3", torch.nn.LeakyReLU(0.2)),
("norm3", torch.nn.BatchNorm1d(512)),
("conv4", torch.nn.Conv1d(512, 512)),
("conv4", torch.nn.Conv1d(512, 512, 1)),
("activation4", torch.nn.LeakyReLU(0.2)),
("norm4", torch.nn.BatchNorm1d(512)),
("conv5", torch.nn.Conv1d(512, 1536)),
("conv5", torch.nn.Conv1d(512, 1536, 1)),
("activation5", torch.nn.LeakyReLU(0.2)),
("norm5", torch.nn.BatchNorm1d(1536))
]))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.linear(1536, 512))
("linear6", torch.nn.Linear(1536, 512))
]))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("linear7", torch.nn.linear(512, 512)),
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.linear(512, self.speaker_number ))
("linear8", torch.nn.Linear(512, self.speaker_number ))
]))
else:
......@@ -132,7 +138,7 @@ class Xtractor(torch.nn.Module):
segmental_layers = []
for k in cfg["segmental"].keys():
if k.startswith("conv"):
segmental_layers.append((k, torch.nn.Conv2d(input_size,
segmental_layers.append((k, torch.nn.Conv1d(input_size,
cfg["segmental"][k]["output_channels"],
cfg["segmental"][k]["kernel_size"],
cfg["segmental"][k]["dilation"])))
......@@ -144,18 +150,18 @@ class Xtractor(torch.nn.Module):
elif k.startswith('norm'):
segmental_layers.append((k, torch.nn.BatchNorm1d(input_size)))
self.sequence_network = nn.Sequential(OrderedDict(segmental_layers))
self.sequence_network = torch.nn.Sequential(OrderedDict(segmental_layers))
# Create sequential object for the second part of the network
input_size = input_size * 2
before_embedding_layers = []
for k in cfg["before_embedding"].keys():
if k.startswith("lin"):
if cfg["embedding"][k]["output"] == "speaker_number":
before_embedding_layers.append((k, torch.nn.linear(input_size, self.speaker_number)))
if cfg["before_embedding"][k]["output"] == "speaker_number":
before_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
before_embedding_layers.append((k, torch.nn.linear(input_size, cfg["embedding"][k]["output"])))
input_size = cfg["embedding"][k]["output"]
before_embedding_layers.append((k, torch.nn.Linear(input_size, cfg["before_embedding"][k]["output"])))
input_size = cfg["before_embedding"][k]["output"]
elif k.startswith("activation"):
before_embedding_layers.append((k, self.activation))
......@@ -164,19 +170,19 @@ class Xtractor(torch.nn.Module):
before_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
before_embedding_layers.append((k, torch.nn.Dropout(p=cfg["emebedding"][k])))
before_embedding_layers.append((k, torch.nn.Dropout(p=cfg["before_embedding"][k])))
self.before_speaker_embedding = nn.Sequential(OrderedDict(before_embedding_layers))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(before_embedding_layers))
# Create sequential object for the second part of the network
after_embedding_layers = []
for k in cfg["after_embedding"].keys():
if k.startswith("lin"):
if cfg["embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.linear(input_size, self.speaker_number)))
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
after_embedding_layers.append((k, torch.nn.linear(input_size, cfg["embedding"][k]["output"])))
input_size = cfg["embedding"][k]["output"]
after_embedding_layers.append((k, torch.nn.Linear(input_size, cfg["after_embedding"][k]["output"])))
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
......@@ -185,45 +191,15 @@ class Xtractor(torch.nn.Module):
after_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["emebedding"][k])))
self.after_embedding_layers = nn.Sequential(OrderedDict(after_embedding_layers))
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["after_embedding"][k])))
def produce_embeddings(self, x):
"""
:param x:
:return:
"""
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
embedding_a = self.seg_lin0(seg_emb)
return embedding_a
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
def forward(self, x, is_eval=False):
"""
:param x:
:return:
seg_emb_0 = self.produce_embeddings(x)
# batch-normalisation after this layer
seg_emb_1 = self.norm6(self.activation(seg_emb_0))
# new layer with batch Normalization
seg_emb_2 = self.norm7(self.activation(self.seg_lin1(self.dropout_lin1(seg_emb_1))))
# No batch-normalisation after this layer
result = self.seg_lin2(seg_emb_2)
return result
"""
x = self.sequence_network(x)
......@@ -239,28 +215,6 @@ class Xtractor(torch.nn.Module):
x = self.after_speaker_embedding(x)
return x
def init_weights(self):
"""
Initialize the x-vector extract weights and biaises
"""
torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
torch.nn.init.xavier_uniform(self.seg_lin0.weight)
torch.nn.init.xavier_uniform(self.seg_lin1.weight)
torch.nn.init.xavier_uniform(self.seg_lin2.weight)
torch.nn.init.constant(self.frame_conv0.bias, 0.1)
torch.nn.init.constant(self.frame_conv1.bias, 0.1)
torch.nn.init.constant(self.frame_conv2.bias, 0.1)
torch.nn.init.constant(self.frame_conv3.bias, 0.1)
torch.nn.init.constant(self.frame_conv4.bias, 0.1)
torch.nn.init.constant(self.seg_lin0.bias, 0.1)
torch.nn.init.constant(self.seg_lin1.bias, 0.1)
torch.nn.init.constant(self.seg_lin2.bias, 0.1)
def xtrain(args):
"""
......@@ -278,8 +232,10 @@ def xtrain(args):
model.train()
else:
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
model.init_weights()
if args.yaml is None:
model = Xtractor(args.class_number)
else:
model = Xtractor(args.class_number, args.yaml)
model.train()
if torch.cuda.device_count() > 1:
......@@ -307,30 +263,7 @@ def xtrain(args):
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
if type(model) is Xtractor:
optimizer = torch.optim.SGD([
{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
],
lr=args.lr, momentum=0.9)
else:
optimizer = torch.optim.SGD([
{'params': model.module.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.module.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.module.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
],
lr=args.lr, momentum=0.9)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment