Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Meysam Shamsi
s4d
Commits
e120ee8a
Commit
e120ee8a
authored
Aug 17, 2020
by
Anthony Larcher
Browse files
seq2seq
parent
c447ff2b
Changes
4
Hide whitespace changes
Inline
Side-by-side
s4d/__init__.py
View file @
e120ee8a
...
...
@@ -38,7 +38,7 @@ from .clustering.hac_utils import bic_square_root
from
.clustering.cc_iv
import
ConnectedComponent
from
.nnet.wavsets
import
Allies
Set
from
.nnet.wavsets
import
Seq
Set
from
.nnet.seqtoseq
import
PreNet
from
.nnet.seqtoseq
import
BLSTM
...
...
s4d/nnet/__init__.py
View file @
e120ee8a
...
...
@@ -23,6 +23,6 @@
Copyright 2014-2020 Anthony Larcher
"""
from
.wavsets
import
Allies
Set
from
.wavsets
import
Seq
Set
from
.seqtoseq
import
PreNet
from
.seqtoseq
import
BLSTM
\ No newline at end of file
s4d/nnet/seqtoseq.py
View file @
e120ee8a
...
...
@@ -28,6 +28,7 @@ import sys
import
numpy
import
random
import
h5py
import
shutil
import
torch
import
torch.nn
as
nn
from
torch
import
optim
...
...
@@ -35,6 +36,7 @@ from torch.utils.data import Dataset
import
logging
from
sidekit.nnet.vad_rnn
import
BLSTM
from
torch.utils.data
import
DataLoader
__license__
=
"LGPL"
__author__
=
"Anthony Larcher"
...
...
@@ -45,7 +47,18 @@ __status__ = "Production"
__docformat__
=
'reS'
def
save_checkpoint
(
state
,
is_best
,
filename
=
'checkpoint.pth.tar'
,
best_filename
=
'model_best.pth.tar'
):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch
.
save
(
state
,
filename
)
if
is_best
:
shutil
.
copyfile
(
filename
,
best_filename
)
class
PreNet
(
nn
.
Module
):
def
__init
(
self
,
...
...
@@ -130,21 +143,269 @@ class BLSTM(nn.Module):
class
SeqToSeq
(
nn
.
Module
):
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
def
__init__
(
self
,
input_size
,
lstm_1
,
lstm_2
,
linear_1
,
linear_2
,
output_size
=
1
):
"""
def
__init__
(
self
):
self
.
preprocessor
=
PreNet
(
sample_rate
=
16000
,
windows_duration
=
0.2
,
frame_shift
=
0.01
)
:param input_size:
:param lstm_1:
:param lstm_2:
:param linear_1:
:param linear_2:
:param output_size:
"""
super
(
BLSTM
,
self
).
__init__
()
self
.
sequence_model
=
BLSTM
(
input_size
=
1
,
lstm_1
=
64
,
lstm_2
=
40
,
linear_1
=
40
,
linear_2
=
10
)
self
.
lstm_1
=
nn
.
LSTM
(
input_size
,
lstm_1
//
2
,
bidirectional
=
True
,
batch_first
=
True
)
self
.
lstm_2
=
nn
.
LSTM
(
lstm_1
,
lstm_2
//
2
,
bidirectional
=
True
,
batch_first
=
True
)
self
.
linear_1
=
nn
.
Linear
(
lstm_2
,
linear_1
)
self
.
linear_2
=
nn
.
Linear
(
linear_1
,
linear_2
)
self
.
output
=
nn
.
Linear
(
linear_2
,
output_size
)
self
.
hidden
=
None
def
forward
(
self
,
input
):
x
=
self
.
preprocessor
(
input
)
output
=
self
.
sequence_model
(
x
)
return
output
def
forward
(
self
,
inputs
):
"""
:param inputs:
:return:
"""
if
self
.
hidden
is
None
:
hidden_1
,
hidden_2
=
None
,
None
else
:
hidden_1
,
hidden_2
=
self
.
hidden
tmp
,
hidden_1
=
self
.
lstm_1
(
inputs
,
hidden_1
)
x
,
hidden_2
=
self
.
lstm_2
(
tmp
,
hidden_2
)
self
.
hidden
=
(
hidden_1
,
hidden_2
)
x
=
torch
.
tanh
(
self
.
linear_1
(
x
))
x
=
torch
.
tanh
(
self
.
linear_2
(
x
))
x
=
torch
.
sigmoid
(
self
.
output
(
x
))
return
x
def
seqTrain
(
data_dir
,
mode
,
duration
=
2.
,
seg_shift
=
0.25
,
filter_type
=
"gate"
,
collar_duration
=
0.1
,
framerate
=
16000
,
epochs
=
100
,
batch_size
=
32
,
lr
=
0.0001
,
loss
=
"cross_validation"
,
patience
=
10
,
tmp_model_name
=
None
,
best_model_name
=
None
,
multi_gpu
=
True
,
opt
=
'sgd'
,
num_thread
=
10
):
"""
:param data_dir:
:param mode:
:param duration:
:param seg_shift:
:param filter_type:
:param collar_duration:
:param framerate:
:param epochs:
:param lr:
:param loss:
:param patience:
:param tmp_model_name:
:param best_model_name:
:param multi_gpu:
:param opt:
:return:
"""
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Start from scratch
model
=
SeqToSeq
()
# TODO implement a model adaptation
if
torch
.
cuda
.
device_count
()
>
1
and
multi_gpu
:
print
(
"Let's use"
,
torch
.
cuda
.
device_count
(),
"GPUs!"
)
model
=
torch
.
nn
.
DataParallel
(
model
)
else
:
print
(
"Train on a single GPU"
)
model
.
to
(
device
)
"""
Create two dataloaders for training and evaluation
"""
training_set
,
validation_set
=
None
,
None
training_loader
=
DataLoader
(
training_set
,
batch_size
=
batch_size
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
num_thread
)
validation_loader
=
DataLoader
(
validation_set
,
batch_size
=
batch_size
,
drop_last
=
True
,
num_workers
=
num_thread
)
"""
Set the training options
"""
if
opt
==
'sgd'
:
_optimizer
=
torch
.
optim
.
SGD
_options
=
{
'lr'
:
lr
,
'momentum'
:
0.9
}
elif
opt
==
'adam'
:
_optimizer
=
torch
.
optim
.
Adam
_options
=
{
'lr'
:
lr
}
elif
opt
==
'rmsprop'
:
_optimizer
=
torch
.
optim
.
RMSprop
_options
=
{
'lr'
:
lr
}
params
=
[
{
'params'
:
[
param
for
name
,
param
in
model
.
named_parameters
()
if
'bn'
not
in
name
]
},
{
'params'
:
[
param
for
name
,
param
in
model
.
named_parameters
()
if
'bn'
in
name
],
'weight_decay'
:
0
},
]
if
type
(
model
)
is
SeqToSeq
:
optimizer
=
_optimizer
([
{
'params'
:
model
.
parameters
(),
'weight_decay'
:
model
.
weight_decay
},],
**
_options
)
else
:
optimizer
=
_optimizer
([
{
'params'
:
model
.
module
.
sequence_network
.
parameters
(),
'weight_decay'
:
model
.
module
.
sequence_network_weight_decay
},],
**
_options
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
,
verbose
=
True
)
best_accuracy
=
0.0
best_accuracy_epoch
=
1
curr_patience
=
patience
for
epoch
in
range
(
1
,
epochs
+
1
):
# Process one epoch and return the current model
if
curr_patience
==
0
:
print
(
f
"Stopping at epoch
{
epoch
}
for cause of patience"
)
break
model
=
train_epoch
(
model
,
epoch
,
training_loader
,
optimizer
,
log_interval
,
device
=
device
)
# Add the cross validation here
accuracy
,
val_loss
=
cross_validation
(
model
,
validation_loader
,
device
=
device
)
logging
.
critical
(
"*** Cross validation accuracy = {} %"
.
format
(
accuracy
))
# Decrease learning rate according to the scheduler policy
scheduler
.
step
(
val_loss
)
print
(
f
"Learning rate is
{
optimizer
.
param_groups
[
0
][
'lr'
]
}
"
)
# remember best accuracy and save checkpoint
is_best
=
accuracy
>
best_accuracy
best_accuracy
=
max
(
accuracy
,
best_accuracy
)
if
type
(
model
)
is
SeqToSeq
:
save_checkpoint
({
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'accuracy'
:
best_accuracy
,
'scheduler'
:
scheduler
},
is_best
,
filename
=
tmp_model_name
+
".pt"
,
best_filename
=
best_model_name
+
'.pt'
)
else
:
save_checkpoint
({
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
module
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'accuracy'
:
best_accuracy
,
'scheduler'
:
scheduler
},
is_best
,
filename
=
tmp_model_name
+
".pt"
,
best_filename
=
best_model_name
+
'.pt'
)
if
is_best
:
best_accuracy_epoch
=
epoch
curr_patience
=
patience
else
:
curr_patience
-=
1
logging
.
critical
(
f
"Best accuracy
{
best_accuracy
*
100.
}
obtained at epoch
{
best_accuracy_epoch
}
"
)
def
train_epoch
(
model
,
epoch
,
training_loader
,
optimizer
,
log_interval
,
device
):
"""
:param model:
:param epoch:
:param training_loader:
:param optimizer:
:param log_interval:
:param device:
:param clipping:
:return:
"""
model
.
train
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
'mean'
)
accuracy
=
0.0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
training_loader
):
target
=
target
.
squeeze
()
optimizer
.
zero_grad
()
output
=
model
(
data
.
to
(
device
),
target
=
target
.
to
(
device
))
loss
=
criterion
(
output
,
target
.
to
(
device
))
loss
.
backward
()
optimizer
.
step
()
accuracy
+=
(
torch
.
argmax
(
output
.
data
,
1
)
==
target
.
to
(
device
)).
sum
()
if
batch_idx
%
log_interval
==
0
:
batch_size
=
target
.
shape
[
0
]
logging
.
critical
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}'
.
format
(
epoch
,
batch_idx
+
1
,
training_loader
.
__len__
(),
100.
*
batch_idx
/
training_loader
.
__len__
(),
loss
.
item
(),
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
)))
return
model
def
cross_validation
(
model
,
validation_loader
,
device
):
"""
:param model:
:param validation_loader:
:param device:
:return:
"""
model
.
eval
()
accuracy
=
0.0
loss
=
0.0
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
with
torch
.
no_grad
():
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
validation_loader
):
batch_size
=
target
.
shape
[
0
]
target
=
target
.
squeeze
()
output
=
model
(
data
.
to
(
device
),
target
=
target
.
to
(
device
),
is_eval
=
True
)
print
(
output
.
shape
)
accuracy
+=
(
torch
.
argmax
(
output
.
data
,
1
)
==
target
.
to
(
device
)).
sum
()
loss
+=
criterion
(
output
,
target
.
to
(
device
))
return
100.
*
accuracy
.
cpu
().
numpy
()
/
((
batch_idx
+
1
)
*
batch_size
),
\
loss
.
cpu
().
numpy
()
/
((
batch_idx
+
1
)
*
batch_size
)
s4d/nnet/wavsets.py
View file @
e120ee8a
...
...
@@ -32,14 +32,20 @@ __status__ = "Production"
__docformat__
=
'reStructuredText'
import
numpy
import
pathlib
import
random
import
scipy
import
sidekit
import
soundfile
import
torch
from
..diar
import
Diar
from
pathlib
import
Path
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
collections
import
namedtuple
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
def
framing
(
sig
,
win_size
,
win_shift
=
1
,
context
=
(
0
,
0
),
pad
=
'zeros'
):
"""
...
...
@@ -82,39 +88,51 @@ def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
def
mdtm_to_label
(
mdtm_filename
,
show_duration
,
framerate
):
start_time
,
stop_time
,
sample_number
,
speaker_dict
):
"""
:param show:
:param show_duration:
:param allies_dir:
:param mode:
:param duration:
:param start:
:param framerate:
:param filter_type:
:param collar_duration:
:param mdtm_filename:
:param start_time:
:param stop_time:
:param sample_number:
:return:
"""
diarization
=
Diar
.
read_mdtm
(
mdtm_filename
)
diarization
.
sort
([
'show'
,
'start'
])
# Create a dictionary of speakers
speaker_set
=
diarization
.
unique
(
'cluster'
)
speaker_dict
=
{}
for
idx
,
spk
in
enumerate
(
speaker_set
):
speaker_dict
[
spk
]
=
idx
# Create the empty labels
label
=
numpy
.
zeros
(
show_duration
,
dtype
=
int
)
label
=
numpy
.
zeros
(
sample_number
,
dtype
=
int
)
# Compute the time stamp of each sample
time_stamps
=
numpy
.
zeros
(
sample_number
,
dtype
=
numpy
.
float32
)
period
=
(
stop_time
-
start_time
)
/
sample_number
for
t
in
range
(
sample_number
):
time_stamps
[
t
]
=
start_time
+
(
2
*
t
+
1
)
*
period
/
2
# Find the label of the first sample
seg_idx
=
0
while
diarization
.
segments
[
seg_idx
][
'stop'
]
<
start_time
:
seg_idx
+=
1
#REPRENDRE ICI
#ii = 0
#while diarization.segments[seg_idx]['start'] < stop_time:
# while time_stamps[ii] < diarization.segments[seg_idx]['stop']:
# label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
# ii += 1
# Fill the labels with spk_idx
for
segment
in
diarization
:
start
=
int
(
segment
[
'start'
])
*
framerate
//
100
stop
=
int
(
segment
[
'stop'
])
*
framerate
//
100
spk_idx
=
speaker_dict
[
segment
[
'cluster'
]]
label
[
start
:
stop
]
=
spk_idx
# start = int(diarization.segments[seg_idx]['start']) * framerate // sampling_frequency
# stop = int(diarization.segments[seg_idx]['stop']) * framerate // sampling_frequency
# spk_idx = speaker_dict[segment['cluster']]
# label[start:stop] = spk_idx
# seg_idx += 1
# Get label of each sample
return
label
...
...
@@ -154,12 +172,12 @@ def get_segment_label(label, seg_idx, mode, duration, framerate, seg_shift, coll
return
segment_label
[
seg_idx
]
class
Allies
Set
(
Dataset
):
class
Diar
Set
(
Dataset
):
"""
Object creates a dataset for
"""
def
__init__
(
self
,
allies
_dir
,
data
_dir
,
mode
,
duration
=
2.
,
seg_shift
=
0.25
,
...
...
@@ -170,7 +188,7 @@ class AlliesSet(Dataset):
Create batches of wavform samples for deep neural network training
:param
allies
_dir: the root directory of ALLIES data
:param
data
_dir: the root directory of ALLIES data
:param mode: can be "vad", "spk_turn", "overlap"
:param duration: duration of the segments in seconds
:param seg_shift: shift to generate overlaping segments
...
...
@@ -182,16 +200,16 @@ class AlliesSet(Dataset):
self
.
segments
=
[]
self
.
duration
=
duration
self
.
seg_shift
=
seg_shift
self
.
input_dir
=
allies
_dir
self
.
input_dir
=
data
_dir
self
.
mode
=
mode
self
.
filter_type
=
filter_type
self
.
collar_duration
=
collar_duration
self
.
wav_name_format
=
allies
_dir
+
'/wav/{}.wav'
self
.
mdtm_name_format
=
allies
_dir
+
'/mdtm/{}.mdtm'
self
.
wav_name_format
=
data
_dir
+
'/wav/{}.wav'
self
.
mdtm_name_format
=
data
_dir
+
'/mdtm/{}.mdtm'
# load the list of training file names
training_file_list
=
[
str
(
f
).
split
(
"/"
)[
-
1
].
split
(
'.'
)[
0
]
for
f
in
list
(
Path
(
allies
_dir
+
"/wav/"
).
rglob
(
"*.[wW][aA][vV]"
))
0
]
for
f
in
list
(
Path
(
data
_dir
+
"/wav/"
).
rglob
(
"*.[wW][aA][vV]"
))
]
for
show
in
training_file_list
:
...
...
@@ -236,3 +254,143 @@ class AlliesSet(Dataset):
def
__len__
(
self
):
return
self
.
len
def
seqSplit
(
mdtm_dir
,
duration
=
2.
):
"""
:param mdtm_dir:
:param duration:
:return:
"""
segment_list
=
Diar
()
speaker_dict
=
dict
()
idx
=
0
# For each MDTM
for
mdtm_file
in
pathlib
.
Path
(
mdtm_dir
).
glob
(
'*.mdtm'
):
# Load MDTM file
ref
=
Diar
.
read_mdtm
(
mdtm_file
)
ref
.
sort
()
last_stop
=
ref
.
segments
[
-
1
][
"stop"
]
# Get the borders of the segments (not the start of the first and not the end of the last
# For each border time B get a segment between B - duration and B + duration
# in which we will pick up randomly later
for
idx
,
seg
in
enumerate
(
ref
.
segments
):
if
idx
>
0
and
seg
[
"start"
]
>
duration
and
seg
[
"start"
]
+
duration
<
last_stop
:
segment_list
.
append
(
show
=
seg
[
'show'
],
cluster
=
""
,
start
=
float
(
seg
[
"start"
]
-
duration
)
/
100.
,
stop
=
float
(
seg
[
"start"
]
+
duration
)
/
100.
)
elif
idx
<
len
(
ref
.
segments
)
-
1
and
seg
[
"stop"
]
+
duration
<
last_stop
:
segment_list
.
append
(
show
=
seg
[
'show'
],
cluster
=
""
,
start
=
float
(
seg
[
"stop"
]
-
duration
)
/
100.
,
stop
=
float
(
seg
[
"stop"
]
+
duration
)
/
100.
)
# Get list of unique speakers
speakers
=
ref
.
unique
(
'cluster'
)
for
spk
in
speakers
:
if
not
spk
in
speaker_dict
:
speaker_dict
[
spk
]
=
idx
idx
+=
1
return
segment_list
,
speaker_dict
class
SeqSet
(
Dataset
):
"""
Object creates a dataset for sequence to sequence training
"""
def
__init__
(
self
,
wav_dir
,
mdtm_dir
,
segment_list
,
mode
,
duration
=
2.
,
filter_type
=
"gate"
,
collar_duration
=
0.1
,
framerate
=
16000
,
transform_pipeline
=
None
):
self
.
wav_dir
=
wav_dir
self
.
mdtm_dir
=
mdtm_dir
self
.
segment_list
=
segment_list
self
.
mode
=
mode
self
.
duration
=
duration
self
.
filter_type
=
filter_type
self
.
collar_duration
=
collar_duration
self
.
framerate
=
framerate
self
.
transform_pipeline
=
transform_pipeline
_transform
=
[]
if
not
self
.
transform_pipeline
==
''
:
trans
=
self
.
transform_pipeline
.
split
(
','
)
for
t
in
trans
:
if
'PreEmphasis'
in
t
:
_transform
.
append
(
PreEmphasis
())
if
'MFCC'
in
t
:
_transform
.
append
(
MFCC
())
if
"CMVN"
in
t
:
_transform
.
append
(
CMVN
())
if
"FrequencyMask"
in
t
:
a
=
int
(
t
.
split
(
'-'
)[
0
].
split
(
'('
)[
1
])
b
=
int
(
t
.
split
(
'-'
)[
1
].
split
(
')'
)[
0
])
_transform
.
append
(
FrequencyMask
(
a
,
b
))
if
"TemporalMask"
in
t
:
a
=
int
(
t
.
split
(
"("
)[
1
].
split
(
")"
)[
0
])
_transform
.
append
(
TemporalMask
(
a
))
self
.
transforms
=
transforms
.
Compose
(
_transform
)
segment_list
,
speaker_dict
=
seqSplit
(
mdtm_dir
=
self
.
mdtm_dir
,
duration
=
self
.
duration
)
self
.
segment_list
=
segment_list
self
.
speaker_dict
=
speaker_dict
self
.
len
=
len
(
segment_list
)
def
__getitem__
(
self
,
index
):
"""
On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
(trames)
:param index:
:return:
"""
# Get segment info to load from
seg
=
self
.
segment_list
[
index
]
# Randomly pick an audio chunk within the current segment
start
=
random
.
uniform
(
seg
.
start_time
,
seg
.
start_time
+
self
.
duration
)
sig
,
_
=
soundfile
.
read
(
self
.
wav_dir
+
seg
.
show
+
".wav"
,
start
=
start
*
self
.
sample_rate
,
stop
=
(
start
+
self
.
duration
)
*
self
.
sample_rate
)
sig
+=
0.0001
*
numpy
.
random
.
randn
(
sig
.
shape
[
0
])
if
self
.
transform_pipeline
:
sig
,
_
,
__
,
___
=
self
.
transforms
((
sig
,
None
,
None
,
None
))
label
=
mdtm_to_label
(
mdtm_filename
=
self
.
mdtm_dir
+
seg
.
show
+
".mdtm"
,
start_time
=
start
,
stop_time
=
start
+
self
.
duration
,
sample_number
=
sig
.
shape
[
0
],
speaker_dict
=
self
.
speaker_dict
)
# For each sampling_time we need to get the label
# A MODIFIER
label
=
get_segment_label
(
tmp_label
,
idx
,
self
.
mode
,
self
.
duration
,
self
.
framerate
,
self
.
seg_shift
,
self
.
collar_duration
,
filter_type
=
self
.
filter_type
)
return
torch
.
from_numpy
(
data
).
type
(
torch
.
FloatTensor
),
torch
.
from_numpy
(
label
.
astype
(
'long'
))
def
__len__
(
self
):
return
self
.
len
Write
Preview