Commit 8e8fb525 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

update anthony

parent 88f4d2b9
......@@ -429,7 +429,7 @@ class IdMapSet(Dataset):
frame_offset=start,
num_frames=duration)
speech += 10e-6 * torch.randn(speech.shape)
#speech += 10e-6 * torch.randn(speech.shape)
if self.sliding_window:
speech = speech.squeeze().unfold(0, self.window_len, self.window_shift)
......@@ -564,7 +564,7 @@ class IdMapSetPerSpeaker(Dataset):
tmp_data.append(speech)
speech = torch.cat(tmp_data, dim=1)
speech += 10e-6 * torch.randn(speech.shape)
#speech += 10e-6 * torch.randn(speech.shape)
if len(self.transformation.keys()) > 0:
speech = data_augmentation(speech,
......
......@@ -978,6 +978,8 @@ def update_training_dictionary(dataset_description,
fill_dict(model_opts, tmp_model_dict)
fill_dict(training_opts, tmp_train_dict)
print(model_opts)
# Overwrite with manually given parameters
if "lr" in kwargs:
training_opts["lr"] = kwargs['lr']
......@@ -1010,26 +1012,28 @@ def get_network(model_opts, local_rank):
model = Xtractor(model_opts["speaker_number"], model_opts, loss=model_opts["loss"]["type"], embedding_size=model_opts["embedding_size"])
# Load the model if it exists
if model_opts["initial_model_name"] is not None and os.path.isfile(model_opts["initial_model_name"]):
logging.critical(f"*** Load model from = {model_opts['initial_model_name']}")
checkpoint = torch.load(model_opts["initial_model_name"])
if model_opts["initial_model_name"] is not None:
if os.path.isfile(model_opts["initial_model_name"]):
print(f"model_opts['initial_model_name'] = {model_opts['initial_model_name']} et os.path.isfile(model_opts['initial_model_name']): {os.path.isfile(model_opts['initial_model_name'])}")
logging.critical(f"*** Load model from = {model_opts['initial_model_name']}")
checkpoint = torch.load(model_opts["initial_model_name"], map_location={"cuda:0" : "cuda:%d" % local_rank})
"""
Here we remove all layers that we don't want to reload
"""
Here we remove all layers that we don't want to reload
"""
pretrained_dict = checkpoint["model_state_dict"]
for part in model_opts["reset_parts"]:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
"""
pretrained_dict = checkpoint["model_state_dict"]
for part in model_opts["reset_parts"]:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith(part)}
new_model_dict = model.state_dict()
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
new_model_dict = model.state_dict()
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict)
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False
# Freeze required layers
for name, param in model.named_parameters():
if name.split(".")[0] in model_opts["reset_parts"]:
param.requires_grad = False
if model_opts["loss"]["type"] == "aam" and not (model_opts["loss"]["aam_margin"] == 0.2 and model_opts["loss"]["aam_s"] == 30):
model.after_speaker_embedding.change_params(model_opts["loss"]["aam_s"], model_opts["loss"]["aam_margin"])
......@@ -1755,6 +1759,7 @@ def extract_embeddings(idmap_name,
modelset= []
segset = []
starts = []
stops = []
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
......@@ -1774,15 +1779,20 @@ def extract_embeddings(idmap_name,
if sliding_window:
tmp_start = numpy.arange(0, data.shape[0] * win_shift, win_shift)
starts.extend(tmp_start * sample_rate + start.detach().cpu().numpy())
win_duration = int(len(tmp_data))
else:
starts.append(start.numpy())
stops.append(tmp_data[0].shape[1])
embeddings = StatServer()
embeddings.stat1 = numpy.concatenate(embed)
embeddings.modelset = numpy.array(modelset).astype('>U')
embeddings.segset = numpy.array(segset).astype('>U')
embeddings.start = numpy.array(starts).squeeze()
embeddings.stop = embeddings.start + win_duration
if sliding_window:
embeddings.stop = embeddings.start + win_duration
else:
embeddings.stop = embeddings.start + numpy.array(stops).squeeze()
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
return embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment