Commit 754a4f9c authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent 9f28b58f
......@@ -170,7 +170,6 @@ if CUDA:
from .nnet import extract_embeddings
from .nnet import extract_sliding_embedding
from .nnet import ResBlock
from .nnet import ResNet18
from .nnet import SincNet
else:
......
......@@ -31,11 +31,14 @@ Copyright 2014-2021 Anthony Larcher and Sylvain Meignier
from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import IdMapSet_per_speaker, SpkSet
from .xsets import IdMapSetPerSpeaker
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling
from .res_net import ResBlock, ResNet18, PreResNet34
from .res_net import ResBlock, PreResNet34
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
from .sincnet import SincNet
from .preprocessor import RawPreprocessor
from .preprocessor import MfccFrontEnd
from .preprocessor import MelSpecFrontEnd
has_pyroom = True
try:
......
......@@ -42,12 +42,11 @@ import yaml
from collections import OrderedDict
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .augmentation import PreEmphasis
from .xsets import SideSet
from .xsets import FileSet
from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .xsets import IdMapSetPerSpeaker
from .xsets import SideSampler
from .res_net import RawPreprocessor
from .res_net import ResBlockWFMS
from .res_net import ResBlock
from .res_net import PreResNet34
......@@ -214,6 +213,8 @@ class MelSpecFrontEnd(torch.nn.Module):
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if x.dim() == 1:
x = x.unsqueeze(0)
out = self.PreEmphasis(x)
out = self.MelSpec(out)+1e-6
out = torch.log(out)
......
......@@ -36,8 +36,6 @@ import torch.optim as optim
import torch.multiprocessing as mp
from torchvision import transforms
from collections import OrderedDict
from .xsets import FrequencyMask, CMVN, TemporalMask
from .sincnet import SincNet, SincConv1d
from ..bosaris import IdMap
from ..statserver import StatServer
from torch.utils.data import DataLoader
......
......@@ -416,7 +416,7 @@ class SincNet(torch.nn.Module):
#return output.transpose(1, 2)
return output
def dimension(self):
def dimension():
doc = "Output features dimension."
def fget(self):
......
......@@ -306,13 +306,15 @@ class IdMapSet(Dataset):
self.data_path = data_path
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
self.transform_pipeline = transform_pipeline
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.transform = []
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
if (len(self.transformation) > 0):
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
self.noise_df = None
if "add_noise" in self.transform:
......@@ -342,14 +344,16 @@ class IdMapSet(Dataset):
stop = len(speech)
else:
nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
conversion_rate = nfo.samplerate // self.sample_rate
start = int(self.idmap.start[index]) * conversion_rate
stop = int(self.idmap.stop[index]) * conversion_rate
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=stop - start)
......@@ -364,6 +368,8 @@ class IdMapSet(Dataset):
noise_df=self.noise_df,
rir_df=self.rir_df)
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
......@@ -403,7 +409,7 @@ class IdMapSetPerSpeaker(Dataset):
self.data_root_path = data_root_path
self.file_extension = file_extension
self.len = len(set(self.idmap.leftids))
self.transform_pipeline = transform_pipeline
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.speaker_list = list(set(self.idmap.leftids))
......@@ -414,8 +420,9 @@ class IdMapSetPerSpeaker(Dataset):
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.transform = []
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
if (len(self.transformation) > 0):
if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
self.transform_list = self.transformation["pipeline"].split(',')
self.noise_df = None
if "add_noise" in self.transform:
......@@ -458,7 +465,7 @@ class IdMapSetPerSpeaker(Dataset):
tmp_data.append(speech)
speech = torch.cat(tmp_data, dim=0)
speech = torch.cat(tmp_data, dim=1)
speech += 10e-6 * torch.randn(speech.shape)
if len(self.transform) > 0:
......@@ -469,6 +476,8 @@ class IdMapSetPerSpeaker(Dataset):
noise_df=self.noise_df,
rir_df=self.rir_df)
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
def __len__(self):
......
......@@ -1226,10 +1226,8 @@ def extract_embeddings(idmap_name,
model_filename,
data_root_name,
device,
model_yaml=None,
speaker_number=None,
file_extension="wav",
transform_pipeline=None,
transform_pipeline={},
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
......@@ -1255,12 +1253,8 @@ def extract_embeddings(idmap_name,
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device)
if speaker_number is None:
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
model_archi = checkpoint["model_archi"]
else:
model_archi = model_yaml
speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
......@@ -1271,18 +1265,18 @@ def extract_embeddings(idmap_name,
else:
idmap = IdMap(idmap_name)
if type(model) is Xtractor:
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
model_cs = model.context_size()
else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
if type(model) is Xtractor:
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
model_cs = model.context_size()
else:
min_duration = (model.module.context_size() - 1) * frame_shift + frame_duration
model_cs = model.module.context_size()
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_root_path=data_root_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
......@@ -1341,7 +1335,6 @@ def extract_embeddings_per_speaker(idmap_name,
model_filename,
data_root_name,
device,
model_yaml=None,
file_extension="wav",
transform_pipeline=None,
frame_shift=0.01,
......@@ -1352,10 +1345,7 @@ def extract_embeddings_per_speaker(idmap_name,
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
if model_yaml is None:
model_archi = checkpoint["model_archi"]
else:
model_archi = model_yaml
model_archi = checkpoint["model_archi"]
model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
......@@ -1390,7 +1380,7 @@ def extract_embeddings_per_speaker(idmap_name,
if extract_after_pooling:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[1]
else:
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
emb_size = model.embedding_size
# Create the StatServer
embeddings = StatServer()
......@@ -1406,6 +1396,7 @@ def extract_embeddings_per_speaker(idmap_name,
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
print(f"Shape of data: {data.shape}")
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
......@@ -1415,13 +1406,12 @@ def extract_sliding_embedding(idmap_name,
window_length,
sample_rate,
overlap,
speaker_number,
model_filename,
model_yaml,
data_root_name ,
device,
file_extension="wav",
transform_pipeline=None):
transform_pipeline=None,
num_thread=1):
"""
:param idmap_name:
......@@ -1438,7 +1428,10 @@ def extract_sliding_embedding(idmap_name,
:return:
"""
# From the original IdMap, create the new one to extract x-vectors
input_idmap = IdMap(idmap_name)
if not isinstance(idmap_name, IdMap):
input_idmap = IdMap(idmap_name)
else:
input_idmap = idmap_name
# Create temporary lists
nb_chunks = 0
......@@ -1469,12 +1462,11 @@ def extract_sliding_embedding(idmap_name,
assert sliding_idmap.validate()
embeddings = extract_embeddings(sliding_idmap,
speaker_number,
model_filename,
model_yaml,
data_root_name,
device,
file_extension=file_extension,
transform_pipeline=transform_pipeline)
transform_pipeline=transform_pipeline,
num_thread=num_thread)
return embeddings
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment