xsets.py 22.2 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT.  If not, see <http://www.gnu.org/licenses/>.

"""
Anthony Larcher's avatar
v1.3.7    
Anthony Larcher committed
25
Copyright 2014-2021 Anthony Larcher
Anthony Larcher's avatar
Anthony Larcher committed
26
27
28

"""

Anthony Larcher's avatar
debug    
Anthony Larcher committed
29
import math
Anthony Larcher's avatar
Anthony Larcher committed
30
import numpy
Anthony Larcher's avatar
Anthony Larcher committed
31
32
import pandas
import random
Anthony Larcher's avatar
Anthony Larcher committed
33
import torch
Anthony Larcher's avatar
Anthony Larcher committed
34
import torchaudio
Anthony Larcher's avatar
Anthony Larcher committed
35
import tqdm
Anthony Larcher's avatar
Anthony Larcher committed
36
import soundfile
Anthony Larcher's avatar
Anthony Larcher committed
37
import yaml
Anthony Larcher's avatar
Anthony Larcher committed
38

Anthony Larcher's avatar
Anthony Larcher committed
39
from .augmentation import data_augmentation
Anthony Larcher's avatar
Anthony Larcher committed
40
from ..bosaris.idmap import IdMap
Anthony Larcher's avatar
New vad    
Anthony Larcher committed
41
from torch.utils.data import Dataset
Anthony Larcher's avatar
Anthony Larcher committed
42

Anthony Larcher's avatar
Anthony Larcher committed
43
44
__license__ = "LGPL"
__author__ = "Anthony Larcher"
Anthony Larcher's avatar
v1.3.7    
Anthony Larcher committed
45
__copyright__ = "Copyright 2015-2021 Anthony Larcher"
Anthony Larcher's avatar
Anthony Larcher committed
46
47
48
49
50
51
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'


Anthony Larcher's avatar
Anthony Larcher committed
52
class SideSampler(torch.utils.data.Sampler):
Anthony Larcher's avatar
SpecAug    
Anthony Larcher committed
53
    """
Anthony Larcher's avatar
Anthony Larcher committed
54
    Data Sampler used to generate uniformly distributed batches
Anthony Larcher's avatar
SpecAug    
Anthony Larcher committed
55
    """
Anthony Larcher's avatar
Anthony Larcher committed
56

Anthony Larcher's avatar
Anthony Larcher committed
57
58
    def __init__(self,
                 data_source,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
59
60
61
62
63
64
                 spk_count,
                 examples_per_speaker,
                 samples_per_speaker,
                 batch_size,
                 seed=0,
                 rank=0,
Anthony Larcher's avatar
Anthony Larcher committed
65
                 num_process=1,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
66
                 num_replicas=1):
Anthony Larcher's avatar
Anthony Larcher committed
67
68
69
70
71
72
73
74
        """[summary]

        Args:
            data_source ([type]): [description]
            spk_count ([type]): [description]
            examples_per_speaker ([type]): [description]
            samples_per_speaker ([type]): [description]
            batch_size ([type]): [description]
Anthony Larcher's avatar
Anthony Larcher committed
75
            num_replicas: number of GPUs for parallel computing
Anthony Larcher's avatar
Anthony Larcher committed
76
77
78
79
80
81
        """
        self.train_sessions = data_source
        self.labels_to_indices = dict()
        self.spk_count = spk_count
        self.examples_per_speaker = examples_per_speaker
        self.samples_per_speaker = samples_per_speaker
Anthony Larcher's avatar
merge    
Anthony Larcher committed
82
83
84
        self.epoch = 0
        self.seed = seed
        self.rank = rank
Anthony Larcher's avatar
Anthony Larcher committed
85
        self.num_process = num_process
Anthony Larcher's avatar
merge    
Anthony Larcher committed
86
87
        self.num_replicas = num_replicas

Anthony Larcher's avatar
merge    
Anthony Larcher committed
88
        assert batch_size % examples_per_speaker == 0
Anthony Larcher's avatar
Anthony Larcher committed
89
        assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
Anthony Larcher's avatar
merge    
Anthony Larcher committed
90

Anthony Larcher's avatar
Anthony Larcher committed
91
        self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
Anthony Larcher's avatar
Anthony Larcher committed
92
93
94
95
96
97

        # reference all segment indexes per speaker
        for idx in range(self.spk_count):
            self.labels_to_indices[idx] = list()
        for idx, value in enumerate(self.train_sessions):
            self.labels_to_indices[value].append(idx)
Anthony Larcher's avatar
Anthony Larcher committed
98
        # shuffle segments per speaker
Anthony Larcher's avatar
merge    
Anthony Larcher committed
99
100
101
102
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        for idx, ldlist in enumerate(self.labels_to_indices.values()):
            ldlist = numpy.array(ldlist)
Anthony Larcher's avatar
Anthony Larcher committed
103
            self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0], generator=g).numpy()]
Anthony Larcher's avatar
Anthony Larcher committed
104
105
106
107

        self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)

    def __iter__(self):
Anthony Larcher's avatar
Anthony Larcher committed
108
109
110
111
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        numpy.random.seed(self.seed + self.epoch)

Anthony Larcher's avatar
Anthony Larcher committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        # Generate batches per speaker
        straight = numpy.arange(self.spk_count)
        indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
        batch_cursor = 0
        # each line of "indices" represents all speaker indexes (shuffled in a different way)
        for idx in range(self.samples_per_speaker):
            if batch_cursor == 0:
                indices[idx, :] = numpy.random.permutation(straight)
            else:
                # if one batch is split between the end of previous line and the beginning of current line
                # we make sure no speaker is present twice in this batch
                probs = numpy.ones_like(straight)
                probs[indices[idx-1, -batch_cursor:]] = 0
                probs = probs/numpy.sum(probs)
                indices[idx, :self.batch_size - batch_cursor] = numpy.random.choice(self.spk_count, self.batch_size - batch_cursor, replace=False, p=probs)
                probs = numpy.ones_like(straight)
                probs[indices[idx, :self.batch_size - batch_cursor]] = 0
                to_pick = numpy.sum(probs).astype(numpy.int)
                probs = probs/numpy.sum(probs)
                indices[idx, self.batch_size - batch_cursor:] = numpy.random.choice(self.spk_count, to_pick, replace=False, p=probs)

                assert numpy.sum(indices[idx, :]) == numpy.sum(straight)
            batch_cursor = (batch_cursor + indices.shape[1]) % self.batch_size

        # now we have the speaker indexes to sample in batches
        batch_matrix = numpy.repeat(indices, self.examples_per_speaker, axis=1).flatten()

        # we want to convert the speaker indexes into segment indexes
        self.index_iterator = numpy.zeros_like(batch_matrix)

        # keep track of next segment index to sample for each speaker
        for idx, value in enumerate(batch_matrix):
            if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
Anthony Larcher's avatar
Anthony Larcher committed
145
                self.labels_to_indices[value] = self.labels_to_indices[value][torch.randperm(self.labels_to_indices[value].shape[0], generator=g)]
Anthony Larcher's avatar
Anthony Larcher committed
146
147
148
                self.segment_cursors[value] = 0
            self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
            self.segment_cursors[value] += 1
Anthony Larcher's avatar
merge    
Anthony Larcher committed
149

Anthony Larcher's avatar
debug    
Anthony Larcher committed
150
        self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
Anthony Larcher's avatar
Anthony Larcher committed
151

Anthony Larcher's avatar
Anthony Larcher committed
152
153
154
        self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()

        return iter(self.index_iterator)
Anthony Larcher's avatar
Anthony Larcher committed
155
156

    def __len__(self) -> int:
Anthony Larcher's avatar
Anthony Larcher committed
157
        return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process
Anthony Larcher's avatar
merge    
Anthony Larcher committed
158
159
160

    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch
Anthony Larcher's avatar
Anthony Larcher committed
161
162


Anthony Larcher's avatar
Anthony Larcher committed
163
164
165
class SideSet(Dataset):

    def __init__(self,
Anthony Larcher's avatar
Anthony Larcher committed
166
                 dataset,
Anthony Larcher's avatar
Anthony Larcher committed
167
                 set_type="train",
168
                 chunk_per_segment=1,
Anthony Larcher's avatar
Anthony Larcher committed
169
                 transform_number=1,
170
                 overlap=0.,
171
                 dataset_df=None,
Anthony Larcher's avatar
Anthony Larcher committed
172
                 min_duration=0.165,
Anthony Larcher's avatar
Anthony Larcher committed
173
                 output_format="pytorch",
Anthony Larcher's avatar
Anthony Larcher committed
174
                 ):
Anthony Larcher's avatar
Anthony Larcher committed
175
176
        """

177
178
179
180
        :param dataset_yaml: name of the YAML file describing the dataset
        :param set_type: string, can be "train" or "validation"
        :param chunk_per_segment: number of chunks to select for each segment
        default is 1 and -1 means select all possible chunks
Anthony Larcher's avatar
Anthony Larcher committed
181
        """
Anthony Larcher's avatar
Anthony Larcher committed
182
        self.data_path = dataset["data_path"]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
183
        self.sample_rate = int(dataset["sample_rate"])
Anthony Larcher's avatar
Anthony Larcher committed
184
        self.data_file_extension = dataset["data_file_extension"]
Anthony Larcher's avatar
Anthony Larcher committed
185
        self.transformation = ''
186
        self.min_duration = min_duration
Anthony Larcher's avatar
Anthony Larcher committed
187
        self.output_format = output_format
Anthony Larcher's avatar
Anthony Larcher committed
188
189
        self.transform_number = transform_number

Anthony Larcher's avatar
Anthony Larcher committed
190
191
        if set_type == "train":
            self.duration = dataset["train"]["duration"]
Anthony Larcher's avatar
Anthony Larcher committed
192
            self.transformation = dataset["train"]["transformation"]
Anthony Larcher's avatar
Anthony Larcher committed
193
        else:
Anthony Larcher's avatar
Anthony Larcher committed
194
195
            self.duration = dataset["valid"]["duration"]
            self.transformation = dataset["valid"]["transformation"]
Anthony Larcher's avatar
Anthony Larcher committed
196

Anthony Larcher's avatar
Anthony Larcher committed
197
        self.sample_number = int(self.duration * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
198
        self.overlap = int(overlap * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
199

Anthony Larcher's avatar
Anthony Larcher committed
200
        # Load the dataset description as pandas.dataframe
Anthony Larcher's avatar
Anthony Larcher committed
201
        if dataset_df is None:
Anthony Larcher's avatar
Anthony Larcher committed
202
203
            df = pandas.read_csv(dataset["dataset_description"])
        else:
Anthony Larcher's avatar
Anthony Larcher committed
204
205
            assert isinstance(dataset_df, pandas.DataFrame)
            df = dataset_df
Anthony Larcher's avatar
Anthony Larcher committed
206

207
208
        # From each segment which duration is longer than the chosen one
        # select the requested segments
Anthony Larcher's avatar
Anthony Larcher committed
209
        if set_type == "train":
210
            tmp_sessions = df.loc[df['duration'] > self.duration]
Anthony Larcher's avatar
Anthony Larcher committed
211
        else:
Anthony Larcher's avatar
Anthony Larcher committed
212
            if not "duration" == '':
213
                tmp_sessions = df.loc[df['duration'] > self.duration]
Anthony Larcher's avatar
Anthony Larcher committed
214
215
            else:
                self.sessions = df
216

Anthony Larcher's avatar
Anthony Larcher committed
217
        # Create lists for each column of the dataframe
218
        df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
Anthony Larcher's avatar
Anthony Larcher committed
219
220
        df_dict["file_start"] = list()
        df_dict["file_duration"] = list()
Anthony Larcher's avatar
Anthony Larcher committed
221
222

        # For each segment, get all possible segments with the current overlap
Anthony Larcher's avatar
Anthony Larcher committed
223
        for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
224
            current_session = tmp_sessions.iloc[idx]
Anthony Larcher's avatar
Anthony Larcher committed
225

226
            # Compute possible starts
Anthony Larcher's avatar
Anthony Larcher committed
227
            possible_starts = numpy.arange(0,
228
                                           int(self.sample_rate * (current_session.duration - self.duration)),
Anthony Larcher's avatar
Anthony Larcher committed
229
230
                                           self.sample_number
                                           ) + int(self.sample_rate * (current_session.duration % self.duration / 2))
231
            possible_starts += int(self.sample_rate * current_session.start)
232
233
234
235
236
237
238

            # Select max(seg_nb, possible_segments) segments
            if chunk_per_segment == -1:
                starts = possible_starts
                chunk_nb = len(possible_starts)
            else:
                chunk_nb = min(len(possible_starts), chunk_per_segment)
Anthony Larcher's avatar
Anthony Larcher committed
239
                starts = numpy.random.permutation(possible_starts)[:chunk_nb]
Anthony Larcher's avatar
Anthony Larcher committed
240
241

            # Once we know how many segments are selected, create the other fields to fill the DataFrame
242
            for ii in range(chunk_nb):
243
244
245
                df_dict["database"].append(current_session.database)
                df_dict["speaker_id"].append(current_session.speaker_id)
                df_dict["file_id"].append(current_session.file_id)
246
247
                df_dict["start"].append(starts[ii])
                df_dict["duration"].append(self.duration)
Anthony Larcher's avatar
Anthony Larcher committed
248
249
                df_dict["file_start"].append(current_session.start)
                df_dict["file_duration"].append(current_session.duration)
250
251
                df_dict["speaker_idx"].append(current_session.speaker_idx)
                df_dict["gender"].append(current_session.gender)
252
253

        self.sessions = pandas.DataFrame.from_dict(df_dict)
Anthony Larcher's avatar
Anthony Larcher committed
254
        self.len = len(self.sessions)
Anthony Larcher's avatar
Anthony Larcher committed
255

Anthony Larcher's avatar
merge    
Anthony Larcher committed
256
        self.transform = dict()
Anthony Larcher's avatar
Anthony Larcher committed
257
        if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
258
259
260
261
262
            transforms = self.transformation["pipeline"].split(',')
            if "add_noise" in transforms:
                self.transform["add_noise"] = self.transformation["add_noise"]
            if "add_reverb" in transforms:
                self.transform["add_reverb"] = self.transformation["add_reverb"]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
263
264
265
266
            if "codec" in transforms:
                self.transform["codec"] = []
            if "phone_filtering" in transforms:
                self.transform["phone_filtering"] = []
Anthony Larcher's avatar
Anthony Larcher committed
267

Anthony Larcher's avatar
Anthony Larcher committed
268
        self.noise_df = None
Anthony Larcher's avatar
Anthony Larcher committed
269
        if "add_noise" in self.transform:
Anthony Larcher's avatar
Anthony Larcher committed
270
271
272
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
            noise_df = noise_df.loc[noise_df.duration > self.duration]
            self.noise_df = noise_df.set_index(noise_df.type)
Anthony Larcher's avatar
Anthony Larcher committed
273

Anthony Larcher's avatar
Anthony Larcher committed
274
        self.rir_df = None
Anthony Larcher's avatar
Anthony Larcher committed
275
        if "add_reverb" in self.transform:
Anthony Larcher's avatar
Anthony Larcher committed
276
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
277
            tmp_rir_df = tmp_rir_df.loc[tmp_rir_df["type"] == "simulated_rirs"]
Anthony Larcher's avatar
Anthony Larcher committed
278
            # load the RIR database
Anthony Larcher's avatar
Anthony Larcher committed
279
280
            self.rir_df = tmp_rir_df.set_index(tmp_rir_df.type)

Anthony Larcher's avatar
Anthony Larcher committed
281

Anthony Larcher's avatar
Anthony Larcher committed
282
283
284
285
286
    def __getitem__(self, index):
        """

        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
287
        # Check the size of the file
288
289
        current_session = self.sessions.iloc[index]

Anthony Larcher's avatar
merge    
Anthony Larcher committed
290
291
        # TODO is this required ?
        nfo = torchaudio.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
Anthony Larcher's avatar
Anthony Larcher committed
292
293
294
295
296
297
298
299
300
301
302
303
        original_start = int(current_session['start'])
        if self.overlap > 0:
            lowest_shift = self.overlap/2
            highest_shift = self.overlap/2
            if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
                lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
            if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
                highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
            start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
        else:
            start_frame = original_start

Anthony Larcher's avatar
merge    
Anthony Larcher committed
304
305
306
307
        conversion_rate = nfo.sample_rate // self.sample_rate

        if start_frame + conversion_rate * self.sample_number >= nfo.num_frames:
            start_frame = numpy.min(nfo.num_frames - conversion_rate * self.sample_number - 1)
Anthony Larcher's avatar
Anthony Larcher committed
308

Anthony Larcher's avatar
Anthony Larcher committed
309
        speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
Anthony Larcher's avatar
merge    
Anthony Larcher committed
310
311
312
313
314
                                            frame_offset=conversion_rate*start_frame,
                                            num_frames=conversion_rate*self.sample_number)

        if nfo.sample_rate != self.sample_rate:
            speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
Anthony Larcher's avatar
Anthony Larcher committed
315

Anthony Larcher's avatar
Anthony Larcher committed
316
        speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
317

Anthony Larcher's avatar
Anthony Larcher committed
318
        if len(self.transform) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
319
            speech = data_augmentation(speech,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
320
                                       self.sample_rate,
Anthony Larcher's avatar
Anthony Larcher committed
321
322
323
324
                                       self.transform,
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)
Anthony Larcher's avatar
Anthony Larcher committed
325
326

        speaker_idx = current_session["speaker_idx"]
Anthony Larcher's avatar
Anthony Larcher committed
327

Anthony Larcher's avatar
Anthony Larcher committed
328
        if self.output_format == "pytorch":
Anthony Larcher's avatar
Anthony Larcher committed
329
            return speech, torch.tensor(speaker_idx)
Anthony Larcher's avatar
Anthony Larcher committed
330
        else:
Anthony Larcher's avatar
Anthony Larcher committed
331
332
            return speech, speaker_idx

Anthony Larcher's avatar
Anthony Larcher committed
333
334
335
336
337
338
339
    def __len__(self):
        """

        :param self:
        :return:
        """
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
340

Anthony Larcher's avatar
merge    
Anthony Larcher committed
341
342
343
344
345
346
347
348
def get_sample(path, resample=None):
  effects = [
    ["remix", "1"]
  ]
  if resample:
    effects.append(["rate", f'{resample}'])
  return torchaudio.sox_effects.apply_effects_file(path, effects=effects)

349

Anthony Larcher's avatar
Anthony Larcher committed
350
351
352
353
354
class IdMapSet(Dataset):
    """
    DataSet that provide data according to a sidekit.IdMap object
    """

355
356
    def __init__(self,
                 idmap_name,
Anthony Larcher's avatar
Anthony Larcher committed
357
                 data_path,
358
                 file_extension,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
359
360
                 transform_pipeline={},
                 transform_number=1,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
361
                 sliding_window=False,
Anthony Larcher's avatar
Anthony Larcher committed
362
363
                 window_len=3.,
                 window_shift=1.,
Anthony Larcher's avatar
Anthony Larcher committed
364
                 sample_rate=16000,
Anthony Larcher's avatar
Anthony Larcher committed
365
                 min_duration=0.165
366
                 ):
Anthony Larcher's avatar
Anthony Larcher committed
367
368
369
370
371
        """

        :param data_root_name:
        :param idmap_name:
        """
Anthony Larcher's avatar
Anthony Larcher committed
372
        if isinstance(idmap_name, IdMap):
373
374
375
376
            self.idmap = idmap_name
        else:
            self.idmap = IdMap(idmap_name)

Anthony Larcher's avatar
Anthony Larcher committed
377
        self.data_path = data_path
Anthony Larcher's avatar
Anthony Larcher committed
378
379
        self.file_extension = file_extension
        self.len = self.idmap.leftids.shape[0]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
380
        self.transformation = transform_pipeline
Anthony Larcher's avatar
Anthony Larcher committed
381
382
383
384
385
        self.min_sample_nb = min_duration * sample_rate
        self.sample_rate = sample_rate
        self.sliding_window = sliding_window
        self.window_len = window_len
        self.window_shift = window_shift
Anthony Larcher's avatar
debug    
Anthony Larcher committed
386
        self.transform_number = transform_number
Anthony Larcher's avatar
Anthony Larcher committed
387
388

        self.noise_df = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
389
        if "add_noise" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
390
391
            # Load the noise dataset, filter according to the duration
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
Anthony Larcher's avatar
debug    
Anthony Larcher committed
392
            self.noise_df = noise_df.set_index(noise_df.type)
Anthony Larcher's avatar
Anthony Larcher committed
393
394

        self.rir_df = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
395
        if "add_reverb" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
396
397
398
            # load the RIR database
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
            self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
Anthony Larcher's avatar
Anthony Larcher committed
399
400
401
402
403
404
405

    def __getitem__(self, index):
        """

        :param index:
        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
406
        if self.idmap.start[index] is None:
Anthony Larcher's avatar
Anthony Larcher committed
407
            start = 0
Anthony Larcher's avatar
debug    
Anthony Larcher committed
408
409
        else:
            start = int(self.idmap.start[index]) * 160
Anthony Larcher's avatar
Anthony Larcher committed
410

Anthony Larcher's avatar
Anthony Larcher committed
411
        if self.idmap.stop[index] is None:
Anthony Larcher's avatar
Anthony Larcher committed
412
            speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
Anthony Larcher's avatar
debug    
Anthony Larcher committed
413
            duration = int(speech.shape[1] - start)
Anthony Larcher's avatar
Anthony Larcher committed
414
        else:
Anthony Larcher's avatar
debug    
Anthony Larcher committed
415
            duration = int(self.idmap.stop[index]) * 160 - start
Anthony Larcher's avatar
Anthony Larcher committed
416
            # add this in case the segment is too short
Anthony Larcher's avatar
Anthony Larcher committed
417
418
419
            if duration <= self.min_sample_nb:
                middle = start + duration // 2
                start = max(0, int(middle - (self.min_sample_nb / 2)))
Anthony Larcher's avatar
debug    
Anthony Larcher committed
420
                duration = int(self.min_sample_nb)
Anthony Larcher's avatar
Anthony Larcher committed
421

Anthony Larcher's avatar
debug    
Anthony Larcher committed
422
423
424
            speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
                                                frame_offset=start,
                                                num_frames=duration)
Anthony Larcher's avatar
Anthony Larcher committed
425
426
427

        speech += 10e-6 * torch.randn(speech.shape)

Anthony Larcher's avatar
Anthony Larcher committed
428
429
430
        if self.sliding_window:
            speech = speech.squeeze().unfold(0,self.window_len,self.window_shift)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
431
        if len(self.transformation.keys()) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
432
433
            speech = data_augmentation(speech,
                                       speech_fs,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
434
                                       self.transformation,
Anthony Larcher's avatar
Anthony Larcher committed
435
436
437
438
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
439
440
        speech = speech.squeeze()

Anthony Larcher's avatar
Anthony Larcher committed
441
        return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, start + duration
Anthony Larcher's avatar
Anthony Larcher committed
442
443
444
445
446
447
448
449

    def __len__(self):
        """

        :param self:
        :return:
        """
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
450
451


Anthony Larcher's avatar
Anthony Larcher committed
452
class IdMapSetPerSpeaker(Dataset):
Anthony Larcher's avatar
Anthony Larcher committed
453
454
455
456
457
458
459
460
    """
    DataSet that provide data according to a sidekit.IdMap object
    """

    def __init__(self,
                 idmap_name,
                 data_root_path,
                 file_extension,
Anthony Larcher's avatar
Anthony Larcher committed
461
                 transform_pipeline={},
Anthony Larcher's avatar
Anthony Larcher committed
462
                 frame_rate=100,
Anthony Larcher's avatar
Anthony Larcher committed
463
                 min_duration=0.165
Anthony Larcher's avatar
Anthony Larcher committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                 ):
        """

        :param data_root_name:
        :param idmap_name:
        """
        if isinstance(idmap_name, IdMap):
            self.idmap = idmap_name
        else:
            self.idmap = IdMap(idmap_name)

        self.data_root_path = data_root_path
        self.file_extension = file_extension
        self.len = len(set(self.idmap.leftids))
Anthony Larcher's avatar
debug    
Anthony Larcher committed
478
        self.transformation = transform_pipeline
Anthony Larcher's avatar
Anthony Larcher committed
479
480
481
        self.min_duration = min_duration
        self.sample_rate = frame_rate
        self.speaker_list = list(set(self.idmap.leftids))
Anthony Larcher's avatar
Anthony Larcher committed
482
483
484
485
486
        self.output_im = IdMap()
        self.output_im.leftids = numpy.unique(self.idmap.leftids)
        self.output_im.rightids = self.output_im.leftids
        self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
        self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
Anthony Larcher's avatar
Anthony Larcher committed
487
488

        self.transform = []
Anthony Larcher's avatar
Anthony Larcher committed
489
490
491
492
493
        #if (len(self.transformation) > 0):
        #    if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
        #        self.transform_list = self.transformation["pipeline"].split(',')
        if self.transformation is not None:
            self.transform_list = self.transformation.split(",")
Anthony Larcher's avatar
Anthony Larcher committed
494
495
496
497
498
499
500
501
502
503
504
505
506

        self.noise_df = None
        if "add_noise" in self.transform:
            # Load the noise dataset, filter according to the duration
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
            tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
            self.noise_df = tmp_df['file_id'].tolist()

        self.rir_df = None
        if "add_reverb" in self.transform:
            # load the RIR database
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
            self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
Anthony Larcher's avatar
Anthony Larcher committed
507
508
509
510
511
512
513
514
515

    def __getitem__(self, index):
        """

        :param index:
        :return:
        """

        # Loop on all segments from the given speaker to load data
Anthony Larcher's avatar
Anthony Larcher committed
516
        spk_id = self.output_im.leftids[index]
Anthony Larcher's avatar
Anthony Larcher committed
517
518
519
        tmp_data = []
        nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
        for id, start, stop in zip(self.idmap.leftids, self.idmap.start, self.idmap.stop):
Anthony Larcher's avatar
Anthony Larcher committed
520
521
522
523
524
525
526
527
528
            if id == spk_id:
                start = int(start)
                stop = int(stop)
                # add this in case the segment is too short
                if stop - start <= self.min_duration * nfo.samplerate:
                    middle = start + (stop - start) // 2
                    start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))

                    stop = int(start + self.min_duration * nfo.samplerate)
Anthony Larcher's avatar
Anthony Larcher committed
529

Anthony Larcher's avatar
Anthony Larcher committed
530
531
532
                speech, speech_fs = torchaudio.load(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
                                                    frame_offset=start,
                                                    num_frames=stop - start)
Anthony Larcher's avatar
Anthony Larcher committed
533

Anthony Larcher's avatar
Anthony Larcher committed
534
                tmp_data.append(speech)
Anthony Larcher's avatar
Anthony Larcher committed
535

Anthony Larcher's avatar
debug    
Anthony Larcher committed
536
        speech = torch.cat(tmp_data, dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
537
        speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
538

Anthony Larcher's avatar
Anthony Larcher committed
539
540
541
542
543
544
545
        if len(self.transform) > 0:
            speech = data_augmentation(speech,
                                       speech_fs,
                                       self.transform,
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)
Anthony Larcher's avatar
Anthony Larcher committed
546

Anthony Larcher's avatar
debug    
Anthony Larcher committed
547
548
        speech = speech.squeeze()

Anthony Larcher's avatar
Anthony Larcher committed
549
        return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
Anthony Larcher's avatar
Anthony Larcher committed
550
551
552
553

    def __len__(self):
        """

Anthony Larcher's avatar
Anthony Larcher committed
554
        :param self:
Anthony Larcher's avatar
Anthony Larcher committed
555
556
        :return:
        """
Anthony Larcher's avatar
debug    
Anthony Larcher committed
557
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
558