Commit c1f8779c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

IdmapSet

parent 006b1957
......@@ -501,3 +501,37 @@ class SideSet(Dataset):
:return:
"""
return self.len
class IdMapSet(Dataset):
"""
DataSet that provide data according to a sidekit.IdMap object
"""
def __init__(self, idmap_name, data_root_path, file_extension):
"""
:param data_root_name:
:param idmap_name:
"""
self.idmap = sidekit.IdMap(idmap_name)
self.data_root_path = data_root_path
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
def __getitem__(self, index):
"""
:param index:
:return:
"""
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
return sig, self.idmap.leftids[index], self.idmap.rightids[index]
def __len__(self):
"""
:param self:
:return:
"""
return self.len
......@@ -38,6 +38,7 @@ import yaml
from torchvision import transforms
from collections import OrderedDict
from .xsets import XvectorMultiDataset, StatDataset, VoxDataset, SideSet
from .xsets import IdMapSet
from .xsets import FrequencyMask, CMVN, TemporalMask, MFCC
from ..bosaris import IdMap
from ..statserver import StatServer
......@@ -54,15 +55,25 @@ __status__ = "Production"
__docformat__ = 'reS'
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def get_lr(optimizer):
"""
:param optimizer:
:return:
"""
for param_group in optimizer.param_groups:
return param_group['lr']
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
......@@ -199,7 +210,8 @@ class Xtractor(torch.nn.Module):
if cfg["before_embedding"][k]["output"] == "speaker_number":
before_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
before_embedding_layers.append((k, torch.nn.Linear(input_size, cfg["before_embedding"][k]["output"])))
before_embedding_layers.append((k, torch.nn.Linear(input_size,
cfg["before_embedding"][k]["output"])))
input_size = cfg["before_embedding"][k]["output"]
elif k.startswith("activation"):
......@@ -221,7 +233,8 @@ class Xtractor(torch.nn.Module):
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
after_embedding_layers.append((k, torch.nn.Linear(input_size, cfg["after_embedding"][k]["output"])))
after_embedding_layers.append((k, torch.nn.Linear(input_size,
cfg["after_embedding"][k]["output"])))
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith("activation"):
......@@ -240,6 +253,7 @@ class Xtractor(torch.nn.Module):
"""
:param x:
:param is_eval:
:return:
"""
if self.preprocessor is not None:
......@@ -260,8 +274,6 @@ class Xtractor(torch.nn.Module):
return x
def xtrain(speaker_number,
dataset_yaml,
epochs=100,
......@@ -274,21 +286,30 @@ def xtrain(speaker_number,
clipping=False,
num_thread=1):
"""
Initialize and train an x-vector on a single GPU
:param args:
:param speaker_number:
:param dataset_yaml:
:param epochs:
:param lr:
:param model_yaml:
:param model_name:
:param tmp_model_name:
:param best_model_name:
:param multi_gpu:
:param clipping:
:param num_thread:
:return:
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# If we start from an existing model
#if model_name is not None:
# if model_name is not None:
# # Load the model
# logging.critical(f"*** Load model from = {model_name}")
# checkpoint = torch.load(model_name)
# model = Xtractor(speaker_number, model_yaml)
# model.load_state_dict(checkpoint["model_state_dict"])
#else:
# else:
if True:
# Initialize a first model
if model_yaml is None:
......@@ -358,9 +379,16 @@ def xtrain(speaker_number,
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_accuracy = 0.0
best_accuracy_epoch = 1
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
model = train_epoch(model, epoch, training_loader, optimizer, dataset_params["log_interval"], device=device, clipping=clipping)
model = train_epoch(model,
epoch,
training_loader,
optimizer,
dataset_params["log_interval"],
device=device,
clipping=clipping)
# Add the cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
......@@ -374,28 +402,39 @@ def xtrain(speaker_number,
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename = tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if type(model) is Xtractor:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
best_accuracy_epoch = epoch
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device, clipping=False):
"""
:param model:
:param epoch:
:param train_seg_df:
:param speaker_dict:
:param training_loader:
:param optimizer:
:param args:
:param log_interval:
:param device:
:param clipping:
:return:
"""
model.train()
......@@ -417,8 +456,8 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
return model
......@@ -444,13 +483,56 @@ def cross_validation(model, validation_loader, device):
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
def extract_embeddings(idmap, model_filename, model_yaml, data_root_name , device):
# Create dataset to load the data
dataset = IdMapSet(data_root_name, idmap_name)
# Load the model
checkpoint = torch.load(model_filename)
model = Xtractor(speaker_number, model_archi=model_yaml)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(device)
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = sidekit.StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
def extract_idmap(args, device_id, segment_indices, fs_params, idmap_name, output_queue):
# Process the data
with torch.no_grad():
for idx, (data, mod, seg) in tqdm(enumerate(dataset)):
vec = model(data.to(device), is_eval=True)
current_idx = numpy.argwhere(numpy.logical_and(im.leftids == mod, im.rightids == seg))[0][0]
embeddings.stat1[current_idx, :] = vec.detach().cpu()
return embeddings
def extract_idmap(args, segment_indices, fs_params, idmap_name, output_queue):
"""
Function that takes a model and an idmap and extract all x-vectors based on this model
and return a StatServer containing the x-vectors
:param args:
:param segment_indices:
:param fs_params:
:param idmap_name:
:param output_queue:
:return:
"""
# device = torch.device("cuda:{}".format(device_ID))
device = torch.device('cpu')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment