Commit e30b391e authored by Hubert Nourtel's avatar Hubert Nourtel
Browse files

Adding use of whole audio file instead of fixed size duration

parent a979bb2d
......@@ -35,6 +35,7 @@ def build():
feature_size = self.preprocessor.n_mfcc
self.loss = loss
self.speaker_number = speaker_number
self.param_device_detection = nn.Parameter(torch.empty(0)) # Empty parameter used to detect model device location
self.sequence_network = nn.Sequential(
OrderedDict(
......@@ -112,7 +113,7 @@ def build():
- the x-vector embedding
i.e., (loss, cce), x_vector = model([...])
"""
x = args["speech"]
x = args["speech"].to(self.param_device_detection.device)
x = x.squeeze(1)
x = self.preprocessor(x)
x = self.sequence_network(x)
......@@ -261,7 +262,6 @@ def get_data_loading_hook(sessions):
# This hook is exectued during dataloading (Done by the CPU in parallel)
def _hook(speech, csv_line, file_ext):
if speech.ndim == 1:
speech = speech.unsqueeze(0)
# print(speech.shape, csv_line, file_ext)
......@@ -276,7 +276,27 @@ def get_data_loading_hook(sessions):
# Fake emotion anontation
n_emo = 5
indice = torch.randint(0, 5, size=(1,))[0] # (Either 0,1,2,3,4)
args["emotion"] = indice # fake emotion anontation
args["emotion"] = indice # fake emotion annotation
return args
return _hook
# Custom data collate for padding with zeroes
# when the whole audio file is considered
def collate_hook(batch):
data_speech_list, data_f0_list, data_emotion_list, target_spk_list = [], [], [], []
# Extract data from batch
for data, target in batch:
data_speech_list.append(data["speech"].squeeze(0))
data_f0_list.append(data["F0"].squeeze(0))
data_emotion_list.append(data["emotion"])
target_spk_list.append(target)
# Pad tensors lists if required and construct output data
out_speech = nn.utils.rnn.pad_sequence(data_speech_list, batch_first=True, padding_value=0.0)
out_f0 = nn.utils.rnn.pad_sequence(data_f0_list, batch_first=True, padding_value=0.0)
out_data_dict = {"speech": out_speech.unsqueeze(1), "F0": out_f0.unsqueeze(1), "emotion": torch.tensor(data_emotion_list)}
out_target = torch.tensor(target_spk_list)
return out_data_dict, out_target
......@@ -9,10 +9,12 @@ loss:
# Warning, this hook is experimental, it is broking some other scripts (extract_xvectors.py, scoring..)
data_loading_hook: ./config/custom/model.py
# Hook to use a custom collate when selected duration is -1
collate_hook: ./config/custom/model.py
# Initialize model from file, reset and freeze parts of it
initial_model_name:
reset_parts: []
freeze_parts: []
model_type: ./config/custom/model.py
model_type: ./config/custom/model.py
\ No newline at end of file
......@@ -181,6 +181,7 @@ def update_training_dictionary(
model_opts["loss"]["aam_s"] = 30
model_opts["data_loading_hook"] = None
model_opts["collate_hook"] = None
model_opts["initial_model_name"] = None
model_opts["reset_parts"] = []
......
......@@ -225,55 +225,60 @@ class SideSet(Dataset):
assert isinstance(dataset_df, pandas.DataFrame)
df = dataset_df
# From each segment which duration is longer than the chosen one
# select the requested segments
if set_type == "train":
tmp_sessions = df.loc[df['duration'] > self.duration]
if self.duration == -1:
# Duration is set to -1, select the whole audio
self.sessions = df
self.len = len(self.sessions)
self.initial_len = len(self.sessions)
else:
if not "duration" == '':
# From each segment which duration is longer than the chosen one
# select the requested segments
if set_type == "train":
tmp_sessions = df.loc[df['duration'] > self.duration]
else:
self.sessions = df
# Create lists for each column of the dataframe
df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
df_dict["file_start"] = list()
df_dict["file_duration"] = list()
# For each segment, get all possible segments with the current overlap
for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
current_session = tmp_sessions.iloc[idx]
# Compute possible starts
possible_starts = numpy.arange(0,
int(self.sample_rate * (current_session.duration - self.duration)),
self.sample_number
) + int(self.sample_rate * (current_session.duration % self.duration / 2))
possible_starts += int(self.sample_rate * current_session.start)
# Select max(seg_nb, possible_segments) segments
if chunk_per_segment == -1:
starts = possible_starts
chunk_nb = len(possible_starts)
else:
chunk_nb = min(len(possible_starts), chunk_per_segment)
starts = numpy.random.permutation(possible_starts)[:chunk_nb]
# Once we know how many segments are selected, create the other fields to fill the DataFrame
for ii in range(chunk_nb):
df_dict["database"].append(current_session.database)
df_dict["speaker_id"].append(current_session.speaker_id)
df_dict["file_id"].append(current_session.file_id)
df_dict["start"].append(starts[ii])
df_dict["duration"].append(self.duration)
df_dict["file_start"].append(current_session.start)
df_dict["file_duration"].append(current_session.duration)
df_dict["speaker_idx"].append(current_session.speaker_idx)
df_dict["gender"].append(current_session.gender)
self.sessions = pandas.DataFrame.from_dict(df_dict)
self.len = len(self.sessions)
self.initial_len = len(tmp_sessions)
if not "duration" == '':
tmp_sessions = df.loc[df['duration'] > self.duration]
else:
self.sessions = df
# Create lists for each column of the dataframe
df_dict = dict(zip(df.columns, [[], [], [], [], [], [], []]))
df_dict["file_start"] = list()
df_dict["file_duration"] = list()
# For each segment, get all possible segments with the current overlap
for idx in tqdm.trange(len(tmp_sessions), desc='indexing all ' + set_type + ' segments', mininterval=1, disable=None):
current_session = tmp_sessions.iloc[idx]
# Compute possible starts
possible_starts = numpy.arange(0,
int(self.sample_rate * (current_session.duration - self.duration)),
self.sample_number
) + int(self.sample_rate * (current_session.duration % self.duration / 2))
possible_starts += int(self.sample_rate * current_session.start)
# Select max(seg_nb, possible_segments) segments
if chunk_per_segment == -1:
starts = possible_starts
chunk_nb = len(possible_starts)
else:
chunk_nb = min(len(possible_starts), chunk_per_segment)
starts = numpy.random.permutation(possible_starts)[:chunk_nb]
# Once we know how many segments are selected, create the other fields to fill the DataFrame
for ii in range(chunk_nb):
df_dict["database"].append(current_session.database)
df_dict["speaker_id"].append(current_session.speaker_id)
df_dict["file_id"].append(current_session.file_id)
df_dict["start"].append(starts[ii])
df_dict["duration"].append(self.duration)
df_dict["file_start"].append(current_session.start)
df_dict["file_duration"].append(current_session.duration)
df_dict["speaker_idx"].append(current_session.speaker_idx)
df_dict["gender"].append(current_session.gender)
self.sessions = pandas.DataFrame.from_dict(df_dict)
self.len = len(self.sessions)
self.initial_len = len(tmp_sessions)
self.transform = dict()
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
......@@ -323,26 +328,32 @@ class SideSet(Dataset):
# TODO is this required ?
nfo = torchaudio.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
original_start = int(current_session['start'])
if self.overlap > 0:
lowest_shift = self.overlap/2
highest_shift = self.overlap/2
if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
else:
start_frame = original_start
conversion_rate = nfo.sample_rate // self.sample_rate
if start_frame + conversion_rate * self.sample_number >= nfo.num_frames:
start_frame = numpy.min(nfo.num_frames - conversion_rate * self.sample_number - 1)
if self.duration == -1:
frame_offset = 0
sample_number = int((nfo.num_frames/nfo.sample_rate) * self.sample_rate)
else:
original_start = int(current_session['start'])
if self.overlap > 0:
lowest_shift = self.overlap/2
highest_shift = self.overlap/2
if original_start < (current_session['file_start']*self.sample_rate + self.sample_number/2):
lowest_shift = int(original_start - current_session['file_start']*self.sample_rate)
if original_start + self.sample_number > (current_session['file_start'] + current_session['file_duration'])*self.sample_rate - self.sample_number/2:
highest_shift = int((current_session['file_start'] + current_session['file_duration'])*self.sample_rate - (original_start + self.sample_number))
start_frame = original_start + int(random.uniform(-lowest_shift, highest_shift))
else:
start_frame = original_start
if start_frame + conversion_rate * self.sample_number >= nfo.num_frames:
start_frame = numpy.min(nfo.num_frames - conversion_rate * self.sample_number - 1)
frame_offset = conversion_rate*start_frame
sample_number = self.sample_number
speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
frame_offset=conversion_rate*start_frame,
num_frames=conversion_rate*self.sample_number)
frame_offset=frame_offset,
num_frames=conversion_rate*sample_number)
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
......
......@@ -635,7 +635,18 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
num_process=1,
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
if dataset_opts["train"]["duration"] == -1:
# Use custom collate
collate_hook_file = model_opts["collate_hook"]
if not os.path.isfile(collate_hook_file):
raise FileNotFoundError(f"File: '{collate_hook_file}' doesn't exist")
spec = importlib.util.spec_from_file_location(collate_hook_file, collate_hook_file)
hook_file = importlib.util.module_from_spec(spec)
spec.loader.exec_module(hook_file)
collate_fn = hook_file.collate_hook
else:
# Use default collate
collate_fn = None
training_loader = DataLoader(training_set,
batch_size=batch_size * dataset_opts["train"]["sampler"]["augmentation_replica"],
shuffle=training_opts["sampling"] == 'whole',
......@@ -644,7 +655,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
sampler=None if training_opts["sampling"] == 'whole' else side_sampler,
num_workers=training_opts["num_cpu"],
persistent_workers=False,
worker_init_fn=utils.seed_worker)
worker_init_fn=utils.seed_worker,
collate_fn=collate_fn)
validation_loader = DataLoader(validation_set,
batch_size=batch_size,
......@@ -652,7 +664,8 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
pin_memory=True,
num_workers=training_opts["num_cpu"],
persistent_workers=False,
worker_init_fn=utils.seed_worker)
worker_init_fn=utils.seed_worker,
collate_fn=collate_fn)
# Compute indices for target and non-target trials once only to avoid recomputing for each epoch
classes = torch.ShortTensor(validation_set.sessions['speaker_idx'].to_numpy())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment