Commit 0bbb111a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

speed up sideset

parent a72d950a
......@@ -29,6 +29,7 @@ Copyright 2014-2021 Anthony Larcher
import glob
import h5py
import numpy
import multiprocessing
import pandas
import os
import pickle
......@@ -110,6 +111,53 @@ def read_batch(batch_file):
class SegSelection(Dataset):
def __init__(sessions, sample_rate, duration, sample_number, overlap)
self.sessions = sessions
self.sample_rate = sample_rate
self.duration = duration
self.sample_number = sample_number
self.overlap = overlap
self.len = len(sessions)
def __getitem__(self, index):
# Compute possible starts
possible_starts = numpy.arange(0,
int(self.sample_rate * (sessions.iloc[index].duration - self.duration)),
self.sample_number - int(self.sample_rate * self.overlap)
)
possible_starts += int(self.sample_rate * sessions.iloc[index].start)
# Select max(seg_nb, possible_segments) segments
if chunk_per_segment == -1:
starts = possible_starts
chunk_nb = len(possible_starts)
else:
chunk_nb = min(len(possible_starts), chunk_per_segment)
starts = numpy.random.permutation(possible_starts)[:chunk_nb] / self.sample_rate
# On renvoie des listes:
seg_database = [sessions.iloc[index].database,] * chunk_nb
seg_speaker_id = [df.iloc[index].speaker_id,] * chunk_nb
seg_file_id = [df.iloc[index].file_id,] * chunk_nb
seg_start = starts.tolist()
seg_duration = [self.duration,] * chunk_nb
seg_speaker_idx = [df.iloc[index].speaker_idx,] * chunk_nb
seg_gender = [df.iloc[index].gender,] * chunk_nb
return seg_database, seg_speaker_id, seg_file_id, seg_start, seg_duration, seg_speaker_idx, seg_gender
def __len__(self):
"""
:param self:
:return:
"""
return self.len
class XvectorDataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
......@@ -421,14 +469,15 @@ class SideSet(Dataset):
# Create lists for each column of the dataframe
df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
"""
# For each segment, get all possible segments with the current overlap
for idx in tqdm.trange(len(tmp_sessions)):
# Compute possible starts
possible_starts = numpy.arange(0,
int(self.sample_rate * (df.iloc[idx].duration - self.duration)),
int(self.sample_rate * (tmp_sessions.iloc[idx].duration - self.duration)),
self.sample_number - int(self.sample_rate * overlap)
)
possible_starts += int(self.sample_rate * df.iloc[idx].start)
possible_starts += int(self.sample_rate * tmp_sessions.iloc[idx].start)
# Select max(seg_nb, possible_segments) segments
if chunk_per_segment == -1:
......@@ -440,13 +489,33 @@ class SideSet(Dataset):
# Once we know how many segments are selected, create the other fields to fill the DataFrame
for ii in range(chunk_nb):
df_dict["database"].append(df.iloc[idx].database)
df_dict["speaker_id"].append(df.iloc[idx].speaker_id)
df_dict["file_id"].append(df.iloc[idx].file_id)
df_dict["database"].append(tmp_sessions.iloc[idx].database)
df_dict["speaker_id"].append(tmp_sessions.iloc[idx].speaker_id)
df_dict["file_id"].append(tmp_sessions.iloc[idx].file_id)
df_dict["start"].append(starts[ii])
df_dict["duration"].append(self.duration)
df_dict["speaker_idx"].append(df.iloc[idx].speaker_idx)
df_dict["gender"].append(df.iloc[idx].gender)
df_dict["speaker_idx"].append(tmp_sessions.iloc[idx].speaker_idx)
df_dict["gender"].append(tmp_sessions.iloc[idx].gender)
"""
"""
New parallel version of segment selection
"""
segset = SegSelection(tmp_sessions, self.sample_rate, self.duration, self.sample_number, self.overlap)
num_thread = multiprocessing.cpu_count()
segloader = DataLoader(segset,
batch_size=1,
drop_last=False,
pin_memory=True,
num_workers=num_thread)
for seg_database, seg_speaker_id, seg_file_id, seg_start, seg_duration, seg_speaker_idx, seg_gender in enumerate(segloader):
df_dict["database"] += seg_database
df_dict["speaker_id"] += seg_speaker_id
df_dict["file_id"] += seg_file_id
df_dict["start"] += seg_start
df_dict["duration"] += seg_duration
df_dict["speaker_idx"] += seg_speaker_idx
df_dict["gender"] += seg_gender
self.sessions = pandas.DataFrame.from_dict(df_dict)
self.len = len(self.sessions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment