Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Ambuj Mehrish
sidekit
Commits
7c0073ee
Commit
7c0073ee
authored
Mar 29, 2021
by
Gaël Le Lan
Browse files
bugfix validation
parent
3a2b9935
Changes
2
Hide whitespace changes
Inline
Side-by-side
nnet/xsets.py
View file @
7c0073ee
...
@@ -208,7 +208,7 @@ class SideSet(Dataset):
...
@@ -208,7 +208,7 @@ class SideSet(Dataset):
chunk_nb
=
len
(
possible_starts
)
chunk_nb
=
len
(
possible_starts
)
else
:
else
:
chunk_nb
=
min
(
len
(
possible_starts
),
chunk_per_segment
)
chunk_nb
=
min
(
len
(
possible_starts
),
chunk_per_segment
)
starts
=
numpy
.
random
.
permutation
(
possible_starts
)[:
chunk_nb
]
/
self
.
sample_rate
starts
=
numpy
.
random
.
permutation
(
possible_starts
)[:
chunk_nb
]
# Once we know how many segments are selected, create the other fields to fill the DataFrame
# Once we know how many segments are selected, create the other fields to fill the DataFrame
for
ii
in
range
(
chunk_nb
):
for
ii
in
range
(
chunk_nb
):
...
@@ -256,13 +256,16 @@ class SideSet(Dataset):
...
@@ -256,13 +256,16 @@ class SideSet(Dataset):
# TODO is this required ?
# TODO is this required ?
nfo
=
soundfile
.
info
(
f
"
{
self
.
data_path
}
/
{
current_session
[
'file_id'
]
}{
self
.
data_file_extension
}
"
)
nfo
=
soundfile
.
info
(
f
"
{
self
.
data_path
}
/
{
current_session
[
'file_id'
]
}{
self
.
data_file_extension
}
"
)
original_start
=
int
(
current_session
[
'start'
])
original_start
=
int
(
current_session
[
'start'
])
lowest_shift
=
self
.
overlap
/
2
if
self
.
overlap
>
0
:
highest_shift
=
self
.
overlap
/
2
lowest_shift
=
self
.
overlap
/
2
if
original_start
<
(
current_session
[
'file_start'
]
*
self
.
sample_rate
+
self
.
sample_number
/
2
):
highest_shift
=
self
.
overlap
/
2
lowest_shift
=
int
(
original_start
-
current_session
[
'file_start'
]
*
self
.
sample_rate
)
if
original_start
<
(
current_session
[
'file_start'
]
*
self
.
sample_rate
+
self
.
sample_number
/
2
):
if
original_start
+
self
.
sample_number
>
(
current_session
[
'file_start'
]
+
current_session
[
'file_duration'
])
*
self
.
sample_rate
-
self
.
sample_number
/
2
:
lowest_shift
=
int
(
original_start
-
current_session
[
'file_start'
]
*
self
.
sample_rate
)
highest_shift
=
int
((
current_session
[
'file_start'
]
+
current_session
[
'file_duration'
])
*
self
.
sample_rate
-
(
original_start
+
self
.
sample_number
))
if
original_start
+
self
.
sample_number
>
(
current_session
[
'file_start'
]
+
current_session
[
'file_duration'
])
*
self
.
sample_rate
-
self
.
sample_number
/
2
:
start_frame
=
original_start
+
int
(
random
.
uniform
(
-
lowest_shift
,
highest_shift
))
highest_shift
=
int
((
current_session
[
'file_start'
]
+
current_session
[
'file_duration'
])
*
self
.
sample_rate
-
(
original_start
+
self
.
sample_number
))
start_frame
=
original_start
+
int
(
random
.
uniform
(
-
lowest_shift
,
highest_shift
))
else
:
start_frame
=
original_start
if
start_frame
+
self
.
sample_number
>=
nfo
.
frames
:
if
start_frame
+
self
.
sample_number
>=
nfo
.
frames
:
start_frame
=
numpy
.
min
(
nfo
.
frames
-
self
.
sample_number
-
1
)
start_frame
=
numpy
.
min
(
nfo
.
frames
-
self
.
sample_number
-
1
)
...
...
nnet/xvector.py
View file @
7c0073ee
...
@@ -1105,7 +1105,7 @@ def xtrain(speaker_number,
...
@@ -1105,7 +1105,7 @@ def xtrain(speaker_number,
test_eer
=
100.
test_eer
=
100.
classes
=
torch
.
Byte
Tensor
(
validation_set
.
sessions
[
'speaker_idx'
].
to_numpy
())
classes
=
torch
.
Short
Tensor
(
validation_set
.
sessions
[
'speaker_idx'
].
to_numpy
())
mask
=
classes
.
unsqueeze
(
1
)
==
classes
.
unsqueeze
(
1
).
T
mask
=
classes
.
unsqueeze
(
1
)
==
classes
.
unsqueeze
(
1
).
T
tar_indices
=
torch
.
tril
(
mask
,
-
1
).
numpy
()
tar_indices
=
torch
.
tril
(
mask
,
-
1
).
numpy
()
non_indices
=
torch
.
tril
(
~
mask
,
-
1
).
numpy
()
non_indices
=
torch
.
tril
(
~
mask
,
-
1
).
numpy
()
...
@@ -1302,6 +1302,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
...
@@ -1302,6 +1302,7 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
embeddings
=
torch
.
zeros
(
validation_shape
)
embeddings
=
torch
.
zeros
(
validation_shape
)
#classes = torch.zeros([validation_shape[0]])
#classes = torch.zeros([validation_shape[0]])
cursor
=
0
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
tqdm
.
tqdm
(
validation_loader
,
desc
=
'validation compute'
,
mininterval
=
1
,
disable
=
None
)):
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
tqdm
.
tqdm
(
validation_loader
,
desc
=
'validation compute'
,
mininterval
=
1
,
disable
=
None
)):
target
=
target
.
squeeze
()
target
=
target
.
squeeze
()
...
@@ -1318,8 +1319,9 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
...
@@ -1318,8 +1319,9 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
batch_embeddings
=
l2_norm
(
batch_embeddings
)
batch_embeddings
=
l2_norm
(
batch_embeddings
)
accuracy
+=
(
torch
.
argmax
(
batch_predictions
.
data
,
1
)
==
target
).
sum
()
accuracy
+=
(
torch
.
argmax
(
batch_predictions
.
data
,
1
)
==
target
).
sum
()
loss
+=
criterion
(
batch_predictions
,
target
)
loss
+=
criterion
(
batch_predictions
,
target
)
embeddings
[
batch_idx
*
batch_size
:
batch_idx
*
batch_size
+
batch_predictions
.
shape
[
0
],:]
=
batch_embeddings
.
detach
().
cpu
()
embeddings
[
cursor
:
cursor
+
batch_size
,:]
=
batch_embeddings
.
detach
().
cpu
()
#classes[batch_idx * batch_size:batch_idx * batch_size + batch_predictions.shape[0]] = target.detach().cpu()
#classes[cursor:cursor + batch_size] = target.detach().cpu()
cursor
+=
batch_size
#print(classes.shape[0])
#print(classes.shape[0])
local_device
=
"cpu"
if
embeddings
.
shape
[
0
]
>
3e4
else
device
local_device
=
"cpu"
if
embeddings
.
shape
[
0
]
>
3e4
else
device
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment