Commit 00d73c1a authored by Florent Desnous 's avatar Florent Desnous
Browse files

corrections to features_server.py and nnt/sad_rnn.py

parent 4d4c8d39
......@@ -208,10 +208,6 @@ class FeaturesServer(object):
:return: the matrix of acoustic parameters ingand their VAD labels after post-process
"""
# Apply a mask on the features
if self.mask is not None:
feat = self._mask(feat)
# Perform RASTA filtering if required
if self.rasta:
feat, label = self._rasta(feat, label)
......@@ -224,6 +220,10 @@ class FeaturesServer(object):
elif self.sdc:
feat = shifted_delta_cepstral(feat, d=self.sdc_config[0], p=self.sdc_config[1], k=self.sdc_config[2])
# Apply a mask on the features
if self.mask is not None:
feat = self._mask(feat)
# Smooth the labels and fuse the channels if more than one.
logging.debug('Smooth the labels and fuse the channels if more than one')
label = label_fusion(label)
......
......@@ -24,8 +24,8 @@ class SAD_RNN():
followed by two linear layers of dimension 40 and 10.
"""
self.input_size = input_size
self.duration = duration * 100
self.step = step * 100
self.duration = int(duration * 100)
self.step = int(step * 100)
self.batch_size = batch_size
if model is None: # load default model
......@@ -52,7 +52,6 @@ class SAD_RNN():
for show in sorted(train_list.keys()):
features, _ = features_server.load(show)
features = features[:, 1:] # tmp TODO
labels = numpy.zeros((len(features), 1), dtype=numpy.int)
for seg in train_list[show]:
labels[seg['start']:seg['stop']] = 1
......@@ -80,7 +79,6 @@ class SAD_RNN():
:return: loss of current batch
"""
X = X.to(device)
Y = Y.to(device)
self.model.hidden = None
......@@ -91,26 +89,25 @@ class SAD_RNN():
optimizer.step()
return float(loss.data)
def _get_scores(self, show, scores_fmt, features_server):
def get_scores(self, show, features_server, score_file_format=''):
"""
Internal method to compute the scores for one show from the output of the network
:param show:
:param epoch:
:param features_server:
Computes the scores for one show from the output of the network
:param show: the show to extract
:param features_server: a sidekit FeaturesServer object
:param score_file_format: optional, used to save or load a score file
:return: scores of the show, as an array of 0..1
"""
if scorces_fmt == '':
if scorce_fmt == '':
score_fn = ''
else:
score_fn = scores_fmt.format(show)
score_fn = score_file_format.format(show)
if os.path.exists(score_fn):
print("Warning: loading existing scores")
return numpy.load(score_fn)
features, _ = features_server.load(show)
features = features[:, 1:] #tmp TODO
x = []
X = torch.tensor([]).to(device)
......@@ -186,6 +183,39 @@ class SAD_RNN():
est_it = len(losses[epoch])
torch.save(self.model.state_dict(), model_file_format.format(epoch+1))
def get_labels(self, model_fn, show, features_server,
onset=0.8, offset=0.95, scores_fn=''):
"""
Get the SAD labels for one show
:param model_fn: File name for the trained model
:param show: show to generate the SAD from
:param features_server: a sidekit FeaturesServer object
:param onset: score threshold above which a segment should start
:param offset: score threshold under which a segment should stop
:param scores_fn: optional file name to save the scores
"""
self.model.load_state_dict(torch.load(model_fn))
self.model.to(device)
scores = self.get_scores(show, features_server, scores_fn)
labels = numpy.zeros(len(scores))
start = 0
segment = False
for i, s in enumerate(scores):
if not segment and s > onset: # speech segment begins
start = i
segment = True
if segment and s < offset: # speech segment ends
segment = False
labels[start:i] = 1
if segment:
labels[start:i] = 1
return labels
def write_sad(self, model_fn, show_list, features_server,
onset, offset, sad_file_format, scores_file_format=''):
......@@ -195,8 +225,8 @@ class SAD_RNN():
:param model_fn: model file name
:param show_list: list of shows to generate the SAD from
:param features_server: a sidekit FeaturesServer object
:param onset: score value above which a segment will start
:param offset: score value below which a segment will stop
:param onset: score threshold above which a segment will start
:param offset: score threshold below which a segment will stop
:param sad_file_format: file format for the segments
:param scores_file_format: optional, used to save scores files
"""
......@@ -205,7 +235,7 @@ class SAD_RNN():
self.model.to(device)
for show in sorted(show_list):
scores = self._get_scores(show, scores_file_format, features_server)
scores = self.get_scores(show, features_server, scores_file_format)
sad = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment