Commit 84c4972a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new definition of embeddings

parent ceee256b
...@@ -98,14 +98,14 @@ class Xtractor(torch.nn.Module): ...@@ -98,14 +98,14 @@ class Xtractor(torch.nn.Module):
seg_emb_0 = torch.cat([mean, std], dim=1) seg_emb_0 = torch.cat([mean, std], dim=1)
# No batch-normalisation after this layer # No batch-normalisation after this layer
seg_emb_1 = self.dropout_lin0(seg_emb_0) seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_1 = self.activation(self.seg_lin0(seg_emb_1)) seg_emb_2 = self.activation(self.seg_lin0(seg_emb_1))
# new layer with batch Normalization # new layer with batch Normalization
seg_emb_2 = self.dropout_lin1(seg_emb_1) seg_emb_3 = self.dropout_lin1(seg_emb_2)
seg_emb_3 = self.norm6(self.activation(self.seg_lin1(seg_emb_2))) seg_emb_4 = self.norm6(self.activation(self.seg_lin1(seg_emb_3)))
# No batch-normalisation after this layer # No batch-normalisation after this layer
seg_emb_4 = self.activation(self.seg_lin2(seg_emb_3)) seg_emb_5 = self.activation(self.seg_lin2(seg_emb_4))
#seg_emb_3 = self.seg_lin2(seg_emb_2) seg_output = torch.nn.functional.softmax(seg_emb_5, dim=1)
return seg_emb_4 return seg_output
def LossFN(self, x, lable): def LossFN(self, x, lable):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(lable))) loss = - torch.trace(torch.mm(torch.log10(x), torch.t(lable)))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment