Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Lebourdais
s4d
Commits
7df4d68e
Commit
7df4d68e
authored
Nov 09, 2020
by
Martin Lebourdais
Browse files
Merge branch 'master' of
https://git-lium.univ-lemans.fr/Meignier/s4d
parents
4de1fd1a
b0835fe2
Changes
5
Hide whitespace changes
Inline
Side-by-side
s4d/__init__.py
View file @
7df4d68e
...
...
@@ -60,4 +60,4 @@ __maintainer__ = "Sylvain Meignier"
__email__
=
"sylvain.meignierr@univ-lemans.fr"
__status__
=
"Production"
__docformat__
=
'reStructuredText'
__version__
=
"0.1.4.
7
"
__version__
=
"0.1.4.
8
"
s4d/nnet/__init__.py
View file @
7df4d68e
...
...
@@ -24,5 +24,7 @@ Copyright 2014-2020 Anthony Larcher
"""
from
.wavsets
import
SeqSet
from
.wavsets
import
create_train_val_seqtoseq
from
.seqtoseq
import
BLSTM
from
.seqtoseq
import
SeqToSeq
\ No newline at end of file
from
.seqtoseq
import
SeqToSeq
from
.seqtoseq
import
seqTrain
\ No newline at end of file
s4d/nnet/seqtoseq.py
View file @
7df4d68e
...
...
@@ -23,29 +23,22 @@
Copyright 2014-2020 Anthony Larcher
"""
import
os
import
sys
import
logging
import
pandas
import
numpy
from
collections
import
OrderedDict
import
random
import
h5py
import
shutil
import
torch
import
torch.nn
as
nn
import
yaml
from
sklearn.model_selection
import
train_test_split
from
torch
import
optim
from
torch.utils.data
import
Dataset
from
.loss
import
ConcordanceCorCoeff
from
.wavsets
import
SeqSet
from
collections
import
OrderedDict
from
sidekit.nnet.sincnet
import
SincNet
from
torch.utils.data
import
DataLoader
from
.wavsets
import
SeqSet
from
.wavsets
import
create_train_val_seqtoseq
__license__
=
"LGPL"
__author__
=
"Anthony Larcher"
__author__
=
"Anthony Larcher
, Martin Lebourdais, Meysam Shamsi
"
__copyright__
=
"Copyright 2015-2020 Anthony Larcher"
__maintainer__
=
"Anthony Larcher"
__email__
=
"anthony.larcher@univ-lemans.fr"
...
...
@@ -68,7 +61,21 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename
shutil
.
copyfile
(
filename
,
best_filename
)
class
BLSTM
(
nn
.
Module
):
def
init_weights
(
m
):
"""
:return:
"""
if
type
(
m
)
==
torch
.
nn
.
Linear
:
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
m
.
bias
.
data
.
fill_
(
0.01
)
class
BLSTM
(
torch
.
nn
.
Module
):
"""
Bi LSTM model used for voice activity detection, speaker turn detection, overlap detection and resegmentation
"""
def
__init__
(
self
,
input_size
,
blstm_sizes
):
...
...
@@ -80,20 +87,13 @@ class BLSTM(nn.Module):
super
(
BLSTM
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
blstm_sizes
=
blstm_sizes
#self.blstm_layers = []
# for blstm_size in blstm_sizes:
# print(f"Input size {input_size},Output_size {self.output_size}")
# self.blstm_layers.append(nn.LSTM(input_size, blstm_size, bidirectional=False, batch_first=True))
# input_size = blstm_size
self
.
output_size
=
blstm_sizes
[
0
]
*
2
# self.blstm_layers = torch.nn.ModuleList(self.blstm_layers)
self
.
output_size
=
blstm_sizes
*
2
self
.
blstm_layers
=
nn
.
LSTM
(
input_size
,
blstm_sizes
[
0
],
bidirectional
=
True
,
batch_first
=
True
,
num_layers
=
2
)
self
.
hidden
=
None
"""
Bi LSTM model used for voice activity detection or speaker turn detection
"""
self
.
blstm_layers
=
torch
.
nn
.
LSTM
(
input_size
,
blstm_sizes
,
bidirectional
=
True
,
batch_first
=
True
,
num_layers
=
2
)
def
forward
(
self
,
inputs
):
"""
...
...
@@ -101,35 +101,18 @@ class BLSTM(nn.Module):
:param inputs:
:return:
"""
#for idx, _s in enumerate(self.blstm_sizes):
# self.blstm_layers[idx].flatten_parameters()
hiddens
=
[]
if
self
.
hidden
is
None
:
#hidden_1, hidden_2 = None, None
for
_s
in
self
.
blstm_sizes
:
hiddens
.
append
(
None
)
else
:
hiddens
=
list
(
self
.
hidden
)
x
=
inputs
outputs
=
[]
# for idx, _s in enumerate(self.blstm_sizes):
# # self.blstm_layers[idx].flatten_parameters()
# print("IN",x.shape)
# x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
# print("OUT",x.shape)
# outputs.append(x)
# self.hidden = tuple(hiddens)
# output = torch.cat(outputs, dim=2)
output
,
h
=
self
.
blstm_layers
(
x
)
output
,
h
=
self
.
blstm_layers
(
inputs
)
return
output
def
output_size
(
self
):
"""
:return:
"""
return
self
.
output_size
class
SeqToSeq
(
nn
.
Module
):
class
SeqToSeq
(
torch
.
nn
.
Module
):
"""
Model used for voice activity detection or speaker turn detection
This model can include a pre-processor to input raw waveform,
...
...
@@ -205,8 +188,7 @@ class SeqToSeq(nn.Module):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Dropout
(
p
=
cfg
[
"post_processing"
][
k
])))
self
.
post_processing
=
torch
.
nn
.
Sequential
(
OrderedDict
(
post_processing_layers
))
#self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
self
.
post_processing
.
apply
(
init_weights
)
def
forward
(
self
,
inputs
):
"""
...
...
@@ -226,7 +208,6 @@ class SeqToSeq(nn.Module):
def
seqTrain
(
dataset_yaml
,
val_dataset_yaml
,
model_yaml
,
mode
,
epochs
=
100
,
lr
=
0.0001
,
patience
=
10
,
...
...
@@ -235,11 +216,6 @@ def seqTrain(dataset_yaml,
best_model_name
=
None
,
multi_gpu
=
True
,
opt
=
'sgd'
,
filter_type
=
"gate"
,
collar_duration
=
0.1
,
framerate
=
16000
,
output_rate
=
100
,
batch_size
=
32
,
log_interval
=
10
,
num_thread
=
10
,
non_overlap_dataset
=
None
,
...
...
@@ -281,29 +257,17 @@ def seqTrain(dataset_yaml,
model
=
torch
.
nn
.
DataParallel
(
model
)
else
:
print
(
"Train on a single GPU"
)
model
.
to
(
device
)
with
open
(
dataset_yaml
,
"r"
)
as
fh
:
dataset_params
=
yaml
.
load
(
fh
,
Loader
=
yaml
.
FullLoader
)
"""
Create two dataloaders for training and evaluation
"""
with
open
(
dataset_yaml
,
"r"
)
as
fh
:
dataset_params
=
yaml
.
load
(
fh
,
Loader
=
yaml
.
FullLoader
)
df
=
pandas
.
read_csv
(
dataset_params
[
"dataset_description"
])
training_df
,
validation_df
=
train_test_split
(
df
,
test_size
=
dataset_params
[
"validation_ratio"
])
_wav_dir
=
dataset_params
[
'wav_dir'
]
_mdtm_dir
=
dataset_params
[
'mdtm_dir'
]
torch
.
manual_seed
(
dataset_params
[
'seed'
])
training_set
=
SeqSet
(
dataset_yaml
,
wav_dir
=
_wav_dir
,
mdtm_dir
=
_mdtm_dir
,
mode
=
mode
,
duration
=
2.
,
filter_type
=
"gate"
,
collar_duration
=
0.1
,
audio_framerate
=
framerate
,
output_framerate
=
output_rate
,
transform_pipeline
=
"MFCC"
)
training_set
,
validation_set
=
create_train_val_seqtoseq
(
dataset_yaml
)
training_loader
=
DataLoader
(
training_set
,
batch_size
=
dataset_params
[
"batch_size"
],
drop_last
=
True
,
...
...
@@ -311,23 +275,9 @@ def seqTrain(dataset_yaml,
pin_memory
=
True
,
num_workers
=
num_thread
)
validation_set
=
SeqSet
(
val_dataset_yaml
,
wav_dir
=
_wav_dir
,
mdtm_dir
=
_mdtm_dir
,
mode
=
mode
,
duration
=
2.
,
filter_type
=
"gate"
,
collar_duration
=
0.1
,
audio_framerate
=
framerate
,
output_framerate
=
output_rate
,
set_type
=
"validation"
,
transform_pipeline
=
"MFCC"
)
validation_loader
=
DataLoader
(
validation_set
,
batch_size
=
dataset_params
[
"batch_size"
],
drop_last
=
True
,
shuffle
=
True
,
pin_memory
=
True
,
num_workers
=
num_thread
)
...
...
@@ -359,24 +309,13 @@ def seqTrain(dataset_yaml,
]
optimizer
=
_optimizer
([{
'params'
:
model
.
parameters
()},],
**
_options
)
#if type(model) is SeqToSeq:
# optimizer = _optimizer([
# {'params': model.parameters(),
# 'weight_decay': model.weight_decay},],
# **_options
# )
#else:
# optimizer = _optimizer([
# {'params': model.module.sequence_network.parameters(),
# #'weight_decay': model.module.sequence_network_weight_decay},],
# **_options
# )
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
,
verbose
=
True
)
best_fmes
=
0.0
best_fmes_epoch
=
1
curr_patience
=
patience
for
epoch
in
range
(
1
,
epochs
+
1
):
# Process one epoch and return the current model
if
curr_patience
==
0
:
...
...
@@ -390,23 +329,23 @@ def seqTrain(dataset_yaml,
device
=
device
)
# Cross validation here
fmes
,
val_loss
=
cross_validation
(
model
,
validation_loader
,
device
=
device
)
logging
.
critical
(
"***
V
alidation
f-Measure
= {}"
.
format
(
fmes
))
accuracy
,
val_loss
=
cross_validation
(
model
,
validation_loader
,
device
=
device
)
logging
.
critical
(
"***
Cross v
alidation
accuracy
= {}
%
"
.
format
(
accuracy
))
# Decrease learning rate according to the scheduler policy
scheduler
.
step
(
val_loss
)
print
(
f
"Learning rate is
{
optimizer
.
param_groups
[
0
][
'lr'
]
}
"
)
#
#
remember best accuracy and save checkpoint
is_best
=
fmes
>
best_fmes
best_
fmes
=
max
(
fmes
,
best_fmes
)
# remember best accuracy and save checkpoint
is_best
=
accuracy
>
best_accuracy
best_
accuracy
=
max
(
accuracy
,
best_accuracy
)
if
type
(
model
)
is
SeqToSeq
:
save_checkpoint
({
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'accuracy'
:
best_
fmes
,
'accuracy'
:
best_
accuracy
,
'scheduler'
:
scheduler
},
is_best
,
filename
=
tmp_model_name
+
".pt"
,
best_filename
=
best_model_name
+
'.pt'
)
else
:
...
...
@@ -414,46 +353,19 @@ def seqTrain(dataset_yaml,
'epoch'
:
epoch
,
'model_state_dict'
:
model
.
module
.
state_dict
(),
'optimizer_state_dict'
:
optimizer
.
state_dict
(),
'accuracy'
:
best_
fmes
,
'accuracy'
:
best_
accuracy
,
'scheduler'
:
scheduler
},
is_best
,
filename
=
tmp_model_name
+
".pt"
,
best_filename
=
best_model_name
+
'.pt'
)
if
is_best
:
best_
fmes
_epoch
=
epoch
best_
accuracy
_epoch
=
epoch
curr_patience
=
patience
else
:
curr_patience
-=
1
logging
.
critical
(
f
"Best F-Mesure
{
best_fmes
*
100.
}
obtained at epoch
{
best_fmes_epoch
}
"
)
def
calc_recall
(
output
,
target
,
device
):
y_trueb
=
target
.
to
(
device
)
y_predb
=
output
rc
=
0.0
pr
=
0.0
batch_size
=
y_trueb
.
shape
[
1
]
for
b
in
range
(
batch_size
):
y_true
=
y_trueb
[:,
b
]
y_pred
=
y_predb
[:,:,
b
]
assert
y_true
.
ndim
==
1
assert
y_pred
.
ndim
==
1
or
y_pred
.
ndim
==
2
if
y_pred
.
ndim
==
2
:
y_pred
=
y_pred
.
argmax
(
dim
=
1
)
tp
=
(
y_true
*
y_pred
).
sum
().
to
(
torch
.
float32
)
tn
=
((
1
-
y_true
)
*
(
1
-
y_pred
)).
sum
().
to
(
torch
.
float32
)
fp
=
((
1
-
y_true
)
*
y_pred
).
sum
().
to
(
torch
.
float32
)
fn
=
(
y_true
*
(
1
-
y_pred
)).
sum
().
to
(
torch
.
float32
)
logging
.
critical
(
f
"Best accuracy
{
best_accuracy
*
100.
}
obtained at epoch
{
best_accuracy_epoch
}
"
)
epsilon
=
1e-7
precision
=
tp
/
(
tp
+
fp
+
epsilon
)
recall
=
tp
/
(
tp
+
fn
+
epsilon
)
rc
+=
recall
pr
+=
precision
return
rc
,
pr
def
train_epoch
(
model
,
epoch
,
training_loader
,
optimizer
,
log_interval
,
device
):
"""
...
...
@@ -468,12 +380,14 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
model
.
to
(
device
)
model
.
train
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
'mean'
)
#criterion = ccc_loss
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
'mean'
,
weight
=
torch
.
FloatTensor
([
0.1
,
0.9
]).
to
(
device
)
)
recall
=
0.0
precision
=
0.0
accuracy
=
0.0
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
training_loader
):
target
=
target
.
squeeze
()
# tnumpy = target.numpy()
# print(tnumpy.shape)
...
...
@@ -491,18 +405,36 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
loss
.
backward
(
retain_graph
=
True
)
optimizer
.
step
()
rc
,
pr
,
acc
=
calc_recall
(
output
.
data
,
target
,
device
)
recall
+=
rc
.
item
()
precision
+=
pr
.
item
()
accuracy
+=
acc
.
item
()
rc
,
pr
=
calc_recall
(
output
.
data
,
target
,
device
)
accuracy
+=
pr
recall
+=
rc
if
batch_idx
%
log_interval
==
0
:
batch_size
=
target
.
shape
[
0
]
# print(100.0 * accuracy.item() / ((batch_idx + 1) * batch_size * 198))
logging
.
critical
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}
\t
Accuracy: {:.3f}
\t
Recall: {:.3f}'
.
format
(
epoch
,
batch_idx
+
1
,
training_loader
.
__len__
(),
100.
*
batch_idx
/
training_loader
.
__len__
(),
loss
.
item
(),
100.0
*
accuracy
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
*
198
),
100.0
*
recall
.
item
()
/
((
batch_idx
+
1
)
*
batch_size
*
198
)
))
if
precision
!=
0
or
recall
!=
0
:
f_measure
=
2
*
(
precision
/
((
batch_idx
+
1
)))
*
(
recall
/
((
batch_idx
+
1
)))
/
\
((
precision
/
((
batch_idx
+
1
)
))
+
(
recall
/
((
batch_idx
+
1
))))
logging
.
critical
(
'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '
\
'Recall: {:.3f} Precision: {:.3f} "
\
F-Measure: {:.3f}'
.
format
(
epoch
,
batch_idx
+
1
,
training_loader
.
__len__
(),
100.
*
batch_idx
/
training_loader
.
__len__
(),
loss
.
item
(),
100.0
*
accuracy
/
((
batch_idx
+
1
)),
100.0
*
recall
/
((
batch_idx
+
1
)),
100.0
*
precision
/
((
batch_idx
+
1
)),
f_measure
)
)
else
:
print
(
f
"precision =
{
precision
}
and recall =
{
recall
}
"
)
return
model
return
model
def
pearsonr
(
x
,
y
):
...
...
@@ -560,7 +492,7 @@ def llincc(x, y):
return ccc:
'''
def
cross_validation
(
model
,
validation_loader
,
device
):
"""
...
...
@@ -570,9 +502,12 @@ def cross_validation(model, validation_loader, device):
:return:
"""
model
.
eval
()
recall
=
0.0
precision
=
0.0
accuracy
=
0.0
loss
=
0.0
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
with
torch
.
no_grad
():
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
validation_loader
):
...
...
@@ -581,13 +516,66 @@ def cross_validation(model, validation_loader, device):
output
=
model
(
data
.
to
(
device
))
output
=
output
.
permute
(
1
,
2
,
0
)
target
=
target
.
permute
(
1
,
0
)
nbpoint
=
output
.
shape
[
0
]
rc
,
pr
=
calc_recall
(
output
.
data
,
target
,
device
)
accuracy
+=
pr
recall
+=
rc
loss
+=
criterion
(
output
,
target
.
to
(
device
))
fmes
=
2
*
(
accuracy
*
recall
)
/
(
recall
+
accuracy
)
return
fmes
/
((
batch_idx
+
1
)
*
batch_size
),
\
loss
.
cpu
().
numpy
()
/
((
batch_idx
+
1
)
*
batch_size
)
rc
,
pr
,
acc
=
calc_recall
(
output
.
data
,
target
,
device
)
recall
+=
rc
.
item
()
precision
+=
pr
.
item
()
accuracy
+=
acc
.
item
()
batch_size
=
target
.
shape
[
0
]
if
precision
!=
0
or
recall
!=
0
:
f_measure
=
2
*
(
precision
/
((
batch_idx
+
1
)))
*
(
recall
/
((
batch_idx
+
1
)))
/
\
((
precision
/
((
batch_idx
+
1
)))
+
(
recall
/
((
batch_idx
+
1
))))
logging
.
critical
(
'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '
\
'Recall: {:.3f} Precision: {:.3f} "
\
F-Measure: {:.3f}'
.
format
(
batch_idx
+
1
,
validation_loader
.
__len__
(),
100.
*
batch_idx
/
validation_loader
.
__len__
(),
loss
.
item
(),
100.0
*
accuracy
/
((
batch_idx
+
1
)),
100.0
*
recall
/
((
batch_idx
+
1
)),
100.0
*
precision
/
((
batch_idx
+
1
)),
f_measure
)
)
return
accuracy
,
loss
def
calc_recall
(
output
,
target
,
device
):
"""
:param output:
:param target:
:param device:
:return:
"""
y_trueb
=
target
.
to
(
device
)
y_predb
=
output
rc
=
0.0
pr
=
0.0
acc
=
0.0
for
b
in
range
(
y_trueb
.
shape
[
-
1
]):
y_true
=
y_trueb
[:,
b
]
y_pred
=
y_predb
[:,:,
b
]
assert
y_true
.
ndim
==
1
assert
y_pred
.
ndim
==
1
or
y_pred
.
ndim
==
2
if
y_pred
.
ndim
==
2
:
y_pred
=
y_pred
.
argmax
(
dim
=
1
)
tp
=
(
y_true
*
y_pred
).
sum
().
to
(
torch
.
float32
)
tn
=
((
1
-
y_true
)
*
(
1
-
y_pred
)).
sum
().
to
(
torch
.
float32
)
fp
=
((
1
-
y_true
)
*
y_pred
).
sum
().
to
(
torch
.
float32
)
fn
=
(
y_true
*
(
1
-
y_pred
)).
sum
().
to
(
torch
.
float32
)
epsilon
=
1e-7
pr
+=
tp
/
(
tp
+
fp
+
epsilon
)
rc
+=
tp
/
(
tp
+
fn
+
epsilon
)
a
=
(
tp
+
tn
)
/
(
tp
+
fp
+
tn
+
fn
+
epsilon
)
acc
+=
(
tp
+
tn
)
/
(
tp
+
fp
+
tn
+
fn
+
epsilon
)
rc
/=
len
(
y_trueb
[
0
])
pr
/=
len
(
y_trueb
[
0
])
acc
/=
len
(
y_trueb
[
0
])
return
rc
,
pr
,
acc
s4d/nnet/wavsets.py
View file @
7df4d68e
...
...
@@ -38,6 +38,7 @@ import scipy
import
sidekit
import
soundfile
import
torch
import
yaml
from
..diar
import
Diar
from
pathlib
import
Path
...
...
@@ -123,7 +124,7 @@ def mdtm_to_label(mdtm_filename,
diarization
.
segments
[
ii
][
'start'
]
=
diarization
.
segments
[
ii
-
1
][
'stop'
]
# Create the empty labels
label
=
numpy
.
zeros
(
sample_number
,
dtype
=
int
)
label
=
[]
# Compute the time stamp of each sample
time_stamps
=
numpy
.
zeros
(
sample_number
,
dtype
=
numpy
.
float32
)
...
...
@@ -131,38 +132,16 @@ def mdtm_to_label(mdtm_filename,
for
t
in
range
(
sample_number
):
time_stamps
[
t
]
=
start_time
+
(
2
*
t
+
1
)
*
period
/
2
for
idx
,
time
in
enumerate
(
time_stamps
):
lbls
=
[]
for
seg
in
diarization
.
segments
:
if
seg
[
'start'
]
/
100.
<=
time
<=
seg
[
'stop'
]
/
100.
:
lbls
.
append
(
speaker_dict
[
seg
[
'cluster'
]])
for
ii
,
i
in
enumerate
(
numpy
.
linspace
(
start_time
,
stop_time
,
num
=
sample_number
)):
cnt
=
0
for
d
in
diarization
.
segments
:
# print(d)
# print(d['start'],d['stop'])
if
d
[
'start'
]
/
100
<=
i
<=
d
[
'stop'
]
/
100
:
cnt
+=
1
overlaps
[
ii
]
=
cnt
# Find the label of the
# first sample
seg_idx
=
0
while
diarization
.
segments
[
seg_idx
][
'stop'
]
/
100.
<
start_time
:
#sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None, None, None, None, None))
seg_idx
+=
1
for
ii
,
t
in
enumerate
(
time_stamps
):
# Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
if
t
<=
diarization
.
segments
[
seg_idx
][
'start'
]
/
100.
:
# On laisse le label 0 (non-speech)
pass
# Si on est déjà dans le premier segment qui overlape
elif
diarization
.
segments
[
seg_idx
][
'start'
]
/
100.
<
t
<
diarization
.
segments
[
seg_idx
][
'stop'
]
/
100.
:
label
[
ii
]
=
speaker_dict
[
diarization
.
segments
[
seg_idx
][
'cluster'
]]
# Si on change de segment
elif
diarization
.
segments
[
seg_idx
][
'stop'
]
/
100.
<
t
and
len
(
diarization
.
segments
)
>
seg_idx
+
1
:
seg_idx
+=
1
# On est entre deux segments:
if
t
<
diarization
.
segments
[
seg_idx
][
'start'
]
/
100.
:
pass
elif
diarization
.
segments
[
seg_idx
][
'start'
]
/
100.
<
t
<
diarization
.
segments
[
seg_idx
][
'stop'
]
/
100.
:
label
[
ii
]
=
speaker_dict
[
diarization
.
segments
[
seg_idx
][
'cluster'
]]
if
len
(
lbls
)
>
0
:
label
.
append
(
lbls
)
else
:
label
.
append
([])
return
(
label
,
overlaps
)
...
...
@@ -207,7 +186,7 @@ def get_segment_label(label,
output_label
=
numpy
.
convolve
(
conv_filt
,
spk_change
,
mode
=
'same'
)
elif
mode
==
"overlap"
:
output_label
=
(
overlaps
>
1
).
astype
(
numpy
.
long
)
output_label
=
(
label
>
0.5
).
astype
(
numpy
.
long
)
else
:
raise
ValueError
(
"mode parameter must be 'vad', 'spk_turn' or 'overlap'"
)
...
...
@@ -240,16 +219,29 @@ def process_segment_label(label,
:param filter_type:
:return:
"""
# Create labels with Diracs at every speaker change detection
spk_change
=
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
)
spk_change
[:
-
1
]
=
label
[:
-
1
]
^
label
[
1
:]
spk_change
=
numpy
.
not_equal
(
spk_change
,
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
))
# depending of the mode, generates the labels and select the segments
if
mode
==
"vad"
:
output_label
=
(
label
>
0.5
).
astype
(
numpy
.
long
)
output_label
=
numpy
.
array
([
len
(
a
)
>
0
for
a
in
label
]
).
astype
(
numpy
.
long
)
elif
mode
==
"spk_turn"
:
tmp_label
=
[]
for
a
in
label
:
if
len
(
a
)
==
0
:
tmp_label
.
append
(
0
)
elif
len
(
a
)
==
1
:
tmp_label
.
append
(
a
[
0
])
else
:
tmp_label
.
append
(
sum
(
a
)
*
1000
)
label
=
numpy
.
array
(
label
)
# Create labels with Diracs at every speaker change detection
spk_change
=
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
)
spk_change
[:
-
1
]
=
label
[:
-
1
]
^
label
[
1
:]
spk_change
=
numpy
.
not_equal
(
spk_change
,
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
))
# Apply convolution to replace diracs by a chosen shape (gate or triangle)
filter_sample
=
int
(
collar_duration
*
framerate
*
2
+
1
)
...
...
@@ -259,7 +251,11 @@ def process_segment_label(label,
output_label
=
numpy
.
convolve
(
conv_filt
,
spk_change
,
mode
=
'same'
)
elif
mode
==
"overlap"
:
output_label
=
(
overlaps
>
1
).
astype
(
numpy
.
long
)
label
=
numpy
.
array
([
len
(
a
)
for
a
in
label
]).
astype
(
numpy
.
long
)
# For the moment, we just consider two classes: overlap / no-overlap
# in the future we might want to classify according to the number of speaker speaking at the same time
output_label
=
(
label
>
1
).
astype
(
numpy
.
long
)
else
:
raise
ValueError
(
"mode parameter must be 'vad', 'spk_turn' or 'overlap'"
)
...
...
@@ -268,7 +264,8 @@ def process_segment_label(label,