Commit fc1b0895 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

merge

parent d694e28b
......@@ -54,7 +54,14 @@ class SideSampler(torch.utils.data.Sampler):
Data Sampler used to generate uniformly distributed batches
"""
def __init__(self, data_source, spk_count, examples_per_speaker, samples_per_speaker, batch_size):
def __init__(self, data_source,
spk_count,
examples_per_speaker,
samples_per_speaker,
batch_size,
seed=0,
rank=0,
num_replicas=1):
"""[summary]
Args:
......@@ -69,8 +76,15 @@ class SideSampler(torch.utils.data.Sampler):
self.spk_count = spk_count
self.examples_per_speaker = examples_per_speaker
self.samples_per_speaker = samples_per_speaker
self.epoch = 0
self.seed = seed
self.rank = rank
self.num_replicas = num_replicas
assert batch_size % examples_per_speaker == 0
self.batch_size = batch_size//examples_per_speaker
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_replicas == 0
self.batch_size = batch_size // examples_per_speaker
# reference all segment indexes per speaker
for idx in range(self.spk_count):
......@@ -78,8 +92,11 @@ class SideSampler(torch.utils.data.Sampler):
for idx, value in enumerate(self.train_sessions):
self.labels_to_indices[value].append(idx)
# suffle segments per speaker
for ldlist in self.labels_to_indices.values():
random.shuffle(ldlist)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
for idx, ldlist in enumerate(self.labels_to_indices.values()):
ldlist = numpy.array(ldlist)
self.labels_to_indices[idx] = ldlist[torch.randperm(ldlist.shape[0]).numpy()]
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
......@@ -115,18 +132,26 @@ class SideSampler(torch.utils.data.Sampler):
# we want to convert the speaker indexes into segment indexes
self.index_iterator = numpy.zeros_like(batch_matrix)
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# keep track of next segment index to sample for each speaker
for idx, value in enumerate(batch_matrix):
if self.segment_cursors[value] > len(self.labels_to_indices[value]) - 1:
random.shuffle(self.labels_to_indices[value])
self.labels_to_indices[value] = self.labels_to_indices[value][torch.randperm(self.labels_to_indices[value].shape[0])]
self.segment_cursors[value] = 0
self.index_iterator[idx] = self.labels_to_indices[value][self.segment_cursors[value]]
self.segment_cursors[value] += 1
self.index_iterator = self.index_iterator.reshape(-1, self.num_replicas * self.examples_per_speaker)[:, self.rank * self.examples_per_speaker:(self.rank + 1) * self.examples_per_speaker].flatten()
return iter(self.index_iterator)
def __len__(self) -> int:
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker)
return (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) // self.num_replicas
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
class SideSet(Dataset):
......@@ -148,9 +173,6 @@ class SideSet(Dataset):
:param chunk_per_segment: number of chunks to select for each segment
default is 1 and -1 means select all possible chunks
"""
#with open(data_set_yaml, "r") as fh:
# dataset = yaml.load(fh, Loader=yaml.FullLoader)
self.data_path = dataset["data_path"]
self.sample_rate = int(dataset["sample_rate"])
self.data_file_extension = dataset["data_file_extension"]
......@@ -259,7 +281,8 @@ class SideSet(Dataset):
# Check the size of the file
current_session = self.sessions.iloc[index]
nfo = soundfile.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
# TODO is this required ?
nfo = torchaudio.info(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}")
original_start = int(current_session['start'])
if self.overlap > 0:
lowest_shift = self.overlap/2
......@@ -272,18 +295,23 @@ class SideSet(Dataset):
else:
start_frame = original_start
if start_frame + self.sample_number >= nfo.frames:
start_frame = numpy.min(nfo.frames - self.sample_number - 1)
conversion_rate = nfo.sample_rate // self.sample_rate
if start_frame + conversion_rate * self.sample_number >= nfo.num_frames:
start_frame = numpy.min(nfo.num_frames - conversion_rate * self.sample_number - 1)
speech, speech_fs = torchaudio.load(f"{self.data_path}/{current_session['file_id']}{self.data_file_extension}",
frame_offset=start_frame,
num_frames=self.sample_number)
frame_offset=conversion_rate*start_frame,
num_frames=conversion_rate*self.sample_number)
if nfo.sample_rate != self.sample_rate:
speech = torchaudio.transforms.Resample(nfo.sample_rate, self.sample_rate).forward(speech)
speech += 10e-6 * torch.randn(speech.shape)
if len(self.transform) > 0:
speech = data_augmentation(speech,
speech_fs,
self.sample_rate,
self.transform,
self.transform_number,
noise_df=self.noise_df,
......@@ -304,6 +332,14 @@ class SideSet(Dataset):
"""
return self.len
def get_sample(path, resample=None):
effects = [
["remix", "1"]
]
if resample:
effects.append(["rate", f'{resample}'])
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
class IdMapSet(Dataset):
"""
......
......@@ -68,9 +68,6 @@ from .loss import SoftmaxAngularProto, ArcLinear
from .loss import l2_norm
from .loss import ArcMarginProduct
from ..sidekit_io import init_logging
torch.backends.cudnn.benchmark = True
os.environ['MKL_THREADING_LAYER'] = 'GNU'
......@@ -199,7 +196,6 @@ def eer(negatives, positives):
def test_metrics(model,
device,
speaker_number,
num_thread,
mixed_precision):
"""Compute model metrics
......@@ -795,7 +791,7 @@ class Xtractor(torch.nn.Module):
:return:
"""
if self.preprocessor is not None:
x = self.preprocessor(x)
x = self.preprocessor(x, is_eval)
x = self.sequence_network(x)
......@@ -817,10 +813,11 @@ class Xtractor(torch.nn.Module):
return self.after_speaker_embedding(x), x
elif self.loss in ['aam', 'aps']:
if is_eval:
x = torch.nn.functional.normalize(x, dim=1)
else:
x = self.after_speaker_embedding(x, target=target), torch.nn.functional.normalize(x, dim=1)
x = self.after_speaker_embedding(x, target=target), torch.nn.functional.normalize(x, dim=1)
elif self.loss == 'smn':
if not is_eval:
x = self.after_speaker_embedding(x, target=target), x
return x
......@@ -1059,7 +1056,7 @@ def get_network(model_opts):
return model
def get_loaders(dataset_opts, training_opts, speaker_number):
def get_loaders(dataset_opts, training_opts, model_opts, speaker_number):
"""
:param dataset_yaml:
......@@ -1081,6 +1078,7 @@ def get_loaders(dataset_opts, training_opts, speaker_number):
training_set = SideSet(dataset_opts,
set_type="train",
chunk_per_segment=-1,
transform_number=dataset_opts['train']['transform_number'],
overlap=dataset_opts['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
......@@ -1091,23 +1089,41 @@ def get_loaders(dataset_opts, training_opts, speaker_number):
dataset_df=validation_df,
output_format="pytorch")
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
dataset_opts["train"]["sampler"]["examples_per_speaker"],
dataset_opts["train"]["sampler"]["samples_per_speaker"],
dataset_opts["batch_size"])
if model_opts["loss"]["type"] == 'aps':
samples_per_speaker = 2
else:
samples_per_speaker = 1
if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//torch.cuda.device_count()
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
dataset_opts["train"]["sampler"]["examples_per_speaker"],
dataset_opts["train"]["sampler"]["samples_per_speaker"],
dataset_opts["batch_size"])
else:
batch_size = dataset_opts["batch_size"]
side_sampler = SideSampler(training_set.sessions['speaker_idx'],
speaker_number,
samples_per_speaker,
batch_size,
batch_size,
seed=dataset_opts['seed'])
training_loader = DataLoader(training_set,
batch_size=dataset_opts["batch_size"],
batch_size=batch_size,
shuffle=False,
drop_last=True,
pin_memory=True,
sampler=side_sampler,
num_workers=training_opts["num_cpu"],
persistent_workers=True)
persistent_workers=False)
validation_loader = DataLoader(validation_set,
batch_size=dataset_opts["batch_size"],
batch_size=batch_size,
drop_last=False,
pin_memory=True,
num_workers=training_opts["num_cpu"],
......@@ -1123,7 +1139,7 @@ def get_loaders(dataset_opts, training_opts, speaker_number):
tar_non_ratio = numpy.sum(tar_indices)/numpy.sum(non_indices)
non_indices *= (numpy.random.rand(*non_indices.shape) < tar_non_ratio)
return training_loader, validation_loader, tar_indices, non_indices
return training_loader, validation_loader, side_sampler, tar_indices, non_indices
def get_optimizer(model, model_opts, train_opts):
......@@ -1178,9 +1194,10 @@ def get_optimizer(model, model_opts, train_opts):
if train_opts["scheduler"]["type"] == 'CyclicLR':
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=1e-3,
mode="triangular2",
step_size_up=75000)
max_lr=train_opts["lr"],
step_size_up=model_opts["speaker_number"] * 2,
step_size_down=None,
mode="triangular2")
elif train_opts["scheduler"]["type"] == "MultiStepLR":
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
milestones=[10000,50000,100000],
......@@ -1244,6 +1261,7 @@ def save_model(model, training_monitor, model_opts, training_opts, optimizer, sc
def new_xtrain(dataset_description,
model_description,
training_description,
local_rank=-1,
**kwargs):
"""
REFACTORING
......@@ -1277,6 +1295,7 @@ def new_xtrain(dataset_description,
torch.cuda.manual_seed(training_opts["seed"])
# Display the entire configurations as YAML dictionnaries
if local_rank < 1:
monitor.logger.info("\n*********************************\nDataset options\n*********************************\n")
monitor.logger.info(yaml.dump(dataset_opts, default_flow_style=False))
monitor.logger.info("\n*********************************\nModel options\n*********************************\n")
......@@ -1290,24 +1309,50 @@ def new_xtrain(dataset_description,
embedding_size = model.embedding_size
# Set the device and manage parallel processing
if torch.cuda.device_count() > 1 and training_opts["multi_gpu"]:
model = torch.nn.DataParallel(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(local_rank)
device = torch.device("cuda")
model.to(device)
# If multi-gpu
""" [HOW TO] from https://gist.github.com/sgraaf/5b0caa3a320f28c27c12b5efeb35aa4c
- Add the following line right after "if __name__ == '__main__':" in your main script :
parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.')
- Then, in your shell :
export NUM_NODES=1
export NUM_GPUS_PER_NODE=2
export NODE_RANK=0
export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))
python -m torch.distributed.launch \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank $NODE_RANK \
train_xvector.py ...
"""
if training_opts["multi_gpu"]:
if local_rank < 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
torch.distributed.init_process_group(backend='nccl', init_method='env://')
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank
)
else:
print("Train on a single GPU")
# Initialise data loaders
training_loader, validation_loader, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
training_loader, validation_loader, sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
monitor.logger.info(f"Start training process")
monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
if local_rank < 1:
monitor.logger.info(f"Start training process")
monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
monitor.logger.info(f"Validation EER will be measured using")
monitor.logger.info(f"\t {numpy.sum(validation_tar_indices)} target trials and")
monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
monitor.logger.info(f"Validation EER will be measured using")
monitor.logger.info(f"\t {numpy.sum(validation_tar_indices)} target trials and")
monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
# Create optimizer and scheduler
optimizer, scheduler = get_optimizer(model, model_opts, training_opts)
......@@ -1325,6 +1370,8 @@ def new_xtrain(dataset_description,
print(f"Stopping at epoch {epoch} for cause of patience")
break
sampler.set_epoch(epoch)
model = new_train_epoch(model,
training_opts,
monitor,
......@@ -1345,7 +1392,7 @@ def new_xtrain(dataset_description,
training_opts["mixed_precision"])
test_eer = None
if training_opts["compute_test_eer"]:
if training_opts["compute_test_eer"] and local_rank < 1:
test_eer = new_test_metrics(model, device, model_opts, dataset_opts, training_opts)
monitor.update(test_eer=test_eer,
......@@ -1353,17 +1400,20 @@ def new_xtrain(dataset_description,
val_loss=val_loss,
val_acc=val_acc)
monitor.display()
if local_rank < 1:
monitor.display()
# Save the current model and if needed update the best one
# TODO ajouter une option qui garde les modèles à certaines époques (par exemple avant le changement de LR
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler)
if local_rank < 1:
save_model(model, monitor, model_opts, training_opts, optimizer, scheduler)
for ii in range(torch.cuda.device_count()):
monitor.logger.info(torch.cuda.memory_summary(ii))
# TODO gérer l'affichage en utilisant le training_monitor
monitor.display_final()
if local_rank < 1:
monitor.display_final()
return monitor.best_eer
......@@ -1753,6 +1803,10 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'smn':
output_tuple, _ = model(data, target=target)
loss, output = output_tuple
loss += criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
......@@ -1843,7 +1897,6 @@ def new_train_epoch(model,
accuracy = 0.0
running_loss = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
data = data.squeeze().to(device)
target = target.squeeze()
target = target.to(device)
......@@ -1854,10 +1907,13 @@ def new_train_epoch(model,
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'smn':
output_tuple, _ = model(data, target=target)
loss, output = output_tuple
loss += criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
loss, output = output_tuple
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
......@@ -1873,7 +1929,8 @@ def new_train_epoch(model,
output, _ = model(data, target=None)
loss = criterion(output, target)
if not torch.isnan(loss):
#if not torch.isnan(loss):
if True:
if scaler is not None:
scaler.scale(loss).backward()
if clipping:
......@@ -1900,18 +1957,18 @@ def new_train_epoch(model,
loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
else:
save_checkpoint({
'epoch': training_monitor.current_epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': 0.0,
'scheduler': 0.0
}, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
with open("batch_loss_NAN.pkl", "wb") as fh:
pickle.dump(data.cpu(), fh)
import sys
sys.exit()
#else:
# save_checkpoint({
# 'epoch': training_monitor.current_epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'accuracy': 0.0,
# 'scheduler': 0.0
# }, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
# with open("batch_loss_NAN.pkl", "wb") as fh:
# pickle.dump(data.cpu(), fh)
# import sys
# sys.exit()
running_loss = 0.0
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
......@@ -1939,7 +1996,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
embeddings = torch.zeros(validation_shape)
#classes = torch.zeros([validation_shape[0]])
cursor = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm.tqdm(validation_loader, desc='validation compute', mininterval=1, disable=None)):
......@@ -1948,7 +2004,13 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
batch_size = target.shape[0]
data = data.squeeze().to(device)
with torch.cuda.amp.autocast(enabled=mixed_precision):
batch_predictions, batch_embeddings = model(data, target=None, is_eval=False)
output, batch_embeddings = model(data, target=None, is_eval=True)
if loss_criteria == 'cce':
batch_embeddings = l2_norm(batch_embeddings)
if loss_criteria == 'smn':
batch_embeddings, batch_predictions = output
else:
batch_predictions = output
accuracy += (torch.argmax(batch_predictions.data, 1) == target).sum()
loss += criterion(batch_predictions, target)
embeddings[cursor:cursor + batch_size,:] = batch_embeddings.detach().cpu()
......@@ -2068,12 +2130,12 @@ def extract_embeddings(idmap_name,
with torch.no_grad():
for idx, (data, mod, seg, start, stop) in enumerate(tqdm.tqdm(dataloader,
desc='xvector extraction',
mininterval=1)):
mininterval=1,
disable=None)):
if data.shape[1] > 20000000:
data = data[...,:20000000]
with torch.cuda.amp.autocast(enabled=mixed_precision):
vec = model(x=data.to(device), is_eval=True)
_, vec = model(x=data.to(device), is_eval=True)
embeddings.stat1[idx, :] = vec.detach().cpu()
return embeddings
......@@ -2100,7 +2162,7 @@ def extract_embeddings_per_speaker(idmap_name,
else:
model = model_filename
print(model)
model = model.to(memory_format=torch.channels_last)
min_duration = (model.context_size() - 1) * frame_shift + frame_duration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment