Commit a7ede196 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

factor_analysis for BEAT

parent 1b768454
......@@ -572,10 +572,15 @@ class FactorAnalyser:
gmm_covariance = "diag" if ubm.invcov.ndim == 2 else "full"
# Set useful variables
with h5py.File(stat_server_filename[0], 'r') as fh: # open the first StatServer to get size
_, sv_size = fh['stat1'].shape
feature_size = fh['stat1'].shape[1] // fh['stat0'].shape[1]
distrib_nb = fh['stat0'].shape[1]
if not isinstance(stat_server_filename[0], h5py._hl.files.File):
with h5py.File(stat_server_filename[0], 'r') as fh: # open the first StatServer to get size
_, sv_size = fh['stat1'].shape
feature_size = fh['stat1'].shape[1] // fh['stat0'].shape[1]
distrib_nb = fh['stat0'].shape[1]
else:
_, sv_size = stat_server_filename[0]['stat1'].shape
feature_size = stat_server_filename[0]['stat1'].shape[1] // stat_server_filename[0]['stat0'].shape[1]
distrib_nb = stat_server_filename[0]['stat0'].shape[1]
upper_triangle_indices = numpy.triu_indices(tv_rank)
......@@ -611,7 +616,11 @@ class FactorAnalyser:
for stat_server_file in stat_server_filename:
# get info from the current StatServer
with h5py.File(stat_server_file, 'r') as fh:
if not isinstance(stat_server_file, h5py._hl.files.File):
fh = h5py.File(stat_server_file, 'r')
else:
fh = stat_server_file
nb_sessions = fh["modelset"].shape[0]
total_session_nb += nb_sessions
batch_nb = int(numpy.floor(nb_sessions / float(batch_size) + 0.999))
......@@ -644,6 +653,9 @@ class FactorAnalyser:
_A, _C, _R = watcher.get()
if not isinstance(stat_server_file, h5py._hl.files.File):
fh.close()
_R /= total_session_nb
# M-step
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment