Commit 2182f0f7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug

parent baae957c
......@@ -58,7 +58,7 @@ class Key:
and columns to the test segments. True is non-target trial.
"""
def __init__(self, key_file_name='',
def __init__(self, key_file_name=None,
models=numpy.array([]),
testsegs=numpy.array([]),
trials=numpy.array([])):
......@@ -74,7 +74,10 @@ class Key:
self.tar = numpy.array([], dtype="bool")
self.non = numpy.array([], dtype="bool")
if key_file_name == '':
if key_file_name is None and models is None and testsegs is None and trials is None:
pass
elif key_file_name == None:
modelset = numpy.unique(models)
segset = numpy.unique(testsegs)
......@@ -103,6 +106,16 @@ class Key:
self.tar = tmp.tar
self.non = tmp.non
@classmethod
def create(cls, modelset, segset, tar, non):
key = Key()
key.modelset = modelset
key.segset = segset
key.tar = tar
key.non = non
assert key.validate(), "Wrong Key format"
return key
@check_path_existance
def write(self, output_file_name):
""" Save Key in HDF5 format
......
......@@ -601,11 +601,13 @@ class IdMapSet(Dataset):
nfo = soundfile.info(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}")
start = int(self.idmap.start[index])
stop = int(self.idmap.stop[index])
print(f"Extract segment from {start} to {stop}")
# add this in case the segment is too short
if stop - start <= self.min_duration * nfo.samplerate:
middle = (stop - start) // 2
start = max(0, int(middle - (self.min_duration * nfo.samplerate / 2)))
stop = int(start + self.min_duration * nfo.samplerate)
print(f"\t{start} to {stop}")
sig, _ = soundfile.read(f"{self.data_root_path}/{self.idmap.rightids[index]}.{self.file_extension}",
start=start,
stop=stop)
......@@ -614,6 +616,8 @@ class IdMapSet(Dataset):
if self.transform_pipeline is not None:
sig, _, ___, _____, _t, _s = self.transforms((sig, 0, 0, 0, 0, 0))
print(f"Size of MFCCs : {sig.shape}")
return torch.from_numpy(sig).type(torch.FloatTensor), \
self.idmap.leftids[index], \
self.idmap.rightids[index], \
......
......@@ -398,6 +398,7 @@ class Xtractor(torch.nn.Module):
padding=cfg['preprocessor']["padding"],
dilation=cfg['preprocessor']["dilation"])
self.feature_size = cfg["feature_size"]
self.preprocessor_weight_decay = 0.000
"""
Prepare sequence network
......@@ -446,10 +447,12 @@ class Xtractor(torch.nn.Module):
Pooling
"""
self.stat_pooling = MeanStdPooling()
tmp_input_size = input_size * 2
if cfg["stat_pooling"]["type"] == "GRU":
self.stat_pooling = GruPooling(input_size=cfg["stat_pooling"]["input_size"],
gru_node=cfg["stat_pooling"]["gru_node"],
nb_gru_layer=cfg["stat_pooling"]["nb_gru_layer"])
tmp_input_size = cfg["stat_pooling"]["gru_node"]
self.stat_pooling_weight_decay = cfg["stat_pooling"]["weight_decay"]
......@@ -457,7 +460,7 @@ class Xtractor(torch.nn.Module):
Prepare last part of the network (after pooling)
"""
# Create sequential object for the second part of the network
input_size = input_size * 2
input_size = tmp_input_size
before_embedding_layers = []
for k in cfg["before_embedding"].keys():
if k.startswith("lin"):
......@@ -955,7 +958,6 @@ def extract_embeddings(idmap_name,
min_duration=model.context_size()
)
dataloader = DataLoader(dataset,
batch_size=1,
shuffle=False,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment