Commit 75570a57 authored by Anthony Larcher's avatar Anthony Larcher
Browse files
parents e7b00592 48f2656a
......@@ -30,6 +30,7 @@ import pdb
import traceback
import logging
import matplotlib.pyplot as plt
import multiprocessing
import numpy
import pandas
import pickle
......@@ -571,22 +572,22 @@ class Xtractor(torch.nn.Module):
def xtrain(speaker_number,
dataset_yaml,
epochs=100,
lr=0.01,
epochs=None,
lr=None,
model_yaml=None,
model_name=None,
loss="cce",
aam_margin=0.5,
aam_s=30,
patience=10,
loss=None,
aam_margin=None,
aam_s=None,
patience=None,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
clipping=False,
opt='sgd',
opt=None,
reset_parts=[],
freeze_parts=[],
num_thread=1):
num_thread=None):
"""
:param speaker_number:
......@@ -614,6 +615,9 @@ def xtrain(speaker_number,
#writer = SummaryWriter("runs/xvectors_experiments_2")
writer = None
if num_thread is None:
num_thread = multiprocessing.cpu_count()
logging.critical(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
......@@ -786,7 +790,9 @@ def xtrain(speaker_number,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
'scheduler': scheduler,
'speaker_number' : speaker_number,
'model_archi': model_yaml
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
......@@ -794,7 +800,9 @@ def xtrain(speaker_number,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
'scheduler': scheduler,
'speaker_number': speaker_number,
'model_archi': model_yaml
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
......@@ -924,11 +932,11 @@ def cross_validation(model, validation_loader, device):
def extract_embeddings(idmap_name,
speaker_number,
model_filename,
model_yaml,
data_root_name ,
data_root_name,
device,
model_yaml=None,
speaker_number=None,
file_extension="wav",
transform_pipeline=None,
frame_shift=0.01,
......@@ -937,6 +945,10 @@ def extract_embeddings(idmap_name,
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
if speaker_number is None:
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
model_yaml = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_yaml)
model.load_state_dict(checkpoint["model_state_dict"])
else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment