Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sylvain Meignier
s4d
Commits
0403dd7f
Commit
0403dd7f
authored
Jan 27, 2022
by
Martin Lebourdais
Browse files
Prediction without labels (not clean, but should work, testing in progress)
parent
2bbed2f5
Changes
4
Hide whitespace changes
Inline
Side-by-side
s4d/nnet/seqtoseq_predict.py
View file @
0403dd7f
...
...
@@ -109,7 +109,7 @@ def calculate_pyannote_metrics(predf,trueyaml,task='overlap'):
cfg
=
yaml
.
load
(
fh
,
Loader
=
yaml
.
FullLoader
)
else
:
cfg
=
trueyaml
show_list
=
cfg
[
"file_list"
]
uem
=
None
if
"uem_dir"
in
cfg
:
...
...
@@ -238,8 +238,8 @@ def model_pred(prediction_yaml,
wav_dir
=
cfg
[
"wav_dir"
]
model_name
=
cfg
[
"model_name"
]
model_yaml
=
cfg
[
"model_archi"
]
seg_set
=
cfg
[
"seg_set"
]
label_set
=
cfg
[
"label_set"
]
seg_set
=
cfg
.
get
(
"seg_set"
,
None
)
label_set
=
cfg
.
get
(
"label_set"
,
None
)
uem_dir
=
cfg
.
get
(
"uem_dir"
,
None
)
batch_size
=
cfg
[
"batch_size"
]
audio_fr
=
cfg
[
"audio_samplerate"
]
...
...
@@ -259,7 +259,7 @@ def model_pred(prediction_yaml,
# loading model
with
open
(
model_yaml
,
"r"
)
as
fh
:
archi
=
yaml
.
load
(
fh
,
Loader
=
yaml
.
FullLoader
)
# manage context in case of convolutional pre-processing
context
=
0
if
not
is_multichannel
:
...
...
@@ -272,7 +272,7 @@ def model_pred(prediction_yaml,
checkpoint
=
torch
.
load
(
model_name
,
map_location
=
device
)
# Pytorch fails with saving mel scale transformation
# Pytorch fails with saving mel scale transformation
if
'pre_processing.sacc.mel_scale.fb'
in
checkpoint
[
'model_state_dict'
].
keys
():
checkpoint
[
'model_state_dict'
][
'pre_processing.sacc.mel_scale.fb'
]
=
torch
.
Tensor
([])
if
'pre_processing.complex_sacc.mel_scale.fb'
in
checkpoint
[
'model_state_dict'
].
keys
():
...
...
@@ -329,7 +329,7 @@ def model_pred(prediction_yaml,
show
=
show
,
first
=
first
,
out_dir
=
out_dir
,
)
pred_history
.
append
({
"show"
:
show
,
"pred"
:
pred
,
"target"
:
target
,
"raw_pred"
:
raw_pred
[:,
1
]})
...
...
@@ -347,12 +347,12 @@ def model_pred(prediction_yaml,
os
.
mkdir
(
res_path
+
"rttm/"
)
#compute scores + create RTTM file with current prediction
scores
=
metrics
(
pred_history
,
out_rttm
,
min_on
,
min_off
,
compute_ap_score
=
True
,
uem
=
uem_dir
)
# Serializing json
# Serializing json
json_object
=
json
.
dumps
(
scores
,
indent
=
4
)
# Writing to json file
with
open
(
fname
,
"w"
)
as
outfile
:
outfile
.
write
(
json_object
)
...
...
@@ -375,7 +375,7 @@ def model_pred(prediction_yaml,
min_off
=
min_off
,
metrics
=
scores
)
scores
[
"test_config"
]
=
{
"th_in"
:
th_in
,
"th_out"
:
th_out
,
"min_on"
:
min_on
,
"min_off"
:
min_off
}
# metrics from Pyannote as a comparison
if
pyametrics
is
not
None
:
logger
.
message
(
"
\n\n
--------Pyannote--------"
)
...
...
@@ -388,9 +388,46 @@ def model_pred(prediction_yaml,
return
pred_history
,
scores
else
:
return
pred_history
res_path
=
"{}metrics/"
.
format
(
exp_dir
)
out_rttm
=
"{}rttm/{:s}_in_{:d}_out_{:d}_on_{:d}_off_{:d}.rttm"
.
format
(
res_path
,
task
,
int
(
th_in
*
100
),
int
(
th_out
*
100
),
int
(
min_on
*
1e3
),
int
(
min_off
*
1e-3
),)
resf
=
open
(
out_rttm
,
'w'
)
for
show
in
pred_history
:
pred
=
show
[
"pred"
]
uem
=
uem_dir
create_rttm
(
resf
,
pred
,
show
[
'show'
],
min_on
=
min_on
,
min_off
=
min_off
,
uem
=
uem
,
show
=
show
[
"show"
])
return
pred_history
,
0
'''
def metrics(predict_per_show,resfpath,min_on,min_off,compute_ap_score=False,uem=None):
resf = open(resfpath,'w')
macrotp = 0
macrotn = 0
macrofp = 0
macrofn = 0
ap=0.0
for show in predict_per_show:
true = show["target"]
uem_mask = get_uem_mask(uem,show["show"],true)
pred = show["pred"]
truem = true[uem_mask]
predm = pred[uem_mask]
tp = numpy.sum(truem * predm)
tn = numpy.sum((1 - truem) * (1 - predm))
fp = numpy.sum((1 - truem) * predm)
fn = numpy.sum(truem * (1 - predm))
macrotp += tp
macrotn += tn
macrofp += fp
macrofn += fn
epsilon = 1e-7
if compute_ap_score:
raw_pred=show["raw_pred"]
ap+=average_precision_score(true,raw_pred,)
create_rttm(resf,pred,show['show'],min_on=min_on,min_off=min_off,uem=uem,show=show["show"])
'''
def
predict
(
batch_size
,
validation_loader
,
model
,
# TODO (2021/03/26) define synthetic overlap ratio to guarantee
...
...
@@ -405,7 +442,7 @@ def predict(batch_size,
show
=
None
,
first
=
True
,
out_dir
=
None
,
):
"""
A MODIFIER POU NE PRENDRE QUE LE NOM DU FICHIER WAV ET CRÉER LE DATA LOADER À L'INTERIEUR
...
...
@@ -423,18 +460,27 @@ def predict(batch_size,
output_target
=
[]
output_idx
=
[]
done
=
first
islabel
=
False
if
first
:
sm
=
torch
.
nn
.
Softmax
(
dim
=
2
)
with
torch
.
no_grad
():
for
batch_idx
,
(
win_idx
,
data
,
target
)
in
enumerate
(
validation_loader
):
target
=
target
.
squeeze
().
cpu
().
numpy
()
for
batch_idx
,
data_full
in
enumerate
(
validation_loader
):
if
len
(
data_full
)
==
3
:
win_idx
,
data
,
target
=
data_full
islabel
=
True
else
:
win_idx
,
data
=
data_full
islabel
=
False
if
islabel
:
target
=
target
.
squeeze
().
cpu
().
numpy
()
output
=
sm
(
model
(
data
.
to
(
device
))).
cpu
().
numpy
()
del
(
data
)
for
ii
in
range
(
output
.
shape
[
0
]):
output_data
.
append
(
output
[
ii
])
output_target
.
append
(
target
[
ii
])
if
islabel
:
output_target
.
append
(
target
[
ii
])
output_idx
.
append
(
int
(
win_idx
[
ii
]))
# Unfold outputs by averaging sliding windows
final_output
,
final_target
=
multi_label_combination
(
output_idx
,
output_target
,
...
...
@@ -445,11 +491,15 @@ def predict(batch_size,
raw_output
=
final_output
[:,
1
]
if
out_dir
:
pickle
.
dump
(
final_output
,
open
(
out_dir
+
f
"pred_
{
show
}
.pkl"
,
"wb"
))
pickle
.
dump
(
final_target
,
open
(
out_dir
+
f
"target_
{
show
}
.pkl"
,
"wb"
))
if
islabel
:
pickle
.
dump
(
final_target
,
open
(
out_dir
+
f
"target_
{
show
}
.pkl"
,
"wb"
))
else
:
if
out_dir
:
final_output
=
pickle
.
load
(
open
(
out_dir
+
f
"pred_
{
show
}
.pkl"
,
"rb"
))
final_target
=
pickle
.
load
(
open
(
out_dir
+
f
"target_
{
show
}
.pkl"
,
"rb"
))
if
islabel
:
final_target
=
pickle
.
load
(
open
(
out_dir
+
f
"target_
{
show
}
.pkl"
,
"rb"
))
else
:
final_target
=
[]
else
:
raise
Exception
(
"No path where to find previously computed predictions !"
)
...
...
@@ -544,11 +594,12 @@ def multi_label_combination(output_idx, output_target, output_data, shift, outpu
Author: Anthony Larcher
"""
islabel
=
len
(
output_target
)
>
0
win_shift
=
int
(
shift
*
output_rate
)
# Initialize the size of final_output
final_output
=
numpy
.
zeros
((
win_shift
*
(
len
(
output_data
)
-
1
)
+
output_data
[
0
].
shape
[
0
],
output_data
[
0
].
shape
[
1
]))
final_target
=
numpy
.
zeros
(
win_shift
*
(
len
(
output_data
)
-
1
)
+
output_data
[
0
].
shape
[
0
])
overlaping_label_count
=
numpy
.
zeros
(
final_output
.
shape
)
...
...
@@ -556,19 +607,27 @@ def multi_label_combination(output_idx, output_target, output_data, shift, outpu
tmp
=
numpy
.
ones
(
output_data
[
0
].
shape
)
# Loop on the overlaping windows
for
idx
,
tmp_t
,
tmp_d
in
zip
(
output_idx
,
output_target
,
output_data
):
start_idx
=
win_shift
*
idx
stop_idx
=
start_idx
+
win_len
if
islabel
:
for
idx
,
tmp_t
,
tmp_d
in
zip
(
output_idx
,
output_target
,
output_data
):
start_idx
=
win_shift
*
idx
stop_idx
=
start_idx
+
win_len
overlaping_label_count
[
start_idx
:
stop_idx
,
:]
+=
tmp
final_output
[
start_idx
:
stop_idx
,
:]
+=
tmp_d
final_target
[
start_idx
:
stop_idx
]
+=
tmp_t
else
:
for
idx
,
tmp_d
in
zip
(
output_idx
,
output_data
):
start_idx
=
win_shift
*
idx
stop_idx
=
start_idx
+
win_len
overlaping_label_count
[
start_idx
:
stop_idx
,
:]
+=
tmp
final_output
[
start_idx
:
stop_idx
,
:]
+=
tmp_d
overlaping_label_count
[
start_idx
:
stop_idx
,
:]
+=
tmp
final_output
[
start_idx
:
stop_idx
,
:]
+=
tmp_d
final_target
[
start_idx
:
stop_idx
]
+=
tmp_t
# Divide by the number of overlapping values
raw_output
=
final_output
final_output
/=
overlaping_label_count
final_target
/=
overlaping_label_count
[:,
0
].
squeeze
()
if
islabel
:
final_target
/=
overlaping_label_count
[:,
0
].
squeeze
()
return
final_output
,
final_target
s4d/nnet/seqtoseq_training.py
View file @
0403dd7f
...
...
@@ -104,11 +104,13 @@ def prepare_loaders(dataset_yaml, logger=None, rng_=None, seed=1234):
task
=
task
,
rng
=
rng
)
elif
sampler_typ
==
"random"
:
print
(
f
"TASK :
{
task
}
, OVART:
{
art_ola_ratio
}
"
)
sampler
=
SeqSetRandomSampler
(
batch_size
=
batch_size
,
batch_num
=
batch_num
,
list_file
=
file_list
,
seg_set
=
seg_set
,
task
=
task
,
mode
=
"train"
,
artificial_ov_ratio
=
art_ola_ratio
,
rng
=
rng
)
else
:
...
...
@@ -126,6 +128,8 @@ def prepare_loaders(dataset_yaml, logger=None, rng_=None, seed=1234):
batch_num
=
batch_num
,
list_file
=
eval_file_list
,
seg_set
=
eval_seg_set
,
task
=
task
,
mode
=
"eval"
,
rng
=
rng
)
eval_loader
=
DataLoader
(
evaluation_set
,
...
...
s4d/nnet/sequence_models.py
View file @
0403dd7f
...
...
@@ -156,12 +156,14 @@ class SeqToSeq(torch.nn.Module):
self
.
feature_size
=
cfg
[
"feature_size"
]
self
.
samplerate
=
cfg
[
"samplerate"
]
self
.
channel_number
=
cfg
[
"channel_number"
]
self
.
sum_channels
=
False
self
.
is_mfcc
=
False
# pre-processing layers
pre_processing_layers
=
[]
for
k
in
cfg
[
"pre_processing"
].
keys
():
if
k
.
startswith
(
"mfcc"
):
self
.
is_mfcc
=
True
n_fft
=
cfg
[
"pre_processing"
][
k
][
"n_fft"
]
win_length
=
cfg
[
"pre_processing"
][
k
].
get
(
"win_length"
,
480
)
hop_length
=
cfg
[
"pre_processing"
][
k
][
"win_shift"
]
...
...
@@ -190,6 +192,7 @@ class SeqToSeq(torch.nn.Module):
self
.
feature_size
=
n_mels
//
stride
[
0
]
input_size
=
self
.
feature_size
if
k
.
startswith
(
"mono_mel"
):
self
.
is_mfcc
=
True
pre_processing_layers
.
append
((
k
,
MelSpec
(
samplerate
=
self
.
samplerate
,
conf
=
cfg
[
"pre_processing"
][
k
])))
input_size
=
self
.
feature_size
...
...
@@ -204,6 +207,7 @@ class SeqToSeq(torch.nn.Module):
input_size
=
self
.
feature_size
# In case of multchannel input, tConv is a learnable filter and sum beamforming
if
k
.
startswith
(
"tConv"
):
self
.
sum_channels
=
True
window_length
=
cfg
[
"pre_processing"
][
k
][
"window_length"
]
hop_length
=
cfg
[
"pre_processing"
][
k
][
"hop_length"
]
# number of FIR filters applied to each channel of the array
...
...
@@ -282,50 +286,57 @@ class SeqToSeq(torch.nn.Module):
post_processing_activation
=
torch
.
nn
.
Tanh
()
post_processing_layers
=
[]
for
k
in
cfg
[
"post_processing"
].
keys
():
self
.
post_processing_fl
=
False
if
"post_processing"
in
cfg
:
self
.
post_processing_fl
=
True
for
k
in
cfg
[
"post_processing"
].
keys
():
if
k
.
startswith
(
"lin"
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Linear
(
input_size
,
cfg
[
"post_processing"
][
k
][
"output"
])))
input_size
=
cfg
[
"post_processing"
][
k
][
"output"
]
if
k
.
startswith
(
"lin"
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Linear
(
input_size
,
cfg
[
"post_processing"
][
k
][
"output"
])))
input_size
=
cfg
[
"post_processing"
][
k
][
"output"
]
elif
k
.
startswith
(
"activation"
):
post_processing_layers
.
append
((
k
,
post_processing_activation
))
elif
k
.
startswith
(
"activation"
):
post_processing_layers
.
append
((
k
,
post_processing_activation
))
elif
k
.
startswith
(
'batch_norm'
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
BatchNorm1d
(
input_size
)))
elif
k
.
startswith
(
'batch_norm'
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
BatchNorm1d
(
input_size
)))
elif
k
.
startswith
(
'dropout'
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Dropout
(
p
=
cfg
[
"post_processing"
][
k
])))
elif
k
.
startswith
(
'dropout'
):
post_processing_layers
.
append
((
k
,
torch
.
nn
.
Dropout
(
p
=
cfg
[
"post_processing"
][
k
])))
self
.
post_processing
=
torch
.
nn
.
Sequential
(
OrderedDict
(
post_processing_layers
))
self
.
post_processing
.
apply
(
self
.
_init_weights
)
self
.
post_processing
=
torch
.
nn
.
Sequential
(
OrderedDict
(
post_processing_layers
))
self
.
post_processing
.
apply
(
self
.
_init_weights
)
self
.
output_size
=
input_size
def
forward
(
self
,
inputs
):
"""
:param inputs:
raw audio signal
:param inputs:
:return:
"""
if
self
.
center
:
if
self
.
sum_channels
:
nch
=
inputs
.
shape
[
1
]
inputs
=
inputs
.
sum
(
dim
=
1
,
keepdim
=
True
)
/
nch
if
self
.
is_mfcc
and
self
.
center
:
x
=
self
.
pre_processing
(
inputs
[:,:,:
-
1
])
elif
self
.
is_wavlm
:
#inputs = inputs.unfold(-1,self.wavlm_win_length, self.wavlm_hop_length)
# print('in wavelm shape before squeeze',inputs.shape)
inputs
=
inputs
.
squeeze
(
dim
=
1
)
# print('in wavelm shape after squeeze',inputs.shape)
x
=
self
.
pre_processing
(
inputs
)
# print("out wavelm",x.shape)
else
:
x
=
inputs
x
=
self
.
pre_processing
(
inputs
)
# remove energy
if
len
(
x
.
shape
)
==
4
:
x
=
torch
.
squeeze
(
x
[:,:,:])
x
=
self
.
sequence_to_sequence
(
x
.
permute
(
0
,
2
,
1
))
x
=
self
.
post_processing
(
x
)
if
self
.
is_lstm
:
x
=
x
.
permute
(
0
,
2
,
1
)
x
=
self
.
sequence_to_sequence
(
x
)
if
self
.
post_processing_fl
:
x
=
self
.
post_processing
(
x
)
else
:
x
=
x
.
permute
(
0
,
2
,
1
)
return
x
def
get_features
(
self
,
inputs
):
...
...
@@ -341,11 +352,11 @@ class SeqToSeq(torch.nn.Module):
return
x
,
w_comb
,
w_att
elif
self
.
is_wavlm
:
#inputs = inputs.unfold(-1,self.wavlm_win_length, self.wavlm_hop_length)
print
(
'in wavelm shape before squeeze'
,
inputs
.
shape
)
#
print('in wavelm shape before squeeze',inputs.shape)
inputs
=
inputs
.
squeeze
(
dim
=
1
)
print
(
'in wavelm shape after squeeze'
,
inputs
.
shape
)
#
print('in wavelm shape after squeeze',inputs.shape)
x
=
self
.
pre_processing
(
inputs
)
print
(
"out wavelm"
,
x
.
shape
)
#
print("out wavelm",x.shape)
return
x
else
:
x
=
self
.
pre_processing
(
inputs
)
...
...
s4d/nnet/sequence_sets.py
View file @
0403dd7f
...
...
@@ -108,36 +108,30 @@ class SeqSetRandomSampler(torch.utils.data.Sampler):
idx_iter
=
list
()
segments
=
shelve
.
open
(
seg_set
)
for
show
in
tqdm
.
tqdm
(
self
.
show_list
,
desc
=
"Sampler initialization,first pass"
,
unit
=
"show"
self
.
show_list
,
desc
=
"Sampler initialization,first pass"
,
unit
=
"show"
):
#print(list(
segments.
keys())
)
seg_tags
=
segments
.
get
(
show
)
try
:
seg_tags
=
segments
.
get
(
show
)
except
:
print
(
segments
[
show
])
for
ii
in
range
(
len
(
seg_tags
)):
seg
=
seg_tags
[
ii
]
ovseg
=
None
if
mode
==
"train"
and
task
==
"overlap"
:
if
rng
.
random
()
<
artificial_ov_ratio
:
ovseg
=
seg_tags
[
rng
.
integers
(
len
(
seg_tags
))]
for
seg
in
seg_tags
:
idx_iter
.
append
({
'seg'
:
seg
,
'overlap'
:
None
})
idx_iter
.
append
({
'seg'
:
seg
,
'overlap'
:
ovseg
})
self
.
index_iterator
=
numpy
.
array
(
idx_iter
)
self
.
length
=
batch_num
*
batch_size
segments
.
close
()
def
__iter__
(
self
):
"""
:return:
"""
self
.
rng
.
shuffle
(
self
.
index_iterator
)
self
.
iter
=
self
.
index_iterator
[:
self
.
length
]
return
iter
(
self
.
iter
)
def
__len__
(
self
)
->
int
:
"""
:return:
"""
return
self
.
length
...
...
@@ -630,14 +624,14 @@ class SeqSet(torch.utils.data.Dataset):
for
show
in
segments
:
crnt_set
=
segments
.
get
(
show
)
self
.
time_base_start
[
show
]
=
crnt_set
[
0
][
"start"
]
#centisenconds
# labels augmentation for speaker turn detection
if
not
"labelling"
in
dataset_params
.
keys
():
self
.
collar_duration
=
0.125
else
:
self
.
collar_duration
=
dataset_params
[
"labelling"
][
"collar_duration"
]
# lire les fichiers contenant les t_start par show pour construire une liste de segments
self
.
duration
=
numpy
.
ceil
(
dataset_params
[
mode
][
"duration"
]
*
self
.
audio_fr
...
...
@@ -658,6 +652,7 @@ class SeqSet(torch.utils.data.Dataset):
self
.
transformation
=
dataset_params
[
"eval"
][
"transformation"
]
self
.
transform
=
dict
()
self
.
spec_aug
=
False
if
(
self
.
transformation
[
"pipeline"
]
!=
""
)
and
(
self
.
transformation
[
"pipeline"
]
is
not
None
):
...
...
@@ -708,54 +703,43 @@ class SeqSet(torch.utils.data.Dataset):
"stop"
:
float
(
segarr
[
4
]),
}
idx_start
=
numpy
.
round
(
seg
[
"start"
]
/
100.0
*
self
.
audio_fr
).
astype
(
int
)
idx_start
=
numpy
.
round
(
seg
[
"start"
]
/
100.0
*
self
.
audio_fr
-
numpy
.
ceil
(
self
.
context
/
2
)).
astype
(
int
)
waveform
=
waveform_loader
(
self
.
wav_dir
+
seg
[
"show"
]
+
".wav"
,
idx_start
=
idx_start
,
seg_len
=
self
.
duration
,
context
=
self
.
context
)
# load audio waveform (dim = (channels,length))
# normalization is applied here
if
self
.
mod
==
'xvector'
:
pass
# Extract the segment from a pretrained xvector file
else
:
waveform
,
speech_fs
=
torchaudio
.
load
(
filepath
=
self
.
wav_dir
+
seg
[
"show"
]
+
".wav"
,
frame_offset
=
idx_start
,
num_frames
=
self
.
duration
,
channels_first
=
True
,
normalize
=
True
)
#waveform = normalize(waveform)
# is the signal mono or not ?
is_multichannel
=
waveform
.
shape
[
0
]
>
1
# is the signal mono or not ?
is_multichannel
=
waveform
.
shape
[
0
]
>
1
# add low energy noise to avoid zero values
waveform
+=
1e-6
*
torch
.
randn
(
waveform
.
shape
[
0
],
waveform
.
shape
[
1
])
# data augmentation if needed
# add low energy noise to avoid zero values
waveform
+=
1e-6
*
torch
.
randn
(
waveform
.
shape
[
0
],
waveform
.
shape
[
1
])
# data augmentation if needed
if
self
.
transform
and
not
is_multichannel
:
waveform
=
data_augmentation
(
waveform
,
self
.
audio_fr
,
self
.
transform
,
self
.
transform_number
,
noise_df
=
self
.
noise_df
,
rir_df
=
self
.
rir_df
,
babble_noise
=
self
.
babble_noise
)
if
self
.
transform
and
not
is_multichannel
:
waveform
=
data_augmentation
(
waveform
,
speech_fs
,
# in case of multichannel signal, each channel has to be transformed
elif
self
.
transform
and
is_multichannel
:
for
ii
in
range
(
waveform
.
shape
[
0
]):
waveform
[
ii
,:]
=
data_augmentation
(
waveform
[
ii
,:].
unsqueeze
(
0
),
self
.
audio_fr
,
self
.
transform
,
self
.
transform_number
,
noise_df
=
self
.
noise_df
,
rir_df
=
self
.
rir_df
,
babble_noise
=
self
.
babble_noise
)
if
waveform
is
None
:
print
(
seg
[
"show"
],
seg
[
"start"
]
/
100.0
,
seg
[
"stop"
]
/
100.0
)
# in case of multichannel signal, each channel has to be transformed
elif
self
.
transform
and
is_multichannel
:
for
ii
in
range
(
waveform
.
shape
[
0
]):
waveform
[
ii
,:]
=
data_augmentation
(
waveform
[
ii
,:],
speech_fs
,
self
.
transform
,
self
.
transform_number
,
noise_df
=
self
.
noise_df
,
rir_df
=
self
.
rir_df
,
babble_noise
=
self
.
babble_noise
)
if
self
.
spec_aug
:
waveform
=
self
.
spec_aug
(
waveform
)
...
...
@@ -766,15 +750,16 @@ class SeqSet(torch.utils.data.Dataset):
with
h5py
.
File
(
self
.
label_set
,
"r"
)
as
data
:
crnt_label
=
data
[
seg
[
"show"
]][
"total"
][:,
start
:
stop
]
expected_frames_num
=
int
(
self
.
duration
/
self
.
audio_fr
*
self
.
output_fr
)
if
crnt_label
.
shape
[
1
]
<
expected_frames_num
:
crnt_label
=
numpy
.
pad
(
crnt_label
,[(
0
,
0
),(
0
,
expected_frames_num
-
crnt_label
.
shape
[
1
])])
if
crnt_label
.
shape
[
1
]
>
expected_frames_num
:
crnt_label
=
crnt_label
[:,:
expected_frames_num
]
# crnt_label = numpy.ones((stop-start,))
if
self
.
task
==
"vad"
:
output_label
=
(
crnt_label
>
0
).
astype
(
numpy
.
long
)
# may probably be optimized...
elif
self
.
task
==
"spk_turn"
:
label
=
numpy
.
zeros_like
(
crnt_label
)
label
[:,:
-
1
]
=
(
numpy
.
abs
(
crnt_label
[:,:
-
1
]
-
crnt_label
[:,
1
:])
>
0
).
astype
(
...
...
@@ -787,8 +772,9 @@ class SeqSet(torch.utils.data.Dataset):
output_label
=
numpy
.
convolve
(
conv_filt
,
label
.
squeeze
(),
mode
=
'same'
)
output_label
=
(
numpy
.
expand_dims
(
output_label
,
axis
=
0
)
>=
1
).
astype
(
numpy
.
long
)
elif
"ov"
in
self
.
task
:
elif
self
.
task
==
"overlap"
:
# batch overlap ratio
if
struct
[
'overlap'
]
is
not
None
:
# Loading artificial overlap
...
...
@@ -798,27 +784,29 @@ class SeqSet(torch.utils.data.Dataset):
"show"
:
ov_segarr
[
0
],
"stop"
:
float
(
ov_segarr
[
4
]),
}
idx_start_ov
=
numpy
.
round
(
ov_seg
[
"start"
]
/
100.0
*
self
.
audio_fr
).
astype
(
int
)
frame_count_ov
=
self
.
duration
idx_start_ov
=
numpy
.
round
(
ov_seg
[
"start"
]
/
100.0
*
self
.
audio_fr
-
numpy
.
ceil
(
self
.
context
/
2
)).
astype
(
int
)
waveform_ov
=
waveform_loader
(
self
.
wav_dir
+
ov_seg
[
"show"
]
+
".wav"
,
idx_start
=
idx_start_ov
,
seg_len
=
frame_count_ov
,
context
=
self
.
context
)
waveform_ov
,
_
=
torchaudio
.
load
(
filepath
=
self
.
wav_dir
+
ov_seg
[
"show"
]
+
".wav"
,
frame_offset
=
idx_start_ov
,
num_frames
=
frame_count_ov
,
channels_first
=
True
,
)
start_ov
=
numpy
.
round
((
ov_seg
[
"start"
]
-
self
.
time_base_start
[
ov_seg
[
"show"
]])
/
100.0
*
self
.
output_fr
).
astype
(
int
)
stop_ov
=
numpy
.
round
((
ov_seg
[
"stop"
]
-
self
.
time_base_start
[
ov_seg
[
"show"
]])
/
100.0
*
self
.
output_fr
).
astype
(
int
)
start_ov
=
numpy
.
round
(
ov_seg
[
"start"
]
/
100.0
*
self
.
output_fr
).
astype
(
int
)
stop_ov
=
numpy
.
round
(
ov_seg
[
"stop"
]
/
100.0
*
self
.
output_fr
).
astype
(
int
)
with
h5py
.
File
(
self
.
label_set
,
"r"
)
as
data
:
label_ov
=
data
[
seg
[
"show"
]][
"total"
][:,
start_ov
:
stop_ov
]
label_ov
=
data
[
ov_seg
[
"show"
]][
"total"
][:,
start_ov
:
stop_ov
]
if
label_ov
.
shape
[
1
]
<
expected_frames_num
:
label_ov
=
numpy
.
pad
(
label_ov
,[(
0
,
0
),(
0
,
expected_frames_num
-
label_ov
.
shape
[
1
])])
if
label_ov
.
shape
[
1
]
>
expected_frames_num
:
label_ov
=
label_ov
[:,:
expected_frames_num
]
speech_power
=
waveform
.
norm
(
p
=
2
)
noise_power
=
waveform_ov
.
norm
(
p
=
2
)
snr_db
=
10
*
self
.
rng
.
random
()
+
1
#snr_db=0
snr
=
10
**
(
snr_db
/
20
)
scale
=
snr
*
noise_power
/
speech_power
...
...
@@ -833,11 +821,10 @@ class SeqSet(torch.utils.data.Dataset):
else
:
raise
NotImplementedError
()
if
torch
.
isnan
(
waveform
).
any
():
print
(
"Waveform NAN !!!"
)
# self.logger.segmentslog(seg['show'],seg['start']/100,seg['stop']/100,(struct['overlap'] is not None))
return
waveform
,
torch
.
from_numpy
(
output_label
).
T