Commit 7aa95240 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

mutliembeddings

parent 23f92f53
......@@ -143,12 +143,17 @@ class Xtractor(torch.nn.Module):
mean = torch.mean(frame_emb_4, dim=2)
std = torch.std(frame_emb_4, dim=2)
seg_emb = torch.cat([mean, std], dim=1)
embedding_A = self.seg_lin0(seg_emb)
embedding_B = self.seg_lin1(self.norm6(self.activation(embedding_A)))
seg_emb_0 = torch.cat([mean, std], dim=1)
# batch-normalisation after this layer
seg_emb_1 = self.seg_lin0(seg_emb_0)
seg_emb_2 = self.activation(seg_emb_1)
seg_emb_3 = self.norm6(seg_emb_2)
seg_emb_4 = self.seg_lin1(seg_emb_3)
seg_emb_5 = self.activation(seg_emb_4)
seg_emb_6 = self.norm7(seg_emb_5)
return embedding_A, embedding_B
return seg_emb_1, seg_emb_2, seg_emb_3, seg_emb_4, seg_emb_5, seg_emb_6
class XtractorHot(Xtractor):
def __init__(self, spk_number, dropout):
......@@ -660,8 +665,12 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
emb_b_size = model.seg_lin1.weight.data.shape[0]
# Create a Tensor to store all x-vectors on the GPU
emb_A = numpy.zeros((idmap.leftids.shape[0], emb_a_size)).astype(numpy.float32)
emb_B = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
emb_1 = numpy.zeros((idmap.leftids.shape[0], emb_a_size)).astype(numpy.float32)
emb_2 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
emb_3 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
emb_4 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
emb_5 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
emb_6 = numpy.zeros((idmap.leftids.shape[0], emb_b_size)).astype(numpy.float32)
# Send on selected device
model.to(device)
......@@ -673,11 +682,15 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
if list(data.shape)[2] < 20:
pass
else:
A, B = model.extract(data.to(device))
emb_A[idx, :] = A.detach().cpu()
emb_B[idx, :] = B.detach().cpu()
seg_1, seg_2, seg_3, seg_4, seg_5, seg_6 = model.extract(data.to(device))
emb_1[idx, :] = seg_1.detach().cpu()
emb_2[idx, :] = seg_2.detach().cpu()
emb_3[idx, :] = seg_3.detach().cpu()
emb_4[idx, :] = seg_4.detach().cpu()
emb_5[idx, :] = seg_5.detach().cpu()
emb_6[idx, :] = seg_6.detach().cpu()
output_queue.put((segment_indices, emb_A, emb_B))
output_queue.put((segment_indices, emb_1, emb_2, emb_3, emb_4, emb_5, emb_6))
def extract_parallel(args, fs_params, dataset):
......@@ -693,10 +706,19 @@ def extract_parallel(args, fs_params, dataset):
idmap = IdMap(idmap_name)
x_server_A = StatServer(idmap, 1, emb_a_size)
x_server_B = StatServer(idmap, 1, emb_b_size)
x_server_A.stat0 = numpy.ones(x_server_A.stat0.shape)
x_server_B.stat0 = numpy.ones(x_server_B.stat0.shape)
x_server_1 = StatServer(idmap, 1, emb_a_size)
x_server_2 = StatServer(idmap, 1, emb_b_size)
x_server_3 = StatServer(idmap, 1, emb_b_size)
x_server_4 = StatServer(idmap, 1, emb_b_size)
x_server_5 = StatServer(idmap, 1, emb_b_size)
x_server_6 = StatServer(idmap, 1, emb_b_size)
x_server_1.stat0 = numpy.ones(x_server_1.stat0.shape)
x_server_2.stat0 = numpy.ones(x_server_2.stat0.shape)
x_server_3.stat0 = numpy.ones(x_server_3.stat0.shape)
x_server_4.stat0 = numpy.ones(x_server_4.stat0.shape)
x_server_5.stat0 = numpy.ones(x_server_5.stat0.shape)
x_server_6.stat0 = numpy.ones(x_server_6.stat0.shape)
# Split the indices
mega_batch_size = idmap.leftids.shape[0] // args.num_processes
......@@ -725,16 +747,20 @@ def extract_parallel(args, fs_params, dataset):
# Get the x-vectors and fill the StatServer
for ii in range(args.num_processes):
indices, A, B = output_queue.get()
x_server_A.stat1[indices, :] = A
x_server_B.stat1[indices, :] = B
indices, seg_1, seg_2, seg_3, seg_4, seg_5, seg_6 = output_queue.get()
x_server_1.stat1[indices, :] = seg_1
x_server_2.stat1[indices, :] = seg_2
x_server_3.stat1[indices, :] = seg_3
x_server_4.stat1[indices, :] = seg_4
x_server_5.stat1[indices, :] = seg_5
x_server_6.stat1[indices, :] = seg_6
for p in processes:
p.join()
print("Process parallel fini")
return x_server_A, x_server_B
return x_server_1, x_server_2, x_server_3, x_server_4, x_server_5, x_server_6
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment