Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Lebourdais
s4d
Commits
86e373f9
Commit
86e373f9
authored
Jun 20, 2018
by
Florent Desnous
Browse files
Upload tutorial for model training
parent
f56d973b
Changes
1
Hide whitespace changes
Inline
Side-by-side
tutorials/tuto_1_iv_model.ipynb
0 → 100644
View file @
86e373f9
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train model for Diarization\n",
"====\n",
"\n",
"This script trains UBM, TV and PLDA models for a diarization system.\n",
"\n",
"Initialization\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"from s4d.diar import Diar\n",
"from s4d.utils import *\n",
"\n",
"from sidekit import Mixture, FactorAnalyser, StatServer, IdMap\n",
"import numpy\n",
"import logging\n",
"import re\n",
"import sidekit\n",
"from sidekit.sidekit_io import *\n",
"try:\n",
" from sortedcontainers import SortedDict as dict\n",
"except ImportError:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"init_logging(level=logging.INFO)\n",
"num_thread = 4\n",
"audio_dir = '../data/train/{}.wav'\n",
"\n",
"\n",
"ubm_seg_fn = './data/seg/ubm_ester.seg'\n",
"nb_gauss = 1024\n",
"mfcc_ubm_fn = './data/mfcc/ubm.h5'\n",
"ubm_idmap_fn = './data/mfcc/ubm_idmap.txt'\n",
"ubm_fn = './data/model/ester_ubm_'+str(nb_gauss)+'.h5'\n",
"\n",
"\n",
"tv_seg_fn = './data/seg/train.tv.seg'\n",
"rank_tv = 300\n",
"it_max_tv = 10\n",
"mfcc_tv_fn = './data/mfcc/tv.h5'\n",
"tv_idmap_fn = './data/mfcc/tv_idmap.h5'\n",
"tv_stat_fn = './data/model/tv.stat.h5'\n",
"tv_fn = './data/model/tv_'+str(rank_tv)+'.h5'\n",
"\n",
"\n",
"plda_seg_fn = './data/seg/train.plda.seg'\n",
"rank_plda = 150\n",
"it_max_plda = 10\n",
"mfcc_plda_fn = './data/mfcc/norm_plda.h5'\n",
"plda_idmap_fn = './data/mfcc/plda_idmap.h5'\n",
"plda_fn = './data/model/plda_'+str(rank_tv)+'_'+str(rank_plda)+'.h5'\n",
"norm_stat_fn = './data/model/norm.stat.h5'\n",
"norm_fn = './data/model/norm.h5'\n",
"norm_iv_fn = './data/model/norm.iv.h5'\n",
"\n",
"\n",
"matrices_fn = './data/model/matrices.h5'\n",
"model_fn = './data/model/ester_model_{}_{}_{}.h5'.format(nb_gauss, rank_tv, rank_plda)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 1: UBM\n",
"---\n",
"Extract MFCC for the UBM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logging.info('Computing MFCC for UBM')\n",
"diar_ubm = Diar.read_seg(ubm_seg_fn, normalize_cluster=True)\n",
"fe = get_feature_extractor(audio_dir, 'sid')\n",
"ubm_idmap = fe.save_multispeakers(diar_ubm.id_map(), output_feature_filename=mfcc_ubm_fn, keep_all=False)\n",
"ubm_idmap.write_txt(ubm_idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the UBM by EM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ubm_idmap = IdMap.read_txt(ubm_idmap_fn)\n",
"\n",
"fs = get_feature_server(mfcc_ubm_fn, 'sid')\n",
"\n",
"spk_lst = ubm_idmap.rightids\n",
"ubm = Mixture()\n",
"ubm.EM_split(fs, spk_lst, nb_gauss,\n",
" iterations=(1, 2, 2, 4, 4, 4, 8, 8, 8, 8, 8, 8, 8), num_thread=num_thread,\n",
" llk_gain=0.01)\n",
"ubm.write(ubm_fn, prefix='ubm/')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 2: TV\n",
"---\n",
"Extract MFCC for TV"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logging.info('Computing MFCC for TV')\n",
"diar_tv = Diar.read_seg(tv_seg_fn, normalize_cluster=True)\n",
"fe = get_feature_extractor(audio_dir, 'sid')\n",
"tv_idmap = fe.save_multispeakers(diar_tv.id_map(), output_feature_filename=mfcc_tv_fn, keep_all=False)\n",
"tv_idmap.write(tv_idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train a Total Variability model using the FactorAnalyser class"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tv_idmap = IdMap.read(tv_idmap_fn)\n",
"\n",
"ubm = Mixture()\n",
"ubm.read(ubm_fn, prefix='ubm/')\n",
"\n",
"fs = get_feature_server(mfcc_tv_fn, 'sid')\n",
"\n",
"tv_idmap.leftids = numpy.copy(tv_idmap.rightids)\n",
"\n",
"tv_stat = StatServer(tv_idmap, ubm.get_distrib_nb(), ubm.dim())\n",
"tv_stat.accumulate_stat(ubm=ubm, feature_server=fs, seg_indices=range(tv_stat.segset.shape[0]), num_thread=num_thread)\n",
"tv_stat.write(tv_stat_fn)\n",
"fa = FactorAnalyser()\n",
"fa.total_variability(tv_stat_fn, ubm, rank_tv, nb_iter=it_max_tv, batch_size=1000, num_thread=num_thread)\n",
"\n",
"write_tv_hdf5([fa.F, fa.mean, fa.Sigma], tv_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 3: PLDA\n",
"---\n",
"Extract the MFCC for the PLDA"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logging.info('Computing MFCC for PLDA')\n",
"diar_plda = Diar.read_seg(plda_seg_fn, normalize_cluster=True)\n",
"fe = get_feature_extractor(audio_dir, 'sid')\n",
"plda_idmap = fe.save_multispeakers(diar_plda.id_map(), output_feature_filename=mfcc_plda_fn, keep_all=False)\n",
"plda_idmap.write(plda_idmap_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accumulate statistics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plda_idmap = IdMap.read(plda_idmap_fn)\n",
"\n",
"ubm = Mixture()\n",
"ubm.read(ubm_fn, prefix='ubm/')\n",
"tv, tv_mean, tv_sigma = read_tv_hdf5(tv_fn)\n",
"\n",
"fs = get_feature_server(mfcc_plda_fn, 'sid')\n",
"\n",
"plda_norm_stat = StatServer(plda_idmap, ubm.get_distrib_nb(), ubm.dim())\n",
"plda_norm_stat.accumulate_stat(ubm=ubm, feature_server=fs, \n",
" seg_indices=range(plda_norm_stat.segset.shape[0]), num_thread=num_thread)\n",
"plda_norm_stat.write(norm_stat_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Extract i-vectors and compute norm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fa = FactorAnalyser(F=tv, mean=tv_mean, Sigma=tv_sigma)\n",
"norm_iv = fa.extract_ivectors(ubm, norm_stat_fn, num_thread=num_thread)\n",
"norm_iv.write(norm_iv_fn)\n",
"\n",
"norm_mean, norm_cov = norm_iv.estimate_spectral_norm_stat1(1, 'sphNorm')\n",
"\n",
"write_norm_hdf5([norm_mean, norm_cov], norm_fn)\n",
"\n",
"norm_iv.spectral_norm_stat1(norm_mean[:1], norm_cov[:1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the PLDA model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fa = FactorAnalyser()\n",
"fa.plda(norm_iv, rank_plda, nb_iter=it_max_plda)\n",
"write_plda_hdf5([fa.mean, fa.F, numpy.zeros((rank_tv, 0)), fa.Sigma], plda_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 4: Compute additional data (optional)\n",
"---\n",
"Adding matrices for additional scoring methods: \n",
"* Mahalonobis matrix\n",
"* Lower Choleski decomposition of the WCCN matrix\n",
"* Within- and Between-class Covariance matrices"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"iv = StatServer(norm_iv_fn)\n",
"matrix_dict = {}\n",
"\n",
"logging.info('compute mahalanobis_matrix')\n",
"mahalanobis_matrix = iv.get_mahalanobis_matrix_stat1()\n",
"matrix_dict['mahalanobis_matrix'] = mahalanobis_matrix\n",
"\n",
"logging.info('compute wccn_choleski')\n",
"wccn_choleski = iv.get_wccn_choleski_stat1()\n",
"matrix_dict['wccn_choleski'] = wccn_choleski\n",
"\n",
"logging.info('compute two_covariance')\n",
"within_covariance = iv.get_within_covariance_stat1()\n",
"matrix_dict['two_covariance/within_covariance'] = within_covariance\n",
"between_covariance = iv.get_between_covariance_stat1()\n",
"matrix_dict['two_covariance/between_covariance'] = between_covariance\n",
"\n",
"write_dict_hdf5(matrix_dict, matrices_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 5: Merge in one model\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with h5py.File(model_fn, 'w') as model:\n",
" for fn in [ubm_fn, tv_fn, norm_fn, plda_fn, matrices_fn]:\n",
" if not os.path.exists(fn):\n",
" continue\n",
" with h5py.File(fn, 'r') as fh:\n",
" for group in fh:\n",
" logging.info(group)\n",
" fh.copy(group, model)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
%% Cell type:markdown id: tags:
Train model for Diarization
====
This script trains UBM, TV and PLDA models for a diarization system.
Initialization
---
%% Cell type:code id: tags:
```
python
import
logging
from
s4d.diar
import
Diar
from
s4d.utils
import
*
from
sidekit
import
Mixture
,
FactorAnalyser
,
StatServer
,
IdMap
import
numpy
import
logging
import
re
import
sidekit
from
sidekit.sidekit_io
import
*
try
:
from
sortedcontainers
import
SortedDict
as
dict
except
ImportError
:
pass
```
%% Cell type:code id: tags:
```
python
init_logging
(
level
=
logging
.
INFO
)
num_thread
=
4
audio_dir
=
'../data/train/{}.wav'
ubm_seg_fn
=
'./data/seg/ubm_ester.seg'
nb_gauss
=
1024
mfcc_ubm_fn
=
'./data/mfcc/ubm.h5'
ubm_idmap_fn
=
'./data/mfcc/ubm_idmap.txt'
ubm_fn
=
'./data/model/ester_ubm_'
+
str
(
nb_gauss
)
+
'.h5'
tv_seg_fn
=
'./data/seg/train.tv.seg'
rank_tv
=
300
it_max_tv
=
10
mfcc_tv_fn
=
'./data/mfcc/tv.h5'
tv_idmap_fn
=
'./data/mfcc/tv_idmap.h5'
tv_stat_fn
=
'./data/model/tv.stat.h5'
tv_fn
=
'./data/model/tv_'
+
str
(
rank_tv
)
+
'.h5'
plda_seg_fn
=
'./data/seg/train.plda.seg'
rank_plda
=
150
it_max_plda
=
10
mfcc_plda_fn
=
'./data/mfcc/norm_plda.h5'
plda_idmap_fn
=
'./data/mfcc/plda_idmap.h5'
plda_fn
=
'./data/model/plda_'
+
str
(
rank_tv
)
+
'_'
+
str
(
rank_plda
)
+
'.h5'
norm_stat_fn
=
'./data/model/norm.stat.h5'
norm_fn
=
'./data/model/norm.h5'
norm_iv_fn
=
'./data/model/norm.iv.h5'
matrices_fn
=
'./data/model/matrices.h5'
model_fn
=
'./data/model/ester_model_{}_{}_{}.h5'
.
format
(
nb_gauss
,
rank_tv
,
rank_plda
)
```
%% Cell type:markdown id: tags:
Step 1: UBM
---
Extract MFCC for the UBM
%% Cell type:code id: tags:
```
python
logging
.
info
(
'Computing MFCC for UBM'
)
diar_ubm
=
Diar
.
read_seg
(
ubm_seg_fn
,
normalize_cluster
=
True
)
fe
=
get_feature_extractor
(
audio_dir
,
'sid'
)
ubm_idmap
=
fe
.
save_multispeakers
(
diar_ubm
.
id_map
(),
output_feature_filename
=
mfcc_ubm_fn
,
keep_all
=
False
)
ubm_idmap
.
write_txt
(
ubm_idmap_fn
)
```
%% Cell type:markdown id: tags:
Train the UBM by EM
%% Cell type:code id: tags:
```
python
ubm_idmap
=
IdMap
.
read_txt
(
ubm_idmap_fn
)
fs
=
get_feature_server
(
mfcc_ubm_fn
,
'sid'
)
spk_lst
=
ubm_idmap
.
rightids
ubm
=
Mixture
()
ubm
.
EM_split
(
fs
,
spk_lst
,
nb_gauss
,
iterations
=
(
1
,
2
,
2
,
4
,
4
,
4
,
8
,
8
,
8
,
8
,
8
,
8
,
8
),
num_thread
=
num_thread
,
llk_gain
=
0.01
)
ubm
.
write
(
ubm_fn
,
prefix
=
'ubm/'
)
```
%% Cell type:markdown id: tags:
Step 2: TV
---
Extract MFCC for TV
%% Cell type:code id: tags:
```
python
logging
.
info
(
'Computing MFCC for TV'
)
diar_tv
=
Diar
.
read_seg
(
tv_seg_fn
,
normalize_cluster
=
True
)
fe
=
get_feature_extractor
(
audio_dir
,
'sid'
)
tv_idmap
=
fe
.
save_multispeakers
(
diar_tv
.
id_map
(),
output_feature_filename
=
mfcc_tv_fn
,
keep_all
=
False
)
tv_idmap
.
write
(
tv_idmap_fn
)
```
%% Cell type:markdown id: tags:
Train a Total Variability model using the FactorAnalyser class
%% Cell type:code id: tags:
```
python
tv_idmap
=
IdMap
.
read
(
tv_idmap_fn
)
ubm
=
Mixture
()
ubm
.
read
(
ubm_fn
,
prefix
=
'ubm/'
)
fs
=
get_feature_server
(
mfcc_tv_fn
,
'sid'
)
tv_idmap
.
leftids
=
numpy
.
copy
(
tv_idmap
.
rightids
)
tv_stat
=
StatServer
(
tv_idmap
,
ubm
.
get_distrib_nb
(),
ubm
.
dim
())
tv_stat
.
accumulate_stat
(
ubm
=
ubm
,
feature_server
=
fs
,
seg_indices
=
range
(
tv_stat
.
segset
.
shape
[
0
]),
num_thread
=
num_thread
)
tv_stat
.
write
(
tv_stat_fn
)
fa
=
FactorAnalyser
()
fa
.
total_variability
(
tv_stat_fn
,
ubm
,
rank_tv
,
nb_iter
=
it_max_tv
,
batch_size
=
1000
,
num_thread
=
num_thread
)
write_tv_hdf5
([
fa
.
F
,
fa
.
mean
,
fa
.
Sigma
],
tv_fn
)
```
%% Cell type:markdown id: tags:
Step 3: PLDA
---
Extract the MFCC for the PLDA
%% Cell type:code id: tags:
```
python
logging
.
info
(
'Computing MFCC for PLDA'
)
diar_plda
=
Diar
.
read_seg
(
plda_seg_fn
,
normalize_cluster
=
True
)
fe
=
get_feature_extractor
(
audio_dir
,
'sid'
)
plda_idmap
=
fe
.
save_multispeakers
(
diar_plda
.
id_map
(),
output_feature_filename
=
mfcc_plda_fn
,
keep_all
=
False
)
plda_idmap
.
write
(
plda_idmap_fn
)
```
%% Cell type:markdown id: tags:
Accumulate statistics
%% Cell type:code id: tags:
```
python
plda_idmap
=
IdMap
.
read
(
plda_idmap_fn
)
ubm
=
Mixture
()
ubm
.
read
(
ubm_fn
,
prefix
=
'ubm/'
)
tv
,
tv_mean
,
tv_sigma
=
read_tv_hdf5
(
tv_fn
)
fs
=
get_feature_server
(
mfcc_plda_fn
,
'sid'
)
plda_norm_stat
=
StatServer
(
plda_idmap
,
ubm
.
get_distrib_nb
(),
ubm
.
dim
())
plda_norm_stat
.
accumulate_stat
(
ubm
=
ubm
,
feature_server
=
fs
,
seg_indices
=
range
(
plda_norm_stat
.
segset
.
shape
[
0
]),
num_thread
=
num_thread
)
plda_norm_stat
.
write
(
norm_stat_fn
)
```
%% Cell type:markdown id: tags:
Extract i-vectors and compute norm
%% Cell type:code id: tags:
```
python
fa
=
FactorAnalyser
(
F
=
tv
,
mean
=
tv_mean
,
Sigma
=
tv_sigma
)
norm_iv
=
fa
.
extract_ivectors
(
ubm
,
norm_stat_fn
,
num_thread
=
num_thread
)
norm_iv
.
write
(
norm_iv_fn
)
norm_mean
,
norm_cov
=
norm_iv
.
estimate_spectral_norm_stat1
(
1
,
'sphNorm'
)
write_norm_hdf5
([
norm_mean
,
norm_cov
],
norm_fn
)
norm_iv
.
spectral_norm_stat1
(
norm_mean
[:
1
],
norm_cov
[:
1
])
```
%% Cell type:markdown id: tags:
Train the PLDA model
%% Cell type:code id: tags:
```
python
fa
=
FactorAnalyser
()
fa
.
plda
(
norm_iv
,
rank_plda
,
nb_iter
=
it_max_plda
)
write_plda_hdf5
([
fa
.
mean
,
fa
.
F
,
numpy
.
zeros
((
rank_tv
,
0
)),
fa
.
Sigma
],
plda_fn
)
```
%% Cell type:markdown id: tags:
Step 4: Compute additional data (optional)
---
Adding matrices for additional scoring methods:
*
Mahalonobis matrix
*
Lower Choleski decomposition of the WCCN matrix
*
Within- and Between-class Covariance matrices
%% Cell type:code id: tags:
```
python
iv
=
StatServer
(
norm_iv_fn
)
matrix_dict
=
{}
logging
.
info
(
'compute mahalanobis_matrix'
)
mahalanobis_matrix
=
iv
.
get_mahalanobis_matrix_stat1
()
matrix_dict
[
'mahalanobis_matrix'
]
=
mahalanobis_matrix
logging
.
info
(
'compute wccn_choleski'
)
wccn_choleski
=
iv
.
get_wccn_choleski_stat1
()
matrix_dict
[
'wccn_choleski'
]
=
wccn_choleski
logging
.
info
(
'compute two_covariance'
)
within_covariance
=
iv
.
get_within_covariance_stat1
()
matrix_dict
[
'two_covariance/within_covariance'
]
=
within_covariance
between_covariance
=
iv
.
get_between_covariance_stat1
()
matrix_dict
[
'two_covariance/between_covariance'
]
=
between_covariance
write_dict_hdf5
(
matrix_dict
,
matrices_fn
)
```
%% Cell type:markdown id: tags:
Step 5: Merge in one model
---
%% Cell type:code id: tags:
```
python
with
h5py
.
File
(
model_fn
,
'w'
)
as
model
:
for
fn
in
[
ubm_fn
,
tv_fn
,
norm_fn
,
plda_fn
,
matrices_fn
]:
if
not
os
.
path
.
exists
(
fn
):
continue
with
h5py
.
File
(
fn
,
'r'
)
as
fh
:
for
group
in
fh
:
logging
.
info
(
group
)
fh
.
copy
(
group
,
model
)
```
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment