Commit 107b690a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

slidding_window x-vectors

parent 53a1e215
......@@ -293,8 +293,11 @@ class IdMapSet(Dataset):
data_path,
file_extension,
transform_pipeline={},
frame_rate=100,
min_duration=0.165
sliding_window=True,
window_len=24000,
window_shift=8000,
sample_rate=16000,
min_duration=0.150
):
"""
......@@ -310,14 +313,13 @@ class IdMapSet(Dataset):
self.file_extension = file_extension
self.len = self.idmap.leftids.shape[0]
self.transformation = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.min_sample_nb = min_duration * sample_rate
self.sample_rate = sample_rate
self.sliding_window = sliding_window
self.window_len = window_len
self.window_shift = window_shift
self.transform = []
#if (len(self.transformation) > 0):
# if (self.transformation["pipeline"] != '') and (self.transformation["pipeline"] is not None):
# self.transform_list = self.transformation["pipeline"].split(',')
if self.transformation is not None:
self.transform_list = self.transformation.split(",")
......@@ -341,28 +343,29 @@ class IdMapSet(Dataset):
:return:
"""
if self.idmap.start[index] is None:
start = 0.0
start = 0
if self.idmap.start[index] is None and self.idmap.stop[index] is None:
if self.idmap.stop[index] is None:
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = 0
stop = len(speech)
duration = len(speech) - start
else:
nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
duration = int(self.idmap.stop[index]) - start
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
if duration <= self.min_sample_nb:
middle = start + duration // 2
start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = self.min_sample_nb
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=stop - start)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=duration)
speech += 10e-6 * torch.randn(speech.shape)
if self.sliding_window:
speech = speech.squeeze().unfold(0,self.window_len,self.window_shift)
if len(self.transform) > 0:
speech = data_augmentation(speech,
speech_fs,
......@@ -373,9 +376,7 @@ class IdMapSet(Dataset):
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, start + duration
def __len__(self):
"""
......
......@@ -1418,6 +1418,7 @@ def extract_embeddings_per_speaker(idmap_name,
return embeddings
"""
def extract_sliding_embedding(idmap_name,
window_length,
sample_rate,
......@@ -1485,4 +1486,111 @@ def extract_sliding_embedding(idmap_name,
transform_pipeline=transform_pipeline,
num_thread=num_thread)
return embeddings
"""
def extract_sliding_embedding(idmap_name,
window_len,
window_shift,
model_filename,
data_root_name ,
device,
sample_rate=16000,
file_extension="wav",
transform_pipeline=None,
num_thread=1,
mixed_precision=False):
"""
:param idmap_name:
:param window_length:
:param sample_rate:
:param overlap:
:param speaker_number:
:param model_filename:
:param model_yaml:
:param data_root_name:
:param device:
:param file_extension:
:param transform_pipeline:
:return:
"""
# From the original IdMap, create the new one to extract x-vectors
if not isinstance(idmap_name, IdMap):
input_idmap = IdMap(idmap_name)
else:
input_idmap = idmap_name
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
if isinstance(idmap_name, IdMap):
idmap = idmap_name
else:
idmap = IdMap(idmap_name)
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
sliding_window=True,
window_len=window_len,
window_shift=window_shift,
sample_rate=sample_rate,
min_duration=0.1
)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=num_thread)
with torch.no_grad():
model.eval()
model.to(device)
# Get the size of embeddings to extract
if type(model) is Xtractor:
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
else:
name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1)):
with torch.cuda.amp.autocast(enabled=mixed_precision):
vec = model(x=data.to(device), is_eval=True)
embeddings.stat1= vec.detach().cpu()
return embeddings
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment