Commit 7df158ba authored by Sylvain Meignier's avatar Sylvain Meignier
Browse files

fusion Sylvain+PA

Merge branch 'master' of git-lium.univ-lemans.fr:Meignier/s4d
parents afa5e8f2 ad38e1c4
......@@ -3,7 +3,6 @@ __author__ = 'meignier'
import copy
import logging
import numpy as np
from scipy.stats import threshold
from scipy.sparse import csgraph
from collections import namedtuple
from sidekit.bosaris.scores import Scores
......@@ -67,7 +66,10 @@ class ConnectedComponent:
logging.debug('threshold the distance matrix')
distances, t = scores2distance(self.scores, self.thr)
graph = threshold(distances, threshmax=t, newval=np.inf)
mask = (distances>t)
graph = distances.copy()
graph[mask] = np.inf
#graph = threshold(distances, threshmax=t, newval=np.inf)
logging.debug('get connected components')
cc_nb, cc_list = csgraph.connected_components(graph, directed=False)
diar_out = copy.deepcopy(self.diar)
......
......@@ -5,7 +5,7 @@ from sidekit import Mixture, FeaturesServer
from s4d.clustering.hac_utils import argmin, roll
from s4d.diar import Diar
from sidekit.statserver import StatServer
from bottleneck import argpartition as argpartsort
from bottleneck import argpartition
class HAC_CLR:
"""
......@@ -45,7 +45,7 @@ class HAC_CLR:
if argtop is None:
#logging.info('compute argtop '+speaker)
argtop = argpartsort(lp*-1.0 , self.ntop, axis=1)[:, :self.ntop]
argtop = argpartition(lp*-1.0 , self.ntop, axis=1)[:, :self.ntop]
#logging.info(argtop.shape)
if self.ntop is not None:
#logging.info('use ntop '+speaker)
......
......@@ -4,8 +4,8 @@ from scipy.cluster import hierarchy as hac
import matplotlib.pyplot as plt
from scipy import stats
import logging
import subprocess
import time
import os
from s4d.clustering.hac_utils import *
class ILP_IV:
......@@ -28,11 +28,14 @@ class ILP_IV:
f = open(filename, 'w')
self._ilp_write(f)
f.close()
cmd = 'glpsol --lp {} -o {} &> {}'.format(
filename, filename + '.out', filename + '.err')
cmd = 'glpsol --lp {} -o {} &> {}'.format(filename, filename + '.out', filename + '.err')
#print(cmd)
subprocess.call(cmd, shell=True)
sleep(5)
if os.path.exists(filename + '.out'):
os.remove(filename + '.out')
os.system(cmd)
time.sleep(1)
while not os.path.exists(filename + '.out'):
time.sleep(1)
f = open(filename + '.out', 'r')
cluster_dict = self._ilp_read(f)
f.close()
......@@ -54,7 +57,11 @@ class ILP_IV:
f.write(' + {}{}{}'.format(cluster, self.sep, cluster))
# sum of dist > thr in the lower triangular part od the distance matrix
s = np.sum(stats.threshold(np.tril(distances, -1), threshmax=t, newval=0)) + 1
mask = (np.tril(distances, -1)>t)
threshold = np.tril(distances, -1).copy()
threshold[mask] = 0
s = np.sum(threshold) + 1
#s = np.sum(stats.threshold(np.tril(distances, -1), threshmax=t, newval=0)) + 1
logging.debug('ilp sum scores: '+str(s))
l = len(cluster_list)
for i in range(l):
......
......@@ -623,7 +623,7 @@ class Diar():
:param coveringOverlap: a boolean value
"""
#index = self.make_index(['show'])
#index = self.make_index(['show'])
#lst = list()
#for show in index:
# diar = index[show]
......@@ -681,11 +681,11 @@ class Diar():
self.sort(['start'])
i = 0
if len(self.segments) > 1:
self.segments[i]['stop'] = min(max(self.segments[i + 1]['start'] - (epsilon // 2), 0), self.segments[i]['stop'] + epsilon)
self.segments[i]['stop'] = max(self.segments[i]['start'], min(max(self.segments[i + 1]['start'] - (epsilon // 2), 0), self.segments[i]['stop'] + epsilon))
i += 1
while i < len(self.segments)-1:
self.segments[i]['start'] = max(self.segments[i - 1]['stop'], self.segments[i]['start'] - epsilon, 0)
self.segments[i]['stop'] = min(max(self.segments[i + 1]['start'] - (epsilon // 2), 0), self.segments[i]['stop'] + epsilon)
self.segments[i]['stop'] = max(self.segments[i]['start'],min(max(self.segments[i + 1]['start'] - (epsilon // 2), 0), self.segments[i]['stop'] + epsilon))
i += 1
def collar(self, epsilon=0, warning=False):
......@@ -735,8 +735,9 @@ class Diar():
if segment['start'] < first:
first = segment['start']
return first
@classmethod
def read_seg(cls, filename, normalize_cluster=False):
def read_seg(cls, filename, normalize_cluster=False, encoding="utf8"):
"""
Read a segmentation file
:param filename: the str input filename
......@@ -744,7 +745,7 @@ class Diar():
case and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
fic = open(filename, 'r', encoding=encoding)
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
......@@ -775,7 +776,7 @@ class Diar():
return diarization
@classmethod
def read_ctm(cls, filename, normalize_cluster=False):
def read_ctm(cls, filename, normalize_cluster=False, encoding="utf8"):
"""
Read a segmentation file
:param filename: the str input filename
......@@ -783,7 +784,7 @@ class Diar():
and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
fic = open(filename, 'r', encoding=encoding)
diarization = Diar()
try:
for line in fic:
......@@ -806,7 +807,57 @@ class Diar():
return diarization
@classmethod
def read_mdtm(cls, filename, normalize_cluster=False):
def read_stm(cls,filename, normalize_cluster=False, encoding="ISO-8859-1"):
"""
Read a segmentation file
:param filename: the str input filename
:param normalize_cluster: normalize the cluster by removing upper case
and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding=encoding)
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
try:
for line in fic:
line = re.sub('\s+',' ',line)
line = line.strip()
# logging.debug(line)
if line.startswith('#') or line.startswith(';;'):
continue
# split line into fields
split = line.split()
show = split[0]
loc = split[2]
if normalize_cluster:
loc = str2str_normalize(loc)
start = int(float(split[3])*100)
stop = int(float(split[4])*100)
addon = split[5].replace(">", "").replace("<", "").replace(","," ")
lineBis = re.sub('\s+',' ',addon)
lineBis = lineBis.strip()
gender = lineBis.split()[2]
if normalize_cluster:
word = str2str_normalize(word)
# print(show, tmp, start, length, gender, channel, env, speaker)
if gender == "female":
diarization.append(show=show, cluster=loc, start=start,
stop=stop,gender="F")
elif gender == "male":
diarization.append(show=show, cluster=loc, start=start,
stop=stop,gender="M")
else:
diarization.append(show=show, cluster=loc, start=start,
stop=stop)
except Exception as e:
logging.error(sys.exc_info()[0])
logging.error(line)
fic.close()
return diarization
@classmethod
def read_mdtm(cls, filename, normalize_cluster=False, encoding="utf8"):
"""
Read a MDTM file
:param filename: the str input filename
......@@ -815,7 +866,7 @@ class Diar():
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
fic = open(filename, 'r', encoding=encoding)
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
......@@ -838,13 +889,13 @@ class Diar():
return diarization
@classmethod
def read_uem(cls, filename):
def read_uem(cls, filename, encoding="utf8"):
"""
Read a UEM file
:param filename: the str input filename
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
fic = open(filename, 'r', encoding=encoding)
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
......@@ -869,14 +920,14 @@ class Diar():
return diarization
@classmethod
def read_rttm(cls, filename, normalize_cluster=False):
def read_rttm(cls, filename, normalize_cluster=False, encoding="utf8"):
"""
Read rttm file
:param filename: str input filename
:param normalize_cluster: normalize the cluster by removing upper case and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
fic = open(filename, 'r', encoding=encoding)
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
......
......@@ -67,7 +67,8 @@ class ModelIV:
print('sn_cov: ', self.sn_cov.shape)
def train(self, feature_server, idmap, normalization=True):
stat = StatServer(idmap, self.ubm)
#stat = StatServer(idmap, self.ubm)
stat = StatServer(idmap, distrib_nb=self.ubm.distrib_nb(), feature_size=self.ubm.dim())
stat.accumulate_stat(ubm=self.ubm, feature_server=feature_server, seg_indices=range(stat.segset.shape[0]), num_thread=self.nb_thread)
stat = stat.sum_stat_per_model()[0]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment