Commit 03d0e5e1 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sliding window extrctor

parent 107b690a
......@@ -1572,14 +1572,10 @@ def extract_sliding_embedding(idmap_name,
name = name + '.weight'
emb_size = model.module.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
embeddings = []
modelset= []
segset = []
starts = []
# Process the data
with torch.no_grad():
......@@ -1588,9 +1584,18 @@ def extract_sliding_embedding(idmap_name,
mininterval=1)):
with torch.cuda.amp.autocast(enabled=mixed_precision):
vec = model(x=data.to(device), is_eval=True)
embeddings.stat1= vec.detach().cpu()
embeddings.append(vec.detach().cpu())
modelset += [mod,] * embeddings.shape[0]
segset += [seg,] * embeddings.shape[0]
starts += numpy.arange(start, start + embeddings.shape[0] * window_shift , window_shift)
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = numpy.array(modelset).astype('>U')
embeddings.segset = numpy.array(segset).astype('>U')
embeddings.start = numpy.array(starts)
embeddings.stop = numpy.array(starts) + window_len
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.concatenate(embeddings)
return embeddings
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment