Commit 7484699c authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

tentative sample_rate bugfix

parent 237c3760
......@@ -468,6 +468,7 @@ def data_augmentation(speech, sample_rate, transform_dict, transform_number, noi
rir_nfo = rir_df.iloc[random.randrange(rir_df.shape[0])].file_id
rir_fn = transform_dict["add_reverb"]["data_path"] + "/" + rir_nfo + ".wav"
rir, rir_fs = torchaudio.load(rir_fn)
assert rir_fs == sample_rate
#rir = rir[rir_nfo[1], :] #keep selected channel
speech = torch.tensor(signal.convolve(speech, rir, mode='full')[:, :speech.shape[1]])
......@@ -544,10 +545,10 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
noise_fn = data_path + "/" + noise_file_id + ".wav"
if noise_duration * sample_rate > speech_shape[1]:
noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1]))
noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1]))
else:
noise_seg, _ = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate))
assert noise_sr == sample_rate
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
......
......@@ -275,7 +275,7 @@ class SideSet(Dataset):
current_session = self.sessions.iloc[index]
# TODO is this required ?
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
nfo = torchaudio.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
original_start = int(current_session['start'])
if self.overlap > 0:
lowest_shift = self.overlap/2
......@@ -288,18 +288,23 @@ class SideSet(Dataset):
else:
start_frame = original_start
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
conversion_rate = nfo.sample_rate // self.sample_rate
if start_frame + conversion_rate * self.sample_number >= nfo.num_frames:
start_frame = numpy.min(nfo.num_frames - conversion_rate * self.sample_number - 1)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
frame_offset=start_frame,
num_frames=self.sample_number)
frame_offset=conversion_rate*start_frame,
num_frames=conversion_rate*self.sample_number)
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
speech += 10e-6 * torch.randn(speech.shape)
if len(self.transform) > 0:
speech = data_augmentation(speech,
speech_fs,
self.sample_rate,
self.transform,
self.transform_number,
noise_df=self.noise_df,
......@@ -389,26 +394,30 @@ class IdMapSet(Dataset):
start = 0.0
if self.idmap.start[index] is None and self.idmap.stop[index] is None:
#speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
speech, speech_fs = get_sample(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}", resample=16000)
#speech, speech_fs = get_sample(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}", resample=self.sample_rate)
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
start = 0
stop = len(speech)
stop = speech.shape[1]
else:
# TODO Check if that code is still relevant with torchaudio.load() in case of sample_rate mismatch
nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
conversion_rate = nfo.samplerate // self.sample_rate
nfo = torchaudio.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
assert nfo.sample_rate == self.sample_rate
conversion_rate = nfo.sample_rate // self.sample_rate
start = int(self.idmap.start[index]) * conversion_rate
stop = int(self.idmap.stop[index]) * conversion_rate
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
if stop - start <= self.min_duration * nfo.sample_rate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
start = max(0, int(middle - (self.min_duration * nfo.sample_rate / 2)))
stop = int(start + self.min_duration * nfo.sample_rate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=stop - start)
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
speech += 10e-6 * torch.randn(speech.shape)
......
......@@ -1480,7 +1480,7 @@ def extract_embeddings(idmap_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
frame_rate=16000,
min_duration=(model_cs + 2) * frame_shift * 2,
backward=backward
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment