Commit a832fd2d authored by Anthony Larcher's avatar Anthony Larcher
Browse files

cleaning extract embedding

parent f383fc1a
......@@ -414,8 +414,8 @@ class IdMapSet(Dataset):
# add this in case the segment is too short
if duration <= self.min_duration * self.sample_rate:
middle = start + duration // 2
start = max(0, int(middle - (self.min_sample_nb / 2)))
duration = int(self.min_sample_nb)
start = int(max(0, int(middle - (self.min_duration * self.sample_rate / 2))))
duration = int(self.min_duration * self.sample_rate)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{self.idmap.rightids[index]}.{self.file_extension}",
frame_offset=start,
......@@ -435,7 +435,7 @@ class IdMapSet(Dataset):
rir_df=self.rir_df)
speech = speech.squeeze()
return speech, self.idmap.leftids[index], self.idmap.rightids[index], start, start + duration
def __len__(self):
......
......@@ -228,7 +228,6 @@ def test_metrics(model,
model_filename=model,
data_root_name=data_opts["test"]["data_path"],
device=device,
loss=model_opts["loss"]["type"],
transform_pipeline=transform_pipeline,
num_thread=train_opts["num_cpu"],
mixed_precision=train_opts["mixed_precision"])
......@@ -1065,8 +1064,6 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
batch_size=batch_size,
seed=training_opts['torch_seed'],
rank=local_rank,
num_process=torch.cuda.device_count(),
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
else:
batch_size = dataset_opts["batch_size"]
......@@ -1613,7 +1610,7 @@ def extract_embeddings(idmap_name,
window_len=win_duration,
window_shift=win_shift,
sample_rate=sample_rate,
min_duration=sliding_window
min_duration=win_duration
)
......@@ -1643,13 +1640,19 @@ def extract_embeddings(idmap_name,
# Create the StatServer
embeddings = StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
#embeddings = StatServer()
#embeddings.modelset = idmap.leftids
#embeddings.segset = idmap.rightids
#embeddings.start = idmap.start
#embeddings.stop = idmap.stop
#embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
#embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
embed = []
modelset= []
segset = []
starts = []
# Process the data
with torch.no_grad():
......@@ -1657,11 +1660,31 @@ def extract_embeddings(idmap_name,
desc='xvector extraction',
mininterval=1,
disable=None)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
#if data.shape[1] > 20000000:
# data = data[...,:20000000]
print(f"data.shape = {data.shape}")
if data.dim() > 2:
data = data.squeeze()
print(f"data.shape = {data.shape}")
with torch.cuda.amp.autocast(enabled=mixed_precision):
_, vec = model(x=data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
tmp_data = torch.split(data,data.shape[0]//(max(1, data.shape[0]//100)))
for td in tmp_data:
_, vec = model(x=td.to(device), is_eval=True)
embed.append(vec.detach().cpu())
modelset += [mod,] * data.shape[0]
segset += [seg,] * data.shape[0]
starts += [numpy.arange(start, start + vec.shape[0] * win_shift , win_shift),]
embeddings = StatServer()
embeddings.modelset = numpy.array(modelset).astype('>U')
embeddings.segset = numpy.array(segset).astype('>U')
embeddings.start = numpy.array(starts)
embeddings.stop = numpy.array(starts) + win_duration
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
print(f"type = {type(embed)}, {type(embed[0])}")
embeddings.stat1 = numpy.concatenate(embed)
return embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment