wavsets.py 14.4 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- coding: utf-8 -*-
#
# This file is part of s4d.
#
# s4d is a python package for speaker diarization.
# Home page: http://www-lium.univ-lemans.fr/s4d/
#
# s4d is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# s4d is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with s4d.  If not, see <http://www.gnu.org/licenses/>.


"""
Copyright 2014-2020 Anthony Larcher
"""

__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher and Sylvain Meignier"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'

import numpy
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
35
36
import pathlib
import random
Anthony Larcher's avatar
Anthony Larcher committed
37
38
import scipy
import sidekit
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
39
import soundfile
Anthony Larcher's avatar
Anthony Larcher committed
40
41
42
43
import torch

from ..diar import Diar
from pathlib import Path
44
45
46
47
48
from sidekit.nnet.xsets import PreEmphasis
from sidekit.nnet.xsets import MFCC
from sidekit.nnet.xsets import CMVN
from sidekit.nnet.xsets import FrequencyMask
from sidekit.nnet.xsets import TemporalMask
Anthony Larcher's avatar
Anthony Larcher committed
49
from torch.utils.data import Dataset
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
50
51
from torchvision import transforms
from collections import namedtuple
Anthony Larcher's avatar
Anthony Larcher committed
52

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
53
#Segment = namedtuple('Segment', ['show', 'start_time', 'end_time'])
Anthony Larcher's avatar
Anthony Larcher committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

def framing(sig, win_size, win_shift=1, context=(0, 0), pad='zeros'):
    """
    :param sig: input signal, can be mono or multi dimensional
    :param win_size: size of the window in term of samples
    :param win_shift: shift of the sliding window in terme of samples
    :param context: tuple of left and right context
    :param pad: can be zeros or edge
    """
    dsize = sig.dtype.itemsize
    if sig.ndim == 1:
        sig = sig[:, numpy.newaxis]
    # Manage padding
    c = (context, ) + (sig.ndim - 1) * ((0, 0), )
    _win_size = win_size + sum(context)
    shape = (int((sig.shape[0] - win_size) / win_shift) + 1, 1, _win_size, sig.shape[1])
    strides = tuple(map(lambda x: x * dsize, [win_shift * sig.shape[1], 1, sig.shape[1], 1]))
    return numpy.lib.stride_tricks.as_strided(sig,
Martin Lebourdais's avatar
Martin Lebourdais committed
72
73
            shape=shape,
            strides=strides).squeeze()
Anthony Larcher's avatar
Anthony Larcher committed
74

Martin Lebourdais's avatar
Martin Lebourdais committed
75
76
    def load_wav_segment(wav_file_name, idx, duration, seg_shift, framerate=16000):
        """
Anthony Larcher's avatar
Anthony Larcher committed
77
78
79
80
81
82
83
84
85
86
87

    :param wav_file_name:
    :param idx:
    :param duration:
    :param seg_shift:
    :param framerate:
    :return:
    """
    # Load waveform
    signal = sidekit.frontend.io.read_audio(wav_file_name, framerate)[0]
    tmp = framing(signal,
Martin Lebourdais's avatar
Martin Lebourdais committed
88
89
90
91
            int(framerate * duration),
            win_shift=int(framerate * seg_shift),
            context=(0, 0),
            pad='zeros')
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
92
    return tmp[idx], len(signal)
Anthony Larcher's avatar
Anthony Larcher committed
93
94
95


def mdtm_to_label(mdtm_filename,
Martin Lebourdais's avatar
Martin Lebourdais committed
96
97
98
99
        start_time,
        stop_time,
        sample_number,
        speaker_dict):
Anthony Larcher's avatar
Anthony Larcher committed
100
101
    """

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
102
103
104
105
    :param mdtm_filename:
    :param start_time:
    :param stop_time:
    :param sample_number:
106
    :param speaker_dict:
Anthony Larcher's avatar
Anthony Larcher committed
107
108
109
110
    :return:
    """
    diarization = Diar.read_mdtm(mdtm_filename)
    diarization.sort(['show', 'start'])
Martin Lebourdais's avatar
Martin Lebourdais committed
111
112
    overlaps = numpy.zeros(sample_number, dtype=int)
    
113
114
115
116
117
118
119
120
121
122
    # When one segment starts just the frame after the previous one ends, o
    # we replace the time of the start by the time of the previous stop to avoid artificial holes
    previous_stop = 0
    for ii, seg in enumerate(diarization.segments):
        if ii == 0:
            previous_stop = seg['stop']
        else:
            if seg['start'] == diarization.segments[ii - 1]['stop'] + 1:
                diarization.segments[ii]['start'] = diarization.segments[ii - 1]['stop']

Anthony Larcher's avatar
Anthony Larcher committed
123
    # Create the empty labels
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
124
125
126
127
128
129
130
    label = numpy.zeros(sample_number, dtype=int)

    # Compute the time stamp of each sample
    time_stamps = numpy.zeros(sample_number, dtype=numpy.float32)
    period = (stop_time - start_time) / sample_number
    for t in range(sample_number):
        time_stamps[t] = start_time + (2 * t + 1) * period / 2
Martin Lebourdais's avatar
Martin Lebourdais committed
131
132
133
134
135
136
137
138
    
    
    for ii,i in enumerate(numpy.linspace(start_time,stop_time,num=sample_number)):
        cnt = 0
        for d in diarization.segments:
            if d['start']/100 <= i <= d['stop']/100:
                cnt+=1
        overlaps[ii]=cnt
139
140
    # Find the label of the
    # first sample
Martin Lebourdais's avatar
Martin Lebourdais committed
141

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
142
    seg_idx = 0
143
    while diarization.segments[seg_idx]['stop'] / 100. < start_time:
Martin Lebourdais's avatar
Martin Lebourdais committed
144
        #sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None,  None, None, None, None))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
145
        seg_idx += 1
146
147
148
149
150
151
152
153
154
    for ii, t in enumerate(time_stamps):
        # Si on est pas encore dans le premier segment qui overlape (on est donc dans du non-speech)
        if t <= diarization.segments[seg_idx]['start']/100.:
            # On laisse le label 0 (non-speech)
            pass
        # Si on est déjà dans le premier segment qui overlape
        elif diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100. :
            label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
        # Si on change de segment
Anthony Larcher's avatar
Anthony Larcher committed
155
        elif diarization.segments[seg_idx]['stop']/100. < t and len(diarization.segments) > seg_idx + 1:
156
157
158
159
            seg_idx += 1
            # On est entre deux segments:
            if t < diarization.segments[seg_idx]['start']/100.:
                pass
Anthony Larcher's avatar
Anthony Larcher committed
160
            elif  diarization.segments[seg_idx]['start']/100. < t < diarization.segments[seg_idx]['stop']/100.:
161
                label[ii] = speaker_dict[diarization.segments[seg_idx]['cluster']]
Anthony Larcher's avatar
Anthony Larcher committed
162

Martin Lebourdais's avatar
Martin Lebourdais committed
163
    return (label,overlaps)
Anthony Larcher's avatar
Anthony Larcher committed
164
165


166
def get_segment_label(label,
Martin Lebourdais's avatar
Martin Lebourdais committed
167
168
169
170
171
172
173
174
        overlaps,
        seg_idx,
        mode,
        duration,
        framerate,
        seg_shift,
        collar_duration,
        filter_type="gate"):
175
176
177
178
179
180
181
182
183
184
185
186
    """

    :param label:
    :param seg_idx:
    :param mode:
    :param duration:
    :param framerate:
    :param seg_shift:
    :param collar_duration:
    :param filter_type:
    :return:
    """
Anthony Larcher's avatar
Anthony Larcher committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    # Create labels with Diracs at every speaker change detection
    spk_change = numpy.zeros(label.shape, dtype=int)
    spk_change[:-1] = label[:-1] ^ label[1:]
    spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))

    # depending of the mode, generates the labels and select the segments
    if mode == "vad":
        output_label = (label > 0.5).astype(numpy.long)

    elif mode == "spk_turn":
        # Apply convolution to replace diracs by a chosen shape (gate or triangle)
        filter_sample = collar_duration * framerate * 2 + 1
        conv_filt = numpy.ones(filter_sample)
        if filter_type == "triangle":
            conv_filt = scipy.signal.triang(filter_sample)
        output_label = numpy.convolve(conv_filt, spk_change, mode='same')

    elif mode == "overlap":
Martin Lebourdais's avatar
Martin Lebourdais committed
206
207
        output_label = (overlaps > 0.5).astype(numpy.long)
        
Anthony Larcher's avatar
Anthony Larcher committed
208
209
210
211
212
    else:
        raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")

    # Create segments with overlap
    segment_label = framing(output_label,
Martin Lebourdais's avatar
Martin Lebourdais committed
213
214
215
216
            int(framerate * duration),
            win_shift=int(framerate * seg_shift),
            context=(0, 0),
            pad='zeros')
Anthony Larcher's avatar
Anthony Larcher committed
217
218
219
220

    return segment_label[seg_idx]


221
def process_segment_label(label,
Martin Lebourdais's avatar
Martin Lebourdais committed
222
223
224
225
226
        overlaps,
        mode,
        framerate,
        collar_duration,
        filter_type="gate"):
Anthony Larcher's avatar
Anthony Larcher committed
227
228
    """

229
230
231
232
233
234
235
236
237
238
239
240
241
242
    :param label:
    :param seg_idx:
    :param mode:
    :param duration:
    :param framerate:
    :param seg_shift:
    :param collar_duration:
    :param filter_type:
    :return:
    """
    # Create labels with Diracs at every speaker change detection
    spk_change = numpy.zeros(label.shape, dtype=int)
    spk_change[:-1] = label[:-1] ^ label[1:]
    spk_change = numpy.not_equal(spk_change, numpy.zeros(label.shape, dtype=int))
Anthony Larcher's avatar
Anthony Larcher committed
243

244
245
246
    # depending of the mode, generates the labels and select the segments
    if mode == "vad":
        output_label = (label > 0.5).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
247

248
249
250
    elif mode == "spk_turn":
        # Apply convolution to replace diracs by a chosen shape (gate or triangle)
        filter_sample = int(collar_duration * framerate * 2 + 1)
Anthony Larcher's avatar
Anthony Larcher committed
251

252
253
254
255
        conv_filt = numpy.ones(filter_sample)
        if filter_type == "triangle":
            conv_filt = scipy.signal.triang(filter_sample)
        output_label = numpy.convolve(conv_filt, spk_change, mode='same')
Anthony Larcher's avatar
Anthony Larcher committed
256

257
    elif mode == "overlap":
Martin Lebourdais's avatar
Martin Lebourdais committed
258
        output_label = (overlaps>1).astype(numpy.long)
Anthony Larcher's avatar
Anthony Larcher committed
259

260
261
    else:
        raise ValueError("mode parameter must be 'vad', 'spk_turn' or 'overlap'")
Anthony Larcher's avatar
Anthony Larcher committed
262

263
    return output_label
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
264
265
266


def seqSplit(mdtm_dir,
Martin Lebourdais's avatar
Martin Lebourdais committed
267
        duration=2.):
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
268
    """
Martin Lebourdais's avatar
Martin Lebourdais committed
269

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    :param mdtm_dir: 
    :param duration: 
    :return: 
    """
    segment_list = Diar()
    speaker_dict = dict()
    idx = 0
    # For each MDTM
    for mdtm_file in pathlib.Path(mdtm_dir).glob('*.mdtm'):

        # Load MDTM file
        ref = Diar.read_mdtm(mdtm_file)
        ref.sort()
        last_stop = ref.segments[-1]["stop"]

        # Get the borders of the segments (not the start of the first and not the end of the last

        # For each border time B get a segment between B - duration and B + duration
        # in which we will pick up randomly later
        for idx, seg in enumerate(ref.segments):
290
            if idx > 0 and seg["start"] / 100. > duration and seg["start"] + duration < last_stop:
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
291
                segment_list.append(show=seg['show'],
Martin Lebourdais's avatar
Martin Lebourdais committed
292
293
294
                        cluster="",
                        start=float(seg["start"]) / 100. - duration,
                        stop=float(seg["start"]) / 100. + duration)
295

Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
296
297
            elif idx < len(ref.segments) - 1 and seg["stop"] + duration < last_stop:
                segment_list.append(show=seg['show'],
Martin Lebourdais's avatar
Martin Lebourdais committed
298
299
300
                        cluster="",
                        start=float(seg["stop"]) / 100. - duration,
                        stop=float(seg["stop"]) / 100. + duration)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
301

Martin Lebourdais's avatar
Martin Lebourdais committed
302
                # Get list of unique speakers
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        speakers = ref.unique('cluster')
        for spk in speakers:
            if not spk in speaker_dict:
                speaker_dict[spk] =  idx
                idx += 1

    return segment_list, speaker_dict


class SeqSet(Dataset):
    """
    Object creates a dataset for sequence to sequence training
    """
    def __init__(self,
Martin Lebourdais's avatar
Martin Lebourdais committed
317
318
319
320
321
322
323
324
325
326
327
328
            dataset_yaml,
            wav_dir,
            mdtm_dir,
            mode,
            duration=2.,
            filter_type="gate",
            collar_duration=0.1,
            audio_framerate=16000,
            output_framerate=100,
            set_type="train",
            dataset_df=None,
            transform_pipeline=""):
329
330
331
332
333
334
335
336
337
338
339
340
        """

        :param wav_dir:
        :param mdtm_dir:
        :param mode:
        :param duration:
        :param filter_type:
        :param collar_duration:
        :param audio_framerate:
        :param output_framerate:
        :param transform_pipeline:
        """
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
341
342
343
344
345
346
347

        self.wav_dir = wav_dir
        self.mdtm_dir = mdtm_dir
        self.mode = mode
        self.duration = duration
        self.filter_type = filter_type
        self.collar_duration = collar_duration
348
349
        self.audio_framerate = audio_framerate
        self.output_framerate = output_framerate
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

        self.transform_pipeline = transform_pipeline

        _transform = []
        if not self.transform_pipeline == '':
            trans = self.transform_pipeline.split(',')
            for t in trans:
                if 'PreEmphasis' in t:
                    _transform.append(PreEmphasis())
                if 'MFCC' in t:
                    _transform.append(MFCC())
                if "CMVN" in t:
                    _transform.append(CMVN())
                if "FrequencyMask" in t:
                    a = int(t.split('-')[0].split('(')[1])
                    b = int(t.split('-')[1].split(')')[0])
                    _transform.append(FrequencyMask(a, b))
                if "TemporalMask" in t:
                    a = int(t.split("(")[1].split(")")[0])
                    _transform.append(TemporalMask(a))
        self.transforms = transforms.Compose(_transform)

        segment_list, speaker_dict = seqSplit(mdtm_dir=self.mdtm_dir,
Martin Lebourdais's avatar
Martin Lebourdais committed
373
                duration=self.duration)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
374
375
376
377
378
379
380
381
382
383
384
385
386
        self.segment_list = segment_list
        self.speaker_dict = speaker_dict
        self.len = len(segment_list)

    def __getitem__(self, index):
        """
        On renvoie un segment wavform brut mais il faut que les labels soient échantillonés à la bonne fréquence
        (trames)
        :param index:
        :return:
        """
        # Get segment info to load from
        seg = self.segment_list[index]
Martin Lebourdais's avatar
Martin Lebourdais committed
387
        ok = False
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
388
        # Randomly pick an audio chunk within the current segment
Martin Lebourdais's avatar
Martin Lebourdais committed
389
390
        while not ok:
            try:
Martin Lebourdais's avatar
Martin Lebourdais committed
391
                
Martin Lebourdais's avatar
Martin Lebourdais committed
392
393
                ok=True
                start = random.uniform(seg["start"], seg["start"] + self.duration)
Martin Lebourdais's avatar
Martin Lebourdais committed
394
395
396
397
398
                sig = numpy.array([0,0])
                try:
                    sig, _ = soundfile.read(self.wav_dir + seg["show"] + ".wav",
                        start=int(start * self.audio_framerate),
                        stop=int((start + self.duration) * self.audio_framerate),
Martin Lebourdais's avatar
Martin Lebourdais committed
399
                    )
Martin Lebourdais's avatar
Martin Lebourdais committed
400
401
402
                except RuntimeError:
                    print("==============="+self.wav_dir+seg["show"]+".flac")
                
Martin Lebourdais's avatar
Martin Lebourdais committed
403
404
405
406
                sig += 0.0001 * numpy.random.randn(sig.shape[0])

                if self.transform_pipeline:
                    sig, speaker_idx,_t, _s = self.transforms((sig, None,  None, None, None, None))
Martin Lebourdais's avatar
Martin Lebourdais committed
407

Martin Lebourdais's avatar
Martin Lebourdais committed
408
409
410
411
            except ValueError as e:
                ok=False
            # sig, speaker_idx, _, __, _t, _s = self.transforms((sig, None,  None, None, None, None))
            if sig.T.shape != (198,30):
Martin Lebourdais's avatar
Martin Lebourdais committed
412
                #problem of dimension idk why ?
Martin Lebourdais's avatar
Martin Lebourdais committed
413
414
415
416
417
418
                ok=False
        tmp_label,overlaps = mdtm_to_label(mdtm_filename=self.mdtm_dir + seg["show"] + ".mdtm",
                start_time=start,
                stop_time=start + self.duration,
                sample_number=sig.shape[1],
                speaker_dict=self.speaker_dict)
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
419

420
        label = process_segment_label(label=tmp_label,
Martin Lebourdais's avatar
Martin Lebourdais committed
421
422
423
424
425
426
                overlaps=overlaps,
                mode=self.mode,
                framerate=self.output_framerate,
                collar_duration=self.collar_duration,
                filter_type=self.filter_type)
        
427
        return torch.from_numpy(sig.T).type(torch.FloatTensor), torch.from_numpy(label.astype('long'))
Anthony Larcher's avatar
seq2seq    
Anthony Larcher committed
428
429
430

    def __len__(self):
        return self.len