wavsets.py 16.6 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d.  If not, see <http://www.gnu.org/licenses/>.


"""
Copyright 2014-2020 Anthony Larcher
"""

__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher and Sylvain Meignier"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'

import numpy
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
35
36
import pathlib
import random
Anthony Larcher's avatar
Anthony Larcher committed
37
38
import scipy
import sidekit
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
39
import soundfile
Anthony Larcher's avatar
Anthony Larcher committed
40
import torch
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
41
import yaml
Anthony Larcher's avatar
Anthony Larcher committed
42
43
44

from ..diar import Diar
from pathlib import Path
45
46
47
48
49
from sidekit.nnet.xsets import PreEmphasis
from sidekit.nnet.xsets import MFCC
from sidekit.nnet.xsets import CMVN
from sidekit.nnet.xsets import FrequencyMask
from sidekit.nnet.xsets import TemporalMask
Anthony Larcher's avatar
Anthony Larcher committed
50
from torch.utils.data import Dataset
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
51
52
from torchvision import transforms
from collections import namedtuple
Anthony Larcher's avatar
Anthony Larcher committed
53

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
54
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
Anthony Larcher's avatar
Anthony Larcher committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

def framing(sig, win_size, win_shift=1, context=(0, 0), pad='zeros'):
    """
    :param sig: input signal, can be mono or multi dimensional
    :param win_size: size of the window in term of samples
    :param win_shift: shift of the sliding window in terme of samples
    :param context: tuple of left and right context
    :param pad: can be zeros or edge
    """
    dsize = sig.dtype.itemsize
    if sig.ndim == 1:
        sig = sig[:, numpy.newaxis]
    # Manage padding
    c = (context, ) + (sig.ndim - 1) * ((0, 0), )
    _win_size = win_size + sum(context)
    shape = (int((sig.shape[0] - win_size) / win_shift) + 1, 1, _win_size, sig.shape[1])
    strides = tuple(map(lambda x: x * dsize, [win_shift * sig.shape[1], 1, sig.shape[1], 1]))
    return numpy.lib.stride_tricks.as_strided(sig,
                                           shape=shape,
                                           strides=strides).squeeze()

def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
    """

    :param wav_file_name:
    :param idx:
    :param duration:
    :param seg_shift:
    :param framerate:
    :return:
    """
    # Load waveform
    signal = sidekit.frontend.io.read_audio(wav_file_name, framerate)[0]
    tmp = framing(signal,
                  int(framerate * duration),
                  win_shift=int(framerate * seg_shift),
                  context=(0, 0),
                  pad='zeros')
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
93
    return tmp[idx], len(signal)
Anthony Larcher's avatar
Anthony Larcher committed
94
95
96


def mdtm_to_label(mdtm_filename,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
97
98
99
100
                  start_time,
                  stop_time,
                  sample_number,
                  speaker_dict):
Anthony Larcher's avatar
Anthony Larcher committed
101
102
    """

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
103
104
105
106
    :param mdtm_filename:
    :param start_time:
    :param stop_time:
    :param sample_number:
107
    :param speaker_dict:
Anthony Larcher's avatar
Anthony Larcher committed
108
109
110
111
112
    :return:
    """
    diarization = Diar.read_mdtm(mdtm_filename)
    diarization.sort(['show', 'start'])

113
114
115
116
117
118
119
120
121
122
    # When one segment starts just the frame after the previous one ends, o
    # we replace the time of the start by the time of the previous stop to avoid artificial holes
    previous_stop = 0
    for ii, seg in enumerate(diarization.segments):
        if ii == 0:
            previous_stop = seg['stop']
        else:
            if seg['start'] == diarization.segments[ii - 1]['stop'] + 1:
                diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']

Anthony Larcher's avatar
Anthony Larcher committed
123
    # Create the empty labels
Anthony Larcher's avatar
Anthony Larcher committed
124
    label = []
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
125
126
127
128
129
130
131

    # Compute the time stamp of each sample
    time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
    period = (stop_time - start_time) / sample_number
    for t in range(sample_number):
        time_stamps[t] = start_time + (2 * t + 1) * period / 2

Anthony Larcher's avatar
VAD    
Anthony Larcher committed
132
133
134
135
136
    for idx, time in enumerate(time_stamps):
        lbls = []
        for seg in diarization.segments:
            if seg['start'] / 100. <= time <= seg['stop'] / 100.:
                lbls.append(speaker_dict[seg['cluster']])
Anthony Larcher's avatar
Anthony Larcher committed
137

Anthony Larcher's avatar
Anthony Larcher committed
138
139
140
141
142
        if len(lbls) > 0:
            label.append(lbls)
        else:
            label.append([])

Anthony Larcher's avatar
Anthony Larcher committed
143
144
145
    return label


146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def get_segment_label(label,
                      seg_idx,
                      mode,
                      duration,
                      framerate,
                      seg_shift,
                      collar_duration,
                      filter_type="gate"):
    """

    :param label:
    :param seg_idx:
    :param mode:
    :param duration:
    :param framerate:
    :param seg_shift:
    :param collar_duration:
    :param filter_type:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    # Create labels with Diracs at every speaker change detection
    spk_change = numpy.zeros(label.shape, dtype=int)
    spk_change[:-1] = label[:-1] ^ label[1:]
    spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))

    # depending of the mode, generates the labels and select the segments
    if mode == "vad":
        output_label = (label > 0.5).astype(numpy.long)

    elif mode == "spk_turn":
        # Apply convolution to replace diracs by a chosen shape (gate or triangle)
        filter_sample = collar_duration * framerate * 2 + 1
        conv_filt = numpy.ones(filter_sample)
        if filter_type == "triangle":
            conv_filt = scipy.signal.triang(filter_sample)
        output_label = numpy.convolve(conv_filt, spk_change, mode='same')

    elif mode == "overlap":
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
185
        output_label = (label > 0.5).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    else:
        raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")

    # Create segments with overlap
    segment_label = framing(output_label,
                  int(framerate * duration),
                  win_shift=int(framerate * seg_shift),
                  context=(0, 0),
                  pad='zeros')

    return segment_label[seg_idx]


200
201
202
203
204
def process_segment_label(label,
                          mode,
                          framerate,
                          collar_duration,
                          filter_type="gate"):
Anthony Larcher's avatar
Anthony Larcher committed
205
206
    """

207
208
209
210
211
212
213
214
215
216
    :param label:
    :param seg_idx:
    :param mode:
    :param duration:
    :param framerate:
    :param seg_shift:
    :param collar_duration:
    :param filter_type:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
217

218
219
    # depending of the mode, generates the labels and select the segments
    if mode == "vad":
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
220
        output_label = numpy.array([len(a) > 0 for a in label]).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
221

222
    elif mode == "spk_turn":
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        tmp_label = []
        for a in label:
            if len(a) == 0:
                tmp_label.append(0)
            elif len(a) == 1:
                tmp_label.append(a[0])
            else:
                tmp_label.append(sum(a) * 1000)

        label = numpy.array(label)

        # Create labels with Diracs at every speaker change detection
        spk_change = numpy.zeros(label.shape, dtype=int)
        spk_change[:-1] = label[:-1] ^ label[1:]
        spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))

240
241
        # Apply convolution to replace diracs by a chosen shape (gate or triangle)
        filter_sample = int(collar_duration * framerate * 2 + 1)
Anthony Larcher's avatar
Anthony Larcher committed
242

243
244
245
246
        conv_filt = numpy.ones(filter_sample)
        if filter_type == "triangle":
            conv_filt = scipy.signal.triang(filter_sample)
        output_label = numpy.convolve(conv_filt, spk_change, mode='same')
Anthony Larcher's avatar
Anthony Larcher committed
247

248
    elif mode == "overlap":
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
249
250
251
252
253
        label = numpy.array([len(a) for a in label]).astype(numpy.long)

        # For the moment, we just consider two classes: overlap / no-overlap
        # in the future we might want to classify according to the number of speaker speaking at the same time
        output_label = (label > 1).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
254

255
256
    else:
        raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
Anthony Larcher's avatar
Anthony Larcher committed
257

258
    return output_label
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284


def seqSplit(mdtm_dir,
             duration=2.):
    """
    
    :param mdtm_dir: 
    :param duration: 
    :return: 
    """
    segment_list = Diar()
    speaker_dict = dict()
    idx = 0
    # For each MDTM
    for mdtm_file in pathlib.Path(mdtm_dir).glob('*.mdtm'):

        # Load MDTM file
        ref = Diar.read_mdtm(mdtm_file)
        ref.sort()
        last_stop = ref.segments[-1]["stop"]

        # Get the borders of the segments (not the start of the first and not the end of the last

        # For each border time B get a segment between B - duration and B + duration
        # in which we will pick up randomly later
        for idx, seg in enumerate(ref.segments):
285
            if idx > 0 and seg["start"] / 100. > duration and seg["start"] + duration < last_stop:
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
286
287
                segment_list.append(show=seg['show'],
                                    cluster="",
288
289
290
                                    start=float(seg["start"]) / 100. - duration,
                                    stop=float(seg["start"]) / 100. + duration)

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
291
292
293
            elif idx < len(ref.segments) - 1 and seg["stop"] + duration < last_stop:
                segment_list.append(show=seg['show'],
                                    cluster="",
294
295
                                    start=float(seg["stop"]) / 100. - duration,
                                    stop=float(seg["stop"]) / 100. + duration)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

        # Get list of unique speakers
        speakers = ref.unique('cluster')
        for spk in speakers:
            if not spk in speaker_dict:
                speaker_dict[spk] =  idx
                idx += 1

    return segment_list, speaker_dict


class SeqSet(Dataset):
    """
    Object creates a dataset for sequence to sequence training
    """
    def __init__(self,
                 wav_dir,
                 mdtm_dir,
                 mode,
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
315
316
                 segment_list=None,
                 speaker_dict=None,
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
317
318
319
                 duration=2.,
                 filter_type="gate",
                 collar_duration=0.1,
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                 audio_framerate=16000,
                 output_framerate=100,
                 transform_pipeline=""):
        """

        :param wav_dir:
        :param mdtm_dir:
        :param mode:
        :param duration:
        :param filter_type:
        :param collar_duration:
        :param audio_framerate:
        :param output_framerate:
        :param transform_pipeline:
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
335
336
337
338
339
340
341

        self.wav_dir = wav_dir
        self.mdtm_dir = mdtm_dir
        self.mode = mode
        self.duration = duration
        self.filter_type = filter_type
        self.collar_duration = collar_duration
342
343
        self.audio_framerate = audio_framerate
        self.output_framerate = output_framerate
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

        self.transform_pipeline = transform_pipeline

        _transform = []
        if not self.transform_pipeline == '':
            trans = self.transform_pipeline.split(',')
            for t in trans:
                if 'PreEmphasis' in t:
                    _transform.append(PreEmphasis())
                if 'MFCC' in t:
                    _transform.append(MFCC())
                if "CMVN" in t:
                    _transform.append(CMVN())
                if "FrequencyMask" in t:
                    a = int(t.split('-')[0].split('(')[1])
                    b = int(t.split('-')[1].split(')')[0])
                    _transform.append(FrequencyMask(a, b))
                if "TemporalMask" in t:
                    a = int(t.split("(")[1].split(")")[0])
                    _transform.append(TemporalMask(a))
        self.transforms = transforms.Compose(_transform)

Anthony Larcher's avatar
VAD    
Anthony Larcher committed
366
367
368
369
        if segment_list is None and speaker_dict is None:
            segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
                                                  duration=self.duration)

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        self.segment_list = segment_list
        self.speaker_dict = speaker_dict
        self.len = len(segment_list)

    def __getitem__(self, index):
        """
        On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
        (trames)
        :param index:
        :return:
        """
        # Get segment info to load from
        seg = self.segment_list[index]

        # Randomly pick an audio chunk within the current segment
385
        start = random.uniform(seg["start"], seg["start"] + self.duration)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
386

387
388
389
        sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
                                start=int(start * self.audio_framerate),
                                stop=int((start + self.duration) * self.audio_framerate)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
390
391
392
393
                                )
        sig += 0.0001 * numpy.random.randn(sig.shape[0])

        if self.transform_pipeline:
394
            sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None,  None, None, None, None))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
395

396
397
398
399
        tmp_label = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg["show"] + ".mdtm",
                                  start_time=start,
                                  stop_time=start + self.duration,
                                  sample_number=sig.shape[1],
Anthony Larcher's avatar
Anthony Larcher committed
400
                                  speaker_dict=self.speaker_dict)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
401

402
403
404
405
406
        label = process_segment_label(label=tmp_label,
                                      mode=self.mode,
                                      framerate=self.output_framerate,
                                      collar_duration=self.collar_duration,
                                      filter_type=self.filter_type)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
407

408
        return torch.from_numpy(sig.T).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
409
410
411

    def __len__(self):
        return self.len
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437


def create_train_val_seqtoseq(dataset_yaml):
    """

    :param self:
    :param wav_dir:
    :param mdtm_dir:
    :param mode:
    :param segment_list
    :param speaker_dict:
    :param duration:
    :param filter_type:
    :param collar_duration:
    :param audio_framerate:
    :param output_framerate:
    :param transform_pipeline:
    :return:
    """
    with open(dataset_yaml, "r") as fh:
        dataset_params = yaml.load(fh, Loader=yaml.FullLoader)

    torch.manual_seed(dataset_params['seed'])

    # Read all MDTM files and ouptut a list of segments with minimum duration as well as a speaker dictionary
    segment_list, speaker_dict = seqSplit(mdtm_dir=dataset_params["mdtm_dir"],
Anthony Larcher's avatar
Anthony Larcher committed
438
                                          duration=dataset_params["train"]["duration"])
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

    split_idx = numpy.random.choice([True, False],
                                    size=(len(segment_list),),
                                    p=[1 - dataset_params["validation_ratio"], dataset_params["validation_ratio"]])
    segment_list_train = Diar.copy_structure(segment_list)
    segment_list_val = Diar.copy_structure(segment_list)
    for idx, seg in enumerate(segment_list.segments):
        if split_idx[idx]:
            segment_list_train.append_seg(seg)
        else:
            segment_list_val.append_seg(seg)

    # Split the list of segment between training and validation sets
    train_set = SeqSet(wav_dir=dataset_params["wav_dir"],
                       mdtm_dir=dataset_params["mdtm_dir"],
Anthony Larcher's avatar
Anthony Larcher committed
454
                       mode=dataset_params["mode"],
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
455
456
                       segment_list=segment_list_train,
                       speaker_dict=speaker_dict,
Anthony Larcher's avatar
Anthony Larcher committed
457
458
459
460
461
462
                       duration=dataset_params["train"]["duration"],
                       filter_type=dataset_params["filter_type"],
                       collar_duration=dataset_params["collar_duration"],
                       audio_framerate=dataset_params["sample_rate"],
                       output_framerate=dataset_params["output_rate"],
                       transform_pipeline=dataset_params["train"]["transformation"]["pipeline"])
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
463
464
465

    validation_set = SeqSet(wav_dir=dataset_params["wav_dir"],
                            mdtm_dir=dataset_params["mdtm_dir"],
Anthony Larcher's avatar
Anthony Larcher committed
466
                            mode=dataset_params["mode"],
Anthony Larcher's avatar
VAD    
Anthony Larcher committed
467
468
                            segment_list=segment_list_val,
                            speaker_dict=speaker_dict,
Anthony Larcher's avatar
Anthony Larcher committed
469
470
471
472
473
474
475
476
                            duration=dataset_params["eval"]["duration"],
                            filter_type=dataset_params["filter_type"],
                            collar_duration=dataset_params["collar_duration"],
                            audio_framerate=dataset_params["sample_rate"],
                            output_framerate=dataset_params["output_rate"],
                            transform_pipeline=dataset_params["eval"]["transformation"]["pipeline"])

    return train_set, validation_set