Commit b0835fe2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

clean version of VAD miss computation of validation accuracy

parent 218ade91
...@@ -441,30 +441,30 @@ def cross_validation(model, validation_loader, device): ...@@ -441,30 +441,30 @@ def cross_validation(model, validation_loader, device):
output = model(data.to(device)) output = model(data.to(device))
output = output.permute(1, 2, 0) output = output.permute(1, 2, 0)
target = target.permute(1, 0) target = target.permute(1, 0)
nbpoint = output.shape[0]
loss = criterion(output, target.to(device)) loss += criterion(output, target.to(device))
rc, pr, acc = calc_recall(output.data, target, device) rc, pr, acc = calc_recall(output.data, target, device)
recall += rc.item() recall += rc.item()
precision += pr.item() precision += pr.item()
accuracy += acc.item() accuracy += acc.item()
batch_size = target.shape[0] batch_size = target.shape[0]
if precision != 0 or recall != 0: if precision != 0 or recall != 0:
f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx + 1))) / \ f_measure = 2 * (precision / ((batch_idx + 1))) * (recall / ((batch_idx + 1))) / \
((precision / ((batch_idx + 1))) + (recall / ((batch_idx + 1)))) ((precision / ((batch_idx + 1))) + (recall / ((batch_idx + 1))))
logging.critical( logging.critical(
'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} ' \ 'Validation: [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.3f} ' \
'Recall: {:.3f} Precision: {:.3f} "\ 'Recall: {:.3f} Precision: {:.3f} "\
F-Measure: {:.3f}'.format(batch_idx + 1, F-Measure: {:.3f}'.format(batch_idx + 1,
validation_loader.__len__(), validation_loader.__len__(),
100. * batch_idx / validation_loader.__len__(), loss.item(), 100. * batch_idx / validation_loader.__len__(), loss.item(),
100.0 * accuracy / ((batch_idx + 1)), 100.0 * accuracy / ((batch_idx + 1)),
100.0 * recall / ((batch_idx + 1)), 100.0 * recall / ((batch_idx + 1)),
100.0 * precision / ((batch_idx + 1)), 100.0 * precision / ((batch_idx + 1)),
f_measure) f_measure)
) )
return accuracy, loss
def calc_recall(output,target,device): def calc_recall(output,target,device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment