Commit aaede2a8 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new dataset for diarization

parent b951d7ce
......@@ -708,6 +708,102 @@ class IdMapSet(Dataset):
return self.len
class IdMapSet_per_speaker(Dataset):
"""
DataSet that provide data according to a sidekit.IdMap object
"""
def __init__(self,
idmap_name,
data_root_path,
file_extension,
transform_pipeline=None,
frame_rate=100,
min_duration=0.165
):
"""
:param data_root_name:
:param idmap_name:
"""
if isinstance(idmap_name, IdMap):
self.idmap = idmap_name
else:
self.idmap = IdMap(idmap_name)
self.data_root_path = data_root_path
self.file_extension = file_extension
self.len = len(set(self.idmap.leftids))
self.transform_pipeline = transform_pipeline
self.min_duration = min_duration
self.sample_rate = frame_rate
self.speaker_list = list(set(self.idmap.leftids))
_transform = []
if transform_pipeline is not None:
trans = transform_pipeline.split(",")
for t in trans:
if 'PreEmphasis' in t:
_transform.append(PreEmphasis())
if 'MFCC' in t:
_transform.append(MFCC())
if "CMVN" in t:
_transform.append(CMVN())
if 'add_noise' in t:
self.add_noise = numpy.ones(self.idmap.leftids.shape[0], dtype=bool)
numpy.random.shuffle(self.add_noise)
_transform.append(AddNoise(noise_db_csv="list/musan.csv",
snr_min_max=[5.0, 15.0],
noise_root_path="./data/musan/"))
self.transforms = transforms.Compose(_transform)
def __getitem__(self, index):
"""
:param index:
:return:
"""
# Loop on all segments from the given speaker to load data
spk_id = self.idmap.leftids[idx]
tmp_data = []
nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
for id, start, stop in zip(self.idmap.leftids, self.idmap.start, self.idmap.stop):
start = int(start)
stop = int(stop)
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = start + (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop,
dtype=wav_type)
tmp_data.append(sig.astype(numpy.float32))
sig = numpy.concatenate(tmp_data, axis=0)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
if self.transform_pipeline is not None:
sig, _, ___, _____, _t, _s = self.transforms((sig, 0, 0, 0, 0, 0))
return torch.from_numpy(sig).type(torch.FloatTensor), \
self.idmap.leftids[index], \
self.idmap.rightids[index], \
start, stop
def __len__(self):
"""
:param self:
:return:
"""
return self.len
class FileSet(Dataset):
"""
Dataset class to load from disk
......
......@@ -1085,6 +1085,80 @@ def extract_embeddings(idmap_name,
return embeddings
def extract_embeddings_per_speaker(idmap_name,
model_filename,
data_root_name,
device,
model_yaml=None,
speaker_number=None,
file_extension="wav",
transform_pipeline=None,
frame_shift=0.01,
frame_duration=0.025,
num_thread=1):
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
if speaker_number is None:
speaker_number = checkpoint["speaker_number"]
if model_yaml is None:
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
if isinstance(idmap_name, IdMap):
idmap = idmap_name
else:
idmap = IdMap(idmap_name)
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
# Create dataset to load the data
dataset = IdMapSet_per_speaker(idmap_name=idmap_name,
data_root_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
min_duration=(model.context_size() + 2) * frame_shift * 2
)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=num_thread)
with torch.no_grad():
model.eval()
model.to(device)
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in tqdm.tqdm(enumerate(dataloader)):
if data.shape[1] > 20000000:
data = data[..., :20000000]
vec = model(data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
def extract_sliding_embedding(idmap_name,
window_length,
sample_rate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment