Commit e6c25a87 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

seq2seq

parent be2b8c7e
......@@ -39,7 +39,6 @@ from .clustering.hac_utils import bic_square_root
from .clustering.cc_iv import ConnectedComponent
from .nnet.wavsets import SeqSet
from .nnet.seqtoseq import PreNet
from .nnet.seqtoseq import BLSTM
from .model_iv import ModelIV
......
......@@ -24,6 +24,5 @@ Copyright 2014-2020 Anthony Larcher
"""
from .wavsets import SeqSet
from .seqtoseq import PreNet
from .seqtoseq import BLSTM
from .seqtoseq import SeqToSeq
\ No newline at end of file
......@@ -25,8 +25,10 @@ Copyright 2014-2020 Anthony Larcher
import os
import sys
import logging
import pandas
import numpy
import OrderedDict
from collections import OrderedDict
import random
import h5py
import shutil
......@@ -37,6 +39,7 @@ import yaml
from torch import optim
from torch.utils.data import Dataset
from .wavsets import SeqSet
from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader
......@@ -76,6 +79,7 @@ class BLSTM(nn.Module):
self.input_size = input_size
self.blstm_sizes = blstm_sizes
self.blstm_layers = []
for blstm_size in blstm_sizes:
self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True))
input_size = blstm_size
......@@ -95,15 +99,19 @@ class BLSTM(nn.Module):
hiddens = []
if self.hidden is None:
#hidden_1, hidden_2 = None, None
for _s in self.lstm_sizes:
for _s in self.blstm_sizes:
hiddens.append(None)
else:
hiddens = self.hidden
x = inputs
for idx, _s in enumerate(self.lstm_sizes):
outputs = []
for idx, _s in enumerate(self.blstm_sizes):
x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
outputs.append(x)
self.hidden = tuple(hiddens)
output = torch.cat(outputs, dim=2)
return x
def output_size(self):
......@@ -120,9 +128,6 @@ class SeqToSeq(nn.Module):
def __init__(self,
model_archi):
# Todo Write like the Xtractor in order to enable a flexible build of the model including \
# Sincnet preprocessor, Convolutional filters, TDNN, BLSTM and other possible layers
super(SeqToSeq, self).__init__()
# Load Yaml configuration
......@@ -130,6 +135,7 @@ class SeqToSeq(nn.Module):
cfg = yaml.load(fh, Loader=yaml.FullLoader)
self.loss = cfg["loss"]
self.feature_size = None
"""
Prepare Preprocessor
......@@ -161,15 +167,16 @@ class SeqToSeq(nn.Module):
input_size = self.feature_size
sequence_to_sequence = BLSTM(input_size=input_size,
blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
self.sequence_to_sequence = BLSTM(input_size=input_size,
blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
input_size = sequence_to_sequence.output_size()
input_size = self.sequence_to_sequence.output_size
"""
Prepare post-processing network
"""
# Create sequential object for the second part of the network
self.post_processing_activation = torch.nn.Tanh()
post_processing_layers = []
for k in cfg["post_processing"].keys():
......@@ -179,7 +186,7 @@ class SeqToSeq(nn.Module):
input_size = cfg["post_processing"][k]["output"]
elif k.startswith("activation"):
post_processing_layers.append((k, self.activation))
post_processing_layers.append((k, self.post_processing_activation))
elif k.startswith('batch_norm'):
post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))
......@@ -187,8 +194,8 @@ class SeqToSeq(nn.Module):
elif k.startswith('dropout'):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(post_processing_layers))
self.before_speaker_embedding_weight_decay = cfg["before_embedding"]["weight_decay"]
self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
#self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
def forward(self, inputs):
......@@ -197,36 +204,32 @@ class SeqToSeq(nn.Module):
:param inputs:
:return:
"""
if self.hidden is None:
hidden_1, hidden_2 = None, None
if self.preprocessor is not None:
x = self.preprocessor(inputs)
else:
hidden_1, hidden_2 = self.hidden
tmp, hidden_1 = self.lstm_1(inputs, hidden_1)
x, hidden_2 = self.lstm_2(tmp, hidden_2)
self.hidden = (hidden_1, hidden_2)
x = torch.tanh(self.linear_1(x))
#x = torch.tanh(self.linear_2(x))
x = self.linear_2(x)
#x = torch.sigmoid(self.output(x))
x = inputs
x = self.sequence_to_sequence(x)
x = self.post_processing(x)
return x
def seqTrain(data_dir,
def seqTrain(dataset_yaml,
model_yaml,
mode,
duration=2.,
seg_shift=0.25,
filter_type="gate",
collar_duration=0.1,
framerate=16000,
epochs=100,
batch_size=32,
lr=0.0001,
loss="cross_validation",
patience=10,
model_name=None,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
opt='sgd',
filter_type="gate",
collar_duration=0.1,
framerate=16000,
output_rate=100,
batch_size=32,
log_interval=10,
num_thread=10
):
"""
......@@ -251,8 +254,14 @@ def seqTrain(data_dir,
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch
model = SeqToSeq()
# TODO implement a model adaptation
if model_name is None:
model = SeqToSeq(model_yaml)
# If we start from an existing model
else:
# Load the model
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = SeqToSeq(model_yaml)
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
......@@ -264,16 +273,31 @@ def seqTrain(data_dir,
"""
Create two dataloaders for training and evaluation
"""
training_set, validation_set = None, None
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
set_type="train",
dataset_df=training_df)
training_loader = DataLoader(training_set,
batch_size=batch_size,
batch_size=dataset_params["batch_size"],
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=num_thread)
validation_set = SeqSet(dataset_yaml,
set_type="validation",
dataset_df=validation_df)
validation_loader = DataLoader(validation_set,
batch_size=batch_size,
batch_size=dataset_params["batch_size"],
drop_last=True,
pin_memory=True,
num_workers=num_thread)
"""
......@@ -333,7 +357,7 @@ def seqTrain(data_dir,
log_interval,
device=device)
# Add the cross validation here
# Cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
......
......@@ -303,6 +303,7 @@ class SeqSet(Dataset):
Object creates a dataset for sequence to sequence training
"""
def __init__(self,
dataset_yaml,
wav_dir,
mdtm_dir,
mode,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment