Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
fa238b80
Commit
fa238b80
authored
Nov 13, 2021
by
Le Lan Gaël
Browse files
fixed correlation pooling
parent
9d1a30bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
fa238b80
*.pyc
*.DS_Store
docs
.vscode/settings.json
.gitignore
nnet/pooling.py
View file @
fa238b80
...
...
@@ -87,16 +87,20 @@ class ChannelWiseCorrPooling(torch.nn.Module):
self
.
mask
=
self
.
mask
.
to
(
x
.
device
)
if
self
.
training
:
x
*=
torch
.
nn
.
functional
.
dropout
(
torch
.
ones
((
1
,
x
.
shape
[
1
],
1
,
1
),
device
=
x
.
device
),
p
=
self
.
channels_dropout
)
#[B, C, Fr, T, m]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
groups
,
x
.
shape
[
-
2
],
self
.
merge_freqs_count
)
#[B, Fr, C, T, m]
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
)
#[B, Fr*C, T, m]
x
=
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
)
#[B, T, C, F]
x
=
x
.
permute
(
0
,
2
,
1
,
3
)
#[B, T, C, Fr, f]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
-
2
],
self
.
groups
,
self
.
merge_freqs_count
)
#[B, T, f, Fr, C]
x
=
x
.
permute
(
0
,
1
,
4
,
3
,
2
)
#[B, T, f, Fr*C]
x
=
x
.
flatten
(
start_dim
=
3
,
end_dim
=
4
)
#[B, Fr*C, T, f]
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
#[B, Fr*C', T, f]
x
=
self
.
L_proj
(
x
)
#[B, Fr*C', T, m]
x
=
x
.
reshape
(
x
.
shape
[
0
],
self
.
groups
,
self
.
out_channels
,
-
1
)
#[B, Fr, C', Tr]
x
=
x
.
reshape
(
x
.
shape
[
0
],
self
.
groups
,
self
.
out_channels
,
-
1
)
x
-=
torch
.
mean
(
x
,
axis
=-
1
,
keepdims
=
True
)
out
=
x
/
(
torch
.
std
(
x
,
axis
=-
1
,
keepdims
=
True
)
+
1e-5
)
#[B, C', C']
...
...
nnet/xvector.py
View file @
fa238b80
...
...
@@ -545,10 +545,10 @@ class Xtractor(torch.nn.Module):
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
int
(
48
*
47
*
5
/
2
),
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
int
(
64
*
63
*
5
/
2
),
out_features
=
self
.
embedding_size
)
self
.
stat_pooling
=
ChannelWiseCorrPooling
(
in_channels
=
256
,
out_channels
=
4
8
)
self
.
stat_pooling
=
ChannelWiseCorrPooling
(
in_channels
=
256
,
out_channels
=
6
4
)
self
.
loss
=
loss
if
self
.
loss
==
"aam"
:
...
...
@@ -1134,7 +1134,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
)
else
:
batch_size
=
dataset_opts
[
"batch_size"
]
//
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
]
batch_size
=
dataset_opts
[
"batch_size"
]
#
// dataset_opts["train"]["sampler"]["examples_per_speaker"]
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
...
...
@@ -1594,7 +1594,7 @@ def train_epoch(model,
loss
+=
criterion
(
output
,
target
)
elif
loss_criteria
==
'aps'
:
output_tuple
,
_
=
model
(
data
,
target
=
target
)
loss
,
output
=
output_tuple
loss
,
no_margin_
output
=
output_tuple
else
:
output
,
_
=
model
(
data
,
target
=
None
)
loss
=
criterion
(
output
,
target
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment