Commit ed17bd6a authored by Anthony Larcher's avatar Anthony Larcher
Browse files

new HAL cross show

parent cc93cb3e
......@@ -277,9 +277,7 @@ def cross_show_HAL(previous_vec,
:param lim:
:param user:
:param file_info:
:param uem:
:param ref:
:param human_in_the_loop:
:param archive_file_info:
:return:
"""
within_vec_backup = copy.deepcopy(within_vec)
......@@ -294,10 +292,79 @@ def cross_show_HAL(previous_vec,
This matrix is normalized to enable/disable clustering between previous/previous
and within/within clusters
"""
ll_vec, scores = compute_distance_cross_show(previous_vec_mean, previous_diar, within_vec_mean)
# merge the mean_per_model for previous and within
ll_vec = concat_statservers(previous_vec, within_vec)
# Compute the score matrix
ndx = sidekit.Ndx(models=ll_vec.modelset, testsegs=ll_vec.modelset)
scores = sidekit.iv_scoring.cosine_scoring(ll_vec,
ll_vec,
ndx,
wccn=None,
check_missing=False,
device=torch.device("cuda"))
previous_locked_spk = []
linkage_speaker_dict = {}
# For each speaker in the current file
for ii in range(previous_vec_mean.modelset.shape[0], ll_vec.modelset.shape[0]):
question_number = 0
# Get the current name of the speaker
current_speaker_name = scores.modelset[ii]
# get the scores obtained with all previous speakers and rank them
sorted_idx = numpy.argsort(scores.scoremat[ii][:previous_vec_mean.modelset.shape[0]])
sorted_scores_current_speaker = scores.scoremat[ii, sorted_idx]
# If one score is above th_x AND that the corresponding previous speaker is not locked
for jj, previous_spk_idx in enumerate(sorted_idx):
tdict = {}
previous_spk_name = ll_vec.modelset[previous_spk_idx]
# There are scores higher than the threshold
if sorted_scores_current_speaker[previous_spk_idx] > th_x:
if not ll_vec.modelset[previous_spk_idx] in previous_locked_spk:
# ---> link the speakers
linkage_speaker_dict[current_speaker_name] = ll_vec.modelset[previous_spk_idx]
# ---> lock the previous speaker
previous_locked_spk.append(ll_vec.modelset[previous_spk_idx])
# move to next speaker
break
# There are no more scores higher than the threshold
else:
if previous_spk_idx not in previous_locked_spk:
# Get the time of the middle of the longest segment for the longest seg within_diar_id in within_diar
tmp_diar = copy.deepcopy(within_diar)
tmp_diar.filter("cluster", "==", current_speaker_name).add_duration().sort(["duration"], reverse=True)
show1 = tmp_diar[0]["show"]
t1 = (tmp_diar[0]["stop"] - tmp_diar[0]["start"]) / 200.
# Get the time of the middle of the longest segment for previous_diar_id in previous_diar
tmp_diar = copy.deepcopy(previous_diar)
tmp_diar.filter("cluster", "==", previous_spk_name).add_duration().sort(["duration"], reverse=True)
show2 = tmp_diar[0]["show"]
t2 = (tmp_diar[0]["stop"] - tmp_diar[0]["start"]) / 200.
# Ask the question to the user
complete_hyp = copy.deepcopy(previous_diar)
complete_hyp.append_diar(within_diar)
message_to_user = MessageToUser(file_info,
s4d_to_allies(complete_hyp),
Request('same', t1, t2, archive_file_info[show2]))
keep_questioning, answer = user.validate(message_to_user)
question_number += 1
if question_number > lim:
break
"""
linkage_speaker_dict = {}
# For each speaker in the current diarization
for ii in range(previous_vec_mean.modelset.shape[0], ll_vec.modelset.shape[0]):
......@@ -366,6 +433,7 @@ def cross_show_HAL(previous_vec,
else:
previous_spk_idx += 1
"""
# concatenate previous_vec et within_vec
new_previous_vec = concat_statservers(previous_vec_backup, within_vec_backup)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment