Commit ffd57577 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

monitor display

parent 63209176
......@@ -325,24 +325,31 @@ class TrainingMonitor():
def __init__(self,
output_file,
log_interval=10,
patience=numpy.inf,
best_accuracy=0.0,
best_accuracy_epoch=1,
best_eer_epoch=1,
best_eer=100,
compute_test_eer=False
):
# Stocker plutot des listes pour conserver l'historique complet
self.current_epoch = 0
self.log_interval = log_interval
self.init_patience = patience
self.current_patience = patience
self.best_accuracy = best_accuracy
self.best_accuracy_epoch = best_accuracy_epoch
self.best_eer_epoch = best_eer_epoch
self.best_eer = best_eer
self.compute_test_eer = compute_test_eer
self.test_eer = []
self.val_eer = []
self.training_loss = []
self.training_acc = []
self.val_loss = []
self.val_acc = []
self.compute_test_eer = compute_test_eer
self.val_eer = []
self.is_best = True
# Initialize the logger
......@@ -371,32 +378,43 @@ class TrainingMonitor():
def update(self,
epoch,
training_acc=None,
training_loss=None,
test_eer=None,
val_eer=None,
val_loss=None,
val_acc=None):
self.current_epoch = epoch
self.val_eer.append(val_eer)
self.val_loss.append(val_loss)
self.val_acc.append(val_acc)
if training_acc:
self.training_acc.append(training_acc)
if training_loss:
self.training_loss.append(training_loss)
if val_eer:
self.val_eer.append(val_eer)
if val_loss:
self.val_loss.append(val_loss)
if val_acc:
self.val_acc.append(val_acc)
# remember best accuracy and save checkpoint
if self.compute_test_eer:
if self.compute_test_eer and test_eer
self.test_eer.append(test_eer)
self.is_best = test_eer < self.best_eer
self.best_eer = min(test_eer, self.best_eer)
else:
if self.is_best:
self.best_eer_epoch = epoch
self.current_patience = self.init_patience
else:
self.current_patience -= 1
elif val_eer:
self.is_best = val_eer < self.best_eer
self.best_eer = min(val_eer, self.best_eer)
self.best_accuracy = max(val_acc, self.best_accuracy)
if self.is_best:
self.best_accuracy_epoch = epoch
self.current_patience = self.init_patience
else:
self.current_patience -= 1
if self.is_best:
self.best_eer_epoch = epoch
self.current_patience = self.init_patience
else:
self.current_patience -= 1
class MeanStdPooling(torch.nn.Module):
......@@ -1854,6 +1872,110 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
return model
def new_train_epoch(model,
training_monitor,
training_loader,
optimizer,
scheduler,
device,
scaler=None,
clipping=False):
"""
:param model:
:param training_monitor:
:param training_loader:
:param optimizer:
:param scheduler:
:param device:
:param scaler:
:param clipping:
:return:
"""
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
if isinstance(model, Xtractor):
loss_criteria = model.loss
else:
loss_criteria = model.module.loss
accuracy = 0.0
running_loss = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
data = data.squeeze().to(device)
target = target.squeeze()
target = target.to(device)
optimizer.zero_grad(set_to_none=True)
if scaler is not None:
with torch.cuda.amp.autocast():
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
else:
if loss_criteria == 'aam':
output, _ = model(data, target=target)
loss = criterion(output, target)
elif loss_criteria == 'aps':
output_tuple, _ = model(data, target=target)
cos_sim_matx, output = output_tuple
loss = criterion(cos_sim_matx, torch.arange(0, int(data.shape[0]/2), device=device)) + criterion(output, target)
else:
output, _ = model(data, target=None)
loss = criterion(output, target)
if not torch.isnan(loss):
if scaler is not None:
scaler.scale(loss).backward()
if clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if batch_idx % training_monitor.log_interval == 0:
batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))
training_monitor.logger.info('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
training_monitor.current_epoch,
batch_idx + 1,
training_loader.__len__(),
100. * batch_idx / training_loader.__len__(),
loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
else:
save_checkpoint({
'epoch': training_monitor.current_epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': 0.0,
'scheduler': 0.0
}, False, filename="model_loss_NAN.pt", best_filename='toto.pt')
with open("batch_loss_NAN.pkl", "wb") as fh:
pickle.dump(data.cpu(), fh)
import sys
sys.exit()
running_loss = 0.0
scheduler.step()
return model
def cross_validation(model, validation_loader, device, validation_shape, tar_indices, non_indices, mixed_precision=False):
"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment