Commit 8e6de0d2 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

debug local_rank

parent 2aacbe26
......@@ -276,7 +276,7 @@ class ArcMarginProduct(torch.nn.Module):
"""
# cos(theta)
cosine = torch.nn.functional.linear(torch.nn.functional.normalize(input),
torch.nn.functional.normalize(self.weight))
torch.nn.functional.normalize(self.weight))
if target == None:
return cosine * self.s
# cos(theta + m)
......
......@@ -105,7 +105,6 @@ class SideSampler(torch.utils.data.Sampler):
self.segment_cursors = numpy.zeros((len(self.labels_to_indices),), dtype=numpy.int)
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
......@@ -269,6 +268,8 @@ class SideSet(Dataset):
self.transform["codec"] = []
if "phone_filtering" in transforms:
self.transform["phone_filtering"] = []
if "stretch" in transforms:
self.transform["stretch"] = []
self.noise_df = None
if "add_noise" in self.transform:
......
......@@ -489,9 +489,9 @@ class Xtractor(torch.nn.Module):
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
s=30.,
m=0.2,
easy_margin=False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
......@@ -525,9 +525,9 @@ class Xtractor(torch.nn.Module):
if self.loss == "aam":
self.after_speaker_embedding = ArcMarginProduct(self.embedding_size,
int(self.speaker_number),
s = 30,
m = 0.2,
easy_margin = False)
s=30.,
m=0.2,
easy_margin=False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.preprocessor_weight_decay = 0.00002
......@@ -778,6 +778,8 @@ class Xtractor(torch.nn.Module):
:param x:
:param is_eval: False for training
:param target:
:param norm_embedding:
:return:
"""
if self.preprocessor is not None:
......@@ -1089,6 +1091,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
validation_set = SideSet(dataset_opts,
set_type="validation",
chunk_per_segment=1,
dataset_df=validation_df,
output_format="pytorch")
......@@ -1330,7 +1333,6 @@ def xtrain(dataset_description,
REFACTORING
- en cas de redemarrage à partir d'un modele existant, recharger l'optimize et le scheduler
"""
local_rank = -1
if "RANK" in os.environ:
local_rank = int(os.environ['RANK'])
......@@ -1429,7 +1431,8 @@ def xtrain(dataset_description,
training_loader, validation_loader,\
sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
model_opts)
model_opts,
local_rank)
if local_rank < 1:
monitor.logger.info(f"Start training process")
......@@ -1669,7 +1672,6 @@ def cross_validation(model, validation_loader, device, validation_shape, tar_ind
return (100. * accuracy.cpu().numpy() / validation_shape[0],
loss.cpu().numpy() / ((batch_idx + 1) * batch_size),
equal_error_rate)
return 0, 0, 0
def extract_embeddings(idmap_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment