Commit 36ac5290 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

svm_scoring and _R - _r*_r

parent f7fae30c
This diff is collapsed.
......@@ -162,8 +162,13 @@ then stored in compressed pickle format:
print('Compute the sufficient statistics')
# Create a StatServer for the enrollment data and compute the statistics
enroll_stat = sidekit.StatServer(enroll_idmap, ubm)
enroll_stat.accumulate_stat(ubm=ubm, feature_server=features_server, seg_indices=range(enroll_stat.segset.shape[0]), num_thread=nbThread)
enroll_stat = sidekit.StatServer(enroll_idmap,
distrib_nb=512,
feature_size=60)
enroll_stat.accumulate_stat(ubm=ubm,
feature_server=features_server,
seg_indices=range(enroll_stat.segset.shape[0]),
num_thread=nbThread)
enroll_stat.write('data/stat_rsr2015_male_enroll.h5')
Adapt the GMM speaker models from the UBM via a MAP adaptation
......@@ -223,4 +228,4 @@ The following results should be obtained at the end of this tutorial:
.. image:: rsr2015_GMM-UBM512_map3_snr40_cmvn_rasta_logE.png
.. image:: rsr2015_gmm-ubm.pdf
......@@ -175,28 +175,36 @@ then computed in the StatServer which is then stored to disk:
.. code:: python
logging.info()
enroll_stat = sidekit.StatServer(enroll_idmap, ubm)
enroll_stat = sidekit.StatServer(enroll_idmap,
distrib_nb=512,
feature_size=60)
enroll_stat.accumulate_stat(ubm=ubm,
feature_server=features_server,
seg_indices=range(enroll_stat.segset.shape[0]),
num_thread=nbThread)
enroll_stat.write('data/stat_rsr2015_male_enroll.h5')
back_stat = sidekit.StatServer(back_idmap, ubm)
back_stat = sidekit.StatServer(back_idmap,
distrib_nb=512,
feature_size=60)
back_stat.accumulate_stat(ubm=ubm,
feature_server=features_server,
seg_indices=range(back_stat.segset.shape[0]),
num_thread=nbThread)
back_stat.write('data/stat_rsr2015_male_back.h5')
nap_stat = sidekit.StatServer(nap_idmap, ubm)
nap_stat = sidekit.StatServer(nap_idmap,
distrib_nb=512,
feature_size=60)
nap_stat.accumulate_stat(ubm=ubm,
feature_server=features_server,
seg_indices=range(nap_stat.segset.shape[0]),
num_thread=nbThread)
nap_stat.write('data/stat_rsr2015_male_nap.h5')
test_stat = sidekit.StatServer(test_idmap, ubm)
test_stat = sidekit.StatServer(test_idmap,
distrib_nb=512,
feature_size=60)
test_stat.accumulate_stat(ubm=ubm,
feature_server=features_server,
seg_indices=range(test_stat.segset.shape[0]),
......@@ -284,12 +292,12 @@ Plot DET curve and compute minDCF and EER
dp.plot_DR30_both(idx=0)
dp.plot_mindcf_point(prior, idx=0)
minDCF, Pmiss, Pfa, prbep, eer = sidekit.bosaris.detplot.fast_minDCF(dp.__tar__[0], dp.__non__[0], prior, normalize=False)
minDCF, Pmiss, Pfa, prbep, eer = sidekit.bosaris.detplot.fast_minDCF(dp.__tar__[0], dp.__non__[0], prior, normalize=True)
logging.info("minDCF = {}, eer = {}".format(minDCF, eer))
After running this script you should obtain the following curve
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. image:: SVM-GMM_NAP_512g.png
.. image:: rsr2015_svm_nap.pdf
......@@ -105,7 +105,7 @@ def e_gather(arg, q):
:param q: input queue that is filled by the producers and emptied in this function (a multiprocessing.Queue object)
:return: the three accumulators
"""
_A, _C, _R = arg
_A, _C, _R, _r = arg
while True:
......@@ -115,8 +115,9 @@ def e_gather(arg, q):
_A += stat0.T.dot(e_hh)
_C += e_h.T.dot(stat1)
_R += numpy.sum(e_hh, axis=0)
_r += numpy.sum(e_h, axis=0)
return _A, _C, _R
return _A, _C, _R, _r
def iv_extract_on_batch(arg, q):
......@@ -595,6 +596,7 @@ class FactorAnalyser:
_A = serialize(numpy.zeros((distrib_nb, tv_rank * (tv_rank + 1) // 2), dtype=numpy.float32))
_C = serialize(numpy.zeros((tv_rank, sv_size), dtype=numpy.float32))
_R = serialize(numpy.zeros((tv_rank * (tv_rank + 1) // 2), dtype=numpy.float32))
_r = serialize(numpy.zeros(tv_rank, dtype=numpy.float32))
total_session_nb = 0
......@@ -614,7 +616,7 @@ class FactorAnalyser:
pool = multiprocessing.Pool(num_thread + 2)
# put Consumer to work first
watcher = pool.apply_async(e_gather, ((_A, _C, _R), q))
watcher = pool.apply_async(e_gather, ((_A, _C, _R, _r), q))
# fire off workers
jobs = []
......@@ -630,14 +632,16 @@ class FactorAnalyser:
for job in jobs:
job.get()
#now we are done, kill the consumer
# now we are done, kill the consumer
q.put((None, None, None, None))
pool.close()
_A, _C, _R = watcher.get()
_A, _C, _R, _r = watcher.get()
_r /= total_session_nb
_R /= total_session_nb
print("_A = {}".format(_A[:4,:4]))
_R -= np.outer(_r, _r)
# M-step
_A_tmp = numpy.zeros((tv_rank, tv_rank), dtype=numpy.float32)
for c in range(distrib_nb):
......
......@@ -957,7 +957,7 @@ class StatServer:
index_map = numpy.repeat(numpy.arange(ubm.distrib_nb()), ubm.dim())
# Adapt mean vectors
alpha = self.stat0 / (self.stat0 + r) # Adaptation coefficient
alpha = (self.stat0 + numpy.finfo(np.float32).eps) / (self.stat0 + numpy.finfo(numpy.float32).eps + r) # Adaptation coefficient
M = self.stat1 / self.stat0[:, index_map]
M[numpy.isnan(M)] = 0 # Replace NaN due to divide by zeros
M = alpha[:, index_map] * M + (1 - alpha[:, index_map]) * \
......
......@@ -119,7 +119,7 @@ def svm_scoring(svm_filename_structure, test_sv, ndx, num_thread=1):
jobs = []
for idx in los:
p = multiprocessing.Process(target=svm_scoring_singleThread,
args=(svm_filename_structure, test_sv, ndx, score, idx))
args=(svm_filename_structure, test_sv, clean_ndx, score, idx))
jobs.append(p)
p.start()
for p in jobs:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment