Commit 86d80832 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xvectors

parent ab937d19
......@@ -82,6 +82,7 @@ class Xtractor(torch.nn.Module):
self.norm3 = torch.nn.BatchNorm1d(512)
self.norm4 = torch.nn.BatchNorm1d(1500)
self.norm6 = torch.nn.BatchNorm1d(512)
self.norm7 = torch.nn.BatchNorm1d(512)
#
self.activation = torch.nn.LeakyReLU(0.2)
......@@ -96,16 +97,16 @@ class Xtractor(torch.nn.Module):
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb_0 = torch.cat([mean, std], dim=1)
# No batch-normalisation after this layer
# batch-normalisation after this layer
seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_2 = self.activation(self.seg_lin0(seg_emb_1))
seg_emb_2 = self.norm6(self.activation(self.seg_lin0(seg_emb_1)))
# new layer with batch Normalization
seg_emb_3 = self.dropout_lin1(seg_emb_2)
seg_emb_4 = self.norm6(self.activation(self.seg_lin1(seg_emb_3)))
seg_emb_4 = self.norm7(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer
seg_emb_5 = self.activation(self.seg_lin2(seg_emb_4))
seg_output = torch.nn.functional.softmax(seg_emb_5, dim=1)
return seg_output
seg_emb_5 = self.seg_lin2(seg_emb_4)
#seg_output = torch.nn.LogSoftmax(seg_emb_5)
return seg_emb_5
def LossFN(self, x, lable):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(lable)))
......@@ -114,23 +115,34 @@ class Xtractor(torch.nn.Module):
def init_weights(self):
"""
"""
torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.seg_lin0.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.seg_lin1.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.seg_lin2.weight, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv0.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv1.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv2.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv3.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv4.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.seg_lin0.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.seg_lin1.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.seg_lin2.bias, mean=-0.5, std=1.)
torch.nn.init.normal_(self.frame_conv0.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv1.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv2.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv3.weight, mean=-0.5, std=0.1)
torch.nn.init.normal_(self.frame_conv4.weight, mean=-0.5, std=0.1)
torch.nn.init.xavier_uniform(self.seg_lin0.weight)
torch.nn.init.xavier_uniform(self.seg_lin1.weight)
torch.nn.init.xavier_uniform(self.seg_lin2.weight)
#torch.nn.init.normal_(self.seg_lin0.weight, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin1.weight, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin2.weight, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv0.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv1.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv2.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv3.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.frame_conv4.bias, mean=-0.5, std=1.)
torch.nn.init.constant(self.frame_conv0.bias, 0.1)
torch.nn.init.constant(self.frame_conv1.bias, 0.1)
torch.nn.init.constant(self.frame_conv2.bias, 0.1)
torch.nn.init.constant(self.frame_conv3.bias, 0.1)
torch.nn.init.constant(self.frame_conv4.bias, 0.1)
#torch.nn.init.normal_(self.seg_lin0.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin1.bias, mean=-0.5, std=1.)
#torch.nn.init.normal_(self.seg_lin2.bias, mean=-0.5, std=1.)
torch.nn.init.constant(self.seg_lin0.bias, 0.1)
torch.nn.init.constant(self.seg_lin1.bias, 0.1)
torch.nn.init.constant(self.seg_lin2.bias, 0.1)
def extract(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
......@@ -219,7 +231,8 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
# criterion = torch.nn.CrossEntropyLoss()
#criterion = torch.nn.NLLLoss()
#criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment