Commit 6d6f89a9 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Clean 1-hot version

parent ee06b1e4
......@@ -109,10 +109,6 @@ class Xtractor(torch.nn.Module):
result = self.activation(seg_emb_5)
return result
def LossFN(self, x, label):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(label)))
return loss
def init_weights(self):
"""
"""
......@@ -155,30 +151,6 @@ class Xtractor(torch.nn.Module):
return seg_emb_1, seg_emb_2, seg_emb_3, seg_emb_4, seg_emb_5, seg_emb_6
class XtractorHot(Xtractor):
def __init__(self, spk_number, dropout):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(20, 512, 5, dilation=1)
self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
self.frame_conv2 = torch.nn.Conv1d(512, 512, 3, dilation=3)
self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
self.frame_conv4 = torch.nn.Conv1d(512, 3 * 512, 1)
self.seg_lin0 = torch.nn.Linear(3 * 512 * 2, 512)
self.dropout_lin0 = torch.nn.Dropout(p=dropout)
self.seg_lin1 = torch.nn.Linear(512, 512)
self.dropout_lin1 = torch.nn.Dropout(p=dropout)
self.seg_lin2 = torch.nn.Linear(512, spk_number)
#
self.norm0 = torch.nn.BatchNorm1d(512)
self.norm1 = torch.nn.BatchNorm1d(512)
self.norm2 = torch.nn.BatchNorm1d(512)
self.norm3 = torch.nn.BatchNorm1d(512)
self.norm4 = torch.nn.BatchNorm1d(3 * 512)
self.norm6 = torch.nn.BatchNorm1d(512)
self.norm7 = torch.nn.BatchNorm1d(512)
#
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
frame_emb_1 = self.norm1(self.activation(self.frame_conv1(frame_emb_0)))
......@@ -557,8 +529,6 @@ def extract_parallel(args, fs_params):
for p in processes:
p.join()
print("Process parallel fini")
return x_server_1, x_server_2, x_server_3, x_server_4, x_server_5, x_server_6
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment