Commit c57d363e authored by Anthony Larcher's avatar Anthony Larcher
Browse files

environment variable

parent 6d8cc6dc
......@@ -32,10 +32,29 @@ import numpy
import os
import sys
# Read environment variable if it exists
SIDEKIT_CONFIG={"theano":True,
"theano_config":'gpu', # Can be 'cpu' or 'gpu'
"libsvm":True
}
for cfg in os.environ['SIDEKIT'].split(","):
k, val = cfg.split("=")
if k == "theano":
if val == "false":
SIDEKIT_CONFIG["theano"] = False
elif k == "theano_config":
SIDEKIT_CONFIG["theano_config"] = val
elif k == "libsvm":
if val == "false":
SIDEKIT_CONFIG["libsvm"] = False
PARALLEL_MODULE = 'multiprocessing' # can be , threading, multiprocessing MPI is planned in the future
PARAM_TYPE = numpy.float32
STAT_TYPE = numpy.float64
THEANO_CONFIG = "cpu" # can be gpu or cpu
# Import bosaris-like classes
......@@ -97,15 +116,15 @@ from sidekit.gmm_scoring import gmm_scoring
from sidekit.jfa_scoring import jfa_scoring
# Import NNET classes and functions
# Import NNET classes and functions if the FLAG is True
theano_imported = False
try:
if THEANO_CONFIG == "gpu":
os.environ['THEANO_FLAGS'] = 'mode=FAST_RUN,device=gpu,floatX=float32'
else:
os.environ['THEANO_FLAGS'] = 'mode=FAST_RUN,device=cpu,floatX=float32'
theano_imported = True
if SIDEKIT_CONFIG["theano"]:
if SIDEKIT_CONFIG["theano_config"] == "gpu":
os.environ['THEANO_FLAGS'] = 'mode=FAST_RUN,device=gpu,floatX=float32'
else:
os.environ['THEANO_FLAGS'] = 'mode=FAST_RUN,device=cpu,floatX=float32'
theano_imported = True
except ImportError:
print("Cannot import Theano")
......@@ -117,26 +136,27 @@ if theano_imported:
from sidekit.sv_utils import clean_stat_server
libsvm_loaded = False
try:
dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'libsvm')
if sys.platform == 'win32':
libsvm = CDLL(os.path.join(dirname, r'libsvm.dll'))
libsvm_loaded = True
else:
libsvm = CDLL(os.path.join(dirname, 'libsvm.so.2'))
libsvm_loaded = True
except:
# For unix the prefix 'lib' is not considered.
if find_library('svm'):
libsvm = CDLL(find_library('svm'))
libsvm_loaded = True
elif find_library('libsvm'):
libsvm = CDLL(find_library('libsvm'))
libsvm_loaded = True
else:
libsvm_loaded = False
logging.warning('WARNNG: libsvm is not installed, please refer to the' +
' documentation if you intend to use SVM classifiers')
if SIDEKIT_CONFIG["libsvm"]:
try:
dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'libsvm')
if sys.platform == 'win32':
libsvm = CDLL(os.path.join(dirname, r'libsvm.dll'))
libsvm_loaded = True
else:
libsvm = CDLL(os.path.join(dirname, 'libsvm.so.2'))
libsvm_loaded = True
except:
# For unix the prefix 'lib' is not considered.
if find_library('svm'):
libsvm = CDLL(find_library('svm'))
libsvm_loaded = True
elif find_library('libsvm'):
libsvm = CDLL(find_library('libsvm'))
libsvm_loaded = True
else:
libsvm_loaded = False
logging.warning('WARNNG: libsvm is not installed, please refer to the' +
' documentation if you intend to use SVM classifiers')
if libsvm_loaded:
from sidekit.libsvm import *
......
......@@ -71,7 +71,7 @@ def fa_model_loop(batch_start,
if sigma.ndim == 2:
A = phi.T.dot(phi)
inv_lambda_unique = dict()
for sess in numpy.unique(stat0[:,0]).astype(int):
for sess in numpy.unique(stat0[:,0]):
inv_lambda_unique[sess] = scipy.linalg.inv(sess * A + numpy.eye(A.shape[0]))
tmp = numpy.zeros((phi.shape[1], phi.shape[1]), dtype=data_type)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment