Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Meysam Shamsi
s4d
Commits
3bb25db0
Commit
3bb25db0
authored
Nov 20, 2020
by
Meysam Shamsi
Browse files
first stable version
parent
f5c00aed
Pipeline
#633
canceled with stages
Changes
5
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
s4d/nnet/seqtoseq.py
View file @
3bb25db0
...
...
@@ -118,7 +118,7 @@ class SeqToSeq(torch.nn.Module):
model_archi
):
super
(
SeqToSeq
,
self
).
__init__
()
# Load Yaml configuration
with
open
(
model_archi
,
'r'
)
as
fh
:
cfg
=
yaml
.
load
(
fh
,
Loader
=
yaml
.
FullLoader
)
...
...
@@ -132,19 +132,20 @@ class SeqToSeq(torch.nn.Module):
self
.
preprocessor
=
None
if
"preprocessor"
in
cfg
:
if
cfg
[
'preprocessor'
][
"type"
]
==
"sincnet"
:
self
.
preprocessor
=
SincNet
(
waveform_normalize
=
cfg
[
'preprocessor'
][
"waveform_normalize"
],
sample_rate
=
cfg
[
'preprocessor'
][
"sample_rate"
],
min_low_hz
=
cfg
[
'preprocessor'
][
"min_low_hz"
],
min_band_hz
=
cfg
[
'preprocessor'
][
"min_band_hz"
],
out_channels
=
cfg
[
'preprocessor'
][
"out_channels"
],
kernel_size
=
cfg
[
'preprocessor'
][
"kernel_size"
],
stride
=
cfg
[
'preprocessor'
][
"stride"
],
max_pool
=
cfg
[
'preprocessor'
][
"max_pool"
],
instance_normalize
=
cfg
[
'preprocessor'
][
"instance_normalize"
],
activation
=
cfg
[
'preprocessor'
][
"activation"
],
dropout
=
cfg
[
'preprocessor'
][
"dropout"
]
)
self
.
preprocessor
=
SincNet
()
# self.preprocessor = SincNet(
# waveform_normalize=cfg['preprocessor']["waveform_normalize"],
# sample_rate=cfg['preprocessor']["sample_rate"],
# min_low_hz=cfg['preprocessor']["min_low_hz"],
# min_band_hz=cfg['preprocessor']["min_band_hz"],
# out_channels=cfg['preprocessor']["out_channels"],
# kernel_size=cfg['preprocessor']["kernel_size"],
# stride=cfg['preprocessor']["stride"],
# max_pool=cfg['preprocessor']["max_pool"],
# instance_normalize=cfg['preprocessor']["instance_normalize"],
# activation=cfg['preprocessor']["activation"],
# dropout=cfg['preprocessor']["dropout"]
# )
self
.
feature_size
=
self
.
preprocessor
.
dimension
"""
...
...
@@ -182,6 +183,9 @@ class SeqToSeq(torch.nn.Module):
elif
k
.
startswith
(
'dropout'
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Dropout
(
p
=
cfg
[
"post_processing"
][
k
])))
elif
k
.
startswith
(
'softmax'
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Softmax
(
dim
=
2
)))
self
.
post_processing
=
torch
.
nn
.
Sequential
(
OrderedDict
(
post_processing_layers
))
self
.
post_processing
.
apply
(
init_weights
)
...
...
@@ -237,13 +241,14 @@ def seqTrain(dataset_yaml,
# Start from scratch
if
model_name
is
None
:
model
=
SeqToSeq
(
model_yaml
)
model
=
SeqToSeq
(
model_yaml
)
# If we start from an existing model
else
:
# Load the model
logging
.
critical
(
f
"*** Load model from =
{
model_name
}
"
)
checkpoint
=
torch
.
load
(
model_name
)
checkpoint
=
torch
.
load
(
model_name
,
map_location
=
'cpu'
)
model
=
SeqToSeq
(
model_yaml
)
model
.
load_state_dict
(
checkpoint
[
'model_state_dict'
])
if
torch
.
cuda
.
device_count
()
>
1
and
multi_gpu
:
print
(
"Let's use"
,
torch
.
cuda
.
device_count
(),
"GPUs!"
)
...
...
@@ -402,8 +407,8 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
((
precision
/
((
batch_idx
+
1
)
))
+
(
recall
/
((
batch_idx
+
1
))))
logging
.
critical
(
'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '
\
'Recall: {:.3f} Precision: {:.3f}
"
\
F-Measure: {:.3f}'
.
format
(
epoch
,
'Recall: {:.3f} Precision: {:.3f}
'
\
'
F-Measure: {:.3f}'
.
format
(
epoch
,
batch_idx
+
1
,
training_loader
.
__len__
(),
100.
*
batch_idx
/
training_loader
.
__len__
(),
loss
.
item
(),
...
...
@@ -455,8 +460,8 @@ def cross_validation(model, validation_loader, device):
((
precision
/
((
batch_idx
+
1
)))
+
(
recall
/
((
batch_idx
+
1
))))
logging
.
critical
(
'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} '
\
'Recall: {:.3f} Precision: {:.3f}
"
\
F-Measure: {:.3f}'
.
format
(
batch_idx
+
1
,
'Recall: {:.3f} Precision: {:.3f}
'
\
'
F-Measure: {:.3f}'
.
format
(
batch_idx
+
1
,
validation_loader
.
__len__
(),
100.
*
batch_idx
/
validation_loader
.
__len__
(),
loss
.
item
(),
100.0
*
accuracy
/
((
batch_idx
+
1
)),
...
...
@@ -464,7 +469,7 @@ def cross_validation(model, validation_loader, device):
100.0
*
precision
/
((
batch_idx
+
1
)),
f_measure
)
)
return
accuracy
,
loss
return
100.0
*
accuracy
/
((
batch_idx
+
1
)),
loss
/
(
batch_idx
+
1
)
def
calc_recall
(
output
,
target
,
device
):
...
...
@@ -487,6 +492,9 @@ def calc_recall(output,target,device):
assert
y_pred
.
ndim
==
1
or
y_pred
.
ndim
==
2
if
y_pred
.
ndim
==
2
:
y_pred
=
y_pred
.
argmax
(
dim
=
1
)
# print("y_true:",y_true)
# print("y_pred:",y_pred)
tp
=
(
y_true
*
y_pred
).
sum
().
to
(
torch
.
float32
)
tn
=
((
1
-
y_true
)
*
(
1
-
y_pred
)).
sum
().
to
(
torch
.
float32
)
...
...
@@ -496,9 +504,8 @@ def calc_recall(output,target,device):
pr
+=
tp
/
(
tp
+
fp
+
epsilon
)
rc
+=
tp
/
(
tp
+
fn
+
epsilon
)
a
=
(
tp
+
tn
)
/
(
tp
+
fp
+
tn
+
fn
+
epsilon
)
acc
+=
(
tp
+
tn
)
/
(
tp
+
fp
+
tn
+
fn
+
epsilon
)
rc
/=
len
(
y_trueb
[
0
])
pr
/=
len
(
y_trueb
[
0
])
acc
/=
len
(
y_trueb
[
0
])
...
...
s4d/nnet/wavsets.py
View file @
3bb25db0
...
...
@@ -52,7 +52,17 @@ from torchvision import transforms
from
collections
import
namedtuple
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
def
overlapping
(
seg1
,
seg2
):
seg1_start
,
seg1_stop
=
seg1
seg2_start
,
seg2_stop
=
seg2
if
seg1_start
<=
seg2_start
:
# |------------|
# |-------|
return
seg1_stop
>
seg2_start
else
:
# |---------------|
# |---------|
return
seg2_stop
>
seg1_start
def
framing
(
sig
,
win_size
,
win_shift
=
1
,
context
=
(
0
,
0
),
pad
=
'zeros'
):
"""
:param sig: input signal, can be mono or multi dimensional
...
...
@@ -129,9 +139,10 @@ def mdtm_to_label(mdtm_filename,
for
t
in
range
(
sample_number
):
time_stamps
[
t
]
=
start_time
+
(
2
*
t
+
1
)
*
period
/
2
framed_segments
=
[
seg
for
seg
in
diarization
.
segments
if
overlapping
((
seg
[
'start'
],
seg
[
'stop'
]),(
start_time
*
100
,
stop_time
*
100
))]
for
idx
,
time
in
enumerate
(
time_stamps
):
lbls
=
[]
for
seg
in
diarization
.
segments
:
for
seg
in
framed_
segments
:
if
seg
[
'start'
]
/
100.
<=
time
<=
seg
[
'stop'
]
/
100.
:
lbls
.
append
(
speaker_dict
[
seg
[
'cluster'
]])
...
...
@@ -260,19 +271,20 @@ def process_segment_label(label,
def
seqSplit
(
mdtm_dir
,
wav_dir
,
uem_dir
=
None
,
duration
=
2.
):
"""
:param mdtm_dir:
:param duration:
:return:
:param mdtm_dir:
:param duration:
:return:
"""
segment_list
=
Diar
()
speaker_dict
=
dict
()
idx
=
0
# For each MDTM
for
mdtm_file
in
pathlib
.
Path
(
mdtm_dir
).
glob
(
'*.mdtm'
):
# Load MDTM file
ref
=
Diar
.
read_mdtm
(
mdtm_file
)
ref
.
sort
()
...
...
@@ -282,10 +294,10 @@ def seqSplit(mdtm_dir,
# Check the length of audio
nfo
=
soundfile
.
info
(
wav_dir
+
str
(
mdtm_file
)[
len
(
mdtm_dir
):].
split
(
"."
)[
0
]
+
".wav"
)
# For each border time B get a segment between B - duration and B + duration
# in which we will pick up randomly later
for
idx
,
seg
in
enumerate
(
ref
.
segments
):
for
idx
2
,
seg
in
enumerate
(
ref
.
segments
):
if
seg
[
"start"
]
/
100.
>
duration
and
seg
[
"start"
]
/
100.
+
duration
<
nfo
.
duration
:
segment_list
.
append
(
show
=
seg
[
'show'
],
...
...
@@ -369,11 +381,13 @@ class SeqSet(Dataset):
if
segment_list
is
None
and
speaker_dict
is
None
:
segment_list
,
speaker_dict
=
seqSplit
(
mdtm_dir
=
self
.
mdtm_dir
,
wav_dir
=
wav_dir
,
duration
=
self
.
duration
)
self
.
segment_list
=
segment_list
self
.
speaker_dict
=
speaker_dict
self
.
len
=
len
(
segment_list
)
def
__getitem__
(
self
,
index
):
"""
...
...
@@ -395,7 +409,7 @@ class SeqSet(Dataset):
sig
+=
0.0001
*
numpy
.
random
.
randn
(
sig
.
shape
[
0
])
if
self
.
transform_pipeline
:
sig
,
speaker_idx
,
_
,
__
,
_t
,
_s
=
self
.
transforms
((
sig
,
None
,
None
,
None
,
None
,
None
))
sig
,
speaker_idx
,
_t
,
_s
=
self
.
transforms
((
sig
,
None
,
None
,
None
,
None
,
None
))
tmp_label
=
mdtm_to_label
(
mdtm_filename
=
self
.
mdtm_dir
+
seg
[
"show"
]
+
".mdtm"
,
start_time
=
start
,
...
...
s4d/nnet/wavsets_Org.py
0 → 100644
View file @
3bb25db0
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2020 Anthony Larcher
"""
__license__
=
"LGPL"
__author__
=
"Anthony Larcher"
__copyright__
=
"Copyright 2015-2020 Anthony Larcher and Sylvain Meignier"
__maintainer__
=
"Anthony Larcher"
__email__
=
"anthony.larcher@univ-lemans.fr"
__status__
=
"Production"
__docformat__
=
'reStructuredText'
import
numpy
import
pathlib
import
random
import
scipy
import
sidekit
import
soundfile
import
torch
import
yaml
from
..diar
import
Diar
from
pathlib
import
Path
from
sidekit.nnet.xsets
import
PreEmphasis
from
sidekit.nnet.xsets
import
MFCC
from
sidekit.nnet.xsets
import
CMVN
from
sidekit.nnet.xsets
import
FrequencyMask
from
sidekit.nnet.xsets
import
TemporalMask
from
torch.utils.data
import
Dataset
from
torchvision
import
transforms
from
collections
import
namedtuple
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
def
framing
(
sig
,
win_size
,
win_shift
=
1
,
context
=
(
0
,
0
),
pad
=
'zeros'
):
"""
:param sig: input signal, can be mono or multi dimensional
:param win_size: size of the window in term of samples
:param win_shift: shift of the sliding window in terme of samples
:param context: tuple of left and right context
:param pad: can be zeros or edge
"""
dsize
=
sig
.
dtype
.
itemsize
if
sig
.
ndim
==
1
:
sig
=
sig
[:,
numpy
.
newaxis
]
# Manage padding
c
=
(
context
,
)
+
(
sig
.
ndim
-
1
)
*
((
0
,
0
),
)
_win_size
=
win_size
+
sum
(
context
)
shape
=
(
int
((
sig
.
shape
[
0
]
-
win_size
)
/
win_shift
)
+
1
,
1
,
_win_size
,
sig
.
shape
[
1
])
strides
=
tuple
(
map
(
lambda
x
:
x
*
dsize
,
[
win_shift
*
sig
.
shape
[
1
],
1
,
sig
.
shape
[
1
],
1
]))
return
numpy
.
lib
.
stride_tricks
.
as_strided
(
sig
,
shape
=
shape
,
strides
=
strides
).
squeeze
()
def
load_wav_segment
(
wav_file_name
,
idx
,
duration
,
seg_shift
,
framerate
=
16000
):
"""
:param wav_file_name:
:param idx:
:param duration:
:param seg_shift:
:param framerate:
:return:
"""
# Load waveform
signal
=
sidekit
.
frontend
.
io
.
read_audio
(
wav_file_name
,
framerate
)[
0
]
tmp
=
framing
(
signal
,
int
(
framerate
*
duration
),
win_shift
=
int
(
framerate
*
seg_shift
),
context
=
(
0
,
0
),
pad
=
'zeros'
)
return
tmp
[
idx
],
len
(
signal
)
def
mdtm_to_label
(
mdtm_filename
,
start_time
,
stop_time
,
sample_number
,
speaker_dict
):
"""
:param mdtm_filename:
:param start_time:
:param stop_time:
:param sample_number:
:param speaker_dict:
:return:
"""
diarization
=
Diar
.
read_mdtm
(
mdtm_filename
)
diarization
.
sort
([
'show'
,
'start'
])
# When one segment starts just the frame after the previous one ends, o
# we replace the time of the start by the time of the previous stop to avoid artificial holes
previous_stop
=
0
for
ii
,
seg
in
enumerate
(
diarization
.
segments
):
if
ii
==
0
:
previous_stop
=
seg
[
'stop'
]
else
:
if
seg
[
'start'
]
==
diarization
.
segments
[
ii
-
1
][
'stop'
]
+
1
:
diarization
.
segments
[
ii
][
'start'
]
=
diarization
.
segments
[
ii
-
1
][
'stop'
]
# Create the empty labels
label
=
[]
# Compute the time stamp of each sample
time_stamps
=
numpy
.
zeros
(
sample_number
,
dtype
=
numpy
.
float32
)
period
=
(
stop_time
-
start_time
)
/
sample_number
for
t
in
range
(
sample_number
):
time_stamps
[
t
]
=
start_time
+
(
2
*
t
+
1
)
*
period
/
2
for
idx
,
time
in
enumerate
(
time_stamps
):
lbls
=
[]
for
seg
in
diarization
.
segments
:
if
seg
[
'start'
]
/
100.
<=
time
<=
seg
[
'stop'
]
/
100.
:
lbls
.
append
(
speaker_dict
[
seg
[
'cluster'
]])
if
len
(
lbls
)
>
0
:
label
.
append
(
lbls
)
else
:
label
.
append
([])
return
label
def
get_segment_label
(
label
,
seg_idx
,
mode
,
duration
,
framerate
,
seg_shift
,
collar_duration
,
filter_type
=
"gate"
):
"""
:param label:
:param seg_idx:
:param mode:
:param duration:
:param framerate:
:param seg_shift:
:param collar_duration:
:param filter_type:
:return:
"""
# Create labels with Diracs at every speaker change detection
spk_change
=
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
)
spk_change
[:
-
1
]
=
label
[:
-
1
]
^
label
[
1
:]
spk_change
=
numpy
.
not_equal
(
spk_change
,
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
))
# depending of the mode, generates the labels and select the segments
if
mode
==
"vad"
:
output_label
=
(
label
>
0.5
).
astype
(
numpy
.
long
)
elif
mode
==
"spk_turn"
:
# Apply convolution to replace diracs by a chosen shape (gate or triangle)
filter_sample
=
collar_duration
*
framerate
*
2
+
1
conv_filt
=
numpy
.
ones
(
filter_sample
)
if
filter_type
==
"triangle"
:
conv_filt
=
scipy
.
signal
.
triang
(
filter_sample
)
output_label
=
numpy
.
convolve
(
conv_filt
,
spk_change
,
mode
=
'same'
)
elif
mode
==
"overlap"
:
output_label
=
(
label
>
0.5
).
astype
(
numpy
.
long
)
else
:
raise
ValueError
(
"mode parameter must be 'vad', 'spk_turn' or 'overlap'"
)
# Create segments with overlap
segment_label
=
framing
(
output_label
,
int
(
framerate
*
duration
),
win_shift
=
int
(
framerate
*
seg_shift
),
context
=
(
0
,
0
),
pad
=
'zeros'
)
return
segment_label
[
seg_idx
]
def
process_segment_label
(
label
,
mode
,
framerate
,
collar_duration
,
filter_type
=
"gate"
):
"""
:param label:
:param seg_idx:
:param mode:
:param duration:
:param framerate:
:param seg_shift:
:param collar_duration:
:param filter_type:
:return:
"""
# depending of the mode, generates the labels and select the segments
if
mode
==
"vad"
:
output_label
=
numpy
.
array
([
len
(
a
)
>
0
for
a
in
label
]).
astype
(
numpy
.
long
)
elif
mode
==
"spk_turn"
:
tmp_label
=
[]
for
a
in
label
:
if
len
(
a
)
==
0
:
tmp_label
.
append
(
0
)
elif
len
(
a
)
==
1
:
tmp_label
.
append
(
a
[
0
])
else
:
tmp_label
.
append
(
sum
(
a
)
*
1000
)
label
=
numpy
.
array
(
label
)
# Create labels with Diracs at every speaker change detection
spk_change
=
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
)
spk_change
[:
-
1
]
=
label
[:
-
1
]
^
label
[
1
:]
spk_change
=
numpy
.
not_equal
(
spk_change
,
numpy
.
zeros
(
label
.
shape
,
dtype
=
int
))
# Apply convolution to replace diracs by a chosen shape (gate or triangle)
filter_sample
=
int
(
collar_duration
*
framerate
*
2
+
1
)
conv_filt
=
numpy
.
ones
(
filter_sample
)
if
filter_type
==
"triangle"
:
conv_filt
=
scipy
.
signal
.
triang
(
filter_sample
)
output_label
=
numpy
.
convolve
(
conv_filt
,
spk_change
,
mode
=
'same'
)
elif
mode
==
"overlap"
:
label
=
numpy
.
array
([
len
(
a
)
for
a
in
label
]).
astype
(
numpy
.
long
)
# For the moment, we just consider two classes: overlap / no-overlap
# in the future we might want to classify according to the number of speaker speaking at the same time
output_label
=
(
label
>
1
).
astype
(
numpy
.
long
)
# output_label=label
# for i in range(len(output_label)):
# if output_label[i]>1:
# output_label[i]=2
else
:
raise
ValueError
(
"mode parameter must be 'vad', 'spk_turn' or 'overlap'"
)
return
output_label
def
seqSplit
(
mdtm_dir
,
wav_dir
,
duration
=
2.
):
"""
:param mdtm_dir:
:param duration:
:return:
"""
segment_list
=
Diar
()
speaker_dict
=
dict
()
idx
=
0
# For each MDTM
for
mdtm_file
in
pathlib
.
Path
(
mdtm_dir
).
glob
(
'*.mdtm'
):
# Load MDTM file
ref
=
Diar
.
read_mdtm
(
mdtm_file
)
ref
.
sort
()
last_stop
=
ref
.
segments
[
-
1
][
"stop"
]
# Get the borders of the segments (not the start of the first and not the end of the last
# Check the length of audio
nfo
=
soundfile
.
info
(
wav_dir
+
str
(
mdtm_file
)[
len
(
mdtm_dir
):].
split
(
"."
)[
0
]
+
".wav"
)
# For each border time B get a segment between B - duration and B + duration
# in which we will pick up randomly later
for
idx
,
seg
in
enumerate
(
ref
.
segments
):
if
seg
[
"start"
]
/
100.
>
duration
and
seg
[
"start"
]
/
100.
+
duration
<
nfo
.
duration
:
segment_list
.
append
(
show
=
seg
[
'show'
],
cluster
=
""
,
start
=
float
(
seg
[
"start"
])
/
100.
-
duration
,
stop
=
float
(
seg
[
"start"
])
/
100.
+
duration
)
if
seg
[
"stop"
]
/
100.
>
duration
and
seg
[
"stop"
]
/
100.
+
duration
<
nfo
.
duration
:
segment_list
.
append
(
show
=
seg
[
'show'
],
cluster
=
""
,
start
=
float
(
seg
[
"stop"
])
/
100.
-
duration
,
stop
=
float
(
seg
[
"stop"
])
/
100.
+
duration
)
# Get list of unique speakers
speakers
=
ref
.
unique
(
'cluster'
)
for
spk
in
speakers
:
if
not
spk
in
speaker_dict
:
speaker_dict
[
spk
]
=
idx
idx
+=
1
return
segment_list
,
speaker_dict
class
SeqSet
(
Dataset
):
"""
Object creates a dataset for sequence to sequence training
"""
def
__init__
(
self
,
wav_dir
,
mdtm_dir
,
mode
,
segment_list
=
None
,
speaker_dict
=
None
,
duration
=
2.
,
filter_type
=
"gate"
,
collar_duration
=
0.1
,
audio_framerate
=
16000
,
output_framerate
=
100
,
transform_pipeline
=
""
):
"""
:param wav_dir:
:param mdtm_dir:
:param mode:
:param duration:
:param filter_type:
:param collar_duration:
:param audio_framerate:
:param output_framerate:
:param transform_pipeline:
"""
self
.
wav_dir
=
wav_dir
self
.
mdtm_dir
=
mdtm_dir
self
.
mode
=
mode
self
.
duration
=
duration
self
.
filter_type
=
filter_type
self
.
collar_duration
=
collar_duration
self
.
audio_framerate
=
audio_framerate
self
.
output_framerate
=
output_framerate
self
.
transform_pipeline
=
transform_pipeline
_transform
=
[]
if
not
self
.
transform_pipeline
==
''
:
trans
=
self
.
transform_pipeline
.
split
(
','
)
for
t
in
trans
:
if
'PreEmphasis'
in
t
:
_transform
.
append
(
PreEmphasis
())
if
'MFCC'
in
t
:
_transform
.
append
(
MFCC
())
if
"CMVN"
in
t
:
_transform
.
append
(
CMVN
())
if
"FrequencyMask"
in
t
:
a
=
int
(
t
.
split
(
'-'
)[
0
].
split
(
'('
)[
1
])
b
=
int
(
t
.
split
(
'-'
)[
1
].
split
(
')'
)[
0
])
_transform
.
append
(
FrequencyMask
(
a
,
b
))
if
"TemporalMask"
in
t
:
a
=
int
(
t
.
split
(
"("
)[
1
].
split
(
")"
)[
0
])
_transform
.
append
(
TemporalMask
(
a
))
self
.
transforms
=
transforms
.
Compose
(
_transform
)
if
segment_list
is
None
and
speaker_dict
is
None
:
segment_list
,
speaker_dict
=
seqSplit
(
mdtm_dir
=
self
.
mdtm_dir
,
duration
=
self
.
duration
)
self
.
segment_list
=
segment_list
self
.
speaker_dict
=
speaker_dict
self
.
len
=
len
(
segment_list
)
def
__getitem__
(
self
,
index
):
"""
On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
(trames)
:param index:
:return:
"""
# Get segment info to load from
seg
=
self
.
segment_list
[
index
]
# Randomly pick an audio chunk within the current segment
start
=
random
.
uniform
(
seg
[
"start"
],
seg
[
"start"
]
+
self
.
duration
)
sig
,
_
=
soundfile
.
read
(
self
.
wav_dir
+
seg
[
"show"
]
+
".wav"
,
start
=
int
(
start
*
self
.
audio_framerate
),
stop
=
int
((
start
+
self
.
duration
)
*
self
.
audio_framerate
)
)
sig
+=
0.0001
*
numpy
.
random
.
randn
(
sig
.
shape
[
0
])
if
self
.
transform_pipeline
:
sig
,
speaker_idx
,
_t
,
_s
=
self
.
transforms
((
sig
,
None
,
None
,
None
,
None
,
None
))
tmp_label
=
mdtm_to_label
(
mdtm_filename
=
self
.
mdtm_dir
+
seg
[
"show"
]
+
".mdtm"
,