Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
7aa95240
Commit
7aa95240
authored
Mar 21, 2019
by
Anthony Larcher
Browse files
mutliembeddings
parent
23f92f53
Changes
1
Hide whitespace changes
Inline
Side-by-side
nnet/xvector.py
View file @
7aa95240
...
...
@@ -143,12 +143,17 @@ class Xtractor(torch.nn.Module):
mean
=
torch
.
mean
(
frame_emb_4
,
dim
=
2
)
std
=
torch
.
std
(
frame_emb_4
,
dim
=
2
)
seg_emb
=
torch
.
cat
([
mean
,
std
],
dim
=
1
)
embedding_A
=
self
.
seg_lin0
(
seg_emb
)
embedding_B
=
self
.
seg_lin1
(
self
.
norm6
(
self
.
activation
(
embedding_A
)))
seg_emb_0
=
torch
.
cat
([
mean
,
std
],
dim
=
1
)
# batch-normalisation after this layer
seg_emb_1
=
self
.
seg_lin0
(
seg_emb_0
)
seg_emb_2
=
self
.
activation
(
seg_emb_1
)
seg_emb_3
=
self
.
norm6
(
seg_emb_2
)
seg_emb_4
=
self
.
seg_lin1
(
seg_emb_3
)
seg_emb_5
=
self
.
activation
(
seg_emb_4
)
seg_emb_6
=
self
.
norm7
(
seg_emb_5
)
return
embedding_A
,
embedding_B
return
seg_emb_1
,
seg_emb_2
,
seg_emb_3
,
seg_emb_4
,
seg_emb_5
,
seg_emb_6
class
XtractorHot
(
Xtractor
):
def
__init__
(
self
,
spk_number
,
dropout
):
...
...
@@ -660,8 +665,12 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
emb_b_size
=
model
.
seg_lin1
.
weight
.
data
.
shape
[
0
]
# Create a Tensor to store all x-vectors on the GPU
emb_A
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_a_size
)).
astype
(
numpy
.
float32
)
emb_B
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_b_size
)).
astype
(
numpy
.
float32
)
emb_1
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_a_size
)).
astype
(
numpy
.
float32
)
emb_2
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_b_size
)).
astype
(
numpy
.
float32
)
emb_3
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_b_size
)).
astype
(
numpy
.
float32
)
emb_4
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_b_size
)).
astype
(
numpy
.
float32
)
emb_5
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_b_size
)).
astype
(
numpy
.
float32
)
emb_6
=
numpy
.
zeros
((
idmap
.
leftids
.
shape
[
0
],
emb_b_size
)).
astype
(
numpy
.
float32
)
# Send on selected device
model
.
to
(
device
)
...
...
@@ -673,11 +682,15 @@ def extract_idmap(args, device_ID, segment_indices, fs_params, idmap_name, outpu
if
list
(
data
.
shape
)[
2
]
<
20
:
pass
else
:
A
,
B
=
model
.
extract
(
data
.
to
(
device
))
emb_A
[
idx
,
:]
=
A
.
detach
().
cpu
()
emb_B
[
idx
,
:]
=
B
.
detach
().
cpu
()
seg_1
,
seg_2
,
seg_3
,
seg_4
,
seg_5
,
seg_6
=
model
.
extract
(
data
.
to
(
device
))
emb_1
[
idx
,
:]
=
seg_1
.
detach
().
cpu
()
emb_2
[
idx
,
:]
=
seg_2
.
detach
().
cpu
()
emb_3
[
idx
,
:]
=
seg_3
.
detach
().
cpu
()
emb_4
[
idx
,
:]
=
seg_4
.
detach
().
cpu
()
emb_5
[
idx
,
:]
=
seg_5
.
detach
().
cpu
()
emb_6
[
idx
,
:]
=
seg_6
.
detach
().
cpu
()
output_queue
.
put
((
segment_indices
,
emb_
A
,
emb_
B
))
output_queue
.
put
((
segment_indices
,
emb_
1
,
emb_
2
,
emb_3
,
emb_4
,
emb_5
,
emb_6
))
def
extract_parallel
(
args
,
fs_params
,
dataset
):
...
...
@@ -693,10 +706,19 @@ def extract_parallel(args, fs_params, dataset):
idmap
=
IdMap
(
idmap_name
)
x_server_A
=
StatServer
(
idmap
,
1
,
emb_a_size
)
x_server_B
=
StatServer
(
idmap
,
1
,
emb_b_size
)
x_server_A
.
stat0
=
numpy
.
ones
(
x_server_A
.
stat0
.
shape
)
x_server_B
.
stat0
=
numpy
.
ones
(
x_server_B
.
stat0
.
shape
)
x_server_1
=
StatServer
(
idmap
,
1
,
emb_a_size
)
x_server_2
=
StatServer
(
idmap
,
1
,
emb_b_size
)
x_server_3
=
StatServer
(
idmap
,
1
,
emb_b_size
)
x_server_4
=
StatServer
(
idmap
,
1
,
emb_b_size
)
x_server_5
=
StatServer
(
idmap
,
1
,
emb_b_size
)
x_server_6
=
StatServer
(
idmap
,
1
,
emb_b_size
)
x_server_1
.
stat0
=
numpy
.
ones
(
x_server_1
.
stat0
.
shape
)
x_server_2
.
stat0
=
numpy
.
ones
(
x_server_2
.
stat0
.
shape
)
x_server_3
.
stat0
=
numpy
.
ones
(
x_server_3
.
stat0
.
shape
)
x_server_4
.
stat0
=
numpy
.
ones
(
x_server_4
.
stat0
.
shape
)
x_server_5
.
stat0
=
numpy
.
ones
(
x_server_5
.
stat0
.
shape
)
x_server_6
.
stat0
=
numpy
.
ones
(
x_server_6
.
stat0
.
shape
)
# Split the indices
mega_batch_size
=
idmap
.
leftids
.
shape
[
0
]
//
args
.
num_processes
...
...
@@ -725,16 +747,20 @@ def extract_parallel(args, fs_params, dataset):
# Get the x-vectors and fill the StatServer
for
ii
in
range
(
args
.
num_processes
):
indices
,
A
,
B
=
output_queue
.
get
()
x_server_A
.
stat1
[
indices
,
:]
=
A
x_server_B
.
stat1
[
indices
,
:]
=
B
indices
,
seg_1
,
seg_2
,
seg_3
,
seg_4
,
seg_5
,
seg_6
=
output_queue
.
get
()
x_server_1
.
stat1
[
indices
,
:]
=
seg_1
x_server_2
.
stat1
[
indices
,
:]
=
seg_2
x_server_3
.
stat1
[
indices
,
:]
=
seg_3
x_server_4
.
stat1
[
indices
,
:]
=
seg_4
x_server_5
.
stat1
[
indices
,
:]
=
seg_5
x_server_6
.
stat1
[
indices
,
:]
=
seg_6
for
p
in
processes
:
p
.
join
()
print
(
"Process parallel fini"
)
return
x_server_
A
,
x_server_
B
return
x_server_
1
,
x_server_
2
,
x_server_3
,
x_server_4
,
x_server_5
,
x_server_6
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment