Commit faf8ddd0 authored by Gaël Le Lan's avatar Gaël Le Lan
Browse files

Kaldi PLDA reading method

parent 8d64e00b
# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher
:mod:`kaldi_io` provides utilities to import Kaldi PLDA/LDA binaries.
"""
__author__ = "Gaël Le Lan"
__email__ = "gael.lelan@orange.com"
import sys
import numpy as np
BYTES_PER_DOUBLE = 8
BYTES_PER_FLOAT = 4
def read_plda(plda_file):
"""
Import PLDA from Kaldi format
output PLDA parameters
:param plda_file: Kaldi format PLDA file
"""
with open(plda_file, 'rb') as ins:
ins.read(2)
assert ins.read(7) == b'<Plda> '
fv_dv_value = ins.read(2)
assert ins.read(1) == b' '
next_bytes_to_read = ins.read(1)[0]
vec_dim = int.from_bytes(ins.read(next_bytes_to_read), byteorder='little')
if fv_dv_value == b'DV':
mean = np.fromstring(ins.read(vec_dim*BYTES_PER_DOUBLE), np.float64)
elif fv_dv_value == b'FV':
mean = np.fromstring(ins.read(vec_dim*BYTES_PER_FLOAT), np.float64)
else:
sys.exit('error reading mean')
fm_dm_value = ins.read(2)
assert ins.read(1) == b' '
next_bytes_to_read = ins.read(1)[0]
n_rows = int.from_bytes(ins.read(next_bytes_to_read), byteorder='little')
next_bytes_to_read = ins.read(1)[0]
n_cols = int.from_bytes(ins.read(next_bytes_to_read), byteorder='little')
if fm_dm_value == b'DM':
transform = np.fromstring(ins.read(n_rows*n_cols*BYTES_PER_DOUBLE),
np.float64).reshape(n_rows, n_cols)
elif fm_dm_value == b'FM':
transform = np.fromstring(ins.read(n_rows*n_cols*BYTES_PER_FLOAT),
np.float64).reshape(n_rows, n_cols)
else:
sys.exit('error reading transform')
fv_dv_value = ins.read(2)
assert ins.read(1) == b' '
next_bytes_to_read = ins.read(1)[0]
vec_dim = int.from_bytes(ins.read(next_bytes_to_read), byteorder='little')
if fv_dv_value == b'DV':
psi = np.fromstring(ins.read(vec_dim*BYTES_PER_DOUBLE), np.float64)
elif fv_dv_value == b'FV':
psi = np.fromstring(ins.read(vec_dim*BYTES_PER_FLOAT), np.float64)
else:
sys.exit('error reading Psi')
assert ins.readline() == b'</Plda> '
return mean, transform, psi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment