Commit 82aa956a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

minor

parent 2952f1a2
......@@ -69,7 +69,7 @@ class Xtractor(torch.nn.Module):
"""
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, spk_number, dropout):
def __init__(self, spk_number, dropout, activation='LeakyReLU'):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(30, 512, 5, dilation=1)
self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
......@@ -90,7 +90,19 @@ class Xtractor(torch.nn.Module):
self.norm6 = torch.nn.BatchNorm1d(512)
self.norm7 = torch.nn.BatchNorm1d(512)
#
self.activation = torch.nn.LeakyReLU(0.2)
if activation == 'LeakyReLU':
self.activation = torch.nn.LeakyReLU(0.2)
elif activation == 'ReLU':
self.activation = torch.nn.ReLU()
elif activation == 'PReLU':
self.activation = torch.nn.PReLU()
elif activation == 'ReLU6':
self.activation = torch.nn.ReLU6()
elif activation == 'SELU':
self.activation = torch.nn.SELU()
else:
raise ValueError("Activation function is not implemented")
def produce_embeddings(self, x):
"""
......@@ -123,7 +135,7 @@ class Xtractor(torch.nn.Module):
# new layer with batch Normalization
seg_emb_2 = self.norm7(self.activation(self.seg_lin1(self.dropout_lin1(seg_emb_1))))
# No batch-normalisation after this layer
result = self.activation(self.seg_lin2(seg_emb_2))
result = self.seg_lin2(seg_emb_2)
return result
def extract(self, x):
......@@ -178,6 +190,7 @@ def xtrain(args):
else:
# Initialize a first model and save to disk
model = Xtractor(args.class_number, args.dropout)
model.init_weights()
model.train()
if torch.cuda.device_count() > 1:
......@@ -205,17 +218,30 @@ def xtrain(args):
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
optimizer = torch.optim.SGD([
{'params': model.module.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.module.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.module.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
],
lr=args.lr, momentum=0.9)
if type(model) is Xtractor:
optimizer = torch.optim.SGD([
{'params': model.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
],
lr=args.lr, momentum=0.9)
else:
optimizer = torch.optim.SGD([
{'params': model.module.frame_conv0.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv1.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv2.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv3.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.frame_conv4.parameters(), 'weight_decay': args.l2_frame},
{'params': model.module.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.module.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.module.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
],
lr=args.lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment