Commit fe6e8004 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sideset per speaker

parent aaede2a8
......@@ -31,7 +31,7 @@ Copyright 2014-2021 Anthony Larcher and Sylvain Meignier
from .augmentation import AddNoise
from .feed_forward import FForwardNetwork
from .feed_forward import kaldi_to_hdf5
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset
from .xsets import XvectorMultiDataset, XvectorDataset, StatDataset, IdMapSet_per_speaker
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding
from .res_net import ResBlock, ResNet18
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis
......
......@@ -738,6 +738,13 @@ class IdMapSet_per_speaker(Dataset):
self.min_duration = min_duration
self.sample_rate = frame_rate
self.speaker_list = list(set(self.idmap.leftids))
self.output_im = IdMap()
self.output_im.leftids = numpy.unique(self.idmap.leftids)
self.output_im.rightids = self.output_im.leftids
self.output_im.start = numpy.empty(self.output_im.rightids.shape[0], "|O")
self.output_im.stop = numpy.empty(self.output_im.rightids.shape[0], "|O")
_transform = []
if transform_pipeline is not None:
......@@ -766,23 +773,24 @@ class IdMapSet_per_speaker(Dataset):
"""
# Loop on all segments from the given speaker to load data
spk_id = self.idmap.leftids[idx]
spk_id = self.output_im.leftids[index]
tmp_data = []
nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
for id, start, stop in zip(self.idmap.leftids, self.idmap.start, self.idmap.stop):
start = int(start)
stop = int(stop)
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop,
dtype=wav_type)
tmp_data.append(sig.astype(numpy.float32))
if id == spk_id:
start = int(start)
stop = int(stop)
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop,
dtype=wav_type)
tmp_data.append(sig.astype(numpy.float32))
sig = numpy.concatenate(tmp_data, axis=0)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
......
......@@ -47,6 +47,7 @@ from collections import OrderedDict
from .xsets import SideSet
from .xsets import FileSet
from .xsets import IdMapSet
from .xsets import IdMapSet_per_speaker
from .res_net import RawPreprocessor, ResBlockWFMS
from ..bosaris import IdMap
from ..statserver import StatServer
......@@ -1142,10 +1143,10 @@ def extract_embeddings_per_speaker(idmap_name,
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.modelset = dataset.output_im.leftids
embeddings.segset = dataset.output_im.rightids
embeddings.start = dataset.output_im.start
embeddings.stop = dataset.output_im.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment