Commit e6c25a87 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

seq2seq

parent be2b8c7e
...@@ -39,7 +39,6 @@ from .clustering.hac_utils import bic_square_root ...@@ -39,7 +39,6 @@ from .clustering.hac_utils import bic_square_root
from .clustering.cc_iv import ConnectedComponent from .clustering.cc_iv import ConnectedComponent
from .nnet.wavsets import SeqSet from .nnet.wavsets import SeqSet
from .nnet.seqtoseq import PreNet
from .nnet.seqtoseq import BLSTM from .nnet.seqtoseq import BLSTM
from .model_iv import ModelIV from .model_iv import ModelIV
......
...@@ -24,6 +24,5 @@ Copyright 2014-2020 Anthony Larcher ...@@ -24,6 +24,5 @@ Copyright 2014-2020 Anthony Larcher
""" """
from .wavsets import SeqSet from .wavsets import SeqSet
from .seqtoseq import PreNet
from .seqtoseq import BLSTM from .seqtoseq import BLSTM
from .seqtoseq import SeqToSeq from .seqtoseq import SeqToSeq
\ No newline at end of file
...@@ -25,8 +25,10 @@ Copyright 2014-2020 Anthony Larcher ...@@ -25,8 +25,10 @@ Copyright 2014-2020 Anthony Larcher
import os import os
import sys import sys
import logging
import pandas
import numpy import numpy
import OrderedDict from collections import OrderedDict
import random import random
import h5py import h5py
import shutil import shutil
...@@ -37,6 +39,7 @@ import yaml ...@@ -37,6 +39,7 @@ import yaml
from torch import optim from torch import optim
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .wavsets import SeqSet
from sidekit.nnet.sincnet import SincNet from sidekit.nnet.sincnet import SincNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -76,6 +79,7 @@ class BLSTM(nn.Module): ...@@ -76,6 +79,7 @@ class BLSTM(nn.Module):
self.input_size = input_size self.input_size = input_size
self.blstm_sizes = blstm_sizes self.blstm_sizes = blstm_sizes
self.blstm_layers = [] self.blstm_layers = []
for blstm_size in blstm_sizes: for blstm_size in blstm_sizes:
self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True)) self.blstm_layers.append(nn.LSTM(input_size, blstm_size // 2, bidirectional=True, batch_first=True))
input_size = blstm_size input_size = blstm_size
...@@ -95,15 +99,19 @@ class BLSTM(nn.Module): ...@@ -95,15 +99,19 @@ class BLSTM(nn.Module):
hiddens = [] hiddens = []
if self.hidden is None: if self.hidden is None:
#hidden_1, hidden_2 = None, None #hidden_1, hidden_2 = None, None
for _s in self.lstm_sizes: for _s in self.blstm_sizes:
hiddens.append(None) hiddens.append(None)
else: else:
hiddens = self.hidden hiddens = self.hidden
x = inputs x = inputs
for idx, _s in enumerate(self.lstm_sizes): outputs = []
for idx, _s in enumerate(self.blstm_sizes):
x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx]) x, hiddens[idx] = self.blstm_layers[idx](x, hiddens[idx])
outputs.append(x)
self.hidden = tuple(hiddens) self.hidden = tuple(hiddens)
output = torch.cat(outputs, dim=2)
return x return x
def output_size(self): def output_size(self):
...@@ -120,9 +128,6 @@ class SeqToSeq(nn.Module): ...@@ -120,9 +128,6 @@ class SeqToSeq(nn.Module):
def __init__(self, def __init__(self,
model_archi): model_archi):
# Todo Write like the Xtractor in order to enable a flexible build of the model including \
# Sincnet preprocessor, Convolutional filters, TDNN, BLSTM and other possible layers
super(SeqToSeq, self).__init__() super(SeqToSeq, self).__init__()
# Load Yaml configuration # Load Yaml configuration
...@@ -130,6 +135,7 @@ class SeqToSeq(nn.Module): ...@@ -130,6 +135,7 @@ class SeqToSeq(nn.Module):
cfg = yaml.load(fh, Loader=yaml.FullLoader) cfg = yaml.load(fh, Loader=yaml.FullLoader)
self.loss = cfg["loss"] self.loss = cfg["loss"]
self.feature_size = None
""" """
Prepare Preprocessor Prepare Preprocessor
...@@ -161,15 +167,16 @@ class SeqToSeq(nn.Module): ...@@ -161,15 +167,16 @@ class SeqToSeq(nn.Module):
input_size = self.feature_size input_size = self.feature_size
sequence_to_sequence = BLSTM(input_size=input_size, self.sequence_to_sequence = BLSTM(input_size=input_size,
blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"]) blstm_sizes=cfg["sequence_to_sequence"]["blstm_sizes"])
input_size = sequence_to_sequence.output_size() input_size = self.sequence_to_sequence.output_size
""" """
Prepare post-processing network Prepare post-processing network
""" """
# Create sequential object for the second part of the network # Create sequential object for the second part of the network
self.post_processing_activation = torch.nn.Tanh()
post_processing_layers = [] post_processing_layers = []
for k in cfg["post_processing"].keys(): for k in cfg["post_processing"].keys():
...@@ -179,7 +186,7 @@ class SeqToSeq(nn.Module): ...@@ -179,7 +186,7 @@ class SeqToSeq(nn.Module):
input_size = cfg["post_processing"][k]["output"] input_size = cfg["post_processing"][k]["output"]
elif k.startswith("activation"): elif k.startswith("activation"):
post_processing_layers.append((k, self.activation)) post_processing_layers.append((k, self.post_processing_activation))
elif k.startswith('batch_norm'): elif k.startswith('batch_norm'):
post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size))) post_processing_layers.append((k, torch.nn.BatchNorm1d(input_size)))
...@@ -187,8 +194,8 @@ class SeqToSeq(nn.Module): ...@@ -187,8 +194,8 @@ class SeqToSeq(nn.Module):
elif k.startswith('dropout'): elif k.startswith('dropout'):
post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k]))) post_processing_layers.append((k, torch.nn.Dropout(p=cfg["post_processing"][k])))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(post_processing_layers)) self.post_processing = torch.nn.Sequential(OrderedDict(post_processing_layers))
self.before_speaker_embedding_weight_decay = cfg["before_embedding"]["weight_decay"] #self.before_speaker_embedding_weight_decay = cfg["post_processing"]["weight_decay"]
def forward(self, inputs): def forward(self, inputs):
...@@ -197,36 +204,32 @@ class SeqToSeq(nn.Module): ...@@ -197,36 +204,32 @@ class SeqToSeq(nn.Module):
:param inputs: :param inputs:
:return: :return:
""" """
if self.hidden is None: if self.preprocessor is not None:
hidden_1, hidden_2 = None, None x = self.preprocessor(inputs)
else: else:
hidden_1, hidden_2 = self.hidden x = inputs
tmp, hidden_1 = self.lstm_1(inputs, hidden_1) x = self.sequence_to_sequence(x)
x, hidden_2 = self.lstm_2(tmp, hidden_2) x = self.post_processing(x)
self.hidden = (hidden_1, hidden_2)
x = torch.tanh(self.linear_1(x))
#x = torch.tanh(self.linear_2(x))
x = self.linear_2(x)
#x = torch.sigmoid(self.output(x))
return x return x
def seqTrain(data_dir, def seqTrain(dataset_yaml,
model_yaml,
mode, mode,
duration=2.,
seg_shift=0.25,
filter_type="gate",
collar_duration=0.1,
framerate=16000,
epochs=100, epochs=100,
batch_size=32,
lr=0.0001, lr=0.0001,
loss="cross_validation",
patience=10, patience=10,
model_name=None,
tmp_model_name=None, tmp_model_name=None,
best_model_name=None, best_model_name=None,
multi_gpu=True, multi_gpu=True,
opt='sgd', opt='sgd',
filter_type="gate",
collar_duration=0.1,
framerate=16000,
output_rate=100,
batch_size=32,
log_interval=10,
num_thread=10 num_thread=10
): ):
""" """
...@@ -251,8 +254,14 @@ def seqTrain(data_dir, ...@@ -251,8 +254,14 @@ def seqTrain(data_dir,
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Start from scratch # Start from scratch
model = SeqToSeq() if model_name is None:
# TODO implement a model adaptation model = SeqToSeq(model_yaml)
# If we start from an existing model
else:
# Load the model
logging.critical(f"*** Load model from = {model_name}")
checkpoint = torch.load(model_name)
model = SeqToSeq(model_yaml)
if torch.cuda.device_count() > 1 and multi_gpu: if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!") print("Let's use", torch.cuda.device_count(), "GPUs!")
...@@ -264,16 +273,31 @@ def seqTrain(data_dir, ...@@ -264,16 +273,31 @@ def seqTrain(data_dir,
""" """
Create two dataloaders for training and evaluation Create two dataloaders for training and evaluation
""" """
training_set, validation_set = None, None with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SeqSet(dataset_yaml,
set_type="train",
dataset_df=training_df)
training_loader = DataLoader(training_set, training_loader = DataLoader(training_set,
batch_size=batch_size, batch_size=dataset_params["batch_size"],
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
pin_memory=True,
num_workers=num_thread) num_workers=num_thread)
validation_set = SeqSet(dataset_yaml,
set_type="validation",
dataset_df=validation_df)
validation_loader = DataLoader(validation_set, validation_loader = DataLoader(validation_set,
batch_size=batch_size, batch_size=dataset_params["batch_size"],
drop_last=True, drop_last=True,
pin_memory=True,
num_workers=num_thread) num_workers=num_thread)
""" """
...@@ -333,7 +357,7 @@ def seqTrain(data_dir, ...@@ -333,7 +357,7 @@ def seqTrain(data_dir,
log_interval, log_interval,
device=device) device=device)
# Add the cross validation here # Cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device) accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy)) logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
......
...@@ -303,6 +303,7 @@ class SeqSet(Dataset): ...@@ -303,6 +303,7 @@ class SeqSet(Dataset):
Object creates a dataset for sequence to sequence training Object creates a dataset for sequence to sequence training
""" """
def __init__(self, def __init__(self,
dataset_yaml,
wav_dir, wav_dir,
mdtm_dir, mdtm_dir,
mode, mode,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment