xsets.py 24.8 KB
Newer Older
Anthony Larcher's avatar
Anthony Larcher committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT.  If not, see <http://www.gnu.org/licenses/>.

"""
Anthony Larcher's avatar
v1.3.7    
Anthony Larcher committed
25
Copyright 2014-2021 Anthony Larcher
Anthony Larcher's avatar
Anthony Larcher committed
26
27
28

"""

Anthony Larcher's avatar
debug    
Anthony Larcher committed
29
import math
Anthony Larcher's avatar
Anthony Larcher committed
30
import numpy
Anthony Larcher's avatar
Anthony Larcher committed
31
32
import pandas
import random
Anthony Larcher's avatar
Anthony Larcher committed
33
import torch
Anthony Larcher's avatar
Anthony Larcher committed
34
import torchaudio
Anthony Larcher's avatar
Anthony Larcher committed
35
import tqdm
Anthony Larcher's avatar
Anthony Larcher committed
36
import soundfile
Anthony Larcher's avatar
Anthony Larcher committed
37
import yaml
Anthony Larcher's avatar
Anthony Larcher committed
38

Anthony Larcher's avatar
Anthony Larcher committed
39
from .augmentation import data_augmentation
Anthony Larcher's avatar
Anthony Larcher committed
40
from ..bosaris.idmap import IdMap
Anthony Larcher's avatar
New vad    
Anthony Larcher committed
41
from torch.utils.data import Dataset
Anthony Larcher's avatar
Anthony Larcher committed
42

Anthony Larcher's avatar
Anthony Larcher committed
43
44
__license__ = "LGPL"
__author__ = "Anthony Larcher"
Anthony Larcher's avatar
v1.3.7    
Anthony Larcher committed
45
__copyright__ = "Copyright 2015-2021 Anthony Larcher"
Anthony Larcher's avatar
Anthony Larcher committed
46
47
48
49
50
51
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'


Anthony Larcher's avatar
Anthony Larcher committed
52
class SideSampler(torch.utils.data.Sampler):
Anthony Larcher's avatar
SpecAug    
Anthony Larcher committed
53
    """
Anthony Larcher's avatar
Anthony Larcher committed
54
    Data Sampler used to generate uniformly distributed batches
Anthony Larcher's avatar
SpecAug    
Anthony Larcher committed
55
    """
Anthony Larcher's avatar
Anthony Larcher committed
56

Anthony Larcher's avatar
Anthony Larcher committed
57
58
    def __init__(self,
                 data_source,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
59
60
61
62
63
64
                 spk_count,
                 examples_per_speaker,
                 samples_per_speaker,
                 batch_size,
                 seed=0,
                 rank=0,
Anthony Larcher's avatar
Anthony Larcher committed
65
                 num_process=1,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
66
                 num_replicas=1):
Anthony Larcher's avatar
Anthony Larcher committed
67
68
69
70
71
72
73
74
        """[summary]

        Args:
            data_source ([type]): [description]
            spk_count ([type]): [description]
            examples_per_speaker ([type]): [description]
            samples_per_speaker ([type]): [description]
            batch_size ([type]): [description]
Anthony Larcher's avatar
Anthony Larcher committed
75
            num_replicas: number of GPUs for parallel computing
Anthony Larcher's avatar
Anthony Larcher committed
76
77
78
79
80
81
        """
        self.train_sessions = data_source
        self.labels_to_indices = dict()
        self.spk_count = spk_count
        self.examples_per_speaker = examples_per_speaker
        self.samples_per_speaker = samples_per_speaker
Anthony Larcher's avatar
merge    
Anthony Larcher committed
82
83
84
        self.epoch = 0
        self.seed = seed
        self.rank = rank
Anthony Larcher's avatar
Anthony Larcher committed
85
        self.num_process = num_process
Anthony Larcher's avatar
merge    
Anthony Larcher committed
86
87
        self.num_replicas = num_replicas

Anthony Larcher's avatar
Anthony Larcher committed
88
        assert batch_size % examples_per_speaker == 0
Anthony Larcher's avatar
Anthony Larcher committed
89
        assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
Anthony Larcher's avatar
merge    
Anthony Larcher committed
90

Anthony Larcher's avatar
back    
Anthony Larcher committed
91
92
        self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
        #self.batch_size = batch_size // self.examples_per_speaker
Anthony Larcher's avatar
Anthony Larcher committed
93
94
95
96
97
98

        # reference all segment indexes per speaker
        for idx in range(self.spk_count):
            self.labels_to_indices[idx] = list()
        for idx, value in enumerate(self.train_sessions):
            self.labels_to_indices[value].append(idx)
Anthony Larcher's avatar
Anthony Larcher committed
99
        # shuffle segments per speaker
Anthony Larcher's avatar
merge    
Anthony Larcher committed
100
101
102
103
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        for idx, ldlist in enumerate(self.labels_to_indices.values()):
            ldlist = numpy.array(ldlist)
Anthony Larcher's avatar
Anthony Larcher committed
104
            self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0], generator=g).numpy()]
Anthony Larcher's avatar
Anthony Larcher committed
105
106
107

        self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)

Anthony Larcher's avatar
Anthony Larcher committed
108

Gaël Le Lan's avatar
Gaël Le Lan committed
109
    def __iter__(self):        
Anthony Larcher's avatar
debug    
Anthony Larcher committed
110
111
112
113
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        numpy.random.seed(self.seed + self.epoch)

Anthony Larcher's avatar
Anthony Larcher committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        # Generate batches per speaker
        straight = numpy.arange(self.spk_count)
        indices = numpy.ones((self.samples_per_speaker, self.spk_count), dtype=numpy.int) * straight
        batch_cursor = 0
        # each line of "indices" represents all speaker indexes (shuffled in a different way)
        for idx in range(self.samples_per_speaker):
            if batch_cursor == 0:
                indices[idx, :] = numpy.random.permutation(straight)
            else:
                # if one batch is split between the end of previous line and the beginning of current line
                # we make sure no speaker is present twice in this batch
                probs = numpy.ones_like(straight)
                probs[indices[idx-1, -batch_cursor:]] = 0
                probs = probs/numpy.sum(probs)
                indices[idx, :self.batch_size - batch_cursor] = numpy.random.choice(self.spk_count, self.batch_size - batch_cursor, replace=False, p=probs)
                probs = numpy.ones_like(straight)
                probs[indices[idx, :self.batch_size - batch_cursor]] = 0
                to_pick = numpy.sum(probs).astype(numpy.int)
                probs = probs/numpy.sum(probs)
                indices[idx, self.batch_size - batch_cursor:] = numpy.random.choice(self.spk_count, to_pick, replace=False, p=probs)

                assert numpy.sum(indices[idx, :]) == numpy.sum(straight)
            batch_cursor = (batch_cursor + indices.shape[1]) % self.batch_size

        # now we have the speaker indexes to sample in batches
        batch_matrix = numpy.repeat(indices, self.examples_per_speaker, axis=1).flatten()

        # we want to convert the speaker indexes into segment indexes
        self.index_iterator = numpy.zeros_like(batch_matrix)

        # keep track of next segment index to sample for each speaker
        for idx, value in enumerate(batch_matrix):
            if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
Anthony Larcher's avatar
Anthony Larcher committed
147
                self.labels_to_indices[value] = self.labels_to_indices[value][torch.randperm(self.labels_to_indices[value].shape[0], generator=g)]
Anthony Larcher's avatar
Anthony Larcher committed
148
149
150
                self.segment_cursors[value] = 0
            self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
            self.segment_cursors[value] += 1
Anthony Larcher's avatar
back    
Anthony Larcher committed
151
        #self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
Anthony Larcher's avatar
Anthony Larcher committed
152

Anthony Larcher's avatar
back    
Anthony Larcher committed
153
154
        self.index_iterator = numpy.repeat(self.index_iterator, self.num_replicas)
        self.index_iterator = self.index_iterator.reshape(-1, self.num_process * self.examples_per_speaker * self.num_replicas)[:, self.rank * self.examples_per_speaker * self.num_replicas:(self.rank + 1) * self.examples_per_speaker * self.num_replicas].flatten()
Anthony Larcher's avatar
Anthony Larcher committed
155
156

        return iter(self.index_iterator)
Anthony Larcher's avatar
Anthony Larcher committed
157

Anthony Larcher's avatar
Anthony Larcher committed
158
    def __len__(self) -> int:
Anthony Larcher's avatar
back    
Anthony Larcher committed
159
160
161
        #return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_process
        return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker * self.num_replicas) // self.num_process

Anthony Larcher's avatar
back    
Anthony Larcher committed
162

Anthony Larcher's avatar
merge    
Anthony Larcher committed
163
164
165

    def set_epoch(self, epoch: int) -> None:
        self.epoch = epoch
Anthony Larcher's avatar
Anthony Larcher committed
166
167


Anthony Larcher's avatar
Anthony Larcher committed
168
169
170
class SideSet(Dataset):

    def __init__(self,
Anthony Larcher's avatar
Anthony Larcher committed
171
                 dataset,
Anthony Larcher's avatar
Anthony Larcher committed
172
                 set_type="train",
173
                 chunk_per_segment=1,
Anthony Larcher's avatar
Anthony Larcher committed
174
                 transform_number=1,
175
                 overlap=0.,
176
                 dataset_df=None,
Anthony Larcher's avatar
Anthony Larcher committed
177
                 min_duration=0.165,
Le Lan Gaël's avatar
Le Lan Gaël committed
178
                 output_format="pytorch"
Anthony Larcher's avatar
Anthony Larcher committed
179
                 ):
Anthony Larcher's avatar
Anthony Larcher committed
180
181
        """

182
183
184
185
        :param dataset_yaml: name of the YAML file describing the dataset
        :param set_type: string, can be "train" or "validation"
        :param chunk_per_segment: number of chunks to select for each segment
        default is 1 and -1 means select all possible chunks
Anthony Larcher's avatar
Anthony Larcher committed
186
        """
Anthony Larcher's avatar
Anthony Larcher committed
187
        self.data_path = dataset["data_path"]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
188
        self.sample_rate = int(dataset["sample_rate"])
Anthony Larcher's avatar
Anthony Larcher committed
189
        self.data_file_extension = dataset["data_file_extension"]
Anthony Larcher's avatar
Anthony Larcher committed
190
        self.transformation = ''
191
        self.min_duration = min_duration
Anthony Larcher's avatar
Anthony Larcher committed
192
        self.output_format = output_format
Anthony Larcher's avatar
Anthony Larcher committed
193
194
        self.transform_number = transform_number

Anthony Larcher's avatar
Anthony Larcher committed
195
196
        if set_type == "train":
            self.duration = dataset["train"]["duration"]
Anthony Larcher's avatar
Anthony Larcher committed
197
            self.transformation = dataset["train"]["transformation"]
Anthony Larcher's avatar
Anthony Larcher committed
198
        else:
Anthony Larcher's avatar
Anthony Larcher committed
199
200
            self.duration = dataset["valid"]["duration"]
            self.transformation = dataset["valid"]["transformation"]
Anthony Larcher's avatar
Anthony Larcher committed
201

Anthony Larcher's avatar
Anthony Larcher committed
202
        self.sample_number = int(self.duration * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
203
        self.overlap = int(overlap * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
204

Anthony Larcher's avatar
Anthony Larcher committed
205
        # Load the dataset description as pandas.dataframe
Anthony Larcher's avatar
Anthony Larcher committed
206
        if dataset_df is None:
Anthony Larcher's avatar
Anthony Larcher committed
207
208
            df = pandas.read_csv(dataset["dataset_description"])
        else:
Anthony Larcher's avatar
Anthony Larcher committed
209
210
            assert isinstance(dataset_df, pandas.DataFrame)
            df = dataset_df
Anthony Larcher's avatar
Anthony Larcher committed
211

212
213
        # From each segment which duration is longer than the chosen one
        # select the requested segments
Anthony Larcher's avatar
Anthony Larcher committed
214
        if set_type == "train":
215
            tmp_sessions = df.loc[df['duration'] > self.duration]
Anthony Larcher's avatar
Anthony Larcher committed
216
        else:
Anthony Larcher's avatar
Anthony Larcher committed
217
            if not "duration" == '':
218
                tmp_sessions = df.loc[df['duration'] > self.duration]
Anthony Larcher's avatar
Anthony Larcher committed
219
220
            else:
                self.sessions = df
221

Anthony Larcher's avatar
Anthony Larcher committed
222
        # Create lists for each column of the dataframe
223
        df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
Anthony Larcher's avatar
Anthony Larcher committed
224
225
        df_dict["file_start"] = list()
        df_dict["file_duration"] = list()
Anthony Larcher's avatar
Anthony Larcher committed
226
227

        # For each segment, get all possible segments with the current overlap
Anthony Larcher's avatar
Anthony Larcher committed
228
        for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
229
            current_session = tmp_sessions.iloc[idx]
Anthony Larcher's avatar
Anthony Larcher committed
230

231
            # Compute possible starts
Anthony Larcher's avatar
Anthony Larcher committed
232
            possible_starts = numpy.arange(0,
233
                                           int(self.sample_rate * (current_session.duration - self.duration)),
Anthony Larcher's avatar
Anthony Larcher committed
234
235
                                           self.sample_number
                                           ) + int(self.sample_rate * (current_session.duration % self.duration / 2))
236
            possible_starts += int(self.sample_rate * current_session.start)
237
238
239
240
241
242
243

            # Select max(seg_nb, possible_segments) segments
            if chunk_per_segment == -1:
                starts = possible_starts
                chunk_nb = len(possible_starts)
            else:
                chunk_nb = min(len(possible_starts), chunk_per_segment)
Anthony Larcher's avatar
Anthony Larcher committed
244
                starts = numpy.random.permutation(possible_starts)[:chunk_nb]
Anthony Larcher's avatar
Anthony Larcher committed
245
246

            # Once we know how many segments are selected, create the other fields to fill the DataFrame
247
            for ii in range(chunk_nb):
248
249
250
                df_dict["database"].append(current_session.database)
                df_dict["speaker_id"].append(current_session.speaker_id)
                df_dict["file_id"].append(current_session.file_id)
251
252
                df_dict["start"].append(starts[ii])
                df_dict["duration"].append(self.duration)
Anthony Larcher's avatar
Anthony Larcher committed
253
254
                df_dict["file_start"].append(current_session.start)
                df_dict["file_duration"].append(current_session.duration)
255
256
                df_dict["speaker_idx"].append(current_session.speaker_idx)
                df_dict["gender"].append(current_session.gender)
257
258

        self.sessions = pandas.DataFrame.from_dict(df_dict)
Anthony Larcher's avatar
Anthony Larcher committed
259
        self.len = len(self.sessions)
Anthony Larcher's avatar
Anthony Larcher committed
260

Anthony Larcher's avatar
merge    
Anthony Larcher committed
261
        self.transform = dict()
Anthony Larcher's avatar
Anthony Larcher committed
262
        if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
Anthony Larcher's avatar
merge    
Anthony Larcher committed
263
264
265
266
267
            transforms = self.transformation["pipeline"].split(',')
            if "add_noise" in transforms:
                self.transform["add_noise"] = self.transformation["add_noise"]
            if "add_reverb" in transforms:
                self.transform["add_reverb"] = self.transformation["add_reverb"]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
268
269
270
271
            if "codec" in transforms:
                self.transform["codec"] = []
            if "phone_filtering" in transforms:
                self.transform["phone_filtering"] = []
Gaël Le Lan's avatar
Gaël Le Lan committed
272
273
            if "stretch" in transforms:
                self.transform["stretch"] = []
Anthony Larcher's avatar
Anthony Larcher committed
274

Anthony Larcher's avatar
Anthony Larcher committed
275
        self.noise_df = None
Anthony Larcher's avatar
Anthony Larcher committed
276
        if "add_noise" in self.transform:
Anthony Larcher's avatar
Anthony Larcher committed
277
278
279
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
            noise_df = noise_df.loc[noise_df.duration > self.duration]
            self.noise_df = noise_df.set_index(noise_df.type)
Anthony Larcher's avatar
Anthony Larcher committed
280

Anthony Larcher's avatar
Anthony Larcher committed
281
        self.rir_df = None
Anthony Larcher's avatar
Anthony Larcher committed
282
        if "add_reverb" in self.transform:
Anthony Larcher's avatar
Anthony Larcher committed
283
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
284
            tmp_rir_df = tmp_rir_df.loc[tmp_rir_df["type"] == "simulated_rirs"]
Anthony Larcher's avatar
Anthony Larcher committed
285
            # load the RIR database
Anthony Larcher's avatar
Anthony Larcher committed
286
287
            self.rir_df = tmp_rir_df.set_index(tmp_rir_df.type)

Anthony Larcher's avatar
Anthony Larcher committed
288

Anthony Larcher's avatar
Anthony Larcher committed
289
290
291
292
293
    def __getitem__(self, index):
        """

        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
294
        # Check the size of the file
295
296
        current_session = self.sessions.iloc[index]

Anthony Larcher's avatar
merge    
Anthony Larcher committed
297
298
        # TODO is this required ?
        nfo = torchaudio.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
Anthony Larcher's avatar
Anthony Larcher committed
299
300
301
302
303
304
305
306
307
308
309
310
        original_start = int(current_session['start'])
        if self.overlap > 0:
            lowest_shift = self.overlap/2
            highest_shift = self.overlap/2
            if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
                lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
            if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
                highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
            start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
        else:
            start_frame = original_start

Anthony Larcher's avatar
merge    
Anthony Larcher committed
311
312
313
314
        conversion_rate = nfo.sample_rate // self.sample_rate

        if start_frame + conversion_rate * self.sample_number >= nfo.num_frames:
            start_frame = numpy.min(nfo.num_frames - conversion_rate * self.sample_number - 1)
Anthony Larcher's avatar
Anthony Larcher committed
315

Anthony Larcher's avatar
Anthony Larcher committed
316
        speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
Anthony Larcher's avatar
merge    
Anthony Larcher committed
317
318
319
320
321
                                            frame_offset=conversion_rate*start_frame,
                                            num_frames=conversion_rate*self.sample_number)

        if nfo.sample_rate != self.sample_rate:
            speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
Anthony Larcher's avatar
Anthony Larcher committed
322

Anthony Larcher's avatar
Anthony Larcher committed
323
        speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
324

Anthony Larcher's avatar
Anthony Larcher committed
325
        if len(self.transform) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
326
            speech = data_augmentation(speech,
Anthony Larcher's avatar
merge    
Anthony Larcher committed
327
                                       self.sample_rate,
Anthony Larcher's avatar
Anthony Larcher committed
328
329
330
331
                                       self.transform,
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)
Anthony Larcher's avatar
Anthony Larcher committed
332
333

        speaker_idx = current_session["speaker_idx"]
Anthony Larcher's avatar
Anthony Larcher committed
334

Anthony Larcher's avatar
Anthony Larcher committed
335
        if self.output_format == "pytorch":
Anthony Larcher's avatar
Anthony Larcher committed
336
            return speech, torch.tensor(speaker_idx)
Anthony Larcher's avatar
Anthony Larcher committed
337
        else:
Anthony Larcher's avatar
Anthony Larcher committed
338
339
            return speech, speaker_idx

Anthony Larcher's avatar
Anthony Larcher committed
340
341
342
343
344
345
346
    def __len__(self):
        """

        :param self:
        :return:
        """
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
347

Anthony Larcher's avatar
merge    
Anthony Larcher committed
348
349
350
351
352
353
354
355
def get_sample(path, resample=None):
  effects = [
    ["remix", "1"]
  ]
  if resample:
    effects.append(["rate", f'{resample}'])
  return torchaudio.sox_effects.apply_effects_file(path, effects=effects)

356

Anthony Larcher's avatar
Anthony Larcher committed
357
358
359
360
361
class IdMapSet(Dataset):
    """
    DataSet that provide data according to a sidekit.IdMap object
    """

362
363
    def __init__(self,
                 idmap_name,
Anthony Larcher's avatar
Anthony Larcher committed
364
                 data_path,
365
                 file_extension,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
366
367
                 transform_pipeline={},
                 transform_number=1,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
368
                 sliding_window=False,
Anthony Larcher's avatar
Anthony Larcher committed
369
                 window_len=3.,
Anthony Larcher's avatar
Anthony Larcher committed
370
                 window_shift=1.5,
Anthony Larcher's avatar
Anthony Larcher committed
371
                 sample_rate=16000,
Anthony Larcher's avatar
Anthony Larcher committed
372
                 min_duration=0.165
373
                 ):
Anthony Larcher's avatar
Anthony Larcher committed
374
375
376
377
378
        """

        :param data_root_name:
        :param idmap_name:
        """
Anthony Larcher's avatar
Anthony Larcher committed
379
        if isinstance(idmap_name, IdMap):
380
381
382
383
            self.idmap = idmap_name
        else:
            self.idmap = IdMap(idmap_name)

Anthony Larcher's avatar
Anthony Larcher committed
384
        self.data_path = data_path
Anthony Larcher's avatar
Anthony Larcher committed
385
386
        self.file_extension = file_extension
        self.len = self.idmap.leftids.shape[0]
Anthony Larcher's avatar
debug    
Anthony Larcher committed
387
        self.transformation = transform_pipeline
Anthony Larcher's avatar
Anthony Larcher committed
388
        self.min_duration = min_duration
Anthony Larcher's avatar
Anthony Larcher committed
389
390
        self.sample_rate = sample_rate
        self.sliding_window = sliding_window
Anthony Larcher's avatar
Anthony Larcher committed
391
392
        self.window_len = int(window_len * self.sample_rate)
        self.window_shift = int(window_shift * self.sample_rate)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
393
        self.transform_number = transform_number
Anthony Larcher's avatar
Anthony Larcher committed
394
395

        self.noise_df = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
396
        if "add_noise" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
397
398
            # Load the noise dataset, filter according to the duration
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
Anthony Larcher's avatar
Anthony Larcher committed
399
400
            tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
            self.noise_df = tmp_df['file_id'].tolist()
Anthony Larcher's avatar
Anthony Larcher committed
401
402

        self.rir_df = None
Anthony Larcher's avatar
debug    
Anthony Larcher committed
403
        if "add_reverb" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
404
405
406
            # load the RIR database
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
            self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
Anthony Larcher's avatar
Anthony Larcher committed
407
408
409
410
411
412
413

    def __getitem__(self, index):
        """

        :param index:
        :return:
        """
Anthony Larcher's avatar
Anthony Larcher committed
414
        # Read start and stop and convert to time in seconds
Anthony Larcher's avatar
Anthony Larcher committed
415
        if self.idmap.start[index] is None:
Anthony Larcher's avatar
Anthony Larcher committed
416
            start = 0
Anthony Larcher's avatar
debug    
Anthony Larcher committed
417
        else:
Anthony Larcher's avatar
Anthony Larcher committed
418
            start = int(self.idmap.start[index] * 0.01 * self.sample_rate)
Anthony Larcher's avatar
Anthony Larcher committed
419

Anthony Larcher's avatar
Anthony Larcher committed
420
        if self.idmap.stop[index] is None:
Gaël Le Lan's avatar
Gaël Le Lan committed
421
            nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
Anthony Larcher's avatar
Anthony Larcher committed
422
            speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
Gaël Le Lan's avatar
Gaël Le Lan committed
423
424
            if nfo.sample_rate != self.sample_rate:
                speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
425
            duration = int(speech.shape[1] - start)
Anthony Larcher's avatar
Anthony Larcher committed
426
        else:
Gaël Le Lan's avatar
Gaël Le Lan committed
427
            # TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
Gaël Le Lan's avatar
Gaël Le Lan committed
428
429
430
            nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
            assert nfo.sample_rate == self.sample_rate
            conversion_rate = nfo.sample_rate // self.sample_rate
431
            duration = (int(self.idmap.stop[index] * 0.01 * self.sample_rate) - start)
Anthony Larcher's avatar
Anthony Larcher committed
432
            # add this in case the segment is too short
Anthony Larcher's avatar
debg    
Anthony Larcher committed
433
            if duration <= self.min_duration * self.sample_rate:
Anthony Larcher's avatar
Anthony Larcher committed
434
                middle = start + duration // 2
Anthony Larcher's avatar
Anthony Larcher committed
435
436
                start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
                duration = int(self.min_duration * self.sample_rate)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
437
            speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
438
439
                                                frame_offset=start * conversion_rate,
                                                num_frames=duration * conversion_rate)
440
441
            if nfo.sample_rate != self.sample_rate:
                speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)                              
Anthony Larcher's avatar
Anthony Larcher committed
442
443
444

        speech += 10e-6 * torch.randn(speech.shape)

Anthony Larcher's avatar
Anthony Larcher committed
445
        if self.sliding_window:
Anthony Larcher's avatar
Anthony Larcher committed
446
            speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
Anthony Larcher's avatar
debug    
Anthony Larcher committed
447
448
449
450
451
452
453
454
455
456
            #middle_points = numpy.arange(start + self.window_len / 2,
            #                             start + duration - self.window_len / 2,
            #                             self.window_shift)
            #starts = middle_points - self.window_shift / 2
            #stops = middle_points + self.window_shift / 2
            #starts[0] = start
            #stops[-1] = start + duration
            #stop = stops
            #start = starts
            stop = start + duration
Anthony Larcher's avatar
Anthony Larcher committed
457
458
        else:
            stop = start + duration
Anthony Larcher's avatar
Anthony Larcher committed
459

Anthony Larcher's avatar
debug    
Anthony Larcher committed
460
        if len(self.transformation.keys()) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
461
462
            speech = data_augmentation(speech,
                                       speech_fs,
Anthony Larcher's avatar
debug    
Anthony Larcher committed
463
                                       self.transformation,
Anthony Larcher's avatar
Anthony Larcher committed
464
465
466
467
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)

Anthony Larcher's avatar
debug    
Anthony Larcher committed
468

Anthony Larcher's avatar
debug    
Anthony Larcher committed
469
        speech = speech.squeeze()
Anthony Larcher's avatar
Anthony Larcher committed
470
        
Anthony Larcher's avatar
Anthony Larcher committed
471
        return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
Anthony Larcher's avatar
Anthony Larcher committed
472
473
474
475
476
477
478
479

    def __len__(self):
        """

        :param self:
        :return:
        """
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
480
481


Anthony Larcher's avatar
Anthony Larcher committed
482
class IdMapSetPerSpeaker(Dataset):
Anthony Larcher's avatar
Anthony Larcher committed
483
484
485
486
487
488
    """
    DataSet that provide data according to a sidekit.IdMap object
    """

    def __init__(self,
                 idmap_name,
489
                 data_path,
Anthony Larcher's avatar
Anthony Larcher committed
490
                 file_extension,
Anthony Larcher's avatar
Anthony Larcher committed
491
                 transform_pipeline={},
Anthony Larcher's avatar
Anthony Larcher committed
492
                 transform_number=1,
493
                 sample_rate=16000,
Anthony Larcher's avatar
Anthony Larcher committed
494
                 min_duration=0.165
Anthony Larcher's avatar
Anthony Larcher committed
495
496
497
498
                 ):
        """

        :param idmap_name:
499
500
501
502
503
504
        :param data_root_path:
        :param file_extension:
        :param transform_pipeline:
        :param transform_number:
        :param sample_rate:
        :param min_duration:
Anthony Larcher's avatar
Anthony Larcher committed
505
506
507
508
509
510
        """
        if isinstance(idmap_name, IdMap):
            self.idmap = idmap_name
        else:
            self.idmap = IdMap(idmap_name)

511
        self.data_path = data_path
Anthony Larcher's avatar
Anthony Larcher committed
512
513
        self.file_extension = file_extension
        self.len = len(set(self.idmap.leftids))
514
515
        self.transformation = transform_pipeline
        self.transform_number = transform_number
Anthony Larcher's avatar
Anthony Larcher committed
516
        self.min_duration = min_duration
517
        self.sample_rate = sample_rate
Anthony Larcher's avatar
Anthony Larcher committed
518
        self.speaker_list = list(set(self.idmap.leftids))
Anthony Larcher's avatar
Anthony Larcher committed
519
520
521
522
523
        self.output_im = IdMap()
        self.output_im.leftids = numpy.unique(self.idmap.leftids)
        self.output_im.rightids = self.output_im.leftids
        self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
        self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
Anthony Larcher's avatar
Anthony Larcher committed
524
525

        self.noise_df = None
Anthony Larcher's avatar
Anthony Larcher committed
526
        if "add_noise" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
527
528
529
530
531
532
            # Load the noise dataset, filter according to the duration
            noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
            tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
            self.noise_df = tmp_df['file_id'].tolist()

        self.rir_df = None
Anthony Larcher's avatar
Anthony Larcher committed
533
        if "add_reverb" in self.transformation:
Anthony Larcher's avatar
Anthony Larcher committed
534
535
536
            # load the RIR database
            tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
            self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
Anthony Larcher's avatar
Anthony Larcher committed
537
538
539
540
541
542
543
544
545

    def __getitem__(self, index):
        """

        :param index:
        :return:
        """

        # Loop on all segments from the given speaker to load data
Anthony Larcher's avatar
Anthony Larcher committed
546
        spk_id = self.output_im.leftids[index]
Anthony Larcher's avatar
Anthony Larcher committed
547
        tmp_data = []
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        #nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
        for sid, seg_id, seg_start, seg_stop in zip(self.idmap.leftids, self.idmap.rightids,
                                                    self.idmap.start, self.idmap.stop):
            if sid == spk_id:

                # Read start and stop and convert to time in seconds
                if seg_start is None:
                    start = 0
                else:
                    start = int(seg_start * 0.01 * self.sample_rate)

                if seg_stop is None:
                    speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
                    duration = int(speech.shape[1] - start)
                else:
                    duration = int(seg_stop * 0.01 * self.sample_rate) - start
                    # add this in case the segment is too short
                    if duration <= self.min_duration * self.sample_rate:
                        middle = start + duration // 2
                        start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
                        duration = int(self.min_duration * self.sample_rate)

                    speech, speech_fs = torchaudio.load(f"{self.data_path}/{seg_id}.{self.file_extension}",
                                                        frame_offset=start,
                                                        num_frames=duration)

                speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
575
                tmp_data.append(speech)
Anthony Larcher's avatar
Anthony Larcher committed
576

Anthony Larcher's avatar
debug    
Anthony Larcher committed
577
        speech = torch.cat(tmp_data, dim=1)
Anthony Larcher's avatar
Anthony Larcher committed
578
        speech += 10e-6 * torch.randn(speech.shape)
Anthony Larcher's avatar
Anthony Larcher committed
579

Anthony Larcher's avatar
Anthony Larcher committed
580
        if len(self.transformation.keys()) > 0:
Anthony Larcher's avatar
Anthony Larcher committed
581
582
            speech = data_augmentation(speech,
                                       speech_fs,
Anthony Larcher's avatar
Anthony Larcher committed
583
                                       self.transformation,
Anthony Larcher's avatar
Anthony Larcher committed
584
585
586
                                       self.transform_number,
                                       noise_df=self.noise_df,
                                       rir_df=self.rir_df)
Anthony Larcher's avatar
Anthony Larcher committed
587

588
        stop = start + duration
Anthony Larcher's avatar
debug    
Anthony Larcher committed
589
590
        speech = speech.squeeze()

Anthony Larcher's avatar
Anthony Larcher committed
591
        return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
Anthony Larcher's avatar
Anthony Larcher committed
592
593
594
595

    def __len__(self):
        """

Anthony Larcher's avatar
Anthony Larcher committed
596
        :param self:
Anthony Larcher's avatar
Anthony Larcher committed
597
598
        :return:
        """
Anthony Larcher's avatar
debug    
Anthony Larcher committed
599
        return self.len
Anthony Larcher's avatar
Anthony Larcher committed
600