Commit 9a6ce925 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug transformation

parent 754a4f9c
......@@ -32,6 +32,8 @@ from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import IdMapSetPerSpeaker
from .xsets import SideSet
from .xsets import SideSampler
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling
from .res_net import ResBlock, PreResNet34
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
......
......@@ -164,7 +164,7 @@ class SideSet(Dataset):
self.transformation = dataset["train"]["transformation"]
else:
self.duration = dataset["eval"]["duration"]
self.transformation = dataset["eval"]["transformation"]
self.transformation = dataset["eval"]["transformation"]
self.sample_number = int(self.duration * self.sample_rate)
......@@ -312,9 +312,11 @@ class IdMapSet(Dataset):
self.transform = []
if (len(self.transformation) > 0):
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
#if (len(self.transformation) > 0):
# if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
# self.transform_list = self.transformation["pipeline"].split(',')
if self.transformation is not None:
self.transform_list = self.transformation.split(",")
self.noise_df = None
if "add_noise" in self.transform:
......@@ -344,16 +346,14 @@ class IdMapSet(Dataset):
stop = len(speech)
else:
nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
conversion_rate = nfo.samplerate // self.sample_rate
start = int(self.idmap.start[index]) * conversion_rate
stop = int(self.idmap.stop[index]) * conversion_rate
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=stop - start)
......@@ -420,9 +420,11 @@ class IdMapSetPerSpeaker(Dataset):
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.transform = []
if (len(self.transformation) > 0):
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
#if (len(self.transformation) > 0):
# if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
# self.transform_list = self.transformation["pipeline"].split(',')
if self.transformation is not None:
self.transform_list = self.transformation.split(",")
self.noise_df = None
if "add_noise" in self.transform:
......
......@@ -224,12 +224,11 @@ def test_metrics(model,
idmap_test_filename = 'h5f/idmap_test.h5'
ndx_test_filename = 'h5f/ndx_test.h5'
key_test_filename = 'h5f/key_test.h5'
data_root_name='/data/larcher/voxceleb1/test/wav'
data_root_name='/lium/scratch/larcher/voxceleb1/test/wav'
transform_pipeline = dict()
xv_stat = extract_embeddings(idmap_name=idmap_test_filename,
speaker_number=speaker_number,
model_filename=model,
data_root_name=data_root_name,
device=device,
......@@ -432,7 +431,7 @@ class Xtractor(torch.nn.Module):
elif model_archi == "fastresnet34":
self.preprocessor = MelSpecFrontEnd()
self.preprocessor = MelSpecFrontEnd(n_mels=80)
self.sequence_network = PreFastResNet34()
self.embedding_size = 256
......@@ -446,7 +445,7 @@ class Xtractor(torch.nn.Module):
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
m = 0.4,
easy_margin = False)
self.preprocessor_weight_decay = 0.000
......@@ -833,6 +832,10 @@ def xtrain(speaker_number,
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
print("Modifiy margin: 0.4")
model.after_speaker_embedding.m = 0.4
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in freeze_parts:
......@@ -995,8 +998,8 @@ def xtrain(speaker_number,
optimizer = _optimizer(param_list, **_options)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=10 * training_loader.__len__(),
gamma=0.95)
step_size=20 * training_loader.__len__(),
gamma=0.5)
if mixed_precision:
scaler = torch.cuda.amp.GradScaler()
......@@ -1396,7 +1399,6 @@ def extract_embeddings_per_speaker(idmap_name,
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
print(f"Shape of data: {data.shape}")
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment