Commit 19ddf37e authored by Colleen Beaumard's avatar Colleen Beaumard
Browse files

Delete scoring_cross_validation.py

parent abe485d4
import torch
from tqdm import tqdm
import os
import seaborn as sns
from math import sqrt
import pandas as pd
import numpy as np
import torchaudio
import sklearn.metrics as metrics
import matplotlib.pyplot as plt
from sidekit.nnet.xvector import Xtractor
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("model", help="The model used")
parser.add_argument("categories", help="The number of categories")
parser.add_argument("batchs", help="The number of batches used during training")
parser.add_argument("lr", help="The learning rate used during training")
parser.add_argument("--emotions", help="The emotions considered during training (has to respect the order in the "
"dictionnary index:emotion: neu ang sad hap+exc fea exc hap dis fru sur)",
default="neu ang sad hap+exc")
parser.add_argument("--freeze", help="If some parts of the mode were frozen")
args = parser.parse_args()
# We store the prediction (the argmax value's index) to compare it later to the golden annotation
labels = list(args.emotions.split(" "))
nb_batch = str(args.batchs)
model_type = args.model
if "-" in model_type:
model_type, model_2nd = model_type.split("-")[0], "-" + model_type.split("-")[1]
else:
model_2nd = ""
lr = str(float(args.lr))
cates = str(args.categories)
if args.freeze is not None:
freeze = "_freeze"
else:
freeze = ""
# Path to save the confusion matrix for cross-validation
if model_type == "custom":
path = "model_{}/Sess_all_cross-valid/{}/{}emo_{}batch_lr-{}{}".format(model_type, model_2nd[1:] ,cates, nb_batch, lr, freeze)
if not os.path.isdir(path.rsplit("/", 2)[0]):
os.mkdir(path.rsplit("/", 2)[0])
else:
path = "model_{}/Sess_all_cross-valid/{}emo_{}batch_lr-{}{}".format(model_type, cates, nb_batch, lr, freeze)
if not os.path.isdir(path.rsplit("/", 1)[0]):
os.mkdir(path.rsplit("/", 1)[0])
if not os.path.isdir(path):
os.mkdir(path)
def load_model(model_path, device):
"""
Load a model
:param model_path: path (str) to the model
:device: cpu, cuda, etc.
"""
device = torch.device(device)
model_config = torch.load(model_path, map_location=device)
model_opts = model_config["model_archi"]
if "embedding_size" not in model_opts:
model_opts["embedding_size"] = 256
xtractor = Xtractor(
model_config["speaker_number"],
model_archi=model_opts["model_type"],
loss=model_opts["loss"]["type"],
embedding_size=model_opts["embedding_size"],
)
xtractor.load_state_dict(model_config["model_state_dict"], strict=True)
xtractor.eval()
return xtractor, model_config
predictions = []
gold_anno = []
# We open the file with the index corresponding to emotions
index = open("list/emos_index.txt", "r")
index_emo = index.readlines()
index.close()
dic_index_emo = {}
if model_type == "custom":
dico = {} # Needed when model_type == custom
for elmt in index_emo:
dic_index_emo[int(elmt.split(": ")[1].replace("\n", ""))] = elmt.split(": ")[0]
print("\nAll emotions:", dic_index_emo, "\n\n## Beginning of extraction ##")
for i in range(1, 6):
print("Session {} is processing...".format(i))
# 1st is the model, 2nd is the weights and all
if model_type == "custom":
xtract, config = load_model("model_{}/Sess{}_test/{}/{}emo_{}batch_lr-{}{}/best_{}{}{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, i, model_2nd[1:], cates, nb_batch, lr, freeze, model_type, model_2nd, freeze, cates, nb_batch, lr, i), "cuda")
else:
xtract, config = load_model("model_{}/Sess{}_test/{}emo_{}batch_lr-{}{}/best_{}{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, i, cates, nb_batch, lr, freeze, model_type, freeze, cates, nb_batch, lr, i), "cuda")
path_wav = "data/IEMOCAP/Session{}/sentences/wav".format(i)
# We open the file to obtain the golden annotation (and we sort it out to keep only the current session)
recap = open("data/recap_emo_file.txt", "r")
recap_emo = recap.readlines()
recap_emo = [line for line in recap_emo if str(i) in line.split("\t")[0]]
recap.close()
for line in tqdm(recap_emo):
# We retrieve the golden emotion and retrieve the index associated to it
gold = line.split("\t")[2].replace("\n", "")
if "hap+exc" in labels:
if gold == "hap" or gold == "exc":
gold = "hap+exc"
if gold in labels:
gold_anno.append(gold)
# We extract the predicted emotion ONLY if the gold emotion is within the considered emotions
file = line.split("\t")[1]
folder = line.split("\t")[1].rsplit("_", 1)[0]
signal, sr = torchaudio.load(os.path.join(path_wav, folder, file))
if model_type == "custom":
# Need to make a dictionnary with a key
# "speech" and the signal in value
dico["speech"] = signal
outModel = xtract(dico)
else:
outModel = xtract(signal)
predictions.append(dic_index_emo[outModel[0][1].argmax().item()])
print("\n")
assert len(predictions) == len(gold_anno)
# We start to compare the predictions and gold_anno lists
UAR = metrics.recall_score(gold_anno, predictions, average="macro")
p = round(UAR, 2) # For confidence interval
accuracy = metrics.accuracy_score(gold_anno, predictions)
accuracy_percent = round(accuracy*100, 2)
UARPercent = round(UAR * 100, 2)
confMatrix = metrics.confusion_matrix(gold_anno, predictions, labels = labels)
print(confMatrix)
gold_dic = []
for i in range(len(confMatrix)):
gold_dic.append(sum(confMatrix[i]))
n = sum(gold_dic) # For confidence interval
conf_inter = round((1.96*sqrt((p*(1-p))/n)), 2) # 95% confidence interval
print("\nUAR:", UARPercent, "% ±", conf_inter, "\t Accuracy: ", accuracy_percent,"%\n")
[print("Total", labels[i], ":", gold_dic[i]) for i in range(len(labels))]
print("Total", n)
annot = []
for i in range(0, len(confMatrix)): # row
annot.append([])
tot = 0
for j in range(0, len(confMatrix[i])): # column
nbr = confMatrix[i][j]
percent = round(nbr/gold_dic[i], 2)*100
tot += percent
if j == len(confMatrix[i])-1:
if tot > 100:
percent -= (tot - 100)
elif tot < 100:
percent += (100 - tot)
full = str(int(percent)) + "% (" + str(nbr) + ")"
annot[i].append(full)
sns.heatmap(confMatrix, annot=annot, fmt="10", cmap="Blues", vmin=0, vmax=1000, xticklabels=labels, yticklabels=labels)
plt.title("Model: {}{}{}_".format(model_type, model_2nd, freeze) + str(config["speaker_number"]) + "emo_{}batch_lr: {}\nData: Test-IEMOCAP-cross_validation".format(nb_batch, lr) + " UAR = " + str(UARPercent) + "% ±" + str(conf_inter) + " Accuracy= " + str(accuracy_percent) + "%")
plt.xlabel("Prediction")
plt.ylabel("Ground truth")
plt.savefig(os.path.join(path, "confusion_matrix_{}{}{}_".format(model_type, model_2nd, freeze) + str(config["speaker_number"]) + "emo_{}batch_lr-{}_IEMOCAP-cross_validation.png".format(nb_batch, lr)))
plt.show()
plt.clf()
print("\nConfusion matrix done!")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment