Commit 382760e8 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new dataset for VAS

parent 00d73c1a
import os
import sys
import numpy
import random
import torch
import torch.nn as nn
from torch import optim
import pickle
from torch.utils.data import Dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SAD_Dataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
"""
def __init__(self, mdtm_file, features_server, batch_size=512, duration=3.2, step=0.8, uem_file=None, shuffle=False):
self.batch_size = batch_size
self.duration = duration
self.step = step
self.features_server = features_server
train_list = {}
with open(mdtm_file, 'r') as f:
for line in f:
show, _, start, dur, _, _, _, _ = line.rstrip().split()
if show not in train_list:
train_list[show] = []
train_list[show].append({
"start": int(float(start) * 100),
"stop": int((float(start) + float(dur)) * 100)})
#
#GERER LE CAS OU UEM=NONE, ON PREND LES FICHIERS COMPLETS
uem_list = {}
if uem_file is not None:
with open(uem_file, 'r') as f:
for line in f:
show, _, start, stop = line.rstrip().split()
if show not in uem_list:
uem_list[show] = []
uem_list[show].append({
"start": int(float(start) * 100),
"stop": int(float(stop) * 100)})
else:
for show in train_list.keys():
uem_list[show].append({"start": None, "stop": None})
self.vad = {}
self.segments = []
for show in sorted(train_list.keys()):
features, _ = features_server.load(show)
labels = numpy.zeros((len(features), 1), dtype=numpy.int)
for seg in train_list[show]:
labels[seg['start']:seg['stop']] = 1
self.vad[show] = labels
for seg in uem_list[show]:
if seg['start'] is not None:
start, stop = seg['start'], seg['stop']
else:
start, stop = 0, len(features)
for i in range(start, min(stop, len(features)) - self.duration, self.step):
self.segments.append((show, i, i + self.duration))
if shuffle:
random.shuffle(self.segments)
self.len = len(self.segments) // self.batch_size
def __getitem__(self, index):
batch_X = numpy.zeros((self.batch_size, self.duration, self.input_size))
batch_Y = numpy.zeros((self.batch_size, self.duration, 1))
for i in range(self.batch_size):
show, start, stop = self.segments[index * self.batch_size + i]
features, _ = self.features_server.load(show)
batch_X[i] = features[start:stop]
batch_Y[i] = self.vad[show][start:stop]
return torch.Tensor(batch_X), torch.Tensor(batch_Y)
def __len__(self):
return self.len
class SAD_RNN():
"""
A SAD_RNN is meant to use a PyTorch RNN model for Speech Activity Detection
......
......@@ -63,21 +63,6 @@ def read_batch(batch_file):
data[idx] = (data[idx] - m) / s
return data, label
def read_hot_batch(batch_file, spk_nb):
with h5py.File(batch_file, 'r') as h5f:
data = _read_dataset_percentile(h5f, 'data')
label = h5f['label'].value
# Normalize and reshape
data = data.reshape((len(label), data.shape[0] // len(label), data.shape[1])).transpose(0, 2, 1)
for idx in range(data.shape[0]):
m = data[idx].mean(axis=0)
s = data[idx].std(axis=0)
data[idx] = (data[idx] - m) / s
lbl = numpy.zeros((128, spk_nb))
lbl[numpy.arange(128), label] += 1
return data, lbl
class XvectorDataset(Dataset):
"""
......@@ -95,22 +80,6 @@ class XvectorDataset(Dataset):
def __len__(self):
return self.len
class XvectorHotDataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
"""
def __init__(self, batch_list, batch_path, spk_nb):
with open(batch_list, 'r') as fh:
self.batch_files = [batch_path + '/' + l.rstrip() for l in fh]
self.len = len(self.batch_files)
self.spk_nb = spk_nb
def __getitem__(self, index):
data, label = read_hot_batch(self.batch_files[index], self.spk_nb)
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype(numpy.float32))
def __len__(self):
return self.len
class XvectorMultiDataset(Dataset):
"""
......@@ -127,21 +96,6 @@ class XvectorMultiDataset(Dataset):
def __len__(self):
return self.len
class XvectorMultiDataset_hot(Dataset):
"""
Object that takes a list of files as a Python List and initialize a DataSet
"""
def __init__(self, batch_list, batch_path, spk_nb):
self.batch_files = [batch_path + '/' + l for l in batch_list]
self.len = len(self.batch_files)
self.spk_nb = spk_nb
def __getitem__(self, index):
data, label = read_hot_batch(self.batch_files[index], self.spk_nb)
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label.astype(numpy.float32))
def __len__(self):
return self.len
class StatDataset(Dataset):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment