Commit a066bd43 authored by Florent Desnous 's avatar Florent Desnous
Browse files

Updated model_iv.train() to use the sidekit.FactorAnalyser i-vector extractor

parent 99e43094
......@@ -20,7 +20,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
from sidekit import Mixture, StatServer, Scores, Ndx, PLDA_scoring, cosine_scoring, mahalanobis_scoring, two_covariance_scoring
from sidekit import Mixture, StatServer, FactorAnalyser, Scores, Ndx, PLDA_scoring, cosine_scoring, mahalanobis_scoring, two_covariance_scoring
from sidekit.sidekit_io import *
import copy
import numpy as np
......@@ -67,12 +67,13 @@ class ModelIV:
print('sn_cov: ', self.sn_cov.shape)
def train(self, feature_server, idmap, normalization=True):
#stat = StatServer(idmap, self.ubm)
stat = StatServer(idmap, distrib_nb=self.ubm.distrib_nb(), feature_size=self.ubm.dim())
stat.accumulate_stat(ubm=self.ubm, feature_server=feature_server, seg_indices=range(stat.segset.shape[0]), num_thread=self.nb_thread)
stat = stat.sum_stat_per_model()[0]
self.ivectors = stat.estimate_hidden(self.tv_mean, self.tv_sigma, V=self.tv, U=None, D=None, num_thread=self.nb_thread)[0]
fa = FactorAnalyser(mean=self.tv_mean, Sigma=self.tv_sigma, F=self.tv)
self.ivectors = fa.extract_ivectors_single(self.ubm, stat)
if normalization:
self.ivectors.spectral_norm_stat1(self.norm_mean[:1], self.norm_cov[:1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment