Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
c4684601
Commit
c4684601
authored
Jan 03, 2022
by
Anthony Larcher
Browse files
update
parents
92f95258
ebdf1b53
Changes
6
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
c4684601
*.pyc
*.DS_Store
docs
.vscode
.gitignore
.vscode
.history
nnet/augmentation.py
View file @
c4684601
...
...
@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
if
"none"
in
augmentations
:
pass
if
"stretch"
in
augmentations
:
strech
=
torchaudio
.
functional
.
TimeStretch
()
rate
=
random
.
uniform
(
0.8
,
1.2
)
...
...
@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape
=
speech
.
shape
[
1
]
configs
=
[
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ALAW'
,
"bits_per_sample"
:
8
},
"8 bit a-law"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
...
...
nnet/pooling.py
View file @
c4684601
...
...
@@ -55,14 +55,68 @@ class MeanStdPooling(torch.nn.Module):
def
forward
(
self
,
x
):
"""
:param x:
:param x:
[B, C*F, T]
:return:
"""
if
len
(
x
.
shape
)
==
4
:
# [B, C, F, T]
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
x
=
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
)
# [B, C*F]
mean
=
torch
.
mean
(
x
,
dim
=
2
)
# [B, C*F]
std
=
torch
.
std
(
x
,
dim
=
2
)
# [B, 2*C*F]
return
torch
.
cat
([
mean
,
std
],
dim
=
1
)
class
ChannelWiseCorrPooling
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
256
,
out_channels
=
64
,
in_freqs
=
10
,
channels_dropout
=
0.25
):
super
(
ChannelWiseCorrPooling
,
self
).
__init__
()
self
.
channels_dropout
=
channels_dropout
self
.
merge_freqs_count
=
2
assert
in_freqs
%
self
.
merge_freqs_count
==
0
self
.
groups
=
in_freqs
//
self
.
merge_freqs_count
self
.
out_channels
=
out_channels
self
.
out_dim
=
int
(
self
.
out_channels
*
(
self
.
out_channels
-
1
)
/
2
)
*
self
.
groups
self
.
L_proj
=
torch
.
nn
.
Conv2d
(
in_channels
*
self
.
groups
,
out_channels
*
self
.
groups
,
kernel_size
=
(
1
,
1
),
groups
=
self
.
groups
)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self
.
mask
=
torch
.
tril
(
torch
.
ones
((
out_channels
,
out_channels
)),
diagonal
=-
1
).
type
(
torch
.
BoolTensor
)
def
forward
(
self
,
x
):
"""
:param x: [B, C, T, F]
:return:
"""
batch_size
=
x
.
shape
[
0
]
num_locations
=
x
.
shape
[
-
1
]
*
x
.
shape
[
-
2
]
/
self
.
groups
self
.
mask
=
self
.
mask
.
to
(
x
.
device
)
if
self
.
training
:
x
*=
torch
.
nn
.
functional
.
dropout
(
torch
.
ones
((
1
,
x
.
shape
[
1
],
1
,
1
),
device
=
x
.
device
),
p
=
self
.
channels_dropout
)
#[B, T, C, F]
x
=
x
.
permute
(
0
,
2
,
1
,
3
)
#[B, T, C, Fr, f]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
-
2
],
self
.
groups
,
self
.
merge_freqs_count
)
#[B, T, f, Fr, C]
x
=
x
.
permute
(
0
,
1
,
4
,
3
,
2
)
#[B, T, f, Fr*C]
x
=
x
.
flatten
(
start_dim
=
3
,
end_dim
=
4
)
#[B, Fr*C, T, f]
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
#[B, Fr*C', T, f]
x
=
self
.
L_proj
(
x
)
#[B, Fr, C', Tr]
x
=
x
.
reshape
(
x
.
shape
[
0
],
self
.
groups
,
self
.
out_channels
,
-
1
)
x
-=
torch
.
mean
(
x
,
axis
=-
1
,
keepdims
=
True
)
out
=
x
/
(
torch
.
std
(
x
,
axis
=-
1
,
keepdims
=
True
)
+
1e-5
)
#[B, C', C']
out
=
torch
.
einsum
(
'abci,abdi->abcd'
,
out
,
out
)
#[B, C'*(C'-1)/2]
out
=
torch
.
masked_select
(
out
,
self
.
mask
).
reshape
(
batch_size
,
-
1
)
out
=
out
/
num_locations
return
out
class
AttentivePooling
(
torch
.
nn
.
Module
):
"""
Mean and Standard deviation attentive pooling
...
...
@@ -94,9 +148,14 @@ class AttentivePooling(torch.nn.Module):
def
forward
(
self
,
x
):
"""
:param x:
:param x:
[B, C*F, T]
:return:
"""
if
len
(
x
.
shape
)
==
4
:
# [B, C, F, T]
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
# [B, C*F, T]
x
=
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
)
if
self
.
global_context
:
w
=
self
.
attention
(
torch
.
cat
([
x
,
self
.
gc
(
x
).
unsqueeze
(
2
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])],
dim
=
1
))
else
:
...
...
nnet/res_net.py
View file @
c4684601
...
...
@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module):
:param x:
:return:
"""
out
=
x
.
unsqueeze
(
1
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
out
=
self
.
layer4
(
out
)
out
=
self
.
layer5
(
out
)
out
=
self
.
layer6
(
out
)
out
=
self
.
layer7
(
out
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
return
out
x
=
x
.
unsqueeze
(
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
x
=
x
.
to
(
memory_format
=
torch
.
channels_last
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
self
.
layer5
(
x
)
x
=
self
.
layer6
(
x
)
x
=
self
.
layer7
(
x
)
return
x
class
PreHalfResNet34
(
torch
.
nn
.
Module
):
...
...
@@ -535,16 +536,15 @@ class PreHalfResNet34(torch.nn.Module):
:param x:
:return:
"""
out
=
x
.
unsqueeze
(
1
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
channels_last
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
out
=
self
.
layer4
(
out
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
contiguous_format
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
return
out
x
=
x
.
unsqueeze
(
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
x
=
x
.
to
(
memory_format
=
torch
.
channels_last
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
return
x
class
PreFastResNet34
(
torch
.
nn
.
Module
):
...
...
@@ -557,7 +557,7 @@ class PreFastResNet34(torch.nn.Module):
self
.
speaker_number
=
speaker_number
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
16
,
kernel_size
=
7
,
stride
=
(
2
,
1
),
padding
=
3
,
bias
=
False
)
stride
=
(
1
,
2
),
padding
=
3
,
bias
=
False
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
16
)
# With block = [3, 4, 6, 3]
...
...
@@ -589,16 +589,15 @@ class PreFastResNet34(torch.nn.Module):
:param x:
:return:
"""
out
=
x
.
unsqueeze
(
1
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
channels_last
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
out
=
self
.
layer1
(
out
)
out
=
self
.
layer2
(
out
)
out
=
self
.
layer3
(
out
)
out
=
self
.
layer4
(
out
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
contiguous_format
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
return
out
x
=
x
.
unsqueeze
(
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
x
=
x
.
to
(
memory_format
=
torch
.
channels_last
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
return
x
def
ResNet34
():
...
...
nnet/xsets.py
View file @
c4684601
...
...
@@ -176,7 +176,7 @@ class SideSet(Dataset):
overlap
=
0.
,
dataset_df
=
None
,
min_duration
=
0.165
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
):
"""
...
...
@@ -275,6 +275,8 @@ class SideSet(Dataset):
self
.
transform
[
"codec"
]
=
[]
if
"phone_filtering"
in
transforms
:
self
.
transform
[
"phone_filtering"
]
=
[]
if
"stretch"
in
transforms
:
self
.
transform
[
"stretch"
]
=
[]
self
.
noise_df
=
None
if
"add_noise"
in
self
.
transform
:
...
...
@@ -429,20 +431,29 @@ class IdMapSet(Dataset):
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
if
self
.
idmap
.
stop
[
index
]
is
None
:
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
else
:
duration
=
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
assert
nfo
.
sample_rate
==
self
.
sample_rate
conversion_rate
=
nfo
.
sample_rate
//
self
.
sample_rate
duration
=
(
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
)
# add this in case the segment is too short
if
duration
<=
self
.
min_duration
*
self
.
sample_rate
:
middle
=
start
+
duration
//
2
start
=
int
(
max
(
0
,
int
(
middle
-
(
self
.
min_duration
*
self
.
sample_rate
/
2
))))
duration
=
int
(
self
.
min_duration
*
self
.
sample_rate
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
frame_offset
=
start
,
num_frames
=
duration
)
frame_offset
=
start
*
conversion_rate
,
num_frames
=
duration
*
conversion_rate
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
#
speech += 10e-6 * torch.randn(speech.shape)
if
self
.
sliding_window
:
speech
=
speech
.
squeeze
().
unfold
(
0
,
self
.
window_len
,
self
.
window_shift
)
...
...
@@ -575,7 +586,7 @@ class IdMapSetPerSpeaker(Dataset):
tmp_data
.
append
(
speech
)
speech
=
torch
.
cat
(
tmp_data
,
dim
=
1
)
speech
+=
10e-6
*
torch
.
randn
(
speech
.
shape
)
#
speech += 10e-6 * torch.randn(speech.shape)
if
len
(
self
.
transformation
.
keys
())
>
0
:
speech
=
data_augmentation
(
speech
,
...
...
nnet/xvector.py
View file @
c4684601
...
...
@@ -37,12 +37,12 @@ import shutil
import
torch
import
tqdm
import
yaml
#torch.autograd.set_detect_anomaly(True)
from
collections
import
OrderedDict
from
torch.utils.data
import
DataLoader
from
sklearn.model_selection
import
train_test_split
from
.pooling
import
MeanStdPooling
from
.pooling
import
AttentivePooling
from
.pooling
import
AttentivePooling
,
ChannelWiseCorrPooling
from
.pooling
import
GruPooling
from
.preprocessor
import
MfccFrontEnd
from
.preprocessor
import
MelSpecFrontEnd
...
...
@@ -510,8 +510,8 @@ class Xtractor(torch.nn.Module):
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
int
(
self
.
speaker_number
),
s
=
30
,
m
=
0.2
,
s
=
30
,
m
=
0.2
,
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
...
...
@@ -533,9 +533,6 @@ class Xtractor(torch.nn.Module):
self
.
sequence_network
=
PreHalfResNet34
()
self
.
embedding_size
=
embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self
.
before_speaker_embedding
=
torch
.
nn
.
Sequential
(
OrderedDict
([
(
"lin_be"
,
torch
.
nn
.
Linear
(
in_features
=
5120
,
out_features
=
self
.
embedding_size
,
bias
=
False
)),
...
...
@@ -544,6 +541,35 @@ class Xtractor(torch.nn.Module):
self
.
stat_pooling
=
AttentivePooling
(
256
,
80
,
global_context
=
True
)
self
.
loss
=
loss
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
int
(
self
.
speaker_number
),
s
=
30
,
m
=
0.2
,
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
),
emb_dim
=
self
.
embedding_size
)
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
self
.
before_speaker_embedding_weight_decay
=
0.00002
self
.
after_speaker_embedding_weight_decay
=
0.000
elif
model_archi
==
"experimental"
:
self
.
preprocessor
=
MelSpecFrontEnd
()
self
.
sequence_network
=
PreHalfResNet34
()
self
.
embedding_size
=
embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
int
(
64
*
63
*
5
/
2
),
out_features
=
self
.
embedding_size
)
self
.
stat_pooling
=
ChannelWiseCorrPooling
(
in_channels
=
256
,
out_channels
=
64
)
self
.
loss
=
loss
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
...
...
@@ -811,7 +837,6 @@ class Xtractor(torch.nn.Module):
# Mean and Standard deviation pooling
x
=
self
.
stat_pooling
(
x
)
x
=
self
.
before_speaker_embedding
(
x
)
if
norm_embedding
:
...
...
@@ -1005,6 +1030,8 @@ def update_training_dictionary(dataset_description,
fill_dict
(
model_opts
,
tmp_model_dict
)
fill_dict
(
training_opts
,
tmp_train_dict
)
print
(
model_opts
)
# Overwrite with manually given parameters
if
"lr"
in
kwargs
:
training_opts
[
"lr"
]
=
kwargs
[
'lr'
]
...
...
@@ -1032,7 +1059,7 @@ def get_network(model_opts, local_rank):
:return: the neural network
"""
if
model_opts
[
"model_type"
]
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
,
"fastresnet34"
,
"halfresnet34"
]:
if
model_opts
[
"model_type"
]
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
,
"fastresnet34"
,
"halfresnet34"
,
"experimental"
]:
model
=
Xtractor
(
model_opts
[
"speaker_number"
],
model_opts
[
"model_type"
],
loss
=
model_opts
[
"loss"
][
"type"
],
embedding_size
=
model_opts
[
"embedding_size"
])
else
:
# Custom type of model
...
...
@@ -1042,11 +1069,11 @@ def get_network(model_opts, local_rank):
if
model_opts
[
"initial_model_name"
]
is
not
None
:
if
os
.
path
.
isfile
(
model_opts
[
"initial_model_name"
]):
logging
.
critical
(
f
"*** Load model from =
{
model_opts
[
'initial_model_name'
]
}
"
)
checkpoint
=
torch
.
load
(
model_opts
[
"initial_model_name"
])
checkpoint
=
torch
.
load
(
model_opts
[
"initial_model_name"
]
,
map_location
=
{
"cuda:0"
:
"cuda:%d"
%
local_rank
}
)
"""
Here we remove all layers that we don't want to reload
"""
pretrained_dict
=
checkpoint
[
"model_state_dict"
]
for
part
in
model_opts
[
"reset_parts"
]:
...
...
@@ -1061,24 +1088,9 @@ def get_network(model_opts, local_rank):
if
name
.
split
(
"."
)[
0
]
in
model_opts
[
"reset_parts"
]:
param
.
requires_grad
=
False
if
model_opts
[
"loss"
][
"type"
]
==
"aam"
and
not
(
model_opts
[
"loss"
][
"aam_margin"
]
==
0.2
and
model_opts
[
"loss"
][
"aam_s"
]
==
30
):
model
.
after_speaker_embedding
.
change_params
(
model_opts
[
"loss"
][
"aam_s"
],
model_opts
[
"loss"
][
"aam_margin"
])
print
(
f
"Modified AAM: margin =
{
model
.
after_speaker_embedding
.
m
}
and s =
{
model
.
after_speaker_embedding
.
s
}
"
)
if
local_rank
<
1
:
logging
.
info
(
model
)
logging
.
info
(
"Model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
before_speaker_embedding
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
stat_pooling
.
parameters
()
if
p
.
requires_grad
)))
#if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
# model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
# print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
return
model
...
...
@@ -1105,7 +1117,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_opts
[
"validation_ratio"
],
stratify
=
stratify
)
# TODO
torch
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
cuda
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
...
...
@@ -1115,7 +1128,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
dataset_df
=
training_df
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
)
validation_set
=
SideSet
(
dataset_opts
,
...
...
@@ -1131,20 +1144,20 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if
training_opts
[
"multi_gpu"
]:
assert
dataset_opts
[
"batch_size"
]
%
torch
.
cuda
.
device_count
()
==
0
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
batch_size
=
dataset_opts
[
"batch_size"
]
//
(
torch
.
cuda
.
device_count
()
*
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
])
batch_size
=
dataset_opts
[
"batch_size"
]
//
torch
.
cuda
.
device_count
()
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
samples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
batch_size
=
batch_size
,
batch_size
=
batch_size
*
torch
.
cuda
.
device_count
()
,
seed
=
training_opts
[
'torch_seed'
],
rank
=
local_rank
,
num_process
=
torch
.
cuda
.
device_count
(),
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
)
else
:
batch_size
=
dataset_opts
[
"batch_size"
]
//
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
]
batch_size
=
dataset_opts
[
"batch_size"
]
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
...
...
@@ -1411,8 +1424,18 @@ def xtrain(dataset_description,
# Initialize the model
model
=
get_network
(
model_opts
,
local_rank
)
if
local_rank
<
1
:
if
local_rank
<
1
:
monitor
.
logger
.
info
(
model
)
monitor
.
logger
.
info
(
"Model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
before_speaker_embedding
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
stat_pooling
.
parameters
()
if
p
.
requires_grad
)))
embedding_size
=
model
.
embedding_size
aam_scheduler
=
None
...
...
@@ -1601,7 +1624,7 @@ def train_epoch(model,
loss
+=
criterion
(
output
,
target
)
elif
loss_criteria
==
'aps'
:
output_tuple
,
_
=
model
(
data
,
target
=
target
)
loss
,
output
=
output_tuple
loss
,
no_margin_
output
=
output_tuple
else
:
output
,
_
=
model
(
data
,
target
=
None
)
loss
=
criterion
(
output
,
target
)
...
...
@@ -1635,7 +1658,7 @@ def train_epoch(model,
if
math
.
fmod
(
batch_idx
,
training_opts
[
"log_interval"
])
==
0
:
batch_size
=
target
.
shape
[
0
]
training_monitor
.
update
(
training_loss
=
loss
.
item
(),
training_acc
=
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
))
training_acc
=
100.0
*
accuracy
/
((
batch_idx
+
1
)
*
batch_size
))
training_monitor
.
logger
.
info
(
'Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
training_monitor
.
current_epoch
,
...
...
@@ -1817,9 +1840,10 @@ def extract_embeddings(idmap_name,
if
sliding_window
:
tmp_start
=
numpy
.
arange
(
0
,
data
.
shape
[
0
]
*
win_shift
,
win_shift
)
starts
.
extend
(
tmp_start
*
sample_rate
+
start
.
detach
().
cpu
().
numpy
())
win_duration
=
int
(
len
(
tmp_data
))
else
:
starts
.
append
(
start
.
numpy
())
stops
.
append
(
len
(
tmp_data
)
)
stops
.
append
(
tmp_data
[
0
].
shape
[
1
]
)
embeddings
=
StatServer
()
embeddings
.
stat1
=
numpy
.
concatenate
(
embed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment