Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
6d6f89a9
Commit
6d6f89a9
authored
Apr 05, 2019
by
Anthony Larcher
Browse files
Clean 1-hot version
parent
ee06b1e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
nnet/xvector.py
View file @
6d6f89a9
...
...
@@ -109,10 +109,6 @@ class Xtractor(torch.nn.Module):
result
=
self
.
activation
(
seg_emb_5
)
return
result
def
LossFN
(
self
,
x
,
label
):
loss
=
-
torch
.
trace
(
torch
.
mm
(
torch
.
log10
(
x
),
torch
.
t
(
label
)))
return
loss
def
init_weights
(
self
):
"""
"""
...
...
@@ -155,30 +151,6 @@ class Xtractor(torch.nn.Module):
return
seg_emb_1
,
seg_emb_2
,
seg_emb_3
,
seg_emb_4
,
seg_emb_5
,
seg_emb_6
class
XtractorHot
(
Xtractor
):
def
__init__
(
self
,
spk_number
,
dropout
):
super
(
Xtractor
,
self
).
__init__
()
self
.
frame_conv0
=
torch
.
nn
.
Conv1d
(
20
,
512
,
5
,
dilation
=
1
)
self
.
frame_conv1
=
torch
.
nn
.
Conv1d
(
512
,
512
,
3
,
dilation
=
2
)
self
.
frame_conv2
=
torch
.
nn
.
Conv1d
(
512
,
512
,
3
,
dilation
=
3
)
self
.
frame_conv3
=
torch
.
nn
.
Conv1d
(
512
,
512
,
1
)
self
.
frame_conv4
=
torch
.
nn
.
Conv1d
(
512
,
3
*
512
,
1
)
self
.
seg_lin0
=
torch
.
nn
.
Linear
(
3
*
512
*
2
,
512
)
self
.
dropout_lin0
=
torch
.
nn
.
Dropout
(
p
=
dropout
)
self
.
seg_lin1
=
torch
.
nn
.
Linear
(
512
,
512
)
self
.
dropout_lin1
=
torch
.
nn
.
Dropout
(
p
=
dropout
)
self
.
seg_lin2
=
torch
.
nn
.
Linear
(
512
,
spk_number
)
#
self
.
norm0
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm1
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm2
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm3
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm4
=
torch
.
nn
.
BatchNorm1d
(
3
*
512
)
self
.
norm6
=
torch
.
nn
.
BatchNorm1d
(
512
)
self
.
norm7
=
torch
.
nn
.
BatchNorm1d
(
512
)
#
self
.
activation
=
torch
.
nn
.
LeakyReLU
(
0.2
)
def
forward
(
self
,
x
):
frame_emb_0
=
self
.
norm0
(
self
.
activation
(
self
.
frame_conv0
(
x
)))
frame_emb_1
=
self
.
norm1
(
self
.
activation
(
self
.
frame_conv1
(
frame_emb_0
)))
...
...
@@ -557,8 +529,6 @@ def extract_parallel(args, fs_params):
for
p
in
processes
:
p
.
join
()
print
(
"Process parallel fini"
)
return
x_server_1
,
x_server_2
,
x_server_3
,
x_server_4
,
x_server_5
,
x_server_6
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment