Commit be028f64 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

nnet accumulate stat

parent 0c34beba
......@@ -493,6 +493,56 @@ class FForwardNetwork():
stat1[idx, :] = s1.flatten()
def compute_stat(self,
idmap,
ndim,
dnn_features_server,
features_server,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
"""
Single thread version of the statistic computation using a DNN.
:param model: neural network as a torch.nn.Module object
:param segset: list of segments to process
:param stat0: local matrix of zero-order statistics
:param stat1: local matrix of first-order statistics
:param dnn_features_server: FeaturesServer that provides input data for the DNN
:param features_server: FeaturesServer that provide additional features to compute first order statistics
:param seg_indices: indices of the
:return: a StatServer with all computed statistics
"""
# get dimension of the features
feature_size = features_server.load(idmap.rightids[0])[0].shape[1]
# Create and initialize a StatServer
ss = sidekit.StatServer(idmap)
ss.stat0 = numpy.zeros((idmap.leftids.shape[0], ndim), dtype=numpy.float32)
ss.stat1 = numpy.zeros((idmap.leftids.shape[0], ndim * feature_size), dtype=numpy.float32)
self.model.cpu()
for idx in numpy.arange(len(idmap.segset)):
print("Compute statistics for {}".format(idmap.segset[idx]))
logging.debug('Compute statistics for {}'.format(idmap.segset[idx]))
show = idmap.segset[idx]
channel = 0
if features_server.features_extractor is not None \
and show.endswith(features_server.double_channel_extension[1]):
channel = 1
stat_features, labels = features_server.load(show, channel=channel)
features, _ = dnn_features_server.load(show, channel=channel)
stat_features = stat_features[labels, :]
s0 = self.model(torch.from_numpy(dnn_features_server.get_context(feat=features)[0]).type(torch.FloatTensor).cpu())[labels]
s0.cpu().data.numpy()
s1 = numpy.dot(stat_features.T, s0).T
ss.stat0[idx, :] = s0.sum(axis=0)
ss.stat1[idx, :] = s1.flatten()
# Return StatServer
return ss
def compute_stat_dnn_parallel(self,
idmap,
ndim,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment