Commit a31a26d6 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add fileset

parent 0fb39dc7
......@@ -28,9 +28,11 @@ The authors would like to thank the BUT Speech@FIT group (http://speech.fit.vutb
for sharing the source code that strongly inspired this module. Thank you for your valuable contribution.
"""
import glob
import h5py
import numpy
import pandas
import os
import pickle
import random
import torch
......@@ -50,6 +52,8 @@ from ..features_server import FeaturesServer
from scipy.fftpack.realtransforms import dct
from torchvision import transforms
from torch.utils.data import DataLoader
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
......@@ -62,6 +66,32 @@ __docformat__ = 'reStructuredText'
wav_type = "float32" # can be int16, float64, int32 or float32
def write_batch(batch_idx, data, label, batch_fn_format):
"""
:param batch_idx:
:param data:
:param label:
:param batch_fn_format:
:return:
"""
with h5py.File(batch_fn_format.format(batch_idx), "w") as h5f:
h5f.create_dataset('/data', data=data, fletcher32=True)
h5f.create_dataset('/label', data=label, fletcher32=True)
def load_batch(batch_fn):
"""
:param batch_fn:
:return:
"""
with h5py.File(batch_fn, "r") as h5f:
data = h5f["/data"][()]
label = h5f["/label"][()]
return torch.from_numpy(data).type(torch.FloatTensor), torch.from_numpy(label).type(torch.LongTensor)
def read_batch(batch_file):
"""
:param batch_file:
......@@ -80,6 +110,8 @@ def read_batch(batch_file):
return data, label
class XvectorDataset(Dataset):
"""
Object that takes a list of files from a file and initialize a Dataset
......@@ -339,7 +371,7 @@ class SideSet(Dataset):
overlap=0.,
dataset_df=None,
min_duration=0.165,
output_format="pytorch"
output_format="pytorch",
):
"""
......@@ -517,6 +549,29 @@ class SideSet(Dataset):
"""
return self.len
def write_to_disk(self, batch_size, batch_fn_format, num_thread):
"""
:param batch_size:
:param batch_fn_format:
:param num_thread:
:return:
"""
# Check if the directory exists if not creates itbatch_fn_format
directory = os.path.dirname(batch_fn_format)
if not os.path.exists(directory):
os.makedirs(directory)
tmp_loader = DataLoader(self,
batch_size=batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_thread)
for batch_idx, (data, target) in enumerate(tmp_loader):
write_batch(batch_idx, data, target, batch_fn_format)
def createSideSets(data_set_yaml,
chunk_per_segment=1,
overlap=0.,
......@@ -640,3 +695,35 @@ class IdMapSet(Dataset):
:return:
"""
return self.len
class FileSet(Dataset):
"""
Dataset class to load from disk
"""
def __init__(self, batch_fn_format):
"""
:param batch_fn_format:
"""
self.batch_fn_format = batch_fn_format
# Get number of batches available on disk
batch_list = glob.glob(batch_fn_format.format('*'))
self.len = len(batch_list)
def __getitem__(self, idx):
"""
:param idx:
:return:
"""
return load_batch(self.batch_fn_format.format(idx))
def __len__(self):
"""
:return:
"""
return self.len
\ No newline at end of file
......@@ -37,12 +37,13 @@ import pickle
import shutil
import time
import torch
import torch.optim as optim
import tqdm
import yaml
from torchvision import transforms
from collections import OrderedDict
from .xsets import SideSet
from .xsets import Fileset
from .xsets import IdMapSet
from .res_net import RawPreprocessor, ResBlockWFMS
from ..bosaris import IdMap
......@@ -57,8 +58,6 @@ from .loss import l2_norm
from torch.nn.parallel import DistributedDataParallel as DDP
import tqdm
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
......@@ -606,7 +605,10 @@ def xtrain(speaker_number,
opt=None,
reset_parts=[],
freeze_parts=[],
num_thread=None):
num_thread=None,
write_batches_to_disk=False,
load_batches_from_disk=False,
tmp_batch_dir=None):
"""
:param speaker_number:
......@@ -634,6 +636,9 @@ def xtrain(speaker_number,
#writer = SummaryWriter("runs/xvectors_experiments_2")
writer = None
if write_batches_to_disk:
load_batches_from_disk = True
if num_thread is None:
num_thread = multiprocessing.cpu_count()
......@@ -711,41 +716,52 @@ def xtrain(speaker_number,
print("Train on a single GPU")
model.to(device)
"""
Set the dataloaders according to the dataset_yaml
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SideSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['train']['chunk_per_segment'],
overlap=dataset_params['train']['overlap'])
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_thread)
if load_batches_from_disk:
train_batch_fn_format = tmp_batch_dir + "/train/train_{}_batch.h5"
val_batch_fn_format = tmp_batch_dir + "/val/val_{}_batch.h5"
validation_set = SideSet(dataset_yaml, set_type="validation", dataset_df=validation_df)
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
pin_memory=True,
num_workers=num_thread)
if not load_batches_from_disk or write_batches_to_disk:
"""
Set the dataloaders according to the dataset_yaml
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SideSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['train']['chunk_per_segment'],
overlap=dataset_params['train']['overlap'])
validation_set = SideSet(dataset_yaml, set_type="validation", dataset_df=validation_df)
# Add for TensorBoard
#dataiter = iter(training_loader)
#data, labels = dataiter.next()
#writer.add_graph(model, data)
if write_batches_to_disk:
training_set.write_to_disk(dataset_params["batch_size"], train_batch_fn_format, num_thread)
validation_set.write_to_disk(dataset_params["batch_size"], val_batch_fn_format, num_thread)
else:
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_thread)
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
pin_memory=True,
num_workers=num_thread)
if load_batches_from_disk:
training_loader = Fileset(train_batch_fn_format)
validation_loader = Fileset(train_batch_fn_format)
"""
Set the training options
......@@ -886,14 +902,15 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval, device,
accuracy = 0.0
running_loss = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
data = data.to(device)
target = target.squeeze()
target = target.to(device)
optimizer.zero_grad()
if loss_criteria == 'aam':
output = model(data.to(device), target=target)
output = model(data), target=target)
else:
output = model(data.to(device), target=None)
output = model(data), target=None)
#with GuruMeditation():
loss = criterion(output, target)
......@@ -1243,7 +1260,8 @@ def xdebug(speaker_number,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['train']['chunk_per_segment'],
overlap=dataset_params['train']['overlap'])
overlap=dataset_params['train']['overlap'],
output_format="numpy")
validation_set = SideSet(dataset_yaml, set_type="validation", dataset_df=validation_df)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment