Commit 1e3cda4b authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add dropout

parent fa14d338
......@@ -63,7 +63,7 @@ def split_file_list(batch_files, num_processes):
class Xtractor(torch.nn.Module):
def __init__(self, spk_number):
def __init__(self, spk_number, dropout):
super(Xtractor, self).__init__()
self.frame_conv0 = torch.nn.Conv1d(20, 512, 5)
self.frame_conv1 = torch.nn.Conv1d(512, 512, 3, dilation=2)
......@@ -71,7 +71,9 @@ class Xtractor(torch.nn.Module):
self.frame_conv3 = torch.nn.Conv1d(512, 512, 1)
self.frame_conv4 = torch.nn.Conv1d(512, 1500, 1)
self.seg_lin0 = torch.nn.Linear(3000, 512)
self.dropout_lin0 = torch.nn.Dropout(p=dropout)
self.seg_lin1 = torch.nn.Linear(512, 512)
self.dropout_lin1 = torch.nn.Dropout(p=dropout)
self.seg_lin2 = torch.nn.Linear(512, spk_number)
#
self.norm0 = torch.nn.BatchNorm1d(512)
......@@ -81,7 +83,7 @@ class Xtractor(torch.nn.Module):
self.norm4 = torch.nn.BatchNorm1d(1500)
self.norm6 = torch.nn.BatchNorm1d(512)
#
self.activation = torch.nn.Softplus()
self.activation = torch.nn.LeakyReLU(0.2)
def forward(self, x):
frame_emb_0 = self.norm0(self.activation(self.frame_conv0(x)))
......@@ -95,13 +97,15 @@ class Xtractor(torch.nn.Module):
std = torch.std(frame_emb_4, dim=2)
seg_emb_0 = torch.cat([mean, std], dim=1)
# No batch-normalisation after this layer
seg_emb_1 = self.activation(self.seg_lin0(seg_emb_0))
seg_emb_1 = self.dropout_lin0(seg_emb_0)
seg_emb_1 = self.activation(self.seg_lin0(seg_emb_1))
# new layer with batch Normalization
seg_emb_2 = self.norm6(self.activation(self.seg_lin1(seg_emb_1)))
seg_emb_2 = self.dropout_lin1(seg_emb_1)
seg_emb_3 = self.norm6(self.activation(self.seg_lin1(seg_emb_2)))
# No batch-normalisation after this layer
# seg_emb_3 = self.activation(self.seg_lin2(seg_emb_2))
seg_emb_3 = self.seg_lin2(seg_emb_2)
return seg_emb_3
seg_emb_4 = self.seg_lin2(seg_emb_3)
return seg_emb_4
def LossFN(self, x, lable):
loss = - torch.trace(torch.mm(torch.log10(x), torch.t(lable)))
......@@ -151,7 +155,7 @@ class Xtractor(torch.nn.Module):
def xtrain(args):
# Initialize a first model and save to disk
model = Xtractor(args.class_number)
model = Xtractor(args.class_number, args.dropout)
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
......@@ -162,6 +166,9 @@ def xtrain(args):
accuracy = cross_validation(args, current_model_file_name)
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
def train_epoch(epoch, args, initial_model_file_name):
# Compute the megabatch number
......@@ -189,7 +196,7 @@ def train_epoch(epoch, args, initial_model_file_name):
def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number)
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(initial_model_file_name))
model.train()
......@@ -208,7 +215,7 @@ def train_worker(rank, epoch, args, initial_model_file_name, batch_list, output_
{'params': model.seg_lin0.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin1.parameters(), 'weight_decay': args.l2_seg},
{'params': model.seg_lin2.parameters(), 'weight_decay': args.l2_seg}
])
], lr=args.lr)
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
......@@ -263,7 +270,7 @@ def train_asynchronous(epoch, args, initial_model_file_name, batch_file_list, me
for p in processes:
p.join()
av_model = Xtractor(args.class_number)
av_model = Xtractor(args.class_number, args.dropout)
tmp = av_model.state_dict()
average_param = dict()
......@@ -327,7 +334,7 @@ def cross_validation(args, current_model_file_name):
def cv_worker(rank, args, current_model_file_name, batch_list, output_queue):
model = Xtractor(args.class_number)
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(current_model_file_name))
model.eval()
......@@ -361,7 +368,7 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
# Load the model
model_file_name = '/'.join([args.model_path, args.model_name])
model = Xtractor(args.class_number)
model = Xtractor(args.class_number, args.dropout)
model.load_state_dict(torch.load(model_file_name))
model.eval()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment