Commit daf472ca authored by Anthony Larcher's avatar Anthony Larcher
Browse files

extract_embedding

parent e3a2b75a
......@@ -540,7 +540,7 @@ class IdMapSet(Dataset):
DataSet that provide data according to a sidekit.IdMap object
"""
def __init__(self, idmap_name, data_root_path, file_extension):
def __init__(self, idmap_name, data_root_path, file_extension, transform_pipeline=None):
"""
:param data_root_name:
......@@ -550,6 +550,19 @@ class IdMapSet(Dataset):
self.data_root_path = data_root_path
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
self.transform_pipeline = transform_pipeline
_transform = []
if transform_pipeline is not None:
trans = transform_pipeline.split(",")
for t in trans:
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'MFCC' in t:
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(CMVN())
self.transforms = transforms.Compose(_transform)
def __getitem__(self, index):
"""
......@@ -558,7 +571,12 @@ class IdMapSet(Dataset):
:return:
"""
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
return sig, self.idmap.leftids[index], self.idmap.rightids[index]
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline is not None:
sig, _, ___, _____ = self.transforms((sig, 0, 0, 0))
return torch.from_numpy(sig).type(torch.FloatTensor), self.idmap.leftids[index], self.idmap.rightids[index]
def __len__(self):
"""
......
# -*- coding: utf-8 -*-
# coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
......@@ -45,6 +45,7 @@ from ..statserver import StatServer
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .sincnet import SincNet
from tqdm import tqdm
__license__ = "LGPL"
__author__ = "Anthony Larcher"
......@@ -339,8 +340,8 @@ def xtrain(speaker_number,
training_set = SideSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['chunk_per_segment'],
overlap=dataset_params['overlap'])
chunk_per_segment=dataset_params['train']['chunk_per_segment'],
overlap=dataset_params['train']['overlap'])
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
shuffle=True,
......@@ -487,10 +488,10 @@ def cross_validation(model, validation_loader, device):
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
def extract_embeddings(idmap, speaker_number, model_filename, model_yaml, data_root_name , device):
def extract_embeddings(idmap_name, speaker_number, model_filename, model_yaml, data_root_name , device, file_extension="wav", transform_pipeline=None):
# Create dataset to load the data
dataset = IdMapSet(data_root_name, idmap_name)
dataset = IdMapSet(idmap_name=idmap_name, data_root_path=data_root_name, file_extension=file_extension, transform_pipeline=transform_pipeline)
# Load the model
checkpoint = torch.load(model_filename)
......@@ -502,21 +503,23 @@ def extract_embeddings(idmap, speaker_number, model_filename, model_yaml, data_r
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = sidekit.StatServer()
idmap = IdMap(idmap_name)
embeddings = StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
# Process the data
with torch.no_grad():
for idx, (data, mod, seg) in tqdm(enumerate(dataset)):
vec = model(data.to(device), is_eval=True)
current_idx = numpy.argwhere(numpy.logical_and(im.leftids == mod, im.rightids == seg))[0][0]
for idx in tqdm(range(len(dataset))):
data, mod, seg = dataset[idx]
vec = model(data[None, :, :].to(device), is_eval=True)
current_idx = numpy.argwhere(numpy.logical_and(idmap.leftids == mod, idmap.rightids == seg))[0][0]
embeddings.stat1[current_idx, :] = vec.detach().cpu()
return embeddings
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment