Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Lebourdais
s4d
Commits
a066bd43
Commit
a066bd43
authored
Jul 06, 2018
by
Florent Desnous
Browse files
Updated model_iv.train() to use the sidekit.FactorAnalyser i-vector extractor
parent
99e43094
Changes
1
Hide whitespace changes
Inline
Side-by-side
s4d/model_iv.py
View file @
a066bd43
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# You should have received a copy of the GNU Lesser General Public License
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
from
sidekit
import
Mixture
,
StatServer
,
Scores
,
Ndx
,
PLDA_scoring
,
cosine_scoring
,
mahalanobis_scoring
,
two_covariance_scoring
from
sidekit
import
Mixture
,
StatServer
,
FactorAnalyser
,
Scores
,
Ndx
,
PLDA_scoring
,
cosine_scoring
,
mahalanobis_scoring
,
two_covariance_scoring
from
sidekit.sidekit_io
import
*
from
sidekit.sidekit_io
import
*
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
...
@@ -67,12 +67,13 @@ class ModelIV:
...
@@ -67,12 +67,13 @@ class ModelIV:
print
(
'sn_cov: '
,
self
.
sn_cov
.
shape
)
print
(
'sn_cov: '
,
self
.
sn_cov
.
shape
)
def
train
(
self
,
feature_server
,
idmap
,
normalization
=
True
):
def
train
(
self
,
feature_server
,
idmap
,
normalization
=
True
):
#stat = StatServer(idmap, self.ubm)
stat
=
StatServer
(
idmap
,
distrib_nb
=
self
.
ubm
.
distrib_nb
(),
feature_size
=
self
.
ubm
.
dim
())
stat
=
StatServer
(
idmap
,
distrib_nb
=
self
.
ubm
.
distrib_nb
(),
feature_size
=
self
.
ubm
.
dim
())
stat
.
accumulate_stat
(
ubm
=
self
.
ubm
,
feature_server
=
feature_server
,
seg_indices
=
range
(
stat
.
segset
.
shape
[
0
]),
num_thread
=
self
.
nb_thread
)
stat
.
accumulate_stat
(
ubm
=
self
.
ubm
,
feature_server
=
feature_server
,
seg_indices
=
range
(
stat
.
segset
.
shape
[
0
]),
num_thread
=
self
.
nb_thread
)
stat
=
stat
.
sum_stat_per_model
()[
0
]
stat
=
stat
.
sum_stat_per_model
()[
0
]
self
.
ivectors
=
stat
.
estimate_hidden
(
self
.
tv_mean
,
self
.
tv_sigma
,
V
=
self
.
tv
,
U
=
None
,
D
=
None
,
num_thread
=
self
.
nb_thread
)[
0
]
fa
=
FactorAnalyser
(
mean
=
self
.
tv_mean
,
Sigma
=
self
.
tv_sigma
,
F
=
self
.
tv
)
self
.
ivectors
=
fa
.
extract_ivectors_single
(
self
.
ubm
,
stat
)
if
normalization
:
if
normalization
:
self
.
ivectors
.
spectral_norm_stat1
(
self
.
norm_mean
[:
1
],
self
.
norm_cov
[:
1
])
self
.
ivectors
.
spectral_norm_stat1
(
self
.
norm_mean
[:
1
],
self
.
norm_cov
[:
1
])
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment