Commit 139ccbdc authored by Anthony Larcher's avatar Anthony Larcher
Browse files

fix minor bugs

parents 80e03e74 d3c430ed
......@@ -68,7 +68,7 @@ class Database:
Can select according to the TV> channel : BFM or LCP and depending on the show:
"""
def __init__(self, root_folder, parameters, show=None):
def __init__(self, root_folder, parameters, show=""):
"""
:param root_folder:
......
#!/usr/bin/python
import copy
import numpy
import pandas
import scipy
from s4d.clustering.hac_utils import scores2distance
from scipy.cluster import hierarchy as hac
from scipy.spatial.distance import squareform
from .utils import s4d_to_allies
from ..user_simulation import MessageToUser
from ..user_simulation import Request
def ask_question(node_to_check, scores, current_diar, user, file_info):
"""
:param node_to_check: row of the linkage matrix to check with the user
:param scores: the corresponding score object
:param current_diar: the current diarization
:param user: the user simulation
:return: is_same_speaker: a boolean, True if the clustering at this node is approved by the user
"""
# get one segment for each side of the node (if it's a leaf, it's easy, if qe already have a cluster we need to
# list all segments in this cluster and then find the most central segment in the cluster from the score matrix
# HERE WE NEED TO FIND TWO indices in the scores.scoremat matrix
# TODO
# From the indices, get the borders of the segments in the Diar object
# TODO
# For each segments, get the center of the segment in seconds
# TODO
t1 = 0.
t2 = 1.
# Ask the question to the user
message_to_user = MessageToUser(file_info,
s4d_to_allies(current_diar),
Request('same', t1, t2))
hal, answer = user.validate(message_to_user)
return answer.answer, hal
def active_learning_tree(current_diar,
current_vec,
scores,
threshold,
user,
file_info,
clustering_method="complete"):
"""
:param current_diar: the segmentation that comes out from the automatic pass
:param current_vec: StatServer of vectors not YET normalized
:param scores: Scores object with symetric matrix of PLDA scores inside
:param threshold: the clustering threshold
:param user: user simulation developed in the ALLIES package
:param clustering_method: complete
:return:
"""
# Perform HAC on the vectors (that shoiuld produce exactly the same segmentation as the one in current diar
ldiar = copy.deepcopy(current_diar)
lscores = copy.deepcopy(scores)
# get the triangular part of the distances
distances, th = scores2distance(lscores, threshold)
distance_sym = squareform(distances)
# t = -1.0 * threshold - min
# cluster the data
link = hac.linkage(distance_sym, method=clustering_method)
tmp = numpy.zeros((link.shape[0], link.shape[1] + 2))
tmp[:, :-2] = link
tmp[:, -2] = link[:, 2] - th
tmp[:, -1] = numpy.abs(link[:, 2] - th)
# tmp[:,-2] permet de savoir si on est au dessus du seuil (positif) ou en dessous (négatif)
# tmp[:, -1] indique la distance au seuil en valeur absolue
# On trie les regroupements par ordre de proximité au seuil
# plus proche en premiere ligne
links_to_check = tmp[numpy.argsort(tmp[:, -1])]
# On récupère la liste des fusions à faire d'abord.
final_links = []
for l in link:
if l[2] < th:
final_links.append(l)
# Maintenant on analyse noeud par noeud du plus proche au plus lointain du seuil de fusion
# et on gère les deux directions
no_more_clustering = False
no_more_separation = False
for ltc in links_to_check:
# Si on cesse dans les deux directions on sort de la boucle de correction
if no_more_clustering and no_more_separation:
break
elif ltc[-2] < 0: # On est en dessous du seuil, on a déjà regroupé
# demande si on sépare: trouve les numéros de deux segments
# représentatifs des deux cluster déjà regroupés et pose la question
if no_more_separation: # On laisse ensemble
pass
else: # On sépare, concretement on retire la fusion de la liste final_link
is_same_speaker = ask_question(ltc, scores, ldiar, user, file_info)
if is_same_speaker:
no_more_separation = True
else:
for ii, fl in enumerate(final_links):
if numpy.array_equal(fl, numpy.array([0., 1., 2., 2.])):
_ = final_links.pop(ii)
elif ltc[-2] > 0: # On est au dessous du seuil, on n'a pas encore regroupé
if no_more_clustering:
pass
else:
is_same_speaker = ask_question(ltc, scores, ldiar, user, file_info) # true
if is_same_speaker: # On fait un nouveau regroupement
final_links.append(ltc[:4])
else:
# On arrête de chercher de ce côté
no_more_clustering = True
# Apply clustering on top of the current_diar
# TODO
# Rename current_vec accordingly
# TODO
return False, current_diar, current_vec
......
......@@ -31,6 +31,7 @@ from ..user_simulation import MessageToUser
from ..user_simulation import Request
from .interactive import apply_correction
from .interactive import active_learning_tree
#from s4d import viterbi, segmentation
#from s4d.clustering import hac_bic
......@@ -851,33 +852,31 @@ def allies_within_show_HAL(current_diar,
user,
uem,
ref):
"""
At the moment only correct the diarization using Active Learning
:param current_diar:
:param current_vec:
:param scores:
:param th:
:param model_cfg:
:param model:
:param file_info:
:param user:
:param uem:
:param ref:
:return:
"""
hal = True
while hal:
# get correction from the user
message_to_user = MessageToUser(file_info,
s4d_to_allies(current_diar),
Request('same', 0., 1.))
hal, answer = user.validate(message_to_user)
answer.time_1 = int(answer.time_1 * 100)
answer.time_2 = int(answer.time_2 * 100)
print(answer)
# Get the index of the corrected segments in the diarization
for idx, seg in enumerate(current_diar.segments):
if seg['start'] <= answer.time_1 <= seg["stop"]:
idx_1 = idx
cluster_1 = current_diar[idx_1]['cluster']
if seg["start"] <= answer.time_2 <= seg["stop"]:
idx_2 = idx
cluster_2 = current_diar[idx_2]['cluster']
# apply corretion and produce the new segmentation
# Only correct when the two segments have to be grouped
if answer.response_type == "same" and answer.answer:
current_diar = apply_correction(current_diar, scores, idx_1, idx_2)
# If the mode is "active learning"
hal, current_diar, current_vec = active_learning_tree(current_diar,
current_vec,
scores,
th,
user,
file_info)
return current_diar, current_vec
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment