Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Anthony Larcher
sidekit
Commits
ebdf1b53
Commit
ebdf1b53
authored
Dec 13, 2021
by
Anthony Larcher
Browse files
Merge branch 'master' of
https://git-lium.univ-lemans.fr/Larcher/sidekit
parents
8e8fb525
9b6795f6
Changes
6
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
ebdf1b53
*.pyc
*.pyc
*.DS_Store
*.DS_Store
docs
docs
.vscode
.gitignore
.vscode
.history
nnet/augmentation.py
View file @
ebdf1b53
...
@@ -173,6 +173,9 @@ def data_augmentation(speech,
...
@@ -173,6 +173,9 @@ def data_augmentation(speech,
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
aug_idx
=
random
.
sample
(
range
(
len
(
transform_dict
.
keys
())),
k
=
transform_number
)
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
augmentations
=
numpy
.
array
(
list
(
transform_dict
.
keys
()))[
aug_idx
]
if
"none"
in
augmentations
:
pass
if
"stretch"
in
augmentations
:
if
"stretch"
in
augmentations
:
strech
=
torchaudio
.
functional
.
TimeStretch
()
strech
=
torchaudio
.
functional
.
TimeStretch
()
rate
=
random
.
uniform
(
0.8
,
1.2
)
rate
=
random
.
uniform
(
0.8
,
1.2
)
...
@@ -261,6 +264,7 @@ def data_augmentation(speech,
...
@@ -261,6 +264,7 @@ def data_augmentation(speech,
final_shape
=
speech
.
shape
[
1
]
final_shape
=
speech
.
shape
[
1
]
configs
=
[
configs
=
[
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ULAW'
,
"bits_per_sample"
:
8
},
"8 bit mu-law"
),
({
"format"
:
"wav"
,
"encoding"
:
'ALAW'
,
"bits_per_sample"
:
8
},
"8 bit a-law"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"gsm"
},
"GSM-FR"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"mp3"
,
"compression"
:
-
9
},
"MP3"
),
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
({
"format"
:
"vorbis"
,
"compression"
:
-
1
},
"Vorbis"
)
...
...
nnet/pooling.py
View file @
ebdf1b53
...
@@ -55,14 +55,68 @@ class MeanStdPooling(torch.nn.Module):
...
@@ -55,14 +55,68 @@ class MeanStdPooling(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
:param x:
:param x:
[B, C*F, T]
:return:
:return:
"""
"""
if
len
(
x
.
shape
)
==
4
:
# [B, C, F, T]
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
x
=
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
)
# [B, C*F]
mean
=
torch
.
mean
(
x
,
dim
=
2
)
mean
=
torch
.
mean
(
x
,
dim
=
2
)
# [B, C*F]
std
=
torch
.
std
(
x
,
dim
=
2
)
std
=
torch
.
std
(
x
,
dim
=
2
)
# [B, 2*C*F]
return
torch
.
cat
([
mean
,
std
],
dim
=
1
)
return
torch
.
cat
([
mean
,
std
],
dim
=
1
)
class
ChannelWiseCorrPooling
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
=
256
,
out_channels
=
64
,
in_freqs
=
10
,
channels_dropout
=
0.25
):
super
(
ChannelWiseCorrPooling
,
self
).
__init__
()
self
.
channels_dropout
=
channels_dropout
self
.
merge_freqs_count
=
2
assert
in_freqs
%
self
.
merge_freqs_count
==
0
self
.
groups
=
in_freqs
//
self
.
merge_freqs_count
self
.
out_channels
=
out_channels
self
.
out_dim
=
int
(
self
.
out_channels
*
(
self
.
out_channels
-
1
)
/
2
)
*
self
.
groups
self
.
L_proj
=
torch
.
nn
.
Conv2d
(
in_channels
*
self
.
groups
,
out_channels
*
self
.
groups
,
kernel_size
=
(
1
,
1
),
groups
=
self
.
groups
)
#self.L_proj = torch.nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self
.
mask
=
torch
.
tril
(
torch
.
ones
((
out_channels
,
out_channels
)),
diagonal
=-
1
).
type
(
torch
.
BoolTensor
)
def
forward
(
self
,
x
):
"""
:param x: [B, C, T, F]
:return:
"""
batch_size
=
x
.
shape
[
0
]
num_locations
=
x
.
shape
[
-
1
]
*
x
.
shape
[
-
2
]
/
self
.
groups
self
.
mask
=
self
.
mask
.
to
(
x
.
device
)
if
self
.
training
:
x
*=
torch
.
nn
.
functional
.
dropout
(
torch
.
ones
((
1
,
x
.
shape
[
1
],
1
,
1
),
device
=
x
.
device
),
p
=
self
.
channels_dropout
)
#[B, T, C, F]
x
=
x
.
permute
(
0
,
2
,
1
,
3
)
#[B, T, C, Fr, f]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
-
2
],
self
.
groups
,
self
.
merge_freqs_count
)
#[B, T, f, Fr, C]
x
=
x
.
permute
(
0
,
1
,
4
,
3
,
2
)
#[B, T, f, Fr*C]
x
=
x
.
flatten
(
start_dim
=
3
,
end_dim
=
4
)
#[B, Fr*C, T, f]
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
#[B, Fr*C', T, f]
x
=
self
.
L_proj
(
x
)
#[B, Fr, C', Tr]
x
=
x
.
reshape
(
x
.
shape
[
0
],
self
.
groups
,
self
.
out_channels
,
-
1
)
x
-=
torch
.
mean
(
x
,
axis
=-
1
,
keepdims
=
True
)
out
=
x
/
(
torch
.
std
(
x
,
axis
=-
1
,
keepdims
=
True
)
+
1e-5
)
#[B, C', C']
out
=
torch
.
einsum
(
'abci,abdi->abcd'
,
out
,
out
)
#[B, C'*(C'-1)/2]
out
=
torch
.
masked_select
(
out
,
self
.
mask
).
reshape
(
batch_size
,
-
1
)
out
=
out
/
num_locations
return
out
class
AttentivePooling
(
torch
.
nn
.
Module
):
class
AttentivePooling
(
torch
.
nn
.
Module
):
"""
"""
Mean and Standard deviation attentive pooling
Mean and Standard deviation attentive pooling
...
@@ -94,9 +148,14 @@ class AttentivePooling(torch.nn.Module):
...
@@ -94,9 +148,14 @@ class AttentivePooling(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
:param x:
:param x:
[B, C*F, T]
:return:
:return:
"""
"""
if
len
(
x
.
shape
)
==
4
:
# [B, C, F, T]
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
# [B, C*F, T]
x
=
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
)
if
self
.
global_context
:
if
self
.
global_context
:
w
=
self
.
attention
(
torch
.
cat
([
x
,
self
.
gc
(
x
).
unsqueeze
(
2
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])],
dim
=
1
))
w
=
self
.
attention
(
torch
.
cat
([
x
,
self
.
gc
(
x
).
unsqueeze
(
2
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])],
dim
=
1
))
else
:
else
:
...
...
nnet/res_net.py
View file @
ebdf1b53
...
@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module):
...
@@ -480,17 +480,18 @@ class PreResNet34(torch.nn.Module):
:param x:
:param x:
:return:
:return:
"""
"""
out
=
x
.
unsqueeze
(
1
)
x
=
x
.
unsqueeze
(
1
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
out
=
self
.
layer1
(
out
)
x
=
x
.
to
(
memory_format
=
torch
.
channels_last
)
out
=
self
.
layer2
(
out
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
layer3
(
out
)
x
=
self
.
layer1
(
x
)
out
=
self
.
layer4
(
out
)
x
=
self
.
layer2
(
x
)
out
=
self
.
layer5
(
out
)
x
=
self
.
layer3
(
x
)
out
=
self
.
layer6
(
out
)
x
=
self
.
layer4
(
x
)
out
=
self
.
layer7
(
out
)
x
=
self
.
layer5
(
x
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
x
=
self
.
layer6
(
x
)
return
out
x
=
self
.
layer7
(
x
)
return
x
class
PreHalfResNet34
(
torch
.
nn
.
Module
):
class
PreHalfResNet34
(
torch
.
nn
.
Module
):
...
@@ -535,16 +536,15 @@ class PreHalfResNet34(torch.nn.Module):
...
@@ -535,16 +536,15 @@ class PreHalfResNet34(torch.nn.Module):
:param x:
:param x:
:return:
:return:
"""
"""
out
=
x
.
unsqueeze
(
1
)
x
=
x
.
unsqueeze
(
1
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
channels_last
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
x
=
x
.
to
(
memory_format
=
torch
.
channels_last
)
out
=
self
.
layer1
(
out
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
layer2
(
out
)
x
=
self
.
layer1
(
x
)
out
=
self
.
layer3
(
out
)
x
=
self
.
layer2
(
x
)
out
=
self
.
layer4
(
out
)
x
=
self
.
layer3
(
x
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
contiguous_format
)
x
=
self
.
layer4
(
x
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
return
x
return
out
class
PreFastResNet34
(
torch
.
nn
.
Module
):
class
PreFastResNet34
(
torch
.
nn
.
Module
):
...
@@ -557,7 +557,7 @@ class PreFastResNet34(torch.nn.Module):
...
@@ -557,7 +557,7 @@ class PreFastResNet34(torch.nn.Module):
self
.
speaker_number
=
speaker_number
self
.
speaker_number
=
speaker_number
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
16
,
kernel_size
=
7
,
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
16
,
kernel_size
=
7
,
stride
=
(
2
,
1
),
padding
=
3
,
bias
=
False
)
stride
=
(
1
,
2
),
padding
=
3
,
bias
=
False
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
16
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
16
)
# With block = [3, 4, 6, 3]
# With block = [3, 4, 6, 3]
...
@@ -589,16 +589,15 @@ class PreFastResNet34(torch.nn.Module):
...
@@ -589,16 +589,15 @@ class PreFastResNet34(torch.nn.Module):
:param x:
:param x:
:return:
:return:
"""
"""
out
=
x
.
unsqueeze
(
1
)
x
=
x
.
unsqueeze
(
1
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
channels_last
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
)
out
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
out
)))
x
=
x
.
to
(
memory_format
=
torch
.
channels_last
)
out
=
self
.
layer1
(
out
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
layer2
(
out
)
x
=
self
.
layer1
(
x
)
out
=
self
.
layer3
(
out
)
x
=
self
.
layer2
(
x
)
out
=
self
.
layer4
(
out
)
x
=
self
.
layer3
(
x
)
out
=
out
.
contiguous
(
memory_format
=
torch
.
contiguous_format
)
x
=
self
.
layer4
(
x
)
out
=
torch
.
flatten
(
out
,
start_dim
=
1
,
end_dim
=
2
)
return
x
return
out
def
ResNet34
():
def
ResNet34
():
...
...
nnet/xsets.py
View file @
ebdf1b53
...
@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler):
...
@@ -106,7 +106,7 @@ class SideSampler(torch.utils.data.Sampler):
self
.
segment_cursors
=
numpy
.
zeros
((
len
(
self
.
labels_to_indices
),),
dtype
=
numpy
.
int
)
self
.
segment_cursors
=
numpy
.
zeros
((
len
(
self
.
labels_to_indices
),),
dtype
=
numpy
.
int
)
def
__iter__
(
self
):
def
__iter__
(
self
):
g
=
torch
.
Generator
()
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
g
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
numpy
.
random
.
seed
(
self
.
seed
+
self
.
epoch
)
numpy
.
random
.
seed
(
self
.
seed
+
self
.
epoch
)
...
@@ -175,7 +175,7 @@ class SideSet(Dataset):
...
@@ -175,7 +175,7 @@ class SideSet(Dataset):
overlap
=
0.
,
overlap
=
0.
,
dataset_df
=
None
,
dataset_df
=
None
,
min_duration
=
0.165
,
min_duration
=
0.165
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
):
):
"""
"""
...
@@ -269,6 +269,8 @@ class SideSet(Dataset):
...
@@ -269,6 +269,8 @@ class SideSet(Dataset):
self
.
transform
[
"codec"
]
=
[]
self
.
transform
[
"codec"
]
=
[]
if
"phone_filtering"
in
transforms
:
if
"phone_filtering"
in
transforms
:
self
.
transform
[
"phone_filtering"
]
=
[]
self
.
transform
[
"phone_filtering"
]
=
[]
if
"stretch"
in
transforms
:
self
.
transform
[
"stretch"
]
=
[]
self
.
noise_df
=
None
self
.
noise_df
=
None
if
"add_noise"
in
self
.
transform
:
if
"add_noise"
in
self
.
transform
:
...
@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
...
@@ -416,18 +418,27 @@ class IdMapSet(Dataset):
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
start
=
int
(
self
.
idmap
.
start
[
index
]
*
0.01
*
self
.
sample_rate
)
if
self
.
idmap
.
stop
[
index
]
is
None
:
if
self
.
idmap
.
stop
[
index
]
is
None
:
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
duration
=
int
(
speech
.
shape
[
1
]
-
start
)
else
:
else
:
duration
=
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo
=
torchaudio
.
info
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
)
assert
nfo
.
sample_rate
==
self
.
sample_rate
conversion_rate
=
nfo
.
sample_rate
//
self
.
sample_rate
duration
=
(
int
(
self
.
idmap
.
stop
[
index
]
*
0.01
*
self
.
sample_rate
)
-
start
)
# add this in case the segment is too short
# add this in case the segment is too short
if
duration
<=
self
.
min_duration
*
self
.
sample_rate
:
if
duration
<=
self
.
min_duration
*
self
.
sample_rate
:
middle
=
start
+
duration
//
2
middle
=
start
+
duration
//
2
start
=
int
(
max
(
0
,
int
(
middle
-
(
self
.
min_duration
*
self
.
sample_rate
/
2
))))
start
=
int
(
max
(
0
,
int
(
middle
-
(
self
.
min_duration
*
self
.
sample_rate
/
2
))))
duration
=
int
(
self
.
min_duration
*
self
.
sample_rate
)
duration
=
int
(
self
.
min_duration
*
self
.
sample_rate
)
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
speech
,
speech_fs
=
torchaudio
.
load
(
f
"
{
self
.
data_path
}
/
{
self
.
idmap
.
rightids
[
index
]
}
.
{
self
.
file_extension
}
"
,
frame_offset
=
start
,
frame_offset
=
start
*
conversion_rate
,
num_frames
=
duration
)
num_frames
=
duration
*
conversion_rate
)
if
nfo
.
sample_rate
!=
self
.
sample_rate
:
speech
=
torchaudio
.
transforms
.
Resample
(
nfo
.
sample_rate
,
self
.
sample_rate
).
forward
(
speech
)
#speech += 10e-6 * torch.randn(speech.shape)
#speech += 10e-6 * torch.randn(speech.shape)
...
...
nnet/xvector.py
View file @
ebdf1b53
...
@@ -36,12 +36,12 @@ import shutil
...
@@ -36,12 +36,12 @@ import shutil
import
torch
import
torch
import
tqdm
import
tqdm
import
yaml
import
yaml
#torch.autograd.set_detect_anomaly(True)
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
sklearn.model_selection
import
train_test_split
from
sklearn.model_selection
import
train_test_split
from
.pooling
import
MeanStdPooling
from
.pooling
import
MeanStdPooling
from
.pooling
import
AttentivePooling
from
.pooling
import
AttentivePooling
,
ChannelWiseCorrPooling
from
.pooling
import
GruPooling
from
.pooling
import
GruPooling
from
.preprocessor
import
MfccFrontEnd
from
.preprocessor
import
MfccFrontEnd
from
.preprocessor
import
MelSpecFrontEnd
from
.preprocessor
import
MelSpecFrontEnd
...
@@ -522,6 +522,35 @@ class Xtractor(torch.nn.Module):
...
@@ -522,6 +522,35 @@ class Xtractor(torch.nn.Module):
self
.
stat_pooling
=
AttentivePooling
(
256
,
80
,
global_context
=
True
)
self
.
stat_pooling
=
AttentivePooling
(
256
,
80
,
global_context
=
True
)
self
.
loss
=
loss
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
int
(
self
.
speaker_number
),
s
=
30
,
m
=
0.2
,
easy_margin
=
False
)
elif
self
.
loss
==
'aps'
:
self
.
after_speaker_embedding
=
SoftmaxAngularProto
(
int
(
self
.
speaker_number
),
emb_dim
=
self
.
embedding_size
)
self
.
preprocessor_weight_decay
=
0.00002
self
.
sequence_network_weight_decay
=
0.00002
self
.
stat_pooling_weight_decay
=
0.00002
self
.
before_speaker_embedding_weight_decay
=
0.00002
self
.
after_speaker_embedding_weight_decay
=
0.000
elif
model_archi
==
"experimental"
:
self
.
preprocessor
=
MelSpecFrontEnd
()
self
.
sequence_network
=
PreHalfResNet34
()
self
.
embedding_size
=
embedding_size
#self.embedding_size = 256
#self.before_speaker_embedding = torch.nn.Linear(in_features = 5120,
# out_features = self.embedding_size)
self
.
before_speaker_embedding
=
torch
.
nn
.
Linear
(
in_features
=
int
(
64
*
63
*
5
/
2
),
out_features
=
self
.
embedding_size
)
self
.
stat_pooling
=
ChannelWiseCorrPooling
(
in_channels
=
256
,
out_channels
=
64
)
self
.
loss
=
loss
self
.
loss
=
loss
if
self
.
loss
==
"aam"
:
if
self
.
loss
==
"aam"
:
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
self
.
after_speaker_embedding
=
ArcMarginProduct
(
self
.
embedding_size
,
...
@@ -788,7 +817,6 @@ class Xtractor(torch.nn.Module):
...
@@ -788,7 +817,6 @@ class Xtractor(torch.nn.Module):
# Mean and Standard deviation pooling
# Mean and Standard deviation pooling
x
=
self
.
stat_pooling
(
x
)
x
=
self
.
stat_pooling
(
x
)
x
=
self
.
before_speaker_embedding
(
x
)
x
=
self
.
before_speaker_embedding
(
x
)
if
norm_embedding
:
if
norm_embedding
:
...
@@ -1005,7 +1033,7 @@ def get_network(model_opts, local_rank):
...
@@ -1005,7 +1033,7 @@ def get_network(model_opts, local_rank):
:return:
:return:
"""
"""
if
model_opts
[
"model_type"
]
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
,
"fastresnet34"
,
"halfresnet34"
]:
if
model_opts
[
"model_type"
]
in
[
"xvector"
,
"rawnet2"
,
"resnet34"
,
"fastresnet34"
,
"halfresnet34"
,
"experimental"
]:
model
=
Xtractor
(
model_opts
[
"speaker_number"
],
model_opts
[
"model_type"
],
loss
=
model_opts
[
"loss"
][
"type"
],
embedding_size
=
model_opts
[
"embedding_size"
])
model
=
Xtractor
(
model_opts
[
"speaker_number"
],
model_opts
[
"model_type"
],
loss
=
model_opts
[
"loss"
][
"type"
],
embedding_size
=
model_opts
[
"embedding_size"
])
else
:
else
:
# Custom type of model
# Custom type of model
...
@@ -1035,24 +1063,9 @@ def get_network(model_opts, local_rank):
...
@@ -1035,24 +1063,9 @@ def get_network(model_opts, local_rank):
if
name
.
split
(
"."
)[
0
]
in
model_opts
[
"reset_parts"
]:
if
name
.
split
(
"."
)[
0
]
in
model_opts
[
"reset_parts"
]:
param
.
requires_grad
=
False
param
.
requires_grad
=
False
if
model_opts
[
"loss"
][
"type"
]
==
"aam"
and
not
(
model_opts
[
"loss"
][
"aam_margin"
]
==
0.2
and
model_opts
[
"loss"
][
"aam_s"
]
==
30
):
#if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
model
.
after_speaker_embedding
.
change_params
(
model_opts
[
"loss"
][
"aam_s"
],
model_opts
[
"loss"
][
"aam_margin"
])
# model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
print
(
f
"Modified AAM: margin =
{
model
.
after_speaker_embedding
.
m
}
and s =
{
model
.
after_speaker_embedding
.
s
}
"
)
# print(f"Modified AAM: margin = {model.after_speaker_embedding.m} and s = {model.after_speaker_embedding.s}")
if
local_rank
<
1
:
logging
.
info
(
model
)
logging
.
info
(
"Model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
before_speaker_embedding
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
stat_pooling
.
parameters
()
if
p
.
requires_grad
)))
return
model
return
model
...
@@ -1080,7 +1093,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1080,7 +1093,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
training_df
,
validation_df
=
train_test_split
(
df
,
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_opts
[
"validation_ratio"
],
test_size
=
dataset_opts
[
"validation_ratio"
],
stratify
=
stratify
)
stratify
=
stratify
)
# TODO
torch
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
cuda
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
torch
.
cuda
.
manual_seed
(
training_opts
[
'torch_seed'
]
+
local_rank
)
...
@@ -1090,7 +1104,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1090,7 +1104,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
transform_number
=
dataset_opts
[
'train'
][
'transform_number'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
overlap
=
dataset_opts
[
'train'
][
'overlap'
],
dataset_df
=
training_df
,
dataset_df
=
training_df
,
output_format
=
"pytorch"
,
output_format
=
"pytorch"
)
)
validation_set
=
SideSet
(
dataset_opts
,
validation_set
=
SideSet
(
dataset_opts
,
...
@@ -1106,20 +1120,20 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
...
@@ -1106,20 +1120,20 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
if
training_opts
[
"multi_gpu"
]:
if
training_opts
[
"multi_gpu"
]:
assert
dataset_opts
[
"batch_size"
]
%
torch
.
cuda
.
device_count
()
==
0
assert
dataset_opts
[
"batch_size"
]
%
torch
.
cuda
.
device_count
()
==
0
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
assert
dataset_opts
[
"batch_size"
]
%
samples_per_speaker
==
0
batch_size
=
dataset_opts
[
"batch_size"
]
//
(
torch
.
cuda
.
device_count
()
*
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
])
batch_size
=
dataset_opts
[
"batch_size"
]
//
torch
.
cuda
.
device_count
()
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
samples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
samples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"samples_per_speaker"
],
batch_size
=
batch_size
,
batch_size
=
batch_size
*
torch
.
cuda
.
device_count
()
,
seed
=
training_opts
[
'torch_seed'
],
seed
=
training_opts
[
'torch_seed'
],
rank
=
local_rank
,
rank
=
local_rank
,
num_process
=
torch
.
cuda
.
device_count
(),
num_process
=
torch
.
cuda
.
device_count
(),
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
num_replicas
=
dataset_opts
[
"train"
][
"sampler"
][
"augmentation_replica"
]
)
)
else
:
else
:
batch_size
=
dataset_opts
[
"batch_size"
]
//
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
]
batch_size
=
dataset_opts
[
"batch_size"
]
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
side_sampler
=
SideSampler
(
data_source
=
training_set
.
sessions
[
'speaker_idx'
],
spk_count
=
model_opts
[
"speaker_number"
],
spk_count
=
model_opts
[
"speaker_number"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
examples_per_speaker
=
dataset_opts
[
"train"
][
"sampler"
][
"examples_per_speaker"
],
...
@@ -1380,8 +1394,18 @@ def xtrain(dataset_description,
...
@@ -1380,8 +1394,18 @@ def xtrain(dataset_description,
# Initialize the model
# Initialize the model
model
=
get_network
(
model_opts
,
local_rank
)
model
=
get_network
(
model_opts
,
local_rank
)
if
local_rank
<
1
:
if
local_rank
<
1
:
monitor
.
logger
.
info
(
model
)
monitor
.
logger
.
info
(
model
)
monitor
.
logger
.
info
(
"Model_parameters_count: {:d}"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
sequence_network
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
before_speaker_embedding
.
parameters
()
if
p
.
requires_grad
)
+
\
sum
(
p
.
numel
()
for
p
in
model
.
stat_pooling
.
parameters
()
if
p
.
requires_grad
)))
embedding_size
=
model
.
embedding_size
embedding_size
=
model
.
embedding_size
aam_scheduler
=
None
aam_scheduler
=
None
...
@@ -1569,7 +1593,7 @@ def train_epoch(model,
...
@@ -1569,7 +1593,7 @@ def train_epoch(model,
loss
+=
criterion
(
output
,
target
)
loss
+=
criterion
(
output
,
target
)
elif
loss_criteria
==
'aps'
:
elif
loss_criteria
==
'aps'
:
output_tuple
,
_
=
model
(
data
,
target
=
target
)
output_tuple
,
_
=
model
(
data
,
target
=
target
)
loss
,
output
=
output_tuple
loss
,
no_margin_
output
=
output_tuple
else
:
else
:
output
,
_
=
model
(
data
,
target
=
None
)
output
,
_
=
model
(
data
,
target
=
None
)
loss
=
criterion
(
output
,
target
)
loss
=
criterion
(
output
,
target
)
...
@@ -1603,7 +1627,7 @@ def train_epoch(model,
...
@@ -1603,7 +1627,7 @@ def train_epoch(model,
if
math
.
fmod
(
batch_idx
,
training_opts
[
"log_interval"
])
==
0
:
if
math
.
fmod
(
batch_idx
,
training_opts
[
"log_interval"
])
==
0
:
batch_size
=
target
.
shape
[
0
]
batch_size
=
target
.
shape
[
0
]
training_monitor
.
update
(
training_loss
=
loss
.
item
(),
training_monitor
.
update
(
training_loss
=
loss
.
item
(),
training_acc
=
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
))
training_acc
=
100.0
*
accuracy
/
((
batch_idx
+
1
)
*
batch_size
))
training_monitor
.
logger
.
info
(
'Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
training_monitor
.
logger
.
info
(
'Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
training_monitor
.
current_epoch
,
training_monitor
.
current_epoch
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment