Commit 80673537 authored by Florent Desnous 's avatar Florent Desnous
Browse files

Upload tutorial for iv-PLDA clustering

parent 86e373f9
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"i-vector clustering with PLDA scoring\n",
"===\n",
"This script demonstrates the use of several clustering algorithms using PLDA scoring and i-vectors. The algorithms proposed are:\n",
" - Integer Linear Programming (ILP) IV\n",
" - HAC IV\n",
" - Connected Components (CC) IV\n",
" - Combination of CC and HAC, and CC and ILP\n",
"\n",
"It takes as input the segments generated by the second tutorial (BIC-HAC) and uses the model learned in the first."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Import theano\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Can not use cuDNN on context None: Disabled by dnn.enabled flag\n",
"Mapped name None to device cuda: GeForce GTX TITAN X (0000:03:00.0)\n"
]
}
],
"source": [
"%matplotlib inline\n",
"\n",
"from s4d.diar import Diar\n",
"from s4d.utils import *\n",
"from s4d import scoring\n",
"from s4d.model_iv import ModelIV\n",
"\n",
"from s4d.clustering.ilp_iv import ilp_iv\n",
"from s4d.clustering.hac_iv import hac_iv\n",
"from s4d.clustering.cc_iv import connexted_component\n",
"\n",
"from sidekit.sidekit_io import *\n",
"from sidekit.bosaris import IdMap, Scores\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import logging\n",
"import numpy\n",
"import copy\n",
"import sys\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"init_logging(level=logging.INFO)\n",
"\n",
"data_dir = 'data'\n",
"model_fn = os.path.join(data_dir, 'model', 'ester_model.h5')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"show = '20041008_1800_1830_INFO_DGA'\n",
"\n",
"audio_fn = os.path.join(data_dir, 'audio', show + '.wav')\n",
"out_dir = os.path.join('out', show)\n",
"mfcc_fn = os.path.join(out_dir, show + '.test_mfcc.h5')\n",
"bic_fn = os.path.join(out_dir, show + '.d.seg')\n",
"idmap_fn = os.path.join(out_dir, show + '.idmap.h5')\n",
"score_fn = os.path.join(out_dir, show + '.score_plda.h5')\n",
"\n",
"diar_bic = Diar.read_seg(bic_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Exctracting MFCC\n",
"==="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fe = get_feature_extractor(audio_fn, type_feature_extractor='sid')\n",
"idmap_bic = fe.save_multispeakers(diar_bic.id_map(), output_feature_filename=mfcc_fn, keep_all=False)\n",
"idmap_bic.write(idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PLDA scoring\n",
"===\n",
"Train a PLDA model for the show and compute the distance matrix"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load: data/model/ester_model.h5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-20 17:11:30,775 - INFO - out/20041008_1800_1830_INFO_DGA ## 20041008_1800_1830_INFO_DGA ## .test_mfcc.h5\n",
"2018-06-20 17:11:30,776 - INFO - feature extractor --> None\n",
"2018-06-20 17:11:30,777 - INFO - --------------------\n",
"2018-06-20 17:11:30,778 - INFO - None\n",
"2018-06-20 17:11:30,779 - INFO - --------------------\n",
"2018-06-20 17:11:30,780 - INFO - \t show: empty \n",
"\n",
"\t input_feature_filename: empty \n",
"\n",
"\t feature_filename_structure: out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.test_mfcc.h5 \n",
"\t \n",
"\t \n",
"\n",
"\t Post processing options: \n",
"\t\t mask: None \n",
"\t\t feat_norm: cmvn_sliding \n",
"\t\t dct_pca: False, dct_pca_config: (12, 12, None) \n",
"\t\t sdc: False, sdc_config: (1, 3, 7) \n",
"\t\t delta: True, double_delta: True, delta_filter: [ 0.25 0.5 0.25 0. -0.25 -0.5 -0.25] \n",
"\t\t rasta: False \n",
"\t\t keep_all_features: True \n",
"\n",
"2018-06-20 17:11:31,020 - INFO - 20041008_1800_1830_INFO_DGA/S0 start: 0 stop: 4450\n",
"2018-06-20 17:11:32,990 - INFO - 20041008_1800_1830_INFO_DGA/S8 start: 0 stop: 25320\n",
"2018-06-20 17:11:38,165 - INFO - 20041008_1800_1830_INFO_DGA/S11 start: 0 stop: 4458\n",
"2018-06-20 17:11:38,864 - INFO - 20041008_1800_1830_INFO_DGA/S30 start: 0 stop: 2754\n",
"2018-06-20 17:11:39,307 - INFO - 20041008_1800_1830_INFO_DGA/S45 start: 0 stop: 4620\n",
"2018-06-20 17:11:40,048 - INFO - 20041008_1800_1830_INFO_DGA/S66 start: 0 stop: 4506\n",
"2018-06-20 17:11:40,440 - INFO - 20041008_1800_1830_INFO_DGA/S142 start: 0 stop: 1251\n",
"2018-06-20 17:11:41,097 - INFO - 20041008_1800_1830_INFO_DGA/S83 start: 0 stop: 15633\n",
"2018-06-20 17:11:44,250 - INFO - 20041008_1800_1830_INFO_DGA/S86 start: 0 stop: 1696\n",
"2018-06-20 17:11:44,715 - INFO - 20041008_1800_1830_INFO_DGA/S89 start: 0 stop: 5770\n",
"2018-06-20 17:11:46,077 - INFO - 20041008_1800_1830_INFO_DGA/S100 start: 0 stop: 4984\n",
"2018-06-20 17:11:47,253 - INFO - 20041008_1800_1830_INFO_DGA/S106 start: 0 stop: 6084\n",
"2018-06-20 17:11:48,819 - INFO - 20041008_1800_1830_INFO_DGA/S123 start: 0 stop: 10635\n",
"2018-06-20 17:11:50,960 - INFO - 20041008_1800_1830_INFO_DGA/S145 start: 0 stop: 545\n",
"2018-06-20 17:11:51,071 - INFO - 20041008_1800_1830_INFO_DGA/S146 start: 0 stop: 1438\n",
"2018-06-20 17:11:51,215 - INFO - 20041008_1800_1830_INFO_DGA/S148 start: 0 stop: 287\n",
"2018-06-20 17:11:51,289 - INFO - 20041008_1800_1830_INFO_DGA/S150 start: 0 stop: 1126\n",
"2018-06-20 17:11:51,697 - INFO - 20041008_1800_1830_INFO_DGA/S153 start: 0 stop: 8490\n",
"2018-06-20 17:11:54,865 - INFO - 20041008_1800_1830_INFO_DGA/S99 start: 0 stop: 42013\n",
"2018-06-20 17:12:03,409 - INFO - 20041008_1800_1830_INFO_DGA/S251 start: 0 stop: 5969\n",
"2018-06-20 17:12:04,979 - INFO - 20041008_1800_1830_INFO_DGA/S263 start: 0 stop: 10479\n",
"2018-06-20 17:12:07,069 - INFO - 20041008_1800_1830_INFO_DGA/S266 start: 0 stop: 1816\n",
"2018-06-20 17:12:07,687 - INFO - 20041008_1800_1830_INFO_DGA/S284 start: 0 stop: 12472\n"
]
}
],
"source": [
"model_iv = ModelIV(model_fn)\n",
"idmap_bic = IdMap(idmap_fn)\n",
"\n",
"fs = get_feature_server(mfcc_fn, 'sid')\n",
"\n",
"model_iv.train(fs, idmap_bic)\n",
"distance = model_iv.score_plda_slow()\n",
"distance.write(score_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the algorithms\n",
"===\n",
"The different algorithms are run using a variable threshold $t$, producing a segmentation file for each value of $t$."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2018-06-20 17:12:16,631 - INFO - Threshold t from -30 to 70 with a step of 10.\n",
"2018-06-20 17:12:16,636 - INFO - t = -30\n",
"2018-06-20 17:12:16,908 - INFO - t = -20\n",
"2018-06-20 17:12:17,045 - INFO - t = -10\n",
"2018-06-20 17:12:17,158 - INFO - t = 0\n",
"2018-06-20 17:12:17,273 - INFO - t = 10\n",
"2018-06-20 17:12:17,391 - INFO - t = 20\n",
"2018-06-20 17:12:17,508 - INFO - t = 30\n",
"2018-06-20 17:12:17,626 - INFO - t = 40\n",
"2018-06-20 17:12:17,743 - INFO - t = 50\n",
"2018-06-20 17:12:17,869 - INFO - t = 60\n",
"2018-06-20 17:12:17,986 - INFO - t = 70\n"
]
}
],
"source": [
"ilp_diar_fn = os.path.join(out_dir, show + '.ilp.{:.2f}.seg')\n",
"hac_diar_fn = os.path.join(out_dir, show + '.hac.{:.2f}.seg')\n",
"cc_diar_fn = os.path.join(out_dir, show + '.cc.{:.2f}.seg')\n",
"cc_ilp_diar_fn = os.path.join(out_dir, show + '.cc+ilp.{:.2f}.seg')\n",
"cc_hac_diar_fn = os.path.join(out_dir, show + '.cc+hac.{:.2f}_{:.2f}.seg')\n",
"\n",
"t_min = -30\n",
"t_max = 80\n",
"t_step = 10\n",
"logging.info(\"Threshold t from {} to {} with a step of {}.\".format(t_min, t_max-t_step, t_step))\n",
"\n",
"for t in range(t_min, t_max, t_step):\n",
" logging.info(\"t = {}\".format(t))\n",
" \n",
" sum_sg0 = sum_sg = sum_cc = 0\n",
" scores = Scores(scores_file_name=score_fn)\n",
" \n",
" diar_iv, _ = ilp_iv(diar_bic, scores, threshold=t)\n",
" Diar.write_seg(ilp_diar_fn.format(t), diar_iv)\n",
"\n",
" diar_iv, _, _ = hac_iv(diar_bic, scores, threshold=t)\n",
" Diar.write_seg(hac_diar_fn.format(t), diar_iv)\n",
"\n",
" diar_iv, cc_list, nb_sg0, nb_sg, nb_cc = connexted_component(diar_bic, scores, threshold=t)\n",
" Diar.write_seg(cc_diar_fn.format(t), diar_iv)\n",
" \n",
" sum_sg0 += nb_sg0\n",
" sum_sg += nb_sg\n",
" sum_cc += nb_cc\n",
"\n",
" diar_out = Diar()\n",
" for cc in copy.deepcopy(cc_list):\n",
" if cc.type != 'cc':\n",
" diar_out.append_diar(cc.diarization)\n",
" for cc in cc_list:\n",
" if cc.type == 'cc':\n",
" diar_iv, _ = ilp_iv(cc.diarization, cc.scores, threshold=t)\n",
" diar_out.append_diar(diar_iv)\n",
" Diar.write_seg(cc_ilp_diar_fn.format(t), diar_out)\n",
"\n",
" diar_start = Diar()\n",
" for cc in cc_list:\n",
" if cc.type != 'cc':\n",
" diar_start.append_diar(cc.diarization)\n",
" \n",
" for t2 in range(t_min, t_max, t_step):\n",
" diar_out = Diar()\n",
" diar_out.append_diar(diar_start)\n",
" for cc in copy.deepcopy(cc_list):\n",
" if cc.type == 'cc':\n",
" diar_hac_iv, _, __ = hac_iv(cc.diarization, cc.scores, threshold=t2)\n",
" diar_out.append_diar(diar_hac_iv)\n",
" Diar.write_seg(cc_hac_diar_fn.format(t, t2), diar_out) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the DER\n",
"===\n",
"Compute the DER for this show for each threshold using ILP"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.-30.00.seg 32.0336835124\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.-20.00.seg 17.6346925965\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.-10.00.seg 7.94298437815\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.0.00.seg 2.60739284403\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.10.00.seg 2.60739284403\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.20.00.seg 2.60739284403\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.30.00.seg 2.60739284403\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.40.00.seg 2.60739284403\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.50.00.seg 5.0164035389\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.60.00.seg 5.0164035389\n",
"out/20041008_1800_1830_INFO_DGA/20041008_1800_1830_INFO_DGA.ilp.70.00.seg 5.0164035389\n"
]
}
],
"source": [
"ref = Diar.read_mdtm(os.path.join(data_dir, 'seg', 'ester1.tst.mdtm'))\n",
"uem = Diar.read_uem(os.path.join(data_dir, 'seg', 'ester1.tst.uem'))\n",
"uems = uem.make_index(['show'])\n",
"refs = ref.make_index(['show'])\n",
"\n",
"for t in range(t_min, t_max, t_step):\n",
" hyp = Diar.read_seg(ilp_diar_fn.format(t))\n",
" res = scoring.compute_der(hyp, refs[show], uem=uems[show], collar=25, no_overlap=False)\n",
" print(ilp_diar_fn.format(t), res.get_der())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment