Commit 11656bee authored by Anthony Larcher's avatar Anthony Larcher
Browse files

sad_dataset

parent aba2e39e
......@@ -8,6 +8,9 @@ import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset
from sidekit.frontend.io import _read_dataset_percentile
from sidekit.frontend.io import read_hdf5_segment
from sidekit.frontend.io import _read_dataset_percentile
import logging
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -16,116 +19,116 @@ class SAD_Dataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
"""
def __init__(self, mdtm_file, feature_file, batch_size=512, duration=3.2, step=0.8, uem_file=None, shuffle=False):
def __init__(self, mdtm_file, features_server, batch_size=512, duration=3.2, step=0.8, uem_file=None,
shuffle=False):
def __init__(self, mdtm_file, feature_file, batch_size=512, duration=3.2, step=0.8, uem_file=None,
shuffle=False):
self.batch_size = batch_size
self.duration = int(duration * 100)
self.step = int(step * 100)
self.features_server = features_server
train_list = {}
with open(mdtm_file, 'r') as f:
lines = [l for l in f]
for line in lines[:500]:
show, _, start, dur, _, _, _, _ = line.rstrip().split()
if show not in train_list:
train_list[show] = []
train_list[show].append({
self.batch_size = batch_size
self.duration = int(duration * 100)
self.step = int(step * 100)
#self.feature_file = open(feature_file, 'r')
self.feature_file = h5py.File(feature_file, 'r')
train_list = {}
with open(mdtm_file, 'r') as f:
for line in f:
show, _, start, dur, _, _, _, _ = line.rstrip().split()
if show not in train_list:
train_list[show] = []
train_list[show].append({
"start": int(float(start) * 100),
"stop": int((float(start) + float(dur)) * 100)})
uem_list = {}
if uem_file is not None:
with open(uem_file, 'r') as f:
for line in f:
show, _, start, stop = line.rstrip().split()
if show not in uem_list:
uem_list[show] = []
uem_list[show].append({
"start": int(float(start) * 100),
"stop": int((float(start) + float(dur)) * 100)})
uem_list = {}
if uem_file is not None:
with open(uem_file, 'r') as f:
for line in f:
show, _, start, stop = line.rstrip().split()
if show not in uem_list:
uem_list[show] = []
uem_list[show].append({
"start": int(float(start) * 100),
"stop": int(float(stop) * 100)})
else:
for show in train_list.keys():
uem_list[show].append({"start": None, "stop": None})
self.vad = {}
self.segments = []
# speech_only_segments = []
# speech_nonspeech_segments = []
for show in sorted(train_list.keys()):
features, _ = features_server.load(show)
labels = numpy.zeros((len(features), 1), dtype=numpy.int)
speech_only_segments = []
speech_nonspeech_segments = []
if show in train_list and show in uem_list:
for seg in train_list[show]:
labels[seg['start']:seg['stop']] = 1
self.vad[show] = labels
for seg in uem_list[show]:
if seg['start'] is not None:
start, stop = seg['start'], seg['stop']
else:
start, stop = 0, len(features)
for i in range(start, min(stop, len(features)) - self.duration, self.step):
self.segments.append((show, i, i + self.duration))
# cree les segments ne contenant QUE de la parole (sans recouvrement)
for i in range(start, min(stop, len(features)) - self.duration, self.duration):
if labels[i:i + self.duration].sum() == self.duration:
speech_only_segments.append((show, i, i + self.duration))
# cree les segments contenant de la PAROLE ET DU SILENCE (avec recouvrement pour equilibrer les classes)
for i in range(start, min(stop, len(features)) - self.duration, self.step):
if labels[i:i + self.duration].sum() < self.duration - 1:
speech_nonspeech_segments.append((show, i, i + self.duration))
# for i in range(start, min(stop, len(features)) - self.duration, self.step):
# self.segments.append((show, i, i + self.duration))
tmp = speech_only_segments + speech_nonspeech_segments
random.shuffle(tmp)
self.segments += tmp
print("Show {}, ratio S/NS = {}".format(show, len(speech_only_segments) / (
len(speech_nonspeech_segments) + len(speech_only_segments))))
# tmp = speech_only_segments + speech_nonspeech_segments
# if shuffle:
# print("taille de tmp: {}".format(len(tmp)))
# random.shuffle(tmp)
# print("taille de tmp: {}".format(len(tmp)))
# print(tmp[0])
# for t in tmp:
# self.segments.append(t)
# self.segments = tmp.copy()
self.input_size = features.shape[1]
print("Final ratio S/NS = {}".format(
len(speech_only_segments) / (len(speech_nonspeech_segments) + len(speech_only_segments))))
self.len = len(self.segments) // self.batch_size
"stop": int(float(stop) * 100)})
else:
for show in train_list.keys():
uem_list[show].append({"start": None, "stop": None})
self.vad = {}
self.segments = []
for show in sorted(train_list.keys()):
#features, _ = features_server.load(show)
#features = read_hdf5_segment(self.feature_file,
# show,
# ['energy', 'cep'],
# label=None,
# start=None, stop=None,
# global_cmvn=False)[0]
features = _read_dataset_percentile(self.feature_file, show+"/cep")
labels = numpy.zeros((len(features), 1), dtype=numpy.int)
speech_only_segments = []
speech_nonspeech_segments = []
if show in train_list and show in uem_list:
for seg in train_list[show]:
labels[seg['start']:seg['stop']] = 1
self.vad[show] = labels
for seg in uem_list[show]:
if seg['start'] is not None:
start, stop = seg['start'], seg['stop']
else:
start, stop = 0, len(features)
# cree les segments ne contenant QUE de la parole (sans recouvrement)
for i in range(start, min(stop, len(features)) - self.duration, self.duration):
if labels[i:i+self.duration].sum() == self.duration:
speech_only_segments.append((show, i, i + self.duration))
# cree les segments contenant de la PAROLE ET DU SILENCE (avec recouvrement pour equilibrer les classes)
for i in range(start, min(stop, len(features)) - self.duration, self.step):
#self.segments.append((show, i, i + self.duration))
if labels[i:i+self.duration].sum() < self.duration - 1:
speech_nonspeech_segments.append((show, i, i + self.duration))
#for i in range(start, min(stop, len(features)) - self.duration, self.step):
# self.segments.append((show, i, i + self.duration))
tmp = speech_only_segments + speech_nonspeech_segments
random.shuffle(tmp)
self.segments += tmp
print("Show {}, ratio S/NS = {}".format(show, len(speech_only_segments)/(len(speech_nonspeech_segments) + len(speech_only_segments))))
# for i in range(start, min(stop, len(features)) - self.duration, self.step):
# self.segments.append((show, i, i + self.duration))
self.input_size = features.shape[1]
if shuffle:
random.shuffle(self.segments)
print("Final ratio S/NS = {}".format(len(speech_only_segments)/(len(speech_nonspeech_segments) + len(speech_only_segments))))
self.len = len(self.segments) // self.batch_size
def __getitem__(self, index):
batch_X = numpy.zeros((self.batch_size, self.duration, self.input_size))
batch_Y = numpy.zeros((self.batch_size, self.duration, 1))
for i in range(self.batch_size):
show, start, stop = self.segments[index * self.batch_size + i]
features, _ = self.features_server.load(show)
#features, _ = self.features_server.load(show)
#features = read_hdf5_segment(self.feature_file,
# show,
# ['energy', 'cep'],
# label=None,
# start=start, stop=stop,
# global_cmvn=False)[0]
features = _read_dataset_percentile(self.feature_file, show + "/cep")
batch_X[i] = features[start:stop]
batch_Y[i] = self.vad[show][start:stop]
#batch_X[i] = features[start:stop]
#batch_Y[i] = self.vad[show][start:stop]
return torch.Tensor(batch_X), torch.Tensor(batch_Y)
def __len__(self):
return self.len
class SAD_RNN():
"""
A SAD_RNN is meant to use a PyTorch RNN model for Speech Activity Detection
......@@ -250,7 +253,7 @@ class SAD_RNN():
for batch_idx, (X, Y) in enumerate(training_set):
batch_loss = self._fit_batch(optimizer, criterion, X, Y)
losses[epoch].append(batch_loss)
print("Epoch {}/{}, loss {:.5f}".format(
logging.critical("Epoch {}/{}, loss {:.5f}".format(
epoch + 1, nb_epochs, numpy.mean(losses[epoch])))
#sys.stdout.write("\rEpoch {}/{}, loss {:.5f}".format(
# epoch + 1, nb_epochs, numpy.mean(losses[epoch])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment