Commit 37d9ba54 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

pooling

parent d0e0a61d
......@@ -174,6 +174,8 @@ class Xtractor(torch.nn.Module):
("norm5", torch.nn.BatchNorm1d(1536))
]))
self.stat_pooling = MeanStdPooling()
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.Linear(3072, 512))
]))
......@@ -290,6 +292,15 @@ class Xtractor(torch.nn.Module):
self.sequence_network = torch.nn.Sequential(OrderedDict(segmental_layers))
self.sequence_network_weight_decay = cfg["segmental"]["weight_decay"]
"""
Pooling
"""
self.stat_pooling = MeanStdPooling()
if cfg["stat_pooling"]["type"] == "GRU":
self.stat_pooling = GruPooling(input_size=cfg["stat_pooling"]["input_size"],
gru_node=cfg["stat_pooling"]["gru_node"],
nb_gru_layer=cfg["stat_pooling"]["nb_gru_layer"])
"""
Prepapre last part of the network (after pooling)
"""
......@@ -353,9 +364,10 @@ class Xtractor(torch.nn.Module):
x = self.sequence_network(x)
# Mean and Standard deviation pooling
mean = torch.mean(x, dim=2)
std = torch.std(x, dim=2)
x = torch.cat([mean, std], dim=1)
#mean = torch.mean(x, dim=2)
#std = torch.std(x, dim=2)
#x = torch.cat([mean, std], dim=1)
x = self.stat_pooling(x)
x = self.before_speaker_embedding(x)
if is_eval:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment