Commit e349f9c7 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new xvector and dataset

parent b7b11a97
......@@ -362,11 +362,11 @@ class SideSet(Dataset):
self.augmentation = dataset["eval"]["augmentation"]
# Load the dataset description as pandas.dataframe
if data_set_df is None:
if dataset_df is None:
df = pandas.read_csv(dataset["dataset_description"])
else:
assert isinstance(data_set_df, pandas.DataFrame)
df = data_set_df
assert isinstance(dataset_df, pandas.DataFrame)
df = dataset_df
# Select all segments which duration is longer than the chosen one
if set_type == "train":
......@@ -429,10 +429,14 @@ class SideSet(Dataset):
:return:
"""
# Open
random_start = numpy.random.randint(int(self.sessions.iloc[index]['start'] * self.sample_rate),
int(self.sessions.iloc[index]['start'] + self.sessions.iloc[index]['duration'] * 16000) - self.sample_number)
sig, _ = soundfile.read(f"{self.data_path}/{self.sessions.iloc[index]['speaker_id']}/{self.sessions.iloc[index]['file_id']}{self.data_file_extension}",
start=self.sessions.iloc[index]['start'],
stop=self.sessions.iloc[index]['start'] + self.sample_number
start=random_start,
stop=random_start + self.sample_number
)
sig += 0.0001 * numpy.random.randn(sig.shape[0])
speaker_idx = self.sessions.iloc[index]["speaker_idx"]
# TODO: add data augmentation here!
......@@ -440,7 +444,7 @@ class SideSet(Dataset):
if self.transform_pipeline:
sig, speaker_idx, _, __ = self.transforms((sig, speaker_idx, self.spec_aug[index], self.temp_aug[index]))
return sig, speaker_idx
return torch.from_numpy(sig).type(torch.FloatTensor), speaker_idx
def __len__(self):
"""
......
......@@ -27,6 +27,7 @@ Copyright 2014-2020 Yevhenii Prokopalo, Anthony Larcher
import logging
import numpy
import pandas
import pickle
import shutil
import torch
......@@ -240,6 +241,8 @@ def xtrain(speaker_number,
:param args:
:return:
"""
device = torch.device("cuda:0")
# If we start from an existing model
if model_name is not None:
# Load the model
......@@ -273,7 +276,7 @@ def xtrain(speaker_number,
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(seed)
torch.manual_seed(dataset_params['seed'])
training_set = SideSet(dataset_yaml, set_type="train", dataset_df=training_df)
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
......@@ -312,10 +315,10 @@ def xtrain(speaker_number,
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
model = train_epoch(model, epoch, training_loader, optimizer, dataset_params["log_interval"])
model = train_epoch(model, epoch, training_loader, optimizer, dataset_params["log_interval"], device=device)
# Add the cross validation here
accuracy, val_loss = cross_validation(model, validation_loader)
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
......@@ -338,7 +341,7 @@ def xtrain(speaker_number,
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
def train_epoch(model, epoch, training_loader, optimizer, log_interval):
def train_epoch(model, epoch, training_loader, optimizer, log_interval, device):
"""
:param model:
......@@ -352,7 +355,7 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval):
criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target, _, __) in enumerate(training_loader):
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
optimizer.zero_grad()
output = model(data.to(device))
......@@ -362,14 +365,15 @@ def train_epoch(model, epoch, training_loader, optimizer, log_interval):
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
return model
def cross_validation(model, validation_loader):
def cross_validation(model, validation_loader, device):
"""
:param args:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment