Commit 3cfe03cc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xtractorHot

parent 1b768454
......@@ -105,8 +105,7 @@ class Xtractor(torch.nn.Module):
seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer
seg_emb_5 = self.seg_lin2(seg_emb_4)
result = torch.nn.functional.softmax(self.activation(seg_emb_5),dim=1)
#return seg_emb_5
result = self.activation(seg_emb_5)
return result
def LossFN(self, x, label):
......@@ -140,19 +139,102 @@ class Xtractor(torch.nn.Module):
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
embedding_A = self.seg_lin0(seg_emb)
embedding_B = self.seg_lin1(self.norm6(self.activation(embedding_A)))
return embedding_A, embedding_B
class XtractorHot(Xtractor):
def __init__(self, spk_number, dropout):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
self.frame_conv4 = torch.nn.Conv1d(512, 3 * 512, 1)
self.seg_lin0 = torch.nn.Linear(3 * 512 * 2, 512)
self.dropout_lin0 = torch.nn.Dropout(p=dropout)
self.seg_lin1 = torch.nn.Linear(512, 512)
self.dropout_lin1 = torch.nn.Dropout(p=dropout)
self.seg_lin2 = torch.nn.Linear(512, spk_number)
#
self.norm0 = torch.nn.BatchNorm1d(512)
self.norm1 = torch.nn.BatchNorm1d(512)
self.norm2 = torch.nn.BatchNorm1d(512)
self.norm3 = torch.nn.BatchNorm1d(512)
self.norm4 = torch.nn.BatchNorm1d(3 * 512)
self.norm6 = torch.nn.BatchNorm1d(512)
self.norm7 = torch.nn.BatchNorm1d(512)
#
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
# Pooling Layer that computes mean and standard devition of frame level embeddings
# The output of the pooling layer is the first segment-level representation
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
seg_emb_0 = torch.cat([mean, std], dim=1)
# batch-normalisation after this layer
seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
# new layer with batch Normalization
seg_emb_3 = self.dropout_lin1(seg_emb_2)
seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer
# seg_emb_1 = self.activation(self.seg_lin0(seg_emb_0))
seg_emb_5 = self.seg_lin2(seg_emb_4)
result = torch.nn.functional.softmax(self.activation(seg_emb_5),dim=1)
return result
def LossFN(self, x, label):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(label)))
return loss
def init_weights(self):
"""
"""
torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
torch.nn.init.xavier_uniform(self.seg_lin0.weight)
torch.nn.init.xavier_uniform(self.seg_lin1.weight)
torch.nn.init.xavier_uniform(self.seg_lin2.weight)
torch.nn.init.constant(self.frame_conv0.bias, 0.1)
torch.nn.init.constant(self.frame_conv1.bias, 0.1)
torch.nn.init.constant(self.frame_conv2.bias, 0.1)
torch.nn.init.constant(self.frame_conv3.bias, 0.1)
torch.nn.init.constant(self.frame_conv4.bias, 0.1)
torch.nn.init.constant(self.seg_lin0.bias, 0.1)
torch.nn.init.constant(self.seg_lin1.bias, 0.1)
torch.nn.init.constant(self.seg_lin2.bias, 0.1)
def extract(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
frame_emb_2 = self.norm2(self.activation(self.frame_conv2(frame_emb_1)))
frame_emb_3 = self.norm3(self.activation(self.frame_conv3(frame_emb_2)))
frame_emb_4 = self.norm4(self.activation(self.frame_conv4(frame_emb_3)))
seg_emb_A = self.seg_lin0(seg_emb)
seg_emb_B = self.seg_lin1(self.activation(seg_emb_A))
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
# return torch.nn.functional.softmax(seg_emb_3,dim=1)
return seg_emb_A, seg_emb_B
embedding_A = self.seg_lin0(seg_emb)
embedding_B = self.seg_lin1(self.norm6(self.activation(embedding_A)))
return embedding_A, embedding_B
def xtrain(args):
......@@ -169,11 +251,11 @@ def xtrain(args):
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
#args.lr = args.lr * 0.9
def xtrain_hot(args):
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
model = XtractorHot(args.class_number, args.dropout)
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
......@@ -185,7 +267,7 @@ def xtrain_hot(args):
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
#args.lr = args.lr * 0.9
def train_epoch(epoch, args, initial_model_file_name):
# Compute the megabatch number
......@@ -288,7 +370,7 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
def train_worker_hot(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model = XtractorHot(args.class_number, args.dropout)
model.load_state_dict(torch.load(initial_model_file_name))
model.train()
......@@ -410,7 +492,7 @@ def train_asynchronous_hot(epoch, args, initial_model_file_name, batch_file_list
for p in processes:
p.join()
av_model = Xtractor(args.class_number, args.dropout)
av_model = XtractorHot(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
......@@ -531,7 +613,7 @@ def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
output_queue.put((cv_loader.__len__(), accuracy.cpu().numpy()))
def cv_worker_hot(rank, args, current_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number, args.dropout)
model = XtractorHot(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
model.eval()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment