Commit 54484b38 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

extract

parent 7e812325
......@@ -1063,6 +1063,9 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
test_size=dataset_opts["validation_ratio"],
stratify=stratify)
torch.manual_seed(training_opts['torch_seed'] + local_rank)
torch.cuda.manual_seed(training_opts['torch_seed'] + local_rank)
training_set = SideSet(dataset_opts,
set_type="train",
chunk_per_segment=-1,
......@@ -1640,7 +1643,7 @@ def extract_embeddings(idmap_name,
data_root_name,
device,
file_extension="wav",
transform_pipeline="",
transform_pipeline={},
sliding_window=False,
win_duration=3.,
win_shift=1.5,
......@@ -1707,36 +1710,35 @@ def extract_embeddings(idmap_name,
segset = []
starts = []
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1,
disable=None)):
if data.dim() > 2:
data = data.squeeze()
with torch.cuda.amp.autocast(enabled=mixed_precision):
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
for td in tmp_data:
_, vec = model(x=td.to(device), is_eval=True)
embed.append(vec.detach().cpu())
modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0])
if sliding_window:
starts.extend(numpy.arange(start, start + vec.shape[0] * win_shift, win_shift))
else:
starts.append(start)
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1,
disable=None)):
if data.dim() > 2:
data = data.squeeze()
with torch.cuda.amp.autocast(enabled=mixed_precision):
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
for td in tmp_data:
_, vec = model(x=td.to(device), is_eval=True)
embed.append(vec.detach().cpu())
modelset.extend(mod * data.shape[0])
segset.extend(seg * data.shape[0])
if sliding_window:
starts.extend(numpy.arange(start, start + data.shape[0] * win_shift, win_shift))
else:
starts.append(start)
embeddings = StatServer()
embeddings.stat1 = numpy.concatenate(embed)
embeddings.modelset = numpy.array(modelset).astype('>U')
embeddings.segset = numpy.array(segset).astype('>U')
embeddings.start = numpy.array(starts)
embeddings.stop = numpy.array(starts) + win_duration
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.concatenate(embed)
return embeddings
......@@ -1746,30 +1748,31 @@ def extract_embeddings_per_speaker(idmap_name,
data_root_name,
device,
file_extension="wav",
transform_pipeline=None,
frame_shift=0.01,
frame_duration=0.025,
extract_after_pooling=False,
transform_pipeline={},
sample_rate=16000,
mixed_precision=False,
num_thread=1):
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename)
model_archi = checkpoint["model_archi"]
model = Xtractor(checkpoint["speaker_number"], model_archi=model_archi, loss="aam")
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_opts = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_opts["model_type"], loss=model_opts["loss"]["type"])
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
model = model.to(memory_format=torch.channels_last)
if isinstance(idmap_name, IdMap):
idmap = idmap_name
else:
idmap = IdMap(idmap_name)
# Create dataset to load the data
dataset = IdMapSetPerSpeaker(idmap_name=idmap_name,
dataset = IdMapSetPerSpeaker(idmap_name=idmap,
data_root_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
frame_rate=int(1. / frame_shift),
frame_rate=sample_rate,
min_duration=(model.context_size() + 2) * frame_shift * 2)
dataloader = DataLoader(dataset,
......@@ -1809,117 +1812,3 @@ def extract_embeddings_per_speaker(idmap_name,
return embeddings
def extract_sliding_embedding(idmap_name,
window_len,
window_shift,
model_filename,
data_root_name ,
device,
sample_rate=16000,
file_extension="wav",
transform_pipeline=None,
num_thread=1,
mixed_precision=False):
"""
:param idmap_name:
:param window_length:
:param sample_rate:
:param overlap:
:param speaker_number:
:param model_filename:
:param model_yaml:
:param data_root_name:
:param device:
:param file_extension:
:param transform_pipeline:
:return:
"""
# From the original IdMap, create the new one to extract x-vectors
if not isinstance(idmap_name, IdMap):
input_idmap = IdMap(idmap_name)
else:
input_idmap = idmap_name
# Load the model
if isinstance(model_filename, str):
checkpoint = torch.load(model_filename, map_location=device)
speaker_number = checkpoint["speaker_number"]
model_archi = checkpoint["model_archi"]
model = Xtractor(speaker_number, model_archi=model_archi)
model.load_state_dict(checkpoint["model_state_dict"])
else:
model = model_filename
if isinstance(idmap_name, IdMap):
idmap = idmap_name
else:
idmap = IdMap(idmap_name)
# Create dataset to load the data
dataset = IdMapSet(idmap_name=idmap_name,
data_path=data_root_name,
file_extension=file_extension,
transform_pipeline=transform_pipeline,
sliding_window=True,
window_len=window_len,
window_shift=window_shift,
sample_rate=sample_rate,
min_duration=0.1
)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=num_thread)
with torch.no_grad():
model.eval()
model.to(device)
# Get the size of embeddings to extract
if type(model) is Xtractor:
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
else:
name = list(model.module.before_speaker_embedding.state_dict().keys())[-1].split('.')[0]
if name != 'bias':
name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
embeddings = []
modelset= []
segset = []
starts = []
# Process the data
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1)):
with torch.cuda.amp.autocast(enabled=mixed_precision):
data = data.squeeze()
tmp_data = torch.split(data,data.shape[0]//(data.shape[0]//100))
for td in tmp_data:
vec = model(x=td.to(device), is_eval=True)
embeddings.append(vec.detach().cpu())
modelset += [mod, ] * data.shape[0]
segset += [seg, ] * data.shape[0]
starts += [numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift),]
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = numpy.array(modelset).astype('>U')
embeddings.segset = numpy.array(segset).astype('>U')
embeddings.start = numpy.array(starts)
embeddings.stop = numpy.array(starts) + window_len
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.concatenate(embeddings)
return embeddings
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment