Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
01af21ba
Commit
01af21ba
authored
Apr 15, 2021
by
Anthony Larcher
Browse files
cleaning
parent
def54548
Changes
1
Hide whitespace changes
Inline
Side-by-side
nnet/xvector.py
View file @
01af21ba
...
@@ -52,8 +52,9 @@ from .xsets import IdMapSetPerSpeaker
...
@@ -52,8 +52,9 @@ from .xsets import IdMapSetPerSpeaker
from
.xsets
import
SideSampler
from
.xsets
import
SideSampler
from
.res_net
import
ResBlockWFMS
from
.res_net
import
ResBlockWFMS
from
.res_net
import
ResBlock
from
.res_net
import
ResBlock
from
.res_net
import
PreResNet34
from
.res_net
import
PreFastResNet34
from
.res_net
import
PreFastResNet34
from
.res_net
import
PreHalfResNet34
from
.res_net
import
PreResNet34
from
..bosaris
import
IdMap
from
..bosaris
import
IdMap
from
..bosaris
import
Key
from
..bosaris
import
Key
from
..bosaris
import
Ndx
from
..bosaris
import
Ndx
...
@@ -417,20 +418,20 @@ class Xtractor(torch.nn.Module):
...
@@ -417,20 +418,20 @@ class Xtractor(torch.nn.Module):
(
"batch_norm5"
,
torch
.
nn
.
BatchNorm1d
(
1536
))
(
"batch_norm5"
,
torch
.
nn
.
BatchNorm1d
(
1536
))
]))
]))
self
.
embedding_size
=
512
self
.
stat_pooling
=
MeanStdPooling
()
self
.
stat_pooling
=
MeanStdPooling
()
self
.
stat_pooling_weight_decay
=
0
self
.
stat_pooling_weight_decay
=
0
self
.
before_speaker_embedding
=
torch
.
nn
.
Sequential
(
OrderedDict
([
self
.
before_speaker_embedding
=
torch
.
nn
.
Sequential
(
OrderedDict
([
(
"linear6"
,
torch
.
nn
.
Linear
(
3072
,
512
))
(
"linear6"
,
torch
.
nn
.
Linear
(
3072
,
self
.
embedding_size
))
]))
]))
self
.
embedding_size
=
512
if
self
.
loss
==
"aam"
:
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
512
,
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
int
(
self
.
speaker_number
),
int
(
self
.
speaker_number
),
s
=
64
,
s
=
64
,
m
=
0.2
,
m
=
0.2
,
easy_margin
=
Tru
e
)
easy_margin
=
Fals
e
)
elif
self
.
loss
==
"cce"
:
elif
self
.
loss
==
"cce"
:
self
.
after_speaker_embedding
=
torch
.
nn
.
Sequential
(
OrderedDict
([
self
.
after_speaker_embedding
=
torch
.
nn
.
Sequential
(
OrderedDict
([
(
"activation6"
,
torch
.
nn
.
LeakyReLU
(
0.2
)),
(
"activation6"
,
torch
.
nn
.
LeakyReLU
(
0.2
)),
...
@@ -446,7 +447,6 @@ class Xtractor(torch.nn.Module):
...
@@ -446,7 +447,6 @@ class Xtractor(torch.nn.Module):
self
.
sequence_network_weight_decay
=
0.0002
self
.
sequence_network_weight_decay
=
0.0002
self
.
before_speaker_embedding_weight_decay
=
0.002
self
.
before_speaker_embedding_weight_decay
=
0.002
self
.
after_speaker_embedding_weight_decay
=
0.002
self
.
after_speaker_embedding_weight_decay
=
0.002
self
.
embedding_size
=
512
elif
model_archi
==
"resnet34"
:
elif
model_archi
==
"resnet34"
:
...
@@ -503,6 +503,32 @@ class Xtractor(torch.nn.Module):
...
@@ -503,6 +503,32 @@ class Xtractor(torch.nn.Module):
self
.
before_speaker_embedding_weight_decay
=
0.00002
self
.
before_speaker_embedding_weight_decay
=
0.00002
self
.
after_speaker_embedding_weight_decay
=
0.0002
self
.
after_speaker_embedding_weight_decay
=
0.0002
elif
model_archi
==
"halfresnet34"
:
self
.
preprocessor
=
MelSpecFrontEnd
(
n_fft
=
1024
,
win_length
=
1024
,
hop_length
=
256
,
n_mels
=
80
)
self
.
sequence_network
=
PreHalfResNet34
()
self
.
embedding_size
=
256
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
5120
,
out_features
=
self
.
embedding_size
)
self
.
stat_pooling
=
AttentivePooling
(
256
,
80
,
global_context
=
True
)
self
.
loss
=
loss
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
int
(
self
.
speaker_number
),
s
=
30
,
m
=
0.2
,
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
))
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
self
.
before_speaker_embedding_weight_decay
=
0.00002
self
.
after_speaker_embedding_weight_decay
=
0.0002
elif
model_archi
==
"rawnet2"
:
elif
model_archi
==
"rawnet2"
:
if
loss
not
in
[
"cce"
,
'aam'
]:
if
loss
not
in
[
"cce"
,
'aam'
]:
...
@@ -850,6 +876,7 @@ def update_training_dictionary(dataset_description,
...
@@ -850,6 +876,7 @@ def update_training_dictionary(dataset_description,
# Initialize default dictionaries
# Initialize default dictionaries
dataset_opts
[
"data_path"
]
=
None
dataset_opts
[
"data_path"
]
=
None
dataset_opts
[
"dataset_csv"
]
=
None
dataset_opts
[
"dataset_csv"
]
=
None
dataset_opts
[
"stratify"
]
=
False
dataset_opts
[
"data_file_extension"
]
=
".wav"
dataset_opts
[
"data_file_extension"
]
=
".wav"
dataset_opts
[
"sample_rate"
]
=
16000
dataset_opts
[
"sample_rate"
]
=
16000
...
@@ -1029,9 +1056,12 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1029,9 +1056,12 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
"""
"""
df
=
pandas
.
read_csv
(
dataset_opts
[
"dataset_csv"
])
df
=
pandas
.
read_csv
(
dataset_opts
[
"dataset_csv"
])
stratify
=
None
if
dataset_opts
[
"stratify"
]:
stratify
=
df
[
"speaker_idx"
]
training_df
,
validation_df
=
train_test_split
(
df
,
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_opts
[
"validation_ratio"
],
test_size
=
dataset_opts
[
"validation_ratio"
],
stratify
=
df
[
"speaker_idx"
]
)
stratify
=
stratify
)
training_set
=
SideSet
(
dataset_opts
,
training_set
=
SideSet
(
dataset_opts
,
set_type
=
"train"
,
set_type
=
"train"
,
...
@@ -1055,7 +1085,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1055,7 +1085,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if
training_opts
[
"multi_gpu"
]:
if
training_opts
[
"multi_gpu"
]:
assert
dataset_opts
[
"batch_size"
]
%
torch
.
cuda
.
device_count
()
==
0
assert
dataset_opts
[
"batch_size"
]
%
torch
.
cuda
.
device_count
()
==
0
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
batch_size
=
dataset_opts
[
"batch_size"
]
//
torch
.
cuda
.
device_count
()
batch_size
=
dataset_opts
[
"batch_size"
]
//
(
torch
.
cuda
.
device_count
()
*
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
])
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
spk_count
=
model_opts
[
"speaker_number"
],
...
@@ -1068,7 +1098,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1068,7 +1098,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
)
)
else
:
else
:
batch_size
=
dataset_opts
[
"batch_size"
]
batch_size
=
dataset_opts
[
"batch_size"
]
//
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
]
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
...
@@ -1115,12 +1145,11 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
...
@@ -1115,12 +1145,11 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
"""
"""
:param model:
:param model:
:param model_yaml:
:param model_opts:
:param train_opts:
:param training_loader:
:return:
:return:
"""
"""
"""
Set the training options
"""
if
train_opts
[
"optimizer"
][
"type"
]
==
'adam'
:
if
train_opts
[
"optimizer"
][
"type"
]
==
'adam'
:
_optimizer
=
torch
.
optim
.
Adam
_optimizer
=
torch
.
optim
.
Adam
_options
=
{
'lr'
:
train_opts
[
"lr"
]}
_options
=
{
'lr'
:
train_opts
[
"lr"
]}
...
@@ -1161,11 +1190,15 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
...
@@ -1161,11 +1190,15 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
optimizer
=
_optimizer
(
param_list
,
**
_options
)
optimizer
=
_optimizer
(
param_list
,
**
_options
)
if
train_opts
[
"scheduler"
][
"type"
]
==
'CyclicLR'
:
if
train_opts
[
"scheduler"
][
"type"
]
==
'CyclicLR'
:
cycle_momentum
=
True
if
train_opts
[
"optimizer"
][
"type"
]
==
"aam"
:
cycle_momentum
=
False
scheduler
=
torch
.
optim
.
lr_scheduler
.
CyclicLR
(
optimizer
=
optimizer
,
scheduler
=
torch
.
optim
.
lr_scheduler
.
CyclicLR
(
optimizer
=
optimizer
,
base_lr
=
1e-8
,
base_lr
=
1e-8
,
max_lr
=
train_opts
[
"lr"
],
max_lr
=
train_opts
[
"lr"
],
step_size_up
=
model_opts
[
"speaker_number"
]
*
2
,
step_size_up
=
model_opts
[
"speaker_number"
]
*
2
,
step_size_down
=
None
,
step_size_down
=
None
,
cycle_momentum
=
cycle_momentum
,
mode
=
"triangular2"
)
mode
=
"triangular2"
)
elif
train_opts
[
"scheduler"
][
"type"
]
==
"MultiStepLR"
:
elif
train_opts
[
"scheduler"
][
"type"
]
==
"MultiStepLR"
:
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
=
optimizer
,
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
=
optimizer
,
...
@@ -1541,9 +1574,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
...
@@ -1541,9 +1574,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
negatives
=
scores
[
non_indices
]
negatives
=
scores
[
non_indices
]
positives
=
scores
[
tar_indices
]
positives
=
scores
[
tar_indices
]
# Faster EER computation available here : https://github.com/gl3lan/fast_eer
#equal_error_rate = eer(negatives, positives)
pmiss
,
pfa
=
rocch
(
positives
,
negatives
)
pmiss
,
pfa
=
rocch
(
positives
,
negatives
)
equal_error_rate
=
rocch2eer
(
pmiss
,
pfa
)
equal_error_rate
=
rocch2eer
(
pmiss
,
pfa
)
...
@@ -1595,15 +1625,8 @@ def extract_embeddings(idmap_name,
...
@@ -1595,15 +1625,8 @@ def extract_embeddings(idmap_name,
else
:
else
:
idmap
=
IdMap
(
idmap_name
)
idmap
=
IdMap
(
idmap_name
)
#if type(model) is Xtractor:
# min_duration = (model.context_size() - 1) * win_shift + win_duration
# model_cs = model.context_size()
#else:
# min_duration = (model.module.context_size() - 1) * win_shift + win_duration
# model_cs = model.module.context_size()
# Create dataset to load the data
# Create dataset to load the data
dataset
=
IdMapSet
(
idmap_name
=
idmap
_name
,
dataset
=
IdMapSet
(
idmap_name
=
idmap
,
data_path
=
data_root_name
,
data_path
=
data_root_name
,
file_extension
=
file_extension
,
file_extension
=
file_extension
,
transform_pipeline
=
transform_pipeline
,
transform_pipeline
=
transform_pipeline
,
...
@@ -1615,7 +1638,6 @@ def extract_embeddings(idmap_name,
...
@@ -1615,7 +1638,6 @@ def extract_embeddings(idmap_name,
min_duration
=
win_duration
min_duration
=
win_duration
)
)
dataloader
=
DataLoader
(
dataset
,
dataloader
=
DataLoader
(
dataset
,
batch_size
=
1
,
batch_size
=
1
,
shuffle
=
False
,
shuffle
=
False
,
...
@@ -1624,33 +1646,9 @@ def extract_embeddings(idmap_name,
...
@@ -1624,33 +1646,9 @@ def extract_embeddings(idmap_name,
num_workers
=
num_thread
)
num_workers
=
num_thread
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
eval
()
model
.
eval
()
model
.
to
(
device
)
model
.
to
(
device
)
# Get the size of embeddings to extract
if
type
(
model
)
is
Xtractor
:
name
=
list
(
model
.
before_speaker_embedding
.
state_dict
().
keys
())[
-
1
].
split
(
'.'
)[
0
]
if
name
!=
'bias'
:
name
=
name
+
'.weight'
emb_size
=
model
.
before_speaker_embedding
.
state_dict
()[
name
].
shape
[
0
]
else
:
name
=
list
(
model
.
module
.
before_speaker_embedding
.
state_dict
().
keys
())[
-
1
].
split
(
'.'
)[
0
]
if
name
!=
'bias'
:
name
=
name
+
'.weight'
emb_size
=
model
.
module
.
before_speaker_embedding
.
state_dict
()[
name
].
shape
[
0
]
# Create the StatServer
#embeddings = StatServer()
#embeddings.modelset = idmap.leftids
#embeddings.segset = idmap.rightids
#embeddings.start = idmap.start
#embeddings.stop = idmap.stop
#embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
#embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
embed
=
[]
embed
=
[]
modelset
=
[]
modelset
=
[]
segset
=
[]
segset
=
[]
...
@@ -1662,12 +1660,9 @@ def extract_embeddings(idmap_name,
...
@@ -1662,12 +1660,9 @@ def extract_embeddings(idmap_name,
desc
=
'xvector extraction'
,
desc
=
'xvector extraction'
,
mininterval
=
1
,
mininterval
=
1
,
disable
=
None
)):
disable
=
None
)):
#if data.shape[1] > 20000000:
# data = data[...,:20000000]
print
(
f
"data.shape =
{
data
.
shape
}
"
)
if
data
.
dim
()
>
2
:
if
data
.
dim
()
>
2
:
data
=
data
.
squeeze
()
data
=
data
.
squeeze
()
print
(
f
"data.shape =
{
data
.
shape
}
"
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
mixed_precision
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
mixed_precision
):
tmp_data
=
torch
.
split
(
data
,
data
.
shape
[
0
]
//
(
max
(
1
,
data
.
shape
[
0
]
//
100
)))
tmp_data
=
torch
.
split
(
data
,
data
.
shape
[
0
]
//
(
max
(
1
,
data
.
shape
[
0
]
//
100
)))
...
@@ -1675,9 +1670,6 @@ def extract_embeddings(idmap_name,
...
@@ -1675,9 +1670,6 @@ def extract_embeddings(idmap_name,
_
,
vec
=
model
(
x
=
td
.
to
(
device
),
is_eval
=
True
)
_
,
vec
=
model
(
x
=
td
.
to
(
device
),
is_eval
=
True
)
embed
.
append
(
vec
.
detach
().
cpu
())
embed
.
append
(
vec
.
detach
().
cpu
())
#modelset.extend([mod,] * data.shape[0])
modelset
.
extend
(
mod
*
data
.
shape
[
0
])
modelset
.
extend
(
mod
*
data
.
shape
[
0
])
segset
.
extend
(
seg
*
data
.
shape
[
0
])
segset
.
extend
(
seg
*
data
.
shape
[
0
])
starts
.
extend
(
numpy
.
arange
(
start
,
start
+
vec
.
shape
[
0
]
*
win_shift
,
win_shift
))
starts
.
extend
(
numpy
.
arange
(
start
,
start
+
vec
.
shape
[
0
]
*
win_shift
,
win_shift
))
...
@@ -1717,8 +1709,6 @@ def extract_embeddings_per_speaker(idmap_name,
...
@@ -1717,8 +1709,6 @@ def extract_embeddings_per_speaker(idmap_name,
model
=
model
.
to
(
memory_format
=
torch
.
channels_last
)
model
=
model
.
to
(
memory_format
=
torch
.
channels_last
)
min_duration
=
(
model
.
context_size
()
-
1
)
*
frame_shift
+
frame_duration
# Create dataset to load the data
# Create dataset to load the data
dataset
=
IdMapSetPerSpeaker
(
idmap_name
=
idmap_name
,
dataset
=
IdMapSetPerSpeaker
(
idmap_name
=
idmap_name
,
data_root_path
=
data_root_name
,
data_root_path
=
data_root_name
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment