wavsets.py 15.1 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d.  If not, see <http://www.gnu.org/licenses/>.


"""
Copyright 2014-2020 Anthony Larcher
"""

__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher and Sylvain Meignier"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'

import numpy
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
35
36
import pathlib
import random
Anthony Larcher's avatar
Anthony Larcher committed
37
38
import scipy
import sidekit
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
39
import soundfile
Anthony Larcher's avatar
Anthony Larcher committed
40
41
42
43
import torch

from ..diar import Diar
from pathlib import Path
44
45
46
47
48
from sidekit.nnet.xsets import PreEmphasis
from sidekit.nnet.xsets import MFCC
from sidekit.nnet.xsets import CMVN
from sidekit.nnet.xsets import FrequencyMask
from sidekit.nnet.xsets import TemporalMask
Martin Lebourdais's avatar
Martin Lebourdais committed
49
from sidekit.nnet.augmentation import AddNoise
Anthony Larcher's avatar
Anthony Larcher committed
50
from torch.utils.data import Dataset
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
51
52
from torchvision import transforms
from collections import namedtuple
Anthony Larcher's avatar
Anthony Larcher committed
53

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
54
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
Anthony Larcher's avatar
Anthony Larcher committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

def framing(sig, win_size, win_shift=1, context=(0, 0), pad='zeros'):
    """
    :param sig: input signal, can be mono or multi dimensional
    :param win_size: size of the window in term of samples
    :param win_shift: shift of the sliding window in terme of samples
    :param context: tuple of left and right context
    :param pad: can be zeros or edge
    """
    dsize = sig.dtype.itemsize
    if sig.ndim == 1:
        sig = sig[:, numpy.newaxis]
    # Manage padding
    c = (context, ) + (sig.ndim - 1) * ((0, 0), )
    _win_size = win_size + sum(context)
    shape = (int((sig.shape[0] - win_size) / win_shift) + 1, 1, _win_size, sig.shape[1])
    strides = tuple(map(lambda x: x * dsize, [win_shift * sig.shape[1], 1, sig.shape[1], 1]))
    return numpy.lib.stride_tricks.as_strided(sig,
Martin Lebourdais's avatar
Martin Lebourdais committed
73
74
            shape=shape,
            strides=strides).squeeze()
Anthony Larcher's avatar
Anthony Larcher committed
75

Martin Lebourdais's avatar
Martin Lebourdais committed
76
77
    def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
        """
Anthony Larcher's avatar
Anthony Larcher committed
78
79
80
81
82
83
84
85
86
87
88

    :param wav_file_name:
    :param idx:
    :param duration:
    :param seg_shift:
    :param framerate:
    :return:
    """
    # Load waveform
    signal = sidekit.frontend.io.read_audio(wav_file_name, framerate)[0]
    tmp = framing(signal,
Martin Lebourdais's avatar
Martin Lebourdais committed
89
90
91
92
            int(framerate * duration),
            win_shift=int(framerate * seg_shift),
            context=(0, 0),
            pad='zeros')
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
93
    return tmp[idx], len(signal)
Anthony Larcher's avatar
Anthony Larcher committed
94
95
96


def mdtm_to_label(mdtm_filename,
Martin Lebourdais's avatar
Martin Lebourdais committed
97
98
99
        start_time,
        stop_time,
        sample_number,
Martin Lebourdais's avatar
Martin Lebourdais committed
100
101
        speaker_dict,
        over=False):
Anthony Larcher's avatar
Anthony Larcher committed
102
103
    """

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
104
105
106
107
    :param mdtm_filename:
    :param start_time:
    :param stop_time:
    :param sample_number:
108
    :param speaker_dict:
Anthony Larcher's avatar
Anthony Larcher committed
109
110
111
112
    :return:
    """
    diarization = Diar.read_mdtm(mdtm_filename)
    diarization.sort(['show', 'start'])
Martin Lebourdais's avatar
Martin Lebourdais committed
113
    overlaps = numpy.zeros(sample_number, dtype=int)
Martin Lebourdais's avatar
Martin Lebourdais committed
114

115
116
117
118
119
120
121
122
123
124
    # When one segment starts just the frame after the previous one ends, o
    # we replace the time of the start by the time of the previous stop to avoid artificial holes
    previous_stop = 0
    for ii, seg in enumerate(diarization.segments):
        if ii == 0:
            previous_stop = seg['stop']
        else:
            if seg['start'] == diarization.segments[ii - 1]['stop'] + 1:
                diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']

Anthony Larcher's avatar
Anthony Larcher committed
125
    # Create the empty labels
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
126
127
128
129
130
131
132
    label = numpy.zeros(sample_number, dtype=int)

    # Compute the time stamp of each sample
    time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
    period = (stop_time - start_time) / sample_number
    for t in range(sample_number):
        time_stamps[t] = start_time + (2 * t + 1) * period / 2
Martin Lebourdais's avatar
Martin Lebourdais committed
133
134


Martin Lebourdais's avatar
Martin Lebourdais committed
135
136
137
    for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)):
        cnt = 0
        for d in diarization.segments:
Martin Lebourdais's avatar
Martin Lebourdais committed
138
139
            # print(d)
            # print(d['start'],d['stop'])
Martin Lebourdais's avatar
Martin Lebourdais committed
140
141
142
            if d['start']/100 <= i <= d['stop']/100:
                cnt+=1
        overlaps[ii]=cnt
143
144
    # Find the label of the
    # first sample
Martin Lebourdais's avatar
Martin Lebourdais committed
145

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
146
    seg_idx = 0
147
    while diarization.segments[seg_idx]['stop'] / 100. < start_time:
Martin Lebourdais's avatar
Martin Lebourdais committed
148
        #sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None,  None, None, None, None))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
149
        seg_idx += 1
150
151
152
153
154
155
156
157
158
    for ii, t in enumerate(time_stamps):
        # Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
        if t <= diarization.segments[seg_idx]['start']/100.:
            # On laisse le label 0 (non-speech)
            pass
        # Si on est déjà dans le premier segment qui overlape
        elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100. :
            label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
        # Si on change de segment
Anthony Larcher's avatar
Anthony Larcher committed
159
        elif diarization.segments[seg_idx]['stop']/100. < t and len(diarization.segments) > seg_idx + 1:
160
161
162
163
            seg_idx += 1
            # On est entre deux segments:
            if t < diarization.segments[seg_idx]['start']/100.:
                pass
Anthony Larcher's avatar
Anthony Larcher committed
164
            elif  diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
165
                label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
Anthony Larcher's avatar
Anthony Larcher committed
166

Martin Lebourdais's avatar
Martin Lebourdais committed
167
    return (label,overlaps)
Anthony Larcher's avatar
Anthony Larcher committed
168
169


170
def get_segment_label(label,
Martin Lebourdais's avatar
Martin Lebourdais committed
171
172
173
174
175
176
177
178
        overlaps,
        seg_idx,
        mode,
        duration,
        framerate,
        seg_shift,
        collar_duration,
        filter_type="gate"):
179
180
181
182
183
184
185
186
187
188
189
190
    """

    :param label:
    :param seg_idx:
    :param mode:
    :param duration:
    :param framerate:
    :param seg_shift:
    :param collar_duration:
    :param filter_type:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

    # Create labels with Diracs at every speaker change detection
    spk_change = numpy.zeros(label.shape, dtype=int)
    spk_change[:-1] = label[:-1] ^ label[1:]
    spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))

    # depending of the mode, generates the labels and select the segments
    if mode == "vad":
        output_label = (label > 0.5).astype(numpy.long)

    elif mode == "spk_turn":
        # Apply convolution to replace diracs by a chosen shape (gate or triangle)
        filter_sample = collar_duration * framerate * 2 + 1
        conv_filt = numpy.ones(filter_sample)
        if filter_type == "triangle":
            conv_filt = scipy.signal.triang(filter_sample)
        output_label = numpy.convolve(conv_filt, spk_change, mode='same')

    elif mode == "overlap":
210
        output_label = (overlaps > 1).astype(numpy.long)
Martin Lebourdais's avatar
Martin Lebourdais committed
211

Anthony Larcher's avatar
Anthony Larcher committed
212
213
214
215
216
    else:
        raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")

    # Create segments with overlap
    segment_label = framing(output_label,
Martin Lebourdais's avatar
Martin Lebourdais committed
217
218
219
220
            int(framerate * duration),
            win_shift=int(framerate * seg_shift),
            context=(0, 0),
            pad='zeros')
Anthony Larcher's avatar
Anthony Larcher committed
221
222
223
224

    return segment_label[seg_idx]


225
def process_segment_label(label,
Martin Lebourdais's avatar
Martin Lebourdais committed
226
227
228
229
230
        overlaps,
        mode,
        framerate,
        collar_duration,
        filter_type="gate"):
Anthony Larcher's avatar
Anthony Larcher committed
231
232
    """

233
234
235
236
237
238
239
240
241
242
243
244
245
246
    :param label:
    :param seg_idx:
    :param mode:
    :param duration:
    :param framerate:
    :param seg_shift:
    :param collar_duration:
    :param filter_type:
    :return:
    """
    # Create labels with Diracs at every speaker change detection
    spk_change = numpy.zeros(label.shape, dtype=int)
    spk_change[:-1] = label[:-1] ^ label[1:]
    spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
Anthony Larcher's avatar
Anthony Larcher committed
247

248
249
250
    # depending of the mode, generates the labels and select the segments
    if mode == "vad":
        output_label = (label > 0.5).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
251

252
253
254
    elif mode == "spk_turn":
        # Apply convolution to replace diracs by a chosen shape (gate or triangle)
        filter_sample = int(collar_duration * framerate * 2 + 1)
Anthony Larcher's avatar
Anthony Larcher committed
255

256
257
258
259
        conv_filt = numpy.ones(filter_sample)
        if filter_type == "triangle":
            conv_filt = scipy.signal.triang(filter_sample)
        output_label = numpy.convolve(conv_filt, spk_change, mode='same')
Anthony Larcher's avatar
Anthony Larcher committed
260

261
    elif mode == "overlap":
Martin Lebourdais's avatar
Martin Lebourdais committed
262
        output_label = (overlaps>1).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
263

264
265
    else:
        raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
Anthony Larcher's avatar
Anthony Larcher committed
266

267
    return output_label
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
268
269
270


def seqSplit(mdtm_dir,
Martin Lebourdais's avatar
Martin Lebourdais committed
271
        duration=2.):
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
272
    """
Martin Lebourdais's avatar
Martin Lebourdais committed
273

Martin Lebourdais's avatar
Martin Lebourdais committed
274
275
276
    :param mdtm_dir:
    :param duration:
    :return:
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    """
    segment_list = Diar()
    speaker_dict = dict()
    idx = 0
    # For each MDTM
    for mdtm_file in pathlib.Path(mdtm_dir).glob('*.mdtm'):

        # Load MDTM file
        ref = Diar.read_mdtm(mdtm_file)
        ref.sort()
        last_stop = ref.segments[-1]["stop"]

        # Get the borders of the segments (not the start of the first and not the end of the last

        # For each border time B get a segment between B - duration and B + duration
        # in which we will pick up randomly later
        for idx, seg in enumerate(ref.segments):
294
            if idx > 0 and seg["start"] / 100. > duration and seg["start"] + duration < last_stop:
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
295
                segment_list.append(show=seg['show'],
Martin Lebourdais's avatar
Martin Lebourdais committed
296
297
298
                        cluster="",
                        start=float(seg["start"]) / 100. - duration,
                        stop=float(seg["start"]) / 100. + duration)
299

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
300
301
            elif idx < len(ref.segments) - 1 and seg["stop"] + duration < last_stop:
                segment_list.append(show=seg['show'],
Martin Lebourdais's avatar
Martin Lebourdais committed
302
303
304
                        cluster="",
                        start=float(seg["stop"]) / 100. - duration,
                        stop=float(seg["stop"]) / 100. + duration)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
305

Martin Lebourdais's avatar
Martin Lebourdais committed
306
                # Get list of unique speakers
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        speakers = ref.unique('cluster')
        for spk in speakers:
            if not spk in speaker_dict:
                speaker_dict[spk] =  idx
                idx += 1

    return segment_list, speaker_dict


class SeqSet(Dataset):
    """
    Object creates a dataset for sequence to sequence training
    """
    def __init__(self,
Martin Lebourdais's avatar
Martin Lebourdais committed
321
322
323
324
325
326
327
328
329
330
331
332
            dataset_yaml,
            wav_dir,
            mdtm_dir,
            mode,
            duration=2.,
            filter_type="gate",
            collar_duration=0.1,
            audio_framerate=16000,
            output_framerate=100,
            set_type="train",
            dataset_df=None,
            transform_pipeline=""):
333
334
335
336
337
338
339
340
341
342
343
344
        """

        :param wav_dir:
        :param mdtm_dir:
        :param mode:
        :param duration:
        :param filter_type:
        :param collar_duration:
        :param audio_framerate:
        :param output_framerate:
        :param transform_pipeline:
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
345
346
347
348
349
350
351

        self.wav_dir = wav_dir
        self.mdtm_dir = mdtm_dir
        self.mode = mode
        self.duration = duration
        self.filter_type = filter_type
        self.collar_duration = collar_duration
352
353
        self.audio_framerate = audio_framerate
        self.output_framerate = output_framerate
Martin Lebourdais's avatar
Martin Lebourdais committed
354
        self.transformation = ''
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
355
        self.transform_pipeline = transform_pipeline
Martin Lebourdais's avatar
Martin Lebourdais committed
356
357
358
359
360
361
362
363
364
365
366
367
        if set_type=="train":
            self.transformation_noise_file_ratio = 1
            self.transformation_noise_db_csv = "/lium/raid01_b/mlebour/GEM/expes/10-20/overlap_generation/results/overlaps_gen.csv"
            self.transformation_noise_snr = [8.0, 12.0]
            self.transformation_noise_root_db = wav_dir

        segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
                duration=self.duration)
        self.segment_list = segment_list
        self.speaker_dict = speaker_dict
        self.len = len(segment_list)
        self.over = False
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
368
369
370
371

        _transform = []
        if not self.transform_pipeline == '':
            trans = self.transform_pipeline.split(',')
Martin Lebourdais's avatar
Martin Lebourdais committed
372
            self.add_noise = numpy.zeros(self.len, dtype=bool)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
373
374
375
            for t in trans:
                if 'PreEmphasis' in t:
                    _transform.append(PreEmphasis())
Martin Lebourdais's avatar
Martin Lebourdais committed
376
377
378
379
380
381
382
383

                if 'add_noise' in t:
                    self.over = True
                    self.add_noise[:int(self.len * self.transformation_noise_file_ratio)] = 1
                    numpy.random.shuffle(self.add_noise)
                    _transform.append(AddNoise(noise_db_csv=self.transformation_noise_db_csv,
                                               snr_min_max=self.transformation_noise_snr,
                                               noise_root_path=self.transformation_noise_root_db))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
384
385
386
387
388
389
390
391
392
393
394
395
396
                if 'MFCC' in t:
                    _transform.append(MFCC())
                if "CMVN" in t:
                    _transform.append(CMVN())
                if "FrequencyMask" in t:
                    a = int(t.split('-')[0].split('(')[1])
                    b = int(t.split('-')[1].split(')')[0])
                    _transform.append(FrequencyMask(a, b))
                if "TemporalMask" in t:
                    a = int(t.split("(")[1].split(")")[0])
                    _transform.append(TemporalMask(a))
        self.transforms = transforms.Compose(_transform)

Martin Lebourdais's avatar
Martin Lebourdais committed
397

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
398
399
400
401
402
403
404
405
406
407

    def __getitem__(self, index):
        """
        On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
        (trames)
        :param index:
        :return:
        """
        # Get segment info to load from
        seg = self.segment_list[index]
Martin Lebourdais's avatar
Martin Lebourdais committed
408
        ok = False
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
409
        # Randomly pick an audio chunk within the current segment
Martin Lebourdais's avatar
Martin Lebourdais committed
410
        while not ok:
Martin Lebourdais's avatar
Martin Lebourdais committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
            ok = True
            start = random.uniform(seg["start"], seg["start"] + self.duration)
            sig = numpy.array([0,0])
            sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
                start=int(start * self.audio_framerate),
                stop=int((start + self.duration) * self.audio_framerate),
            )
            if len(sig)==0:
                ok = False
                continue

            sig += 0.0001 * numpy.random.randn(sig.shape[0])

            if self.transform_pipeline:
                sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None,  None, None, None, None))



            if sig.shape[1]!=198:
                ok = False
                continue

Martin Lebourdais's avatar
Martin Lebourdais committed
433
434
435
436
        tmp_label,overlaps = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg["show"] + ".mdtm",
                start_time=start,
                stop_time=start + self.duration,
                sample_number=sig.shape[1],
Martin Lebourdais's avatar
Martin Lebourdais committed
437
438
                speaker_dict=self.speaker_dict,
                over=self.over)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
439

440
        label = process_segment_label(label=tmp_label,
Martin Lebourdais's avatar
Martin Lebourdais committed
441
442
443
444
445
                overlaps=overlaps,
                mode=self.mode,
                framerate=self.output_framerate,
                collar_duration=self.collar_duration,
                filter_type=self.filter_type)
Martin Lebourdais's avatar
Martin Lebourdais committed
446

447
        return torch.from_numpy(sig.T).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
448
449
450

    def __len__(self):
        return self.len