Commit 3b87841b authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add-on for beat

parent 4a23887a
......@@ -190,7 +190,7 @@ __maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'
__version__="1.3.4"
__version__="1.3.6.2"
# __all__ = ["io",
# "vad",
......
......@@ -565,9 +565,7 @@ class FeaturesServer(object):
features_list = []
for idx, load_arg in enumerate(zip(show_list, channel_list, feature_filename_list, label_list, start_list, stop_list)):
logging.critical("load file {} / {}".format(idx + 1, len(show_list)))
print("load file {} / {}".format(idx + 1, len(show_list)))
print("min: {}, max{}".format(self.load(*load_arg)[0].min(), self.load(*load_arg)[0].max()))
logging.critical("load file {} / {}".format(idx + 1, len(show_list)))
features_list.append(self.load(*load_arg)[0])
return numpy.vstack(features_list)
......
......@@ -706,10 +706,6 @@ class Mixture(object):
start_list=start_list,
stop_list=stop_list,
num_thread=num_thread)
print("\n\n Features pour l'init\n\n")
print("max: {}, min: {}, mean(0): {}".format(features.min(), features.max(), features.mean(0)))
print(features)
n_frames = features.shape[0]
mu = features.mean(0)
cov = (features**2).mean(0)
......
......@@ -333,45 +333,49 @@ class StatServer:
"""
line_number = 0
for idx, ss in enumerate(arg):
assert(isinstance(ss, sidekit.StatServer) and ss.validate()), "Arguments must be proper StatServers"
assert (isinstance(ss, sidekit.StatServer) and ss.validate()), "Arguments must be proper StatServers"
# Check consistency of StatServers (dimension of the stat0 and stat1)
if idx == 0:
dim_stat0 = ss.stat0.shape[1]
dim_stat1 = ss.stat1.shape[1]
dim_stat1 = ss.stat1.shape[1]
else:
assert(dim_stat0 == ss.stat0.shape[1] and
dim_stat1 == ss.stat1.shape[1]), "Stat dimensions are not consistent"
assert (dim_stat0 == ss.stat0.shape[1] and
dim_stat1 == ss.stat1.shape[1]), "Stat dimensions are not consistent"
line_number += ss.modelset.shape[0]
# Get a list of unique modelID-segmentID
# Get a list of unique modelID-segmentID-start-stop
id_list = []
for ss in arg:
id_list += list(ss.segset)
for m, s, start, stop in zip(ss.modelset, ss.segset, ss.start, ss.stop):
id_list.append("{}-{}-{}-{}".format(m, s, str(start), str(stop)))
id_set = set(id_list)
if line_number != len(id_set):
print("WARNING: duplicated segmentID in input StatServers")
# Initialize the new StatServer with unique set of segmentID
tmp = numpy.array(list(id_set))
new_stat_server = sidekit.StatServer()
new_stat_server.modelset = numpy.empty(len(id_set), dtype='object')
new_stat_server.segset = numpy.array(list(id_set))
new_stat_server.segset = numpy.empty(len(id_set), dtype='object')
new_stat_server.start = numpy.empty(len(id_set), 'object')
new_stat_server.stop = numpy.empty(len(id_set), dtype='object')
new_stat_server.stat0 = numpy.zeros((len(id_set), dim_stat0), dtype=STAT_TYPE)
new_stat_server.stat1 = numpy.zeros((len(id_set), dim_stat1), dtype=STAT_TYPE)
for ss in arg:
for idx, segment in enumerate(ss.segset):
new_idx = numpy.argwhere(new_stat_server.segset == segment)
for idx, (m, s, start, stop) in enumerate(zip(ss.modelset, ss.segset, ss.start, ss.stop)):
key = "{}-{}-{}-{}".format(m, s, str(start), str(stop))
new_idx = numpy.argwhere(tmp == key)
new_stat_server.modelset[new_idx] = ss.modelset[idx]
new_stat_server.segset[new_idx] = ss.segset[idx]
new_stat_server.start[new_idx] = ss.start[idx]
new_stat_server.stop[new_idx] = ss.stop[idx]
new_stat_server.stat0[new_idx, :] = ss.stat0[idx, :].astype(STAT_TYPE)
new_stat_server.stat1[new_idx, :] = ss.stat1[idx, :].astype(STAT_TYPE)
assert(new_stat_server.validate()), "Problem in StatServer Merging"
assert (new_stat_server.validate()), "Problem in StatServer Merging"
return new_stat_server
@staticmethod
......@@ -417,11 +421,11 @@ class StatServer:
start = copy.deepcopy(self.start)
start[numpy.isnan(self.start.astype('float'))] = -1
start = start.astype('int8', copy=False)
start = start.astype('int32', copy=False)
stop = copy.deepcopy(self.stop)
stop[numpy.isnan(self.stop.astype('float'))] = -1
stop = stop.astype('int8', copy=False)
stop = stop.astype('int32', copy=False)
with h5py.File(output_file_name, mode) as f:
......@@ -484,11 +488,11 @@ class StatServer:
start = copy.deepcopy(self.start)
start[numpy.isnan(self.start.astype('float'))] = -1
start = start.astype('int8', copy=False)
start = start.astype('int32', copy=False)
stop = copy.deepcopy(self.stop)
stop[numpy.isnan(self.stop.astype('float'))] = -1
stop = stop.astype('int8', copy=False)
stop = stop.astype('int32', copy=False)
# If the file doesn't exist before, create it
if mode == "w":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment