Commit f0ad2a1d authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent c99cba01
......@@ -433,7 +433,12 @@ if has_pyroom:
return data, sample[1], sample[2], sample[3] , sample[4], sample[5]
def data_augmentation(speech, sample_rate, transform_dict, transform_number, noise_df=None, rir_df=None):
def data_augmentation(speech,
sample_rate,
transform_dict,
transform_number,
noise_df=None,
rir_df=None):
"""
:param speech:
......
......@@ -251,7 +251,8 @@ class ArcMarginProduct(torch.nn.Module):
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, target=None):
assert input.size()[0] == target.size()[0]
if target is not None:
assert input.size()[0] == target.size()[0]
assert input.size()[1] == self.in_features
# cos(theta)
......
......@@ -292,7 +292,8 @@ class IdMapSet(Dataset):
idmap_name,
data_path,
file_extension,
transform_pipeline="",
transform_pipeline={},
transform_number=1,
sliding_window=False,
window_len=24000,
window_shift=8000,
......@@ -318,20 +319,21 @@ class IdMapSet(Dataset):
self.sliding_window = sliding_window
self.window_len = window_len
self.window_shift = window_shift
self.transform_number = transform_number
self.transform = []
if self.transformation is not None:
self.transform_list = self.transformation.split(",")
#if self.transformation is not None:
# self.transform_list = self.transformation.split(",")
self.noise_df = None
if "add_noise" in self.transform:
if "add_noise" in self.transformation:
# Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = tmp_df['file_id'].tolist()
#tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = noise_df.set_index(noise_df.type)
self.rir_df = None
if "add_reverb" in self.transform:
if "add_reverb" in self.transformation:
# load the RIR database
tmp_rir_df = pandas.read_csv(self.transformation["add_reverb"]["rir_db_csv"])
self.rir_df = zip(tmp_rir_df['file_id'].tolist(), tmp_rir_df['channel'].tolist())
......@@ -344,18 +346,19 @@ class IdMapSet(Dataset):
"""
if self.idmap.start[index] is None:
start = 0
else:
start = int(self.idmap.start[index]) * 160
if self.idmap.stop[index] is None:
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
duration = speech.shape[1] - start
duration = int(speech.shape[1] - start)
else:
start = int(self.idmap.start[index])
duration = int(self.idmap.stop[index]) - start
duration = int(self.idmap.stop[index]) * 160 - start
# add this in case the segment is too short
if duration <= self.min_sample_nb:
middle = start + duration // 2
start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = self.min_sample_nb
duration = int(self.min_sample_nb)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
......@@ -366,10 +369,10 @@ class IdMapSet(Dataset):
if self.sliding_window:
speech = speech.squeeze().unfold(0,self.window_len,self.window_shift)
if len(self.transform) > 0:
if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech,
speech_fs,
self.transform,
self.transformation,
self.transform_number,
noise_df=self.noise_df,
rir_df=self.rir_df)
......
......@@ -823,7 +823,7 @@ def xtrain(speaker_number,
else:
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = Xtractor(speaker_number, model_yaml)
model = Xtractor(speaker_number, model_yaml, loss=loss)
"""
Here we remove all layers that we don't want to reload
......@@ -1216,7 +1216,7 @@ def cross_validation(model, validation_loader, device, validation_shape, mask, m
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
if loss_criteria == 'aam':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
batch_predictions, batch_embeddings = model(data, target=target, is_eval=False)
elif loss_criteria == 'aps':
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
else:
......@@ -1247,6 +1247,7 @@ def extract_embeddings(idmap_name,
model_filename,
data_root_name,
device,
loss,
file_extension="wav",
transform_pipeline="",
frame_shift=0.01,
......@@ -1276,7 +1277,7 @@ def extract_embeddings(idmap_name,
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model = Xtractor(speaker_number, model_archi=model_archi, loss=loss)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
......@@ -1342,6 +1343,7 @@ def extract_embeddings(idmap_name,
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
with torch.cuda.amp.autocast(enabled=mixed_precision):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment