Commit 316eb23f authored by Colleen Beaumard's avatar Colleen Beaumard
Browse files

Modification of scoring scripts (consideration of any emotions)

parent ad0138b5
......@@ -5,12 +5,6 @@ This repository contains the framework for training emotion recognition models o
### Data preparation
IEMOCAP dataset must be in data/IEMOCAP.
<<<<<<< HEAD
=======
! The scripts were made on a basis of 4 categories of emotion (neu, ang, sad, hap+exc). If you want to change those, you will have to change manually the scripts !
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
```bash
# To activate env.sh
source ../../env.sh
......@@ -63,11 +57,7 @@ release_model.sh model_half_resnet34/best_model_half_resnet34_cuda_JIT.pt model_
### Evaluation
To launch the evaluation of the model, run:
```bash
<<<<<<< HEAD
python ./local/scoring.py #model #session_test #nb_categories #batch #lr #--emotions(default:neu ang sad hap+exc, specific order needed) #--freeze(if parts of the model were frozen)
=======
python ./local/scoring.py #model #session_test #nb_categories #batch #lr #epoch #--emotions(default:neu ang sad hap+exc, specific order needed) #--freeze(if parts of the model were frozen)
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
```
A confusion matrix and losses plot will be made, and all the files will be moves to a special directory (example: "model\_half\_resnet34/Sess1\_test/4emo\_100batch\_lr-0.0001").
......
......@@ -16,14 +16,9 @@ parser.add_argument("session_test", help="The session considered as a test (for
parser.add_argument("categories", help="The number of categories")
parser.add_argument("batchs", help="The number of batches used during training")
parser.add_argument("lr", help="The learning rate used during training")
<<<<<<< HEAD
parser.add_argument("--emotions", help="The emotions considered during training (has to respect the order in the "
"dictionnary index:emotion!): neu ang sad hap+exc fea exc hap dis fru sur",
default="neu ang sad hap+exc")
=======
parser.add_argument("epoch", help="The last epoch of training")
parser.add_argument("--emotions", help="The emotions considered during training (has to respect the order in the dictionnary index:emotion!)", default="neu ang sad hap+exc")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
parser.add_argument("--freeze", help="If some parts of the mode were frozen")
args = parser.parse_args()
......@@ -38,10 +33,6 @@ def load_model(model_path, device):
device = torch.device(device)
model_config = torch.load(model_path, map_location=device)
model_opts = model_config["model_archi"]
<<<<<<< HEAD
=======
#ptrix = metrics.confusion_matrix(gold_anno, predictions, labels=emoData["emoList"])rint(model_config["model_state_dict"])
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
if "embedding_size" not in model_opts:
model_opts["embedding_size"] = 256
......@@ -51,103 +42,60 @@ def load_model(model_path, device):
loss=model_opts["loss"]["type"],
embedding_size=model_opts["embedding_size"],
)
<<<<<<< HEAD
xtractor.load_state_dict(model_config["model_state_dict"], strict=True)
xtractor.eval()
return xtractor, model_config
=======
xtractor.load_state_dict(model_config["model_state_dict"], strict=True)
#xtractor = xtractor.to(device)
xtractor.eval()
return xtractor, model_config
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
# We store the prediction (the argmax value's index) to compare it later to the golden annotation
ses_nb = args.session_test
labels = list(args.emotions.split(" "))
nb_batch = str(args.batchs)
model_type = args.model
<<<<<<< HEAD
lr = str(args.lr)
=======
#patience = str(args.patience)
lr = str(args.lr)
epoch = args.epoch
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
cates = str(args.categories)
if args.freeze is not None:
freeze = "_freeze"
else:
freeze = ""
<<<<<<< HEAD
### For the confusion matrix ###
# 1st is the model, 2nd is the weights and all
xtract, config = load_model("model_{}/best_{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt"
.format(model_type, model_type, cates, nb_batch, lr, ses_nb), "cuda")
=======
## For the confusion matrix ##
# 1st is the model, 2nd is the weights and all
xtract, config = load_model("model_{}/best_{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, model_type, cates, nb_batch, lr, ses_nb), "cuda")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
predictions = []
gold_anno = []
path = "data/IEMOCAP/Session{}/sentences/wav".format(ses_nb)
<<<<<<< HEAD
# We open the file to obtain the golden annotation (and we sort it out to keep only the current session)
recap = open("data/recap_emo_file.txt", "r")
=======
#outModel[0][1].argmax() -> Get the tensor([indice]) of the highest value
#outModel[0][1][0][outModel[0][1].argmax()] -> Get the score of the highest value according to its index
# We open the file to obtain the golden annotation (and we sort it out to keep only the current session)
recap = open("data/recap_emo_file.txt","r")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
recap_emo = recap.readlines()
recap_emo = [line for line in recap_emo if ses_nb in line.split("\t")[0]]
recap.close()
# We open the file with the index corresponding to emotions
<<<<<<< HEAD
index = open("list/emos_index.txt", "r")
=======
index = open("list/emos_index.txt","r")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
index_emo = index.readlines()
index.close()
dic_index_emo = {}
<<<<<<< HEAD
# Needed when the model_type == custom
dico = {}
for elmt in index_emo:
dic_index_emo[int(elmt.split(": ")[1].replace("\n", ""))] = elmt.split(": ")[0]
=======
for elmt in index_emo:
dic_index_emo[int(elmt.split(": ")[1].replace("\n",""))] = elmt.split(": ")[0]
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
print("\nAll emotions:", dic_index_emo, "\n\n## Beginning of extraction ##")
for line in tqdm(recap_emo):
# We retrieve the golden emotion and retrieve the index associated to it
<<<<<<< HEAD
gold = line.split("\t")[2].replace("\n", "")
=======
gold = line.split("\t")[2].replace("\n","")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
if "hap+exc" in labels:
if gold == "hap" or gold == "exc":
gold = "hap+exc"
if gold in labels:
gold_anno.append(gold)
<<<<<<< HEAD
# We extract the predicted emotion ONLY if the gold emotion is within the considered emotions
file = line.split("\t")[1]
folder = line.split("\t")[1].rsplit("_", 1)[0]
......@@ -160,13 +108,6 @@ for line in tqdm(recap_emo):
outModel = xtract(dico)
else:
outModel = xtract(signal)
=======
# We extract the predicted emotion ONLY if the gold emotion is within the considered emotions
file = line.split("\t")[1]
folder = line.split("\t")[1].rsplit("_",1)[0]
signal, sr = torchaudio.load(os.path.join(path,folder,file))
outModel = xtract(signal)
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
predictions.append(dic_index_emo[outModel[0][1].argmax().item()])
assert len(predictions) == len(gold_anno)
......@@ -174,16 +115,11 @@ assert len(predictions) == len(gold_anno)
# We start to compare the predictions and gold_anno lists
UAR = metrics.recall_score(gold_anno, predictions, average="macro")
UARPercent = round(UAR * 100, 2)
<<<<<<< HEAD
print("\nUAR:", UARPercent, "%\n")
=======
print("UAR:", UARPercent, "%")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
confMatrix = metrics.confusion_matrix(gold_anno, predictions, labels=labels)
print(confMatrix)
<<<<<<< HEAD
gold_dic = {key: 0 for key in labels}
dico = {key: gold_dic for key in labels}
......@@ -192,78 +128,32 @@ for gold, pred in zip(gold_anno, predictions):
gold_dic[gold] += 1
[print("Total", key, ":", value) for key, value in gold_dic.items()]
=======
gold_dic = {"neu":0, "ang":0,"sad":0,"hap+exc":0}
dico = {"neu":{"ang":0, "sad":0,"neu":0,"hap+exc":0},
"ang":{"ang":0, "sad":0,"neu":0,"hap+exc":0},
"sad":{"ang":0, "sad":0,"neu":0,"hap+exc":0},
"hap+exc":{"ang":0, "sad":0,"neu":0,"hap+exc":0}}
for gold, pred in zip(gold_anno, predictions):
dico[gold][pred] +=1
gold_dic[gold] += 1
print(gold_dic, "\n")
print(dico)
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
annot = []
em = [value for value in gold_dic.values()]
<<<<<<< HEAD
for i in range(0, len(confMatrix)): # row
annot.append([])
tot = 0
for j in range(0, len(confMatrix[i])): # column
nbr = confMatrix[i][j]
percent = round(nbr/em[i], 2)*100
=======
for i in range(0,len(confMatrix)): # row
annot.append([])
tot = 0
for j in range(0,len(confMatrix[i])): # column
nbr = confMatrix[i][j]
percent = round(nbr/em[i],2)*100
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
tot += percent
if j == len(confMatrix[i])-1:
if tot > 100:
percent -= (tot - 100)
elif tot < 100:
percent += (100 - tot)
<<<<<<< HEAD
full = str(int(percent)) + "% (" + str(nbr) + ")"
annot[i].append(full)
### For the losses ###
=======
full = str(int(percent)) + "% (" + str(nbr) + ")"
annot[i].append(full)
path = "model_{}/Sess{}_test/{}emo_{}batch_lr-{}{}".format(model_type, ses_nb, cates, nb_batch, lr, freeze)
if not os.path.isdir(path.rsplit("/",1)[0]):
os.mkdir(path.rsplit("/",1)[0])
if not os.path.isdir(path):
os.mkdir(path)
sns.heatmap(confMatrix, annot=annot, fmt="10", cmap="Blues", vmin=0, vmax=350, xticklabels=labels, yticklabels=labels)
plt.title("Model: " + str(config["model_archi"]["model_type"]) + "{}_".format(freeze) + str(config["speaker_number"]) + "emo_{}batch\nepoch: {} lr: {} Data: Test-IEMOCAP {}".format(nb_batch, epoch, lr, ses_nb) + " UAR = " + str(UARPercent) + "%")
plt.xlabel("Prediction")
plt.ylabel("Ground truth")
plt.savefig(os.path.join(path, "confusion_matrix_" + str(config["model_archi"]["model_type"]) + "{}_".format(freeze) + str(config["speaker_number"]) + "emo_{}batch_epoch-{}_lr-{}_Test-IEMOCAP{}.png".format(nb_batch, epoch, lr, ses_nb)))
plt.show()
plt.clf()
print("Confusion matrix done!")
## For the losses ##
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
fil = open("logs/{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.log".format(model_type, cates, nb_batch, lr, ses_nb), "r")
file = fil.readlines()
fil.close()
# Search for "Loss:", "Validation Loss", "Epoch" and "reducing" in all lines
valid_loss = [line for line in file if "Validation Loss" in line]
<<<<<<< HEAD
# "Loss:" and "Epoch" in same line
if model_type == "custom":
......@@ -313,49 +203,17 @@ print("\nConfusion matrix done!")
# Plot losses
ticks = [nb for nb in range(0, 9)]
eticks = [nb for nb in range(0, len(aepoch), 5)]
=======
loss_epoch = [line for line in file if "Epoch" in line] [2:] # "Loss:" and "Epoch:" in same line
reduce_lr = [line for line in file if "reducing" in line]
# Epoch 9: reducing learning rate of group 0 to 2.5000e-04.
vloss = []
tloss = []
aepoch = []
for linev, linel in zip(valid_loss, loss_epoch): # 2022-03-31 12:11:10,774 - INFO - Epoch: 23 [1/192 (0%)] Loss: 1.209560 Accuracy: 90.000
linel = linel.split(":")
linev = linev.split("=")
vloss.append(round(float(linev[3].replace("\n", "").replace(" ", "")),2)) # 0.7448596954345703
aepoch.append(linel[3].split(" ")[1]) # 23
tloss.append(round(float(linel[4].split("\t")[0].replace(" ","")),2)) # 1.209560
# See if the learning rate has been reduced, and if yes, when
assert len(aepoch) == len(tloss) == len(vloss)
for e,t,v in zip(aepoch, tloss, vloss):
print("Epoch:", e, "\ttloss:", t, "\tvloss:", v)
ticks = [nb for nb in range(0,9)]
eticks = [nb for nb in range(0,len(aepoch),5)]
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
plt.plot(aepoch, tloss, label="Training loss")
plt.plot(aepoch, vloss, label="Validation loss")
plt.yticks(ticks)
plt.xticks(eticks)
<<<<<<< HEAD
if len(reduce_lr) != 0:
colors = ["b", "g", "y", "c", "m", "r"]
=======
if len(reduce_lr) != 0:
colors = ["b","g","y","c","m","r"]
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
i = 0
for key, value in reduce_lr.items():
label = "lr: " + str(value)
plt.axvline(x=key, color=colors[i], linestyle='--', label=label)
<<<<<<< HEAD
i += 1
plt.legend()
plt.title("Model: {}{}_{}emo_{}batch\nEpoch: {} lr: {} Data: Test-IEMOCAP {}".format(model_type, freeze, cates,
......@@ -381,16 +239,3 @@ os.replace("model_{}/tmp_{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_
os.replace("logs/{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.log".format(model_type, cates, nb_batch, lr, ses_nb),
os.path.join(path, "{}{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.log".format(model_type, freeze, cates,
nb_batch, lr, ses_nb)))
=======
i+= 1
plt.legend()
plt.title("Model: {}{}_{}emo_{}batch\nEpoch: {} lr: {} Data: Test-IEMOCAP {}".format(model_type, freeze, cates, nb_batch, epoch, lr, ses_nb))
plt.savefig(os.path.join(path, "losses_{}{}_{}emo_{}batch_epoch-{}_lr-{}_Test-IEMOCAP{}.png".format(model_type, freeze, cates, nb_batch, epoch, lr, ses_nb)))
plt.show()
plt.clf()
print("Losses plotted!")
os.replace("model_{}/best_{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, model_type, cates, nb_batch, lr, ses_nb), os.path.join(path, "best_{}{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, freeze, cates, nb_batch, lr, ses_nb)))
os.replace("model_{}/tmp_{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, model_type, cates, nb_batch, lr, ses_nb), os.path.join(path, "tmp_{}{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, freeze, cates, nb_batch, lr, ses_nb)))
os.replace("logs/half_resnet34_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.log".format(cates, nb_batch, lr, ses_nb), os.path.join(path, "half_resnet34{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.log".format(freeze, cates, nb_batch, lr, ses_nb)))
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
......@@ -15,13 +15,9 @@ parser.add_argument("model", help="The model used")
parser.add_argument("categories", help="The number of categories")
parser.add_argument("batchs", help="The number of batches used during training")
parser.add_argument("lr", help="The learning rate used during training")
<<<<<<< HEAD
parser.add_argument("--emotions", help="The emotions considered during training (has to respect the order in the "
"dictionnary index:emotion: neu ang sad hap+exc fea exc hap dis fru sur)",
default="neu ang sad hap+exc")
=======
parser.add_argument("--emotions", help="The emotions considered during training (has to respect the order in the dictionnary index:emotion!)", default="neu ang sad hap+exc")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
parser.add_argument("--freeze", help="If some parts of the mode were frozen")
args = parser.parse_args()
......@@ -38,20 +34,12 @@ else:
# Path to save the confusion matrix for cross-validation
path = "model_{}/Sess_all_cross-valid/{}emo_{}batch_lr-{}{}".format(model_type, cates, nb_batch, lr, freeze)
<<<<<<< HEAD
if not os.path.isdir(path.rsplit("/", 1)[0]):
os.mkdir(path.rsplit("/", 1)[0])
if not os.path.isdir(path):
os.mkdir(path)
=======
if not os.path.isdir(path.rsplit("/",1)[0]):
os.mkdir(path.rsplit("/",1)[0])
if not os.path.isdir(path):
os.mkdir(path)
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
def load_model(model_path, device):
"""
Load a model
......@@ -76,15 +64,11 @@ def load_model(model_path, device):
xtractor.eval()
return xtractor, model_config
<<<<<<< HEAD
=======
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
predictions = []
gold_anno = []
# We open the file with the index corresponding to emotions
<<<<<<< HEAD
index = open("list/emos_index.txt", "r")
index_emo = index.readlines()
index.close()
......@@ -101,44 +85,18 @@ for i in range(1, 6):
xtract, config = load_model("model_{}/Sess{}_test/{}emo_{}batch_lr-{}{}/best_{}{}_{}emo_{}batch_lr-{}_"
"Test-IEMOCAP{}.pt".format(model_type, i, cates, nb_batch, lr, freeze, model_type,
freeze, cates, nb_batch, lr, i), "cuda")
=======
index = open("list/emos_index.txt","r")
index_emo = index.readlines()
index.close()
dic_index_emo = {}
for elmt in index_emo:
dic_index_emo[int(elmt.split(": ")[1].replace("\n",""))] = elmt.split(": ")[0]
print("\nAll emotions:", dic_index_emo, "\n\n## Beginning of extraction ##")
for i in range(1,6):
print("Session {} is processing...".format(i))
# 1st is the model, 2nd is the weights and all
xtract, config = load_model("model_{}/Sess{}_test/{}emo_{}batch_lr{}{}/best_{}{}_{}emo_{}batch_lr-{}_Test-IEMOCAP{}.pt".format(model_type, i, cates, nb_batch, lr, freeze, model_type, freeze, cates, nb_batch, lr, i), "cuda")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
path_wav = "data/IEMOCAP/Session{}/sentences/wav".format(i)
# We open the file to obtain the golden annotation (and we sort it out to keep only the current session)
<<<<<<< HEAD
recap = open("data/recap_emo_file.txt", "r")
recap_emo = recap.readlines()
recap_emo = [line for line in recap_emo if str(i) in line.split("\t")[0]]
=======
recap = open("data/recap_emo_file.txt","r")
recap_emo = recap.readlines()
recap_emo = [line for line in recap_emo if i in line.split("\t")[0]]
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
recap.close()
for line in tqdm(recap_emo):
# We retrieve the golden emotion and retrieve the index associated to it
<<<<<<< HEAD
gold = line.split("\t")[2].replace("\n", "")
=======
gold = line.split("\t")[2].replace("\n","")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
if "hap+exc" in labels:
if gold == "hap" or gold == "exc":
gold = "hap+exc"
......@@ -147,7 +105,6 @@ for i in range(1,6):
# We extract the predicted emotion ONLY if the gold emotion is within the considered emotions
file = line.split("\t")[1]
<<<<<<< HEAD
folder = line.split("\t")[1].rsplit("_", 1)[0]
signal, sr = torchaudio.load(os.path.join(path_wav, folder, file))
......@@ -160,28 +117,16 @@ for i in range(1,6):
outModel = xtract(signal)
predictions.append(dic_index_emo[outModel[0][1].argmax().item()])
print("\n")
=======
folder = line.split("\t")[1].rsplit("_",1)[0]
signal, sr = torchaudio.load(os.path.join(path_wav,folder,file))
outModel = xtract(signal)
predictions.append(dic_index_emo[outModel[0][1].argmax().item()])
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
assert len(predictions) == len(gold_anno)
# We start to compare the predictions and gold_anno lists
UAR = metrics.recall_score(gold_anno, predictions, average="macro")
UARPercent = round(UAR * 100, 2)
<<<<<<< HEAD
print("UAR:", UARPercent, "%\n")
=======
print("UAR:", UARPercent, "%")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
confMatrix = metrics.confusion_matrix(gold_anno, predictions, labels=labels)
print(confMatrix)
<<<<<<< HEAD
gold_dic = {key: 0 for key in labels}
dico = {key: gold_dic for key in labels}
......@@ -190,47 +135,22 @@ for gold, pred in zip(gold_anno, predictions):
gold_dic[gold] += 1
[print("Total", key, ":", value) for key, value in gold_dic.items()]
=======
gold_dic = {"neu":0, "ang":0,"sad":0,"hap+exc":0}
dico = {"neu":{"ang":0, "sad":0,"neu":0,"hap+exc":0},
"ang":{"ang":0, "sad":0,"neu":0,"hap+exc":0},
"sad":{"ang":0, "sad":0,"neu":0,"hap+exc":0},
"hap+exc":{"ang":0, "sad":0,"neu":0,"hap+exc":0}}
for gold, pred in zip(gold_anno, predictions):
dico[gold][pred] +=1
gold_dic[gold] += 1
print(gold_dic, "\n")
print(dico)
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
annot = []
em = [value for value in gold_dic.values()]
<<<<<<< HEAD
for i in range(0, len(confMatrix)): # row
annot.append([])
tot = 0
for j in range(0, len(confMatrix[i])): # column
nbr = confMatrix[i][j]
percent = round(nbr/em[i], 2)*100
=======
for i in range(0,len(confMatrix)): # row
annot.append([])
tot = 0
for j in range(0,len(confMatrix[i])): # column
nbr = confMatrix[i][j]
percent = round(nbr/em[i],2)*100
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
tot += percent
if j == len(confMatrix[i])-1:
if tot > 100:
percent -= (tot - 100)
elif tot < 100:
percent += (100 - tot)
<<<<<<< HEAD
full = str(int(percent)) + "% (" + str(nbr) + ")"
annot[i].append(full)
......@@ -245,17 +165,3 @@ plt.savefig(os.path.join(path, "confusion_matrix_{}{}_".format(model_type, freez
plt.show()
plt.clf()
print("\nConfusion matrix done!")
=======
full = str(int(percent)) + "% (" + str(nbr) + ")"
annot[i].append(full)
sns.heatmap(confMatrix, annot=annot, fmt="10", cmap="Blues", vmin=0, vmax=350, xticklabels=labels, yticklabels=labels)
plt.title("Model: " + str(config["model_archi"]["model_type"]) + "{}_".format(freeze) + str(config["speaker_number"]) + "emo_{}batch\nlr: {} Data: Test-IEMOCAP-cross_validation".format(nb_batch, lr) + " UAR = " + str(UARPercent) + "%")
plt.xlabel("Prediction")
plt.ylabel("Ground truth")
plt.savefig(os.path.join(path, "confusion_matrix_" + str(config["model_archi"]["model_type"]) + "{}_".format(freeze) + str(config["speaker_number"]) + "emo_{}batch_lr-{}_IEMOCAP-cross_validation.png".format(nb_batch, lr)))
plt.show()
plt.clf()
print("Confusion matrix done!")
>>>>>>> 9c8935c7502c8fd41da5a3818c80cdd9e37e0812
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment