Commit f56d973b authored by Florent Desnous 's avatar Florent Desnous
Browse files

Delete tuto_1_iv_model.ipynb

parent 2ca453b4
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train model for Diarization\n",
"====\n",
"\n",
"This script trains UBM, TV and PLDA models for a diarization system.\n",
"\n",
"Initialization\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"from s4d.diar import Diar\n",
"from s4d.utils import *\n",
"\n",
"from sidekit import Mixture, FactorAnalyser, StatServer, IdMap\n",
"import numpy\n",
"import logging\n",
"import re\n",
"import sidekit\n",
"from sidekit.sidekit_io import *\n",
"try:\n",
" from sortedcontainers import SortedDict as dict\n",
"except ImportError:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"audio_dir = '../data/train/{}.wav'\n",
"num_thread = 4\n",
"\n",
"\n",
"ubm_seg_fn = './data/seg/ubm_ester.seg'\n",
"nb_gauss = 1024\n",
"mfcc_ubm_fn = './data/mfcc/ubm.h5'\n",
"ubm_idmap_fn = './data/mfcc/ubm_idmap.txt'\n",
"ubm_fn = './data/model/ester_ubm_'+str(nb_gauss)+'.h5'\n",
"\n",
"tv_seg_fn = './data/seg/train.tv.seg'\n",
"rank_tv = 300\n",
"it_max_tv = 10\n",
"mfcc_tv_fn = './data/mfcc/tv.h5'\n",
"tv_idmap_fn = './data/mfcc/tv_idmap.h5'\n",
"tv_stat_fn = './data/model/tv.stat.h5'\n",
"tv_fn = './data/model/tv_'+str(rank_tv)+'.h5'\n",
"\n",
"\n",
"plda_seg_fn = './data/seg/train.plda.seg'\n",
"rank_plda = 150\n",
"it_max_plda = 10\n",
"mfcc_plda_fn = './data/mfcc/norm_plda.h5'\n",
"plda_idmap_fn = './data/mfcc/plda_idmap.h5'\n",
"plda_fn = './data/model/plda_'+str(rank_tv)+'_'+str(rank_plda)+'.h5'\n",
"norm_stat_fn = './data/model/norm.stat.h5'\n",
"norm_fn = './data/model/norm.h5'\n",
"norm_iv_fn = './data/model/norm.iv.h5'\n",
"\n",
"matrices_fn = './data/model/matrices.h5'\n",
"model_fn = './data/model/ester_model_{}_{}_{}.h5'.format(nb_gauss, rank_tv, rank_plda)\n",
"\n",
"init_logging(level=logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 1: UBM\n",
"---\n",
"Extract MFCC for the UBM"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"logging.info('Computing MFCC for UBM')\n",
"diar_ubm = Diar.read_seg(ubm_seg_fn, normalize_cluster=True)\n",
"fe = get_feature_extractor(audio_dir, 'sid')\n",
"ubm_idmap = fe.save_multispeakers(diar_ubm.id_map(), output_feature_filename=mfcc_ubm_fn, keep_all=False)\n",
"ubm_idmap.write_txt(ubm_idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the UBM by EM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ubm_idmap = IdMap.read_txt(ubm_idmap_fn)\n",
"\n",
"fs = get_feature_server(mfcc_ubm_fn, 'sid')\n",
"\n",
"spk_lst = ubm_idmap.rightids\n",
"ubm = Mixture()\n",
"ubm.EM_split(fs, spk_lst, nb_gauss,\n",
" iterations=(1, 2, 2, 4, 4, 4, 8, 8, 8, 8, 8, 8, 8), num_thread=num_thread,\n",
" llk_gain=0.01)\n",
"ubm.write(ubm_fn, prefix='ubm/')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 2: TV\n",
"---\n",
"Extract MFCC for TV"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-14 17:20:11,559 - INFO - [1 1 1 ..., 1 1 1]\n",
"2018-06-14 17:20:49,962 - INFO - percentil 34.8445533752\n",
"2018-06-14 17:20:49,963 - INFO - show: 20010803_1000_1100_rfi cluster: 20010803_1000_1100_rfi_alain_janse thr:34.8445533752\n",
"2018-06-14 17:20:49,965 - INFO - keep_all id: 20010803_1000_1100_rfi show: 20010803_1000_1100_rfi/20010803_1000_1100_rfi_alain_janse start: 0 stop: 4608\n",
"2018-06-14 17:20:49,965 - INFO - output finename: ./data/mfcc/tv.h5\n",
"2018-06-14 17:21:13,165 - INFO - [1 1 1 ..., 1 1 1]\n",
"2018-06-14 17:21:51,840 - INFO - percentil 36.9146942139\n",
"2018-06-14 17:21:51,841 - INFO - show: 20010803_1000_1100_rfi cluster: gladys_say##11 thr:36.9146942139\n",
"2018-06-14 17:21:51,843 - INFO - keep_all id: 20010803_1000_1100_rfi show: 20010803_1000_1100_rfi/gladys_say##11 start: 0 stop: 6140\n",
"2018-06-14 17:21:51,843 - INFO - output finename: ./data/mfcc/tv.h5\n",
"2018-06-14 17:22:14,705 - INFO - [1 1 1 ..., 1 1 1]\n"
]
}
],
"source": [
"logging.info('Computing MFCC for TV')\n",
"diar_tv = Diar.read_seg(tv_seg_fn, normalize_cluster=True)\n",
"fe = get_feature_extractor(audio_dir, 'sid')\n",
"tv_idmap = fe.save_multispeakers(diar_tv.id_map(), output_feature_filename=mfcc_tv_fn, keep_all=False)\n",
"tv_idmap.write(tv_idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train a Total Variability model using the FactorAnalyser class"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-15 12:13:46,646 - INFO - ./data/mfcc ## tv ## .h5\n",
"2018-06-15 12:13:46,648 - INFO - feature extractor --> None\n",
"2018-06-15 12:13:46,648 - INFO - --------------------\n",
"2018-06-15 12:13:46,649 - INFO - None\n",
"2018-06-15 12:13:46,650 - INFO - --------------------\n",
"2018-06-15 12:13:46,651 - INFO - \t show: empty \n",
"\n",
"\t input_feature_filename: empty \n",
"\n",
"\t feature_filename_structure: ./data/mfcc/tv.h5 \n",
"\t \n",
"\t \n",
"\n",
"\t Post processing options: \n",
"\t\t mask: None \n",
"\t\t feat_norm: cmvn_sliding \n",
"\t\t dct_pca: False, dct_pca_config: (12, 12, None) \n",
"\t\t sdc: False, sdc_config: (1, 3, 7) \n",
"\t\t delta: True, double_delta: True, delta_filter: [ 0.25 0.5 0.25 0. -0.25 -0.5 -0.25] \n",
"\t\t rasta: False \n",
"\t\t keep_all_features: True \n",
"\n"
]
}
],
"source": [
"tv_idmap = IdMap.read(tv_idmap_fn)\n",
"\n",
"ubm = Mixture()\n",
"ubm.read(ubm_fn, prefix='ubm/')\n",
"\n",
"fs = get_feature_server(mfcc_tv_fn, 'sid')\n",
"\n",
"tv_idmap.leftids = numpy.copy(tv_idmap.rightids)\n",
"\n",
"tv_stat = StatServer(tv_idmap, ubm.get_distrib_nb(), ubm.dim())\n",
"tv_stat.accumulate_stat(ubm=ubm, feature_server=fs, seg_indices=range(tv_stat.segset.shape[0]), num_thread=num_thread)\n",
"tv_stat.write(tv_stat_fn)\n",
"fa = FactorAnalyser()\n",
"fa.total_variability(tv_stat_fn, ubm, rank_tv, nb_iter=it_max_tv, batch_size=1000, num_thread=num_thread)\n",
"\n",
"write_tv_hdf5([fa.F, fa.mean, fa.Sigma], tv_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 3: PLDA\n",
"---\n",
"Extract the MFCC for the PLDA"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logging.info('Computing MFCC for PLDA')\n",
"diar_plda = Diar.read_seg(plda_seg_fn, normalize_cluster=True)\n",
"fe = get_feature_extractor(audio_dir, 'sid')\n",
"plda_idmap = fe.save_multispeakers(diar_plda.id_map(), output_feature_filename=mfcc_plda_fn, keep_all=False)\n",
"plda_idmap.write(plda_idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accumulate statistics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plda_idmap = IdMap.read(plda_idmap_fn)\n",
"\n",
"ubm = Mixture()\n",
"ubm.read(ubm_fn, prefix='ubm/')\n",
"tv, tv_mean, tv_sigma = read_tv_hdf5(tv_fn)\n",
"\n",
"fs = get_feature_server(mfcc_plda_fn, 'sid')\n",
"\n",
"plda_norm_stat = StatServer(plda_idmap, ubm.get_distrib_nb(), ubm.dim())\n",
"plda_norm_stat.accumulate_stat(ubm=ubm, feature_server=fs, \n",
" seg_indices=range(plda_norm_stat.segset.shape[0]), num_thread=num_thread)\n",
"plda_norm_stat.write(norm_stat_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Extract i-vectors and compute norm"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"fa = FactorAnalyser(F=tv, mean=tv_mean, Sigma=tv_sigma)\n",
"norm_iv = fa.extract_ivectors(ubm, norm_stat_fn, num_thread=num_thread)\n",
"norm_iv.write(norm_iv_fn)\n",
"\n",
"norm_mean, norm_cov = norm_iv.estimate_spectral_norm_stat1(1, 'sphNorm')\n",
"\n",
"write_norm_hdf5([norm_mean, norm_cov], norm_fn)\n",
"\n",
"norm_iv.spectral_norm_stat1(norm_mean[:1], norm_cov[:1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the PLDA model"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:43,844 - INFO - Estimate between class covariance, it 1 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:44,138 - INFO - Estimate between class covariance, it 2 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:44,450 - INFO - Estimate between class covariance, it 3 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:44,711 - INFO - Estimate between class covariance, it 4 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:44,980 - INFO - Estimate between class covariance, it 5 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:45,277 - INFO - Estimate between class covariance, it 6 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:45,568 - INFO - Estimate between class covariance, it 7 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:45,854 - INFO - Estimate between class covariance, it 8 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:46,154 - INFO - Estimate between class covariance, it 9 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:31:46,450 - INFO - Estimate between class covariance, it 10 / 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E_step\n"
]
}
],
"source": [
"fa = FactorAnalyser()\n",
"fa.plda(norm_iv, rank_plda, nb_iter=it_max_plda)\n",
"write_plda_hdf5([fa.mean, fa.F, numpy.zeros((rank_tv, 0)), fa.Sigma], plda_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 4: Compute additional data (optional)\n",
"---\n",
"Adding matrices for additional scoring methods: \n",
"* Mahalonobis matrix\n",
"* Lower Choleski decomposition of the WCCN matrix\n",
"* Within- and Between-class Covariance matrices"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-05-15 07:25:52,093 - INFO - compute mahalanobis_matrix\n",
"2018-05-15 07:25:52,399 - INFO - compute wccn_choleski\n",
"2018-05-15 07:25:52,799 - INFO - compute two_covariance\n"
]
}
],
"source": [
"iv = StatServer(norm_iv_fn)\n",
"matrix_dict = {}\n",
"\n",
"logging.info('compute mahalanobis_matrix')\n",
"mahalanobis_matrix = iv.get_mahalanobis_matrix_stat1()\n",
"matrix_dict['mahalanobis_matrix'] = mahalanobis_matrix\n",
"\n",
"logging.info('compute wccn_choleski')\n",
"wccn_choleski = iv.get_wccn_choleski_stat1()\n",
"matrix_dict['wccn_choleski'] = wccn_choleski\n",
"\n",
"logging.info('compute two_covariance')\n",
"within_covariance = iv.get_within_covariance_stat1()\n",
"matrix_dict['two_covariance/within_covariance'] = within_covariance\n",
"between_covariance = iv.get_between_covariance_stat1()\n",
"matrix_dict['two_covariance/between_covariance'] = between_covariance\n",
"\n",
"write_dict_hdf5(matrix_dict, matrices_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 5: Merge in one model\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-18 12:32:56,168 - INFO - ubm\n",
"2018-06-18 12:32:56,208 - INFO - tv\n",
"2018-06-18 12:32:57,057 - INFO - norm\n",
"2018-06-18 12:32:57,070 - INFO - plda\n",
"2018-06-18 12:32:57,084 - INFO - mahalanobis_matrix\n",
"2018-06-18 12:32:57,098 - INFO - two_covariance\n",
"2018-06-18 12:32:57,108 - INFO - wccn_choleski\n"
]
}
],
"source": [
"with h5py.File(model_fn, 'w') as model:\n",
" for fn in [ubm_fn, tv_fn, norm_fn, plda_fn, matrices_fn]:\n",
" if not os.path.exists(fn):\n",
" continue\n",
" with h5py.File(fn, 'r') as fh:\n",
" for group in fh:\n",
" logging.info(group)\n",
" fh.copy(group, model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment