Commit c9f843ae authored by Anthony Larcher's avatar Anthony Larcher
Browse files

xv

parent ad033a0c
......@@ -76,8 +76,9 @@ class Xtractor(torch.nn.Module):
:param config:
"""
self.speaker_number = speaker_number
self.activation = torch.nn.ReLU()
if config is None:
self.sequence_network = nn.Sequential(OrderedDict([
self.sequence_network = torch.nn.Sequential(OrderedDict([
("conv1", torch.nn.Conv1d(30, 512, 5, dilation=1)),
("activation1", torch.nn.LeakyReLU(0.2)),
("norm1", torch.nn.BatchNorm1d(512)),
......@@ -95,11 +96,11 @@ class Xtractor(torch.nn.Module):
("norm5", torch.nn.BatchNorm1d(1536))
]))
self.before_speaker_embedding = nn.Sequential(OrderedDict([
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.linear(1536, 512))
]))
self.after_speaker_embedding = nn.Sequential(OrderedDict([
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("linear7", torch.nn.linear(512, 512)),
......@@ -147,25 +148,48 @@ class Xtractor(torch.nn.Module):
# Create sequential object for the second part of the network
input_size = input_size * 2
embedding_layers = []
for k in cfg["embedding"].keys():
before_embedding_layers = []
for k in cfg["before_embedding"].keys():
if k.startswith("lin"):
if cfg["embedding"][k]["output"] == "speaker_number":
embedding_layers.append((k, torch.nn.linear(input_size, self.speaker_number)))
before_embedding_layers.append((k, torch.nn.linear(input_size, self.speaker_number)))
else:
embedding_layers.append((k, torch.nn.linear(input_size, cfg["embedding"][k]["output"])))
before_embedding_layers.append((k, torch.nn.linear(input_size, cfg["embedding"][k]["output"])))
input_size = cfg["embedding"][k]["output"]
elif k.startswith("activation"):
embedding_layers.append((k, self.activation))
before_embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
before_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
embedding_layers.append((k, torch.nn.Dropout(p=cfg["emebedding"][k])))
before_embedding_layers.append((k, torch.nn.Dropout(p=cfg["emebedding"][k])))
self.before_speaker_embedding = nn.Sequential(OrderedDict(before_embedding_layers))
# Create sequential object for the second part of the network
after_embedding_layers = []
for k in cfg["after_embedding"].keys():
if k.startswith("lin"):
if cfg["embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.linear(input_size, self.speaker_number)))
else:
after_embedding_layers.append((k, torch.nn.linear(input_size, cfg["embedding"][k]["output"])))
input_size = cfg["embedding"][k]["output"]
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
after_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["emebedding"][k])))
self.after_embedding_layers = nn.Sequential(OrderedDict(after_embedding_layers))
self.before_speaker_embedding = nn.Sequential(OrderedDict(embedding_layers))
def produce_embeddings(self, x):
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment