Commit 23c9bd94 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

mpi debug in process

parent b86c4d27
......@@ -119,7 +119,6 @@ from sidekit.gmm_scoring import gmm_scoring
from sidekit.jfa_scoring import jfa_scoring
from sidekit.sidekit_mpi import total_variability_mpi
# Import NNET classes and functions if the FLAG is True
theano_imported = False
......@@ -172,7 +171,7 @@ if libsvm_loaded:
if SIDEKIT_CONFIG["mpi"]:
found_mpi4py = importlib.find_loader('mpi4py') is not None
if found_mpi4py:
from sidekit.sidekit_mpi import *
from sidekit.sidekit_mpi import EM_split, total_variability, extract_ivector
print("Import MPI")
......
......@@ -27,6 +27,7 @@ Copyright 2014-2016 Sylvain Meignier and Anthony Larcher
:mod:`features_server` provides methods to manage features
"""
import multiprocessing
import numpy
import logging
import h5py
......@@ -615,3 +616,90 @@ class FeaturesServer(object):
features_list.append(self.load(*load_arg)[0])
return numpy.vstack(features_list)
def _stack_features_worker(self,
input_queue,
output_queue):
"""Load a list of feature files into a Queue object
:param input: a Queue object
:param output: a list of Queue objects to fill
"""
while True:
next_task = input_queue.get()
if next_task is None:
# Poison pill means shutdown
output_queue.put(None)
input_queue.task_done()
break
output_queue.put(self.load(*next_task)[0])
input_queue.task_done()
#@profile
def stack_features_parallel(self, # fileList, numThread=1):
show_list,
channel_list=None,
feature_filename_list=None,
label_list=None,
start_list=None,
stop_list=None,
num_thread=1):
"""Load a list of feature files and stack them in a unique ndarray.
The list of files to load is splited in sublists processed in parallel
:param fileList: a list of files to load
:param numThread: numbe of thead (optional, default is 1)
"""
if channel_list is None:
channel_list = numpy.zeros(len(show_list))
if feature_filename_list is None:
feature_filename_list = numpy.empty(len(show_list), dtype='|O')
if label_list is None:
label_list = numpy.empty(len(show_list), dtype='|O')
if start_list is None:
start_list = numpy.empty(len(show_list), dtype='|O')
if stop_list is None:
stop_list = numpy.empty(len(show_list), dtype='|O')
#queue_in = Queue.Queue(maxsize=len(fileList)+numThread)
queue_in = multiprocessing.JoinableQueue(maxsize=len(show_list)+num_thread)
queue_out = []
# Start worker processes
jobs = []
for i in range(num_thread):
queue_out.append(multiprocessing.Queue())
p = multiprocessing.Process(target=self._stack_features_worker,
args=(queue_in, queue_out[i]))
jobs.append(p)
p.start()
# Submit tasks
for task in zip(show_list, channel_list, feature_filename_list, label_list, start_list, stop_list):
queue_in.put(task)
# Add None to the queue to kill the workers
for task in range(num_thread):
queue_in.put(None)
# Wait for all the tasks to finish
queue_in.join()
output = []
for q in queue_out:
while True:
data = q.get()
if data is None:
break
output.append(data)
for p in jobs:
p.join()
return numpy.concatenate(output, axis=0)
......@@ -621,7 +621,7 @@ class Mixture(object):
"""
# Init using all data
features = features_server.stack_features(feature_list)
features = features_server.stack_features_parallel(feature_list, num_thread)
n_frames = features.shape[0]
mu = features.mean(0)
cov = (features**2).mean(0)
......
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment