Commit 0a148e1e authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug extract xv per speaker and related dataset

parent 575a86bc
......@@ -475,36 +475,41 @@ class IdMapSetPerSpeaker(Dataset):
def __init__(self,
idmap_name,
data_root_path,
data_path,
file_extension,
transform_pipeline={},
transform_number=1,
frame_rate=100,
sample_rate=16000,
min_duration=0.165
):
"""
:param data_root_name:
:param idmap_name:
:param data_root_path:
:param file_extension:
:param transform_pipeline:
:param transform_number:
:param sample_rate:
:param min_duration:
"""
if isinstance(idmap_name, IdMap):
self.idmap = idmap_name
else:
self.idmap = IdMap(idmap_name)
self.data_root_path = data_root_path
self.data_path = data_path
self.file_extension = file_extension
self.len = len(set(self.idmap.leftids))
self.transformation = transform_pipeline
self.transform_number = transform_number
self.min_duration = min_duration
self.sample_rate = frame_rate
self.sample_rate = sample_rate
self.speaker_list = list(set(self.idmap.leftids))
self.output_im = IdMap()
self.output_im.leftids = numpy.unique(self.idmap.leftids)
self.output_im.rightids = self.output_im.leftids
self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.transformation = transform_pipeline
self.transform_number = transform_number
self.noise_df = None
if "add_noise" in self.transformation:
......@@ -529,22 +534,33 @@ class IdMapSetPerSpeaker(Dataset):
# Loop on all segments from the given speaker to load data
spk_id = self.output_im.leftids[index]
tmp_data = []
nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
for id, start, stop in zip(self.idmap.leftids, self.idmap.start, self.idmap.stop):
if id == spk_id:
start = int(start)
stop = int(stop)
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
speech, speech_fs = torchaudio.load(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
num_frames=stop - start)
#nfo = soundfile.info(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
for sid, seg_id, seg_start, seg_stop in zip(self.idmap.leftids, self.idmap.rightids,
self.idmap.start, self.idmap.stop):
if sid == spk_id:
# Read start and stop and convert to time in seconds
if seg_start is None:
start = 0
else:
start = int(seg_start * 0.01 * self.sample_rate)
if seg_stop is None:
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}")
duration = int(speech.shape[1] - start)
else:
duration = int(seg_stop * 0.01 * self.sample_rate) - start
# add this in case the segment is too short
if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2
start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
duration = int(self.min_duration * self.sample_rate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{seg_id}.{self.file_extension}",
frame_offset=start,
num_frames=duration)
speech += 10e-6 * torch.randn(speech.shape)
tmp_data.append(speech)
speech = torch.cat(tmp_data, dim=1)
......@@ -558,6 +574,7 @@ class IdMapSetPerSpeaker(Dataset):
noise_df=self.noise_df,
rir_df=self.rir_df)
stop = start + duration
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, stop
......
......@@ -1779,7 +1779,6 @@ def extract_embeddings(idmap_name,
embeddings.stop = numpy.array(starts) + win_duration
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
return embeddings
......@@ -1792,6 +1791,19 @@ def extract_embeddings_per_speaker(idmap_name,
sample_rate=16000,
mixed_precision=False,
num_thread=1):
"""
:param idmap_name:
:param model_filename:
:param data_root_name:
:param device:
:param file_extension:
:param transform_pipeline:
:param sample_rate:
:param mixed_precision:
:param num_thread:
:return:
"""
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device)
......@@ -1827,7 +1839,6 @@ def extract_embeddings_per_speaker(idmap_name,
model.to(device)
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
emb_size = model.embedding_size
# Create the StatServer
......@@ -1840,8 +1851,10 @@ def extract_embeddings_per_speaker(idmap_name,
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader, desc='xvector extraction', mininterval=1)):
with torch.cuda.amp.autocast(enabled=mixed_precision):
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
_, vec = model(x=data.to(device), is_eval=True, norm_embedding=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment