Commit 48678299 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Refactoring and resnet18

parent 582f5048
......@@ -22,16 +22,17 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
Copyright 2014-2020 Anthony Larcher and Sylvain Meignier
"""
from ctypes import *
from ctypes.util import find_library
import importlib
import logging
import numpy
import os
import sys
import importlib
# Read environment variable if it exists
......@@ -59,69 +60,67 @@ PARAM_TYPE = numpy.float32
STAT_TYPE = numpy.float64 # can be numpy.float32 to speed up the computation but can lead to numerical issuess
# Import bosaris-like classes
from sidekit.bosaris import IdMap
from sidekit.bosaris import Ndx
from sidekit.bosaris import Key
from sidekit.bosaris import Scores
from sidekit.bosaris import DetPlot
from sidekit.bosaris import effective_prior
from sidekit.bosaris import logit_effective_prior
from sidekit.bosaris import fast_minDCF
from .bosaris import IdMap
from .bosaris import Ndx
from .bosaris import Key
from .bosaris import Scores
from .bosaris import DetPlot
from .bosaris import effective_prior
from .bosaris import logit_effective_prior
from .bosaris import fast_minDCF
# Import classes
from sidekit.features_extractor import FeaturesExtractor
from sidekit.features_server import FeaturesServer
from sidekit.mixture import Mixture
from sidekit.statserver import StatServer
from sidekit.factor_analyser import FactorAnalyser
from sidekit.frontend.io import write_pcm
from sidekit.frontend.io import read_pcm
from sidekit.frontend.io import pcmu2lin
from sidekit.frontend.io import read_sph
from sidekit.frontend.io import write_label
from sidekit.frontend.io import read_label
from sidekit.frontend.io import read_spro4
from sidekit.frontend.io import read_audio
from sidekit.frontend.io import write_spro4
from sidekit.frontend.io import read_htk
from sidekit.frontend.io import write_htk
from sidekit.frontend.vad import vad_energy
from sidekit.frontend.vad import vad_snr
from sidekit.frontend.vad import label_fusion
from sidekit.frontend.vad import speech_enhancement
from sidekit.frontend.normfeat import cms
from sidekit.frontend.normfeat import cmvn
from sidekit.frontend.normfeat import stg
from sidekit.frontend.normfeat import rasta_filt
from sidekit.frontend.features import compute_delta
from sidekit.frontend.features import framing
from sidekit.frontend.features import pre_emphasis
from sidekit.frontend.features import trfbank
from sidekit.frontend.features import mel_filter_bank
from sidekit.frontend.features import mfcc
from sidekit.frontend.features import pca_dct
from sidekit.frontend.features import shifted_delta_cepstral
from sidekit.iv_scoring import cosine_scoring
from sidekit.iv_scoring import mahalanobis_scoring
from sidekit.iv_scoring import two_covariance_scoring
from sidekit.iv_scoring import PLDA_scoring
from sidekit.gmm_scoring import gmm_scoring
from sidekit.jfa_scoring import jfa_scoring
from sidekit.sidekit_io import write_norm_hdf5
from sidekit.sidekit_io import write_matrix_hdf5
from sidekit.sv_utils import clean_stat_server
from .features_extractor import FeaturesExtractor
from .features_server import FeaturesServer
from .mixture import Mixture, vad_energy
from .statserver import StatServer
from .factor_analyser import FactorAnalyser
from .frontend.io import write_pcm
from .frontend.io import read_pcm
from .frontend.io import pcmu2lin
from .frontend.io import read_sph
from .frontend.io import write_label
from .frontend.io import read_label
from .frontend.io import read_spro4
from .frontend.io import read_audio
from .frontend.io import write_spro4
from .frontend.io import read_htk
from .frontend.io import write_htk
from .frontend.vad import vad_energy
from .frontend.vad import vad_snr
from .frontend.vad import label_fusion
from .frontend.vad import speech_enhancement
from .frontend.normfeat import cms
from .frontend.normfeat import cmvn
from .frontend.normfeat import stg
from .frontend.normfeat import rasta_filt
from .frontend.features import compute_delta
from .frontend.features import framing
from .frontend.features import pre_emphasis
from .frontend.features import trfbank
from .frontend.features import mel_filter_bank
from .frontend.features import mfcc
from .frontend.features import pca_dct
from .frontend.features import shifted_delta_cepstral
from .iv_scoring import cosine_scoring
from .iv_scoring import mahalanobis_scoring
from .iv_scoring import two_covariance_scoring
from .iv_scoring import PLDA_scoring
from .gmm_scoring import gmm_scoring
from .jfa_scoring import jfa_scoring
from .sidekit_io import write_norm_hdf5
from .sidekit_io import write_matrix_hdf5
from .sv_utils import clean_stat_server
libsvm_loaded = False
if SIDEKIT_CONFIG["libsvm"]:
......@@ -158,38 +157,39 @@ if SIDEKIT_CONFIG["cuda"]:
if CUDA:
from sidekit.nnet import FForwardNetwork
from sidekit.nnet import kaldi_to_hdf5
from sidekit.nnet import XvectorMultiDataset
from sidekit.nnet import XvectorDataset
from sidekit.nnet import StatDataset
from sidekit.nnet import Xtractor
from sidekit.nnet import xtrain
from sidekit.nnet import xtrain_single
from sidekit.nnet import xtrain_new
from sidekit.nnet import extract_idmap
from sidekit.nnet import extract_parallel
#from sidekit.nnet import SAD_RNN
from .nnet import FForwardNetwork
from .nnet import kaldi_to_hdf5
from .nnet import XvectorMultiDataset
from .nnet import XvectorDataset
from .nnet import StatDataset
from .nnet import Xtractor
from .nnet import xtrain
from .nnet import xtrain_single
from .nnet import extract_idmap
from .nnet import extract_parallel
from .nnet import ResBlock
from .nnet import ResNet18
#from .nnet import SAD_RNN
else:
print("Don't import Torch")
if SIDEKIT_CONFIG["mpi"]:
found_mpi4py = importlib.find_loader('mpi4py') is not None
if found_mpi4py:
from sidekit.sidekit_mpi import EM_split
from sidekit.sidekit_mpi import total_variability
from sidekit.sidekit_mpi import extract_ivector
from .sidekit_mpi import EM_split
from .sidekit_mpi import total_variability
from .sidekit_mpi import extract_ivector
print("Import MPI")
__author__ = "Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2019 Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-20120 Anthony Larcher and Sylvain Meignier"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reStructuredText'
__version__="1.3.3"
__version__="1.3.4"
# __all__ = ["io",
# "vad",
......
......@@ -18,16 +18,15 @@ The BOSARIS toolkit in MATLAB can be downloaded from `the website
<https://sites.google.com/site/bosaristoolkit/>`_.
"""
from sidekit.bosaris.idmap import IdMap
from sidekit.bosaris.ndx import Ndx
from sidekit.bosaris.plotwindow import PlotWindow
from sidekit.bosaris.key import Key
from sidekit.bosaris.scores import Scores
from sidekit.bosaris.detplot import DetPlot
from sidekit.bosaris.detplot import effective_prior
from sidekit.bosaris.detplot import logit_effective_prior
from sidekit.bosaris.detplot import fast_minDCF
from .idmap import IdMap
from .ndx import Ndx
from .plotwindow import PlotWindow
from .key import Key
from .scores import Scores
from .detplot import DetPlot
from .detplot import effective_prior
from .detplot import logit_effective_prior
from .detplot import fast_minDCF
__author__ = "Anthony Larcher"
......
......@@ -43,9 +43,9 @@ import scipy
from collections import namedtuple
import logging
from sidekit.bosaris import PlotWindow
from sidekit.bosaris import Scores
from sidekit.bosaris import Key
from . import PlotWindow
from . import Scores
from . import Key
__author__ = "Anthony Larcher"
......
......@@ -26,7 +26,7 @@ import logging
import copy
import h5py
from sidekit.sidekit_wrappers import check_path_existance
from ..sidekit_wrappers import check_path_existance
__author__ = "Anthony Larcher"
......
......@@ -20,12 +20,12 @@
"""
This is the 'key' module
"""
import numpy
import sys
import h5py
import logging
from sidekit.bosaris.ndx import Ndx
from sidekit.sidekit_wrappers import check_path_existance
import numpy
import sys
from .ndx import Ndx
from ..sidekit_wrappers import check_path_existance
__author__ = "Anthony Larcher"
__maintainer__ = "Anthony Larcher"
......
......@@ -24,7 +24,7 @@ import h5py
import logging
import numpy
import sys
from sidekit.sidekit_wrappers import check_path_existance, deprecated
from ..sidekit_wrappers import check_path_existance, deprecated
__author__ = "Anthony Larcher"
__maintainer__ = "Anthony Larcher"
......
......@@ -25,9 +25,9 @@ import h5py
import logging
import numpy
import os
from sidekit.bosaris.ndx import Ndx
from sidekit.bosaris.key import Key
from sidekit.sidekit_wrappers import check_path_existance
from .ndx import Ndx
from .key import Key
from ..sidekit_wrappers import check_path_existance
__author__ = "Anthony Larcher"
......
......@@ -21,7 +21,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Sylvain Meignier and Anthony Larcher
Copyright 2014-2020 Sylvain Meignier and Anthony Larcher
:mod:`features_server` provides methods to manage features
......@@ -33,17 +33,18 @@ import numpy
import os
from sidekit import PARAM_TYPE
from sidekit.frontend.features import mfcc, plp
from sidekit.frontend.io import read_audio, read_label, write_hdf5, _add_reverb, _add_noise
from sidekit.frontend.vad import vad_snr, vad_energy, vad_percentil
from sidekit.sidekit_wrappers import process_parallel_lists
from sidekit.bosaris.idmap import IdMap
from . import PARAM_TYPE
from .frontend.features import mfcc, plp
from .frontend.io import read_audio, read_label, write_hdf5, _add_reverb, _add_noise
from .frontend.vad import vad_snr, vad_percentil
from .mixture import vad_energy
from .sidekit_wrappers import process_parallel_lists
from .bosaris import IdMap
__license__ = "LGPL"
__author__ = "Anthony Larcher & Sylvain Meignier"
__copyright__ = "Copyright 2014-2019 Anthony Larcher"
__copyright__ = "Copyright 2014-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
......
......@@ -22,52 +22,51 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
Copyright 2014-2020 Anthony Larcher and Sylvain Meignier
:mod:`frontend` provides methods to process an audio signal in order to extract
useful parameters for speaker verification.
"""
from sidekit.frontend.io import write_pcm
from sidekit.frontend.io import read_pcm
from sidekit.frontend.io import pcmu2lin
from sidekit.frontend.io import read_sph
from sidekit.frontend.io import write_label
from sidekit.frontend.io import read_label
from sidekit.frontend.io import read_spro4
from sidekit.frontend.io import read_audio
from sidekit.frontend.io import write_spro4
from sidekit.frontend.io import read_htk
from sidekit.frontend.io import write_htk
from sidekit.frontend.io import read_hdf5_segment
from sidekit.frontend.io import write_hdf5
from sidekit.frontend.io import read_hdf5
from .io import write_pcm
from .io import read_pcm
from .io import pcmu2lin
from .io import read_sph
from .io import write_label
from .io import read_label
from .io import read_spro4
from .io import read_audio
from .io import write_spro4
from .io import read_htk
from .io import write_htk
from .io import read_hdf5_segment
from .io import write_hdf5
from .io import read_hdf5
from sidekit.frontend.vad import vad_energy
from sidekit.frontend.vad import vad_snr
from sidekit.frontend.vad import label_fusion
from sidekit.frontend.vad import speech_enhancement
from .vad import vad_snr
from .vad import label_fusion
from .vad import speech_enhancement
from sidekit.frontend.normfeat import cms
from sidekit.frontend.normfeat import cmvn
from sidekit.frontend.normfeat import stg
from sidekit.frontend.normfeat import rasta_filt
from sidekit.frontend.normfeat import cep_sliding_norm
from .normfeat import cms
from .normfeat import cmvn
from .normfeat import stg
from .normfeat import rasta_filt
from .normfeat import cep_sliding_norm
from sidekit.frontend.features import compute_delta
from sidekit.frontend.features import framing
from sidekit.frontend.features import pre_emphasis
from sidekit.frontend.features import trfbank
from sidekit.frontend.features import mel_filter_bank
from sidekit.frontend.features import mfcc
from sidekit.frontend.features import pca_dct
from sidekit.frontend.features import shifted_delta_cepstral
from .features import compute_delta
from .features import framing
from .features import pre_emphasis
from .features import trfbank
from .features import mel_filter_bank
from .features import mfcc
from .features import pca_dct
from .features import shifted_delta_cepstral
__author__ = "Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2019 Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2020 Anthony Larcher and Sylvain Meignier"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......
......@@ -22,7 +22,7 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
Copyright 2014-2020 Anthony Larcher and Sylvain Meignier
:mod:`frontend` provides methods to process an audio signal in order to extract
useful parameters for speaker verification.
......@@ -32,16 +32,13 @@ import numpy
import numpy.matlib
import scipy
from scipy.fftpack.realtransforms import dct
from sidekit.frontend.vad import pre_emphasis
from sidekit.frontend.io import *
from sidekit.frontend.normfeat import *
from sidekit.frontend.features import *
from .vad import pre_emphasis
import numpy.matlib
PARAM_TYPE = numpy.float32
__author__ = "Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2019 Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2020 Anthony Larcher and Sylvain Meignier"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......
......@@ -22,7 +22,7 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher
Copyright 2014-2020 Anthony Larcher
:mod:`frontend` provides methods to process an audio signal in order to extract
useful parameters for speaker verification.
......@@ -40,13 +40,12 @@ import wave
import scipy.signal
import scipy.io.wavfile
from scipy.signal import lfilter
from scipy.signal import decimate
from sidekit.sidekit_wrappers import check_path_existance, process_parallel_lists
from ..sidekit_wrappers import check_path_existance, process_parallel_lists
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2014-2019 Anthony Larcher"
__copyright__ = "Copyright 2014-2020 Anthony Larcher"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......
......@@ -22,7 +22,7 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
Copyright 2014-2020 Anthony Larcher and Sylvain Meignier
:mod:`frontend` provides methods to process an audio signal in order to extract
useful parameters for speaker verification.
......@@ -34,7 +34,7 @@ from scipy.signal import lfilter
__author__ = "Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2019 Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2020 Anthony Larcher and Sylvain Meignier"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......@@ -163,21 +163,21 @@ def stg(features, label=None, win=301):
# Process first window
r = numpy.argsort(speech_features[:win, ], axis=0)
r = numpy.argsort(r, axis=0)
arg = (r[: (win - 1) / 2] + 0.5) / win
stg_features[: (win - 1) / 2, :] = stats.norm.ppf(arg, 0, 1)
arg = (r[: (win - 1) // 2] + 0.5) / win
stg_features[: (win - 1) // 2, :] = stats.norm.ppf(arg, 0, 1)
# process all following windows except the last one
for m in range(int((win - 1) / 2), int(nframes - (win - 1) / 2)):
idx = list(range(int(m - (win - 1) / 2), int(m + (win - 1) / 2 + 1)))
foo = speech_features[idx, :]
r = numpy.sum(foo < foo[(win - 1) / 2], axis=0) + 1
r = numpy.sum(foo < foo[(win - 1) // 2], axis=0) + 1
arg = (r - 0.5) / win
stg_features[m, :] = stats.norm.ppf(arg, 0, 1)
# Process the last window
r = numpy.argsort(speech_features[list(range(nframes - win, nframes)), ], axis=0)
r = numpy.argsort(r, axis=0)
arg = (r[(win + 1) / 2: win, :] + 0.5) / win
arg = (r[(win + 1) // 2: win, :] + 0.5) / win
stg_features[list(range(int(nframes - (win - 1) / 2), nframes)), ] = stats.norm.ppf(arg, 0, 1)
else:
......
......@@ -22,7 +22,7 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher and Sylvain Meignier
Copyright 2014-2020 Anthony Larcher and Sylvain Meignier
:mod:`frontend` provides methods to process an audio signal in order to extract
useful parameters for speaker verification.
......@@ -30,13 +30,12 @@ useful parameters for speaker verification.
import copy
import logging
import numpy
from scipy.fftpack import fft
from scipy.fftpack import fft, ifft
from scipy import ndimage
from sidekit.mixture import Mixture
__author__ = "Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2019 Anthony Larcher and Sylvain Meignier"
__copyright__ = "Copyright 2014-2020 Anthony Larcher and Sylvain Meignier"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......@@ -283,9 +282,9 @@ def speech_enhancement(X, Gain, NN):
ffty = new_absx * argx # multiply amplitude with its normalized spectrum
y = numpy.real(numpy.fft.fftpack.ifft(numpy.concatenate((ffty,
numpy.conj(ffty[numpy.arange(Fmax - 2, 0, -1)])))))
#y = numpy.real(numpy.fft.fftpack.ifft(numpy.concatenate((ffty,
# numpy.conj(ffty[numpy.arange(Fmax - 2, 0, -1)])))))
y = numpy.real(ifft(numpy.concatenate((ffty, numpy.conj(ffty[numpy.arange(Fmax - 2, 0, -1)])))))
y[:FrameSize - FrameShift] = y[:FrameSize - FrameShift] + y0
y0 = y[FrameShift:FrameSize] # keep 129 to FrameSize point samples
x[:FrameSize - FrameShift] = x[FrameShift:FrameSize]
......
......@@ -22,11 +22,11 @@
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2019 Anthony Larcher
Copyright 2014-2020 Anthony Larcher
"""
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2014-2019 Anthony Larcher"
__copyright__ = "Copyright 2014-2020 Anthony Larcher"
__license__ = "LGPL"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
......@@ -34,5 +34,5 @@ __status__ = "Production"
__docformat__ = 'reStructuredText'
from sidekit.libsvm.svm import *
from sidekit.libsvm.svmutil import *
from .svm import *
from .svmutil import *
......@@ -6,6 +6,7 @@ All rights reserved.
"""
from ctypes import *
from ctypes import c_int, c_double
from ctypes.util import find_library
from os import path
import sys
......
......@@ -8,10 +8,11 @@ All rights reserved.
import sys
import os
import pickle
sys.path = [os.path.dirname(os.path.abspath(__file__))] + sys.path
from sidekit.libsvm.svm import svm_node, svm_problem, svm_parameter, svm_model, toPyModel
from sidekit.libsvm.svm import *
from .svm import svm_node, svm_problem, svm_parameter, svm_model, toPyModel, SVM_TYPE, KERNEL_TYPE, \
gen_svm_nodearray, print_null
from ctypes import c_int, c_double
sys.path = [os.path.dirname(os.path.abspath(__file__))] + sys.path
def save_svm(svm_file_name, w, b):
"""
......@@ -66,7 +67,7 @@ def svm_load_model(model_file_name):
Load a LIBSVM model from model_file_name and return.
:param model_file_name: file name to load from
"""
model = sidekit.libsvm.svm_load_model(model_file_name.encode())
model = svm_load_model(model_file_name.encode())
if not model:
print("can't open model file %s" % model_file_name)
return None
......@@ -82,7 +83,7 @@ def svm_save_model(model_file_name, model):
:param model_file_name: file name to write to
:param model: model to save
"""
sidekit.libsvm.svm_save_model(model_file_name.encode(), model)
svm_save_model(model_file_name.encode(), model)
def evaluations(ty, pv):
......@@ -189,10 +190,10 @@ def svm_train(arg1, arg2=None, arg3=None):
if param.cross_validation:
l, nr_fold = prob.l, param.nr_fold
target = (c_double * l)()
target = (c_double * l)() # pytype: disable=not-callable
libsvm.svm_cross_validation(prob, param, nr_fold, target)
ACC, MSE, SCC = evaluations(prob.y[:l], target[:l])