Commit ab937d19 authored by Anthony Larcher's avatar Anthony Larcher
Browse files
parents abb09f26 84c4972a
......@@ -98,14 +98,14 @@ class Xtractor(torch.nn.Module):
seg_emb_0 = torch.cat([mean, std], dim=1)
# No batch-normalisation after this layer
seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_1 = self.activation(self.seg_lin0(seg_emb_1))
seg_emb_2 = self.activation(self.seg_lin0(seg_emb_1))
# new layer with batch Normalization
seg_emb_2 = self.dropout_lin1(seg_emb_1)
seg_emb_3 = self.norm6(self.activation(self.seg_lin1(seg_emb_2)))
seg_emb_3 = self.dropout_lin1(seg_emb_2)
seg_emb_4 = self.norm6(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer
seg_emb_4 = self.activation(self.seg_lin2(seg_emb_3))
#seg_emb_3 = self.seg_lin2(seg_emb_2)
return seg_emb_4
seg_emb_5 = self.activation(self.seg_lin2(seg_emb_4))
seg_output = torch.nn.functional.softmax(seg_emb_5, dim=1)
return seg_output
def LossFN(self, x, lable):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(lable)))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment