xsets.py 20.2 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT.  If not, see <http://www.gnu.org/licenses/>.

"""
Anthony Larcher's avatar
v1.3.7    
Anthony Larcher committed
25
Copyright 2014-2021 Anthony Larcher
Anthony Larcher's avatar
Anthony Larcher committed
26
27
28

"""

Anthony Larcher's avatar
debug    
Anthony Larcher committed
29
import math
Anthony Larcher's avatar
Anthony Larcher committed
30
import numpy
Anthony Larcher's avatar
Anthony Larcher committed
31
32
import pandas
import random
Anthony Larcher's avatar
Anthony Larcher committed
33
import torch
Anthony Larcher's avatar
Anthony Larcher committed
34
import torchaudio
Anthony Larcher's avatar
Anthony Larcher committed
35
import tqdm
Anthony Larcher's avatar
Anthony Larcher committed
36
import soundfile
Anthony Larcher's avatar
Anthony Larcher committed
37
import yaml
Anthony Larcher's avatar
Anthony Larcher committed
38

Anthony Larcher's avatar
Anthony Larcher committed
39
from .augmentation import data_augmentation
Anthony Larcher's avatar
Anthony Larcher committed
40
from ..bosaris.idmap import IdMap
Anthony Larcher's avatar
New vad    
Anthony Larcher committed
41
from torch.utils.data import Dataset
Anthony Larcher's avatar
Anthony Larcher committed
42

Anthony Larcher's avatar
Anthony Larcher committed
43
44
__license__ = "LGPL"
__author__ = "Anthony Larcher"
Anthony Larcher's avatar
v1.3.7    
Anthony Larcher committed
45
__copyright__ = "Copyright 2015-2021 Anthony Larcher"
Anthony Larcher's avatar
Anthony Larcher committed
46
47
48
49
50
51
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'


Anthony Larcher's avatar
Anthony Larcher committed
52
class SideSampler(torch.utils.data.Sampler):
Anthony Larcher's avatar
SpecAug    
Anthony Larcher committed
53
    """
Anthony Larcher's avatar
Anthony Larcher committed
54
    Data Sampler used to generate uniformly distributed batches
Anthony Larcher's avatar
SpecAug    
Anthony Larcher committed
55
    """
Anthony Larcher's avatar
Anthony Larcher committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def __init__(self, data_source, spk_count, examples_per_speaker, samples_per_speaker, batch_size):
        """[summary]

        Args:
            data_source ([type]): [description]
            spk_count ([type]): [description]
            examples_per_speaker ([type]): [description]
            samples_per_speaker ([type]): [description]
            batch_size ([type]): [description]
        """
        self.train_sessions = data_source
        self.labels_to_indices = dict()
        self.spk_count = spk_count
        self.examples_per_speaker = examples_per_speaker
        self.samples_per_speaker = samples_per_speaker
Anthony Larcher's avatar
merge    
Anthony Larcher committed
72
73
        assert batch_size % examples_per_speaker == 0
        self.batch_size = batch_size//examples_per_speaker
Anthony Larcher's avatar
Anthony Larcher committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

        # reference all segment indexes per speaker
        for idx in range(self.spk_count):
            self.labels_to_indices[idx] = list()
        for idx, value in enumerate(self.train_sessions):
            self.labels_to_indices[value].append(idx)
        # suffle segments per speaker
        for ldlist in self.labels_to_indices.values():
            random.shuffle(ldlist)

        self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)


    def __iter__(self):
        # Generate batches per speaker
        straight = numpy.arange(self.spk_count)
        indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
        batch_cursor = 0
        # each line of "indices" represents all speaker indexes (shuffled in a different way)
        for idx in range(self.samples_per_speaker):
            if batch_cursor == 0:
                indices[idx, :] = numpy.random.permutation(straight)
            else:
                # if one batch is split between the end of previous line and the beginning of current line
                # we make sure no speaker is present twice in this batch
                probs = numpy.ones_like(straight)
                probs[indices[idx-1, -batch_cursor:]] = 0
                probs = probs/numpy.sum(probs)
                indices[idx, :self.batch_size - batch_cursor] = numpy.random.choice(self.spk_count, self.batch_size - batch_cursor, replace=False, p=probs)
                probs = numpy.ones_like(straight)
                probs[indices[idx, :self.batch_size - batch_cursor]] = 0
                to_pick = numpy.sum(probs).astype(numpy.int)
                probs = probs/numpy.sum(probs)
                indices[idx, self.batch_size - batch_cursor:] = numpy.random.choice(self.spk_count, to_pick, replace=False, p=probs)

                assert numpy.sum(indices[idx, :]) == numpy.sum(straight)
            batch_cursor = (batch_cursor + indices.shape[1]) % self.batch_size

        # now we have the speaker indexes to sample in batches
        batch_matrix = numpy.repeat(indices, self.examples_per_speaker, axis=1).flatten()

        # we want to convert the speaker indexes into segment indexes
        self.index_iterator = numpy.zeros_like(batch_matrix)

        # keep track of next segment index to sample for each speaker
        for idx, value in enumerate(batch_matrix):
            if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
                random.shuffle(self.labels_to_indices[value])
                self.segment_cursors[value] = 0
            self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
            self.segment_cursors[value] += 1
        return iter(self.index_iterator)


    def __len__(self) -> int:
        return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker)


Anthony Larcher's avatar
Anthony Larcher committed
132
133
134
class SideSet(Dataset):

    def __init__(self,
Anthony Larcher's avatar
Anthony Larcher committed
135
                 dataset,
Anthony Larcher's avatar
Anthony Larcher committed
136
                 set_type="train",
137
                 chunk_per_segment=1,
Anthony Larcher's avatar
Anthony Larcher committed
138
                 transform_number=1,
139
                 overlap=0.,
140
                 dataset_df=None,
Anthony Larcher's avatar
Anthony Larcher committed
141
                 min_duration=0.165,
Anthony Larcher's avatar
Anthony Larcher committed
142
                 output_format="pytorch",
Anthony Larcher's avatar
Anthony Larcher committed
143
                 ):
Anthony Larcher's avatar
Anthony Larcher committed
144
145
        """

146
147
148
149
        :param dataset_yaml: name of the YAML file describing the dataset
        :param set_type: string, can be "train" or "validation"
        :param chunk_per_segment: number of chunks to select for each segment
        default is 1 and -1 means select all possible chunks
Anthony Larcher's avatar
Anthony Larcher committed
150
        """
Anthony Larcher's avatar
Anthony Larcher committed
151
152
        #with open(data_set_yaml, "r") as fh:
        #    dataset = yaml.load(fh, Loader=yaml.FullLoader)
Anthony Larcher's avatar
Anthony Larcher committed
153

Anthony Larcher's avatar
Anthony Larcher committed
154
        self.data_path = dataset["data_path"]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
155
        self.sample_rate = int(dataset["sample_rate"])
Anthony Larcher's avatar
Anthony Larcher committed
156
        self.data_file_extension = dataset["data_file_extension"]
Anthony Larcher's avatar
Anthony Larcher committed
157
        self.transformation = ''
158
        self.min_duration = min_duration
Anthony Larcher's avatar
Anthony Larcher committed
159
        self.output_format = output_format
Anthony Larcher's avatar
Anthony Larcher committed
160
161
        self.transform_number = transform_number

Anthony Larcher's avatar
Anthony Larcher committed
162
163
        if set_type == "train":
            self.duration = dataset["train"]["duration"]
Anthony Larcher's avatar
Anthony Larcher committed
164
            self.transformation = dataset["train"]["transformation"]
Anthony Larcher's avatar
Anthony Larcher committed
165
        else:
Anthony Larcher's avatar
Anthony Larcher committed
166
167
            self.duration = dataset["valid"]["duration"]
            self.transformation = dataset["valid"]["transformation"]
Anthony Larcher's avatar
Anthony Larcher committed
168

Anthony Larcher's avatar
Anthony Larcher committed
169
        self.sample_number = int(self.duration * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
170
        self.overlap = int(overlap * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
171

Anthony Larcher's avatar
Anthony Larcher committed
172
        # Load the dataset description as pandas.dataframe
Anthony Larcher's avatar
Anthony Larcher committed
173
        if dataset_df is None:
Anthony Larcher's avatar
Anthony Larcher committed
174
175
            df = pandas.read_csv(dataset["dataset_description"])
        else:
Anthony Larcher's avatar
Anthony Larcher committed
176
177
            assert isinstance(dataset_df, pandas.DataFrame)
            df = dataset_df
Anthony Larcher's avatar
Anthony Larcher committed
178

179
180
        # From each segment which duration is longer than the chosen one
        # select the requested segments
Anthony Larcher's avatar
Anthony Larcher committed
181
        if set_type == "train":
182
            tmp_sessions = df.loc[df['duration'] > self.duration]
Anthony Larcher's avatar
Anthony Larcher committed
183
        else:
Anthony Larcher's avatar
Anthony Larcher committed
184
            if not "duration" == '':
185
                tmp_sessions = df.loc[df['duration'] > self.duration]
Anthony Larcher's avatar
Anthony Larcher committed
186
187
            else:
                self.sessions = df
188

Anthony Larcher's avatar
Anthony Larcher committed
189
        # Create lists for each column of the dataframe
190
        df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
Anthony Larcher's avatar
Anthony Larcher committed
191
192
        df_dict["file_start"] = list()
        df_dict["file_duration"] = list()
Anthony Larcher's avatar
Anthony Larcher committed
193
194

        # For each segment, get all possible segments with the current overlap
Anthony Larcher's avatar
Anthony Larcher committed
195
        for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
196
            current_session = tmp_sessions.iloc[idx]
Anthony Larcher's avatar
Anthony Larcher committed
197

198
            # Compute possible starts
Anthony Larcher's avatar
Anthony Larcher committed
199
            possible_starts = numpy.arange(0,
200
                                           int(self.sample_rate * (current_session.duration - self.duration)),
Anthony Larcher's avatar
Anthony Larcher committed
201
202
                                           self.sample_number
                                           ) + int(self.sample_rate * (current_session.duration % self.duration / 2))
203
            possible_starts += int(self.sample_rate * current_session.start)
204
205
206
207
208
209
210

            # Select max(seg_nb, possible_segments) segments
            if chunk_per_segment == -1:
                starts = possible_starts
                chunk_nb = len(possible_starts)
            else:
                chunk_nb = min(len(possible_starts), chunk_per_segment)
Anthony Larcher's avatar
Anthony Larcher committed
211
                starts = numpy.random.permutation(possible_starts)[:chunk_nb]
Anthony Larcher's avatar
Anthony Larcher committed
212
213

            # Once we know how many segments are selected, create the other fields to fill the DataFrame
214
            for ii in range(chunk_nb):
215
216
217
                df_dict["database"].append(current_session.database)
                df_dict["speaker_id"].append(current_session.speaker_id)
                df_dict["file_id"].append(current_session.file_id)
218
219
                df_dict["start"].append(starts[ii])
                df_dict["duration"].append(self.duration)
Anthony Larcher's avatar
Anthony Larcher committed
220
221
                df_dict["file_start"].append(current_session.start)
                df_dict["file_duration"].append(current_session.duration)
222
223
                df_dict["speaker_idx"].append(current_session.speaker_idx)
                df_dict["gender"].append(current_session.gender)
224
225

        self.sessions = pandas.DataFrame.from_dict(df_dict)
Anthony Larcher's avatar
Anthony Larcher committed
226
        self.len = len(self.sessions)
Anthony Larcher's avatar
Anthony Larcher committed
227

Anthony Larcher's avatar
merge    
Anthony Larcher committed
228
        self.transform = dict()
Anthony Larcher's avatar
Anthony Larcher committed
229
        if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
230
231
232
233
234
            transforms = self.transformation["pipeline"].split(',')
            if "add_noise" in transforms:
                self.transform["add_noise"] = self.transformation["add_noise"]
            if "add_reverb" in transforms:
                self.transform["add_reverb"] = self.transformation["add_reverb"]
Anthony Larcher's avatar
Anthony Larcher committed
235

Anthony Larcher's avatar
Anthony Larcher committed
236
        self.noise_df = None
Anthony Larcher's avatar
Anthony Larcher committed
237
        if "add_noise" in self.transform:
Anthony Larcher's avatar
Anthony Larcher committed
238
239
240
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
            noise_df = noise_df.loc[noise_df.duration > self.duration]
            self.noise_df = noise_df.set_index(noise_df.type)
Anthony Larcher's avatar
Anthony Larcher committed
241

Anthony Larcher's avatar
Anthony Larcher committed
242
        self.rir_df = None
Anthony Larcher's avatar
Anthony Larcher committed
243
244
        if "add_reverb" in self.transform:
            # load the RIR database
Anthony Larcher's avatar
Anthony Larcher committed
245
            self.rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
246

Anthony Larcher's avatar
Anthony Larcher committed
247
248
249
250
251
    def __getitem__(self, index):
        """

        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
252
        # Check the size of the file
253
254
255
        current_session = self.sessions.iloc[index]

        nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
Anthony Larcher's avatar
Anthony Larcher committed
256
257
258
259
260
261
262
263
264
265
266
267
        original_start = int(current_session['start'])
        if self.overlap > 0:
            lowest_shift = self.overlap/2
            highest_shift = self.overlap/2
            if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
                lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
            if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
                highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
            start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
        else:
            start_frame = original_start

Anthony Larcher's avatar
Anthony Larcher committed
268
269
270
        if start_frame + self.sample_number >= nfo.frames:
            start_frame = numpy.min(nfo.frames - self.sample_number - 1)

Anthony Larcher's avatar
Anthony Larcher committed
271
272
273
        speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
                                            frame_offset=start_frame,
                                            num_frames=self.sample_number)
Anthony Larcher's avatar
Anthony Larcher committed
274

Anthony Larcher's avatar
Anthony Larcher committed
275
        speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
276

Anthony Larcher's avatar
Anthony Larcher committed
277
        if len(self.transform) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
278
279
280
281
282
283
            speech = data_augmentation(speech,
                                       speech_fs,
                                       self.transform,
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)
Anthony Larcher's avatar
Anthony Larcher committed
284
285

        speaker_idx = current_session["speaker_idx"]
Anthony Larcher's avatar
Anthony Larcher committed
286

Anthony Larcher's avatar
Anthony Larcher committed
287
        if self.output_format == "pytorch":
Anthony Larcher's avatar
Anthony Larcher committed
288
            return speech, torch.tensor(speaker_idx)
Anthony Larcher's avatar
Anthony Larcher committed
289
        else:
Anthony Larcher's avatar
Anthony Larcher committed
290
291
            return speech, speaker_idx

Anthony Larcher's avatar
Anthony Larcher committed
292
293
294
295
296
297
298
    def __len__(self):
        """

        :param self:
        :return:
        """
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
299

300

Anthony Larcher's avatar
Anthony Larcher committed
301
302
303
304
305
class IdMapSet(Dataset):
    """
    DataSet that provide data according to a sidekit.IdMap object
    """

306
307
    def __init__(self,
                 idmap_name,
Anthony Larcher's avatar
Anthony Larcher committed
308
                 data_path,
309
                 file_extension,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
310
311
                 transform_pipeline={},
                 transform_number=1,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
312
                 sliding_window=False,
Anthony Larcher's avatar
Anthony Larcher committed
313
314
315
                 window_len=24000,
                 window_shift=8000,
                 sample_rate=16000,
Anthony Larcher's avatar
Anthony Larcher committed
316
                 min_duration=0.165
317
                 ):
Anthony Larcher's avatar
Anthony Larcher committed
318
319
320
321
322
        """

        :param data_root_name:
        :param idmap_name:
        """
Anthony Larcher's avatar
Anthony Larcher committed
323
        if isinstance(idmap_name, IdMap):
324
325
326
327
            self.idmap = idmap_name
        else:
            self.idmap = IdMap(idmap_name)

Anthony Larcher's avatar
Anthony Larcher committed
328
        self.data_path = data_path
Anthony Larcher's avatar
Anthony Larcher committed
329
330
        self.file_extension = file_extension
        self.len = self.idmap.leftids.shape[0]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
331
        self.transformation = transform_pipeline
Anthony Larcher's avatar
Anthony Larcher committed
332
333
334
335
336
        self.min_sample_nb = min_duration * sample_rate
        self.sample_rate = sample_rate
        self.sliding_window = sliding_window
        self.window_len = window_len
        self.window_shift = window_shift
Anthony Larcher's avatar
debug    
Anthony Larcher committed
337
338
        self.transform_number = transform_number
        
Anthony Larcher's avatar
Anthony Larcher committed
339

Anthony Larcher's avatar
debug    
Anthony Larcher committed
340
341
        #if self.transformation is not None:
        #    self.transform_list = self.transformation.split(",")
Anthony Larcher's avatar
Anthony Larcher committed
342
343

        self.noise_df = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
344
        if "add_noise" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
345
346
            # Load the noise dataset, filter according to the duration
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
Anthony Larcher's avatar
debug    
Anthony Larcher committed
347
348
            #tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
            self.noise_df = noise_df.set_index(noise_df.type)
Anthony Larcher's avatar
Anthony Larcher committed
349
350

        self.rir_df = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
351
        if "add_reverb" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
352
353
354
            # load the RIR database
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
            self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
Anthony Larcher's avatar
Anthony Larcher committed
355
356
357
358
359
360
361

    def __getitem__(self, index):
        """

        :param index:
        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
362
        if self.idmap.start[index] is None:
Anthony Larcher's avatar
Anthony Larcher committed
363
            start = 0
Anthony Larcher's avatar
debug    
Anthony Larcher committed
364
365
        else:
            start = int(self.idmap.start[index]) * 160
Anthony Larcher's avatar
Anthony Larcher committed
366

Anthony Larcher's avatar
Anthony Larcher committed
367
        if self.idmap.stop[index] is None:
Anthony Larcher's avatar
Anthony Larcher committed
368
            speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
Anthony Larcher's avatar
debug    
Anthony Larcher committed
369
            duration = int(speech.shape[1] - start)
Anthony Larcher's avatar
Anthony Larcher committed
370
        else:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
371
            duration = int(self.idmap.stop[index]) * 160 - start
Anthony Larcher's avatar
Anthony Larcher committed
372
            # add this in case the segment is too short
Anthony Larcher's avatar
Anthony Larcher committed
373
374
375
            if duration <= self.min_sample_nb:
                middle = start + duration // 2
                start = max(0, int(middle - (self.min_sample_nb / 2)))
Anthony Larcher's avatar
debug    
Anthony Larcher committed
376
                duration = int(self.min_sample_nb)
Anthony Larcher's avatar
Anthony Larcher committed
377

Anthony Larcher's avatar
debug    
Anthony Larcher committed
378
379
380
            speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
                                                frame_offset=start,
                                                num_frames=duration)
Anthony Larcher's avatar
Anthony Larcher committed
381
382
383

        speech += 10e-6 * torch.randn(speech.shape)

Anthony Larcher's avatar
Anthony Larcher committed
384
385
386
        if self.sliding_window:
            speech = speech.squeeze().unfold(0,self.window_len,self.window_shift)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
387
        if len(self.transformation.keys()) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
388
389
            speech = data_augmentation(speech,
                                       speech_fs,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
390
                                       self.transformation,
Anthony Larcher's avatar
Anthony Larcher committed
391
392
393
394
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
395
396
        speech = speech.squeeze()

Anthony Larcher's avatar
Anthony Larcher committed
397
        return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, start + duration
Anthony Larcher's avatar
Anthony Larcher committed
398
399
400
401
402
403
404
405

    def __len__(self):
        """

        :param self:
        :return:
        """
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
406
407


Anthony Larcher's avatar
Anthony Larcher committed
408
class IdMapSetPerSpeaker(Dataset):
Anthony Larcher's avatar
Anthony Larcher committed
409
410
411
412
413
414
415
416
    """
    DataSet that provide data according to a sidekit.IdMap object
    """

    def __init__(self,
                 idmap_name,
                 data_root_path,
                 file_extension,
Anthony Larcher's avatar
Anthony Larcher committed
417
                 transform_pipeline={},
Anthony Larcher's avatar
Anthony Larcher committed
418
                 frame_rate=100,
Anthony Larcher's avatar
Anthony Larcher committed
419
                 min_duration=0.165
Anthony Larcher's avatar
Anthony Larcher committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
                 ):
        """

        :param data_root_name:
        :param idmap_name:
        """
        if isinstance(idmap_name, IdMap):
            self.idmap = idmap_name
        else:
            self.idmap = IdMap(idmap_name)

        self.data_root_path = data_root_path
        self.file_extension = file_extension
        self.len = len(set(self.idmap.leftids))
Anthony Larcher's avatar
debug    
Anthony Larcher committed
434
        self.transformation = transform_pipeline
Anthony Larcher's avatar
Anthony Larcher committed
435
436
437
        self.min_duration = min_duration
        self.sample_rate = frame_rate
        self.speaker_list = list(set(self.idmap.leftids))
Anthony Larcher's avatar
Anthony Larcher committed
438
439
440
441
442
        self.output_im = IdMap()
        self.output_im.leftids = numpy.unique(self.idmap.leftids)
        self.output_im.rightids = self.output_im.leftids
        self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
        self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
Anthony Larcher's avatar
Anthony Larcher committed
443
444

        self.transform = []
Anthony Larcher's avatar
Anthony Larcher committed
445
446
447
448
449
        #if (len(self.transformation) > 0):
        #    if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
        #        self.transform_list = self.transformation["pipeline"].split(',')
        if self.transformation is not None:
            self.transform_list = self.transformation.split(",")
Anthony Larcher's avatar
Anthony Larcher committed
450
451
452
453
454
455
456
457
458
459
460
461
462

        self.noise_df = None
        if "add_noise" in self.transform:
            # Load the noise dataset, filter according to the duration
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
            tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
            self.noise_df = tmp_df['file_id'].tolist()

        self.rir_df = None
        if "add_reverb" in self.transform:
            # load the RIR database
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
            self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
Anthony Larcher's avatar
Anthony Larcher committed
463
464
465
466
467
468
469
470
471

    def __getitem__(self, index):
        """

        :param index:
        :return:
        """

        # Loop on all segments from the given speaker to load data
Anthony Larcher's avatar
Anthony Larcher committed
472
        spk_id = self.output_im.leftids[index]
Anthony Larcher's avatar
Anthony Larcher committed
473
474
475
        tmp_data = []
        nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
        for id, start, stop in zip(self.idmap.leftids, self.idmap.start, self.idmap.stop):
Anthony Larcher's avatar
Anthony Larcher committed
476
477
478
479
480
481
482
483
484
            if id == spk_id:
                start = int(start)
                stop = int(stop)
                # add this in case the segment is too short
                if stop - start <= self.min_duration * nfo.samplerate:
                    middle = start + (stop - start) // 2
                    start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))

                    stop = int(start + self.min_duration * nfo.samplerate)
Anthony Larcher's avatar
Anthony Larcher committed
485

Anthony Larcher's avatar
Anthony Larcher committed
486
487
488
                speech, speech_fs = torchaudio.load(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
                                                    frame_offset=start,
                                                    num_frames=stop - start)
Anthony Larcher's avatar
Anthony Larcher committed
489

Anthony Larcher's avatar
Anthony Larcher committed
490
                tmp_data.append(speech)
Anthony Larcher's avatar
Anthony Larcher committed
491

Anthony Larcher's avatar
debug    
Anthony Larcher committed
492
        speech = torch.cat(tmp_data, dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
493
        speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
494

Anthony Larcher's avatar
Anthony Larcher committed
495
496
497
498
499
500
501
        if len(self.transform) > 0:
            speech = data_augmentation(speech,
                                       speech_fs,
                                       self.transform,
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)
Anthony Larcher's avatar
Anthony Larcher committed
502

Anthony Larcher's avatar
debug    
Anthony Larcher committed
503
504
        speech = speech.squeeze()

Anthony Larcher's avatar
Anthony Larcher committed
505
        return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
Anthony Larcher's avatar
Anthony Larcher committed
506
507
508
509

    def __len__(self):
        """

Anthony Larcher's avatar
Anthony Larcher committed
510
        :param self:
Anthony Larcher's avatar
Anthony Larcher committed
511
512
        :return:
        """
Anthony Larcher's avatar
debug    
Anthony Larcher committed
513
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
514