Commit 9ab7a6b8 authored by Le Lan Gaël's avatar Le Lan Gaël
Browse files

local_rank fix

parent 3de11dde
......@@ -164,6 +164,9 @@ def data_augmentation(speech,
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "none" in augmentations:
pass
if "stretch" in augmentations:
stretched_length = int(speech.shape[1] * random.uniform(0.95,1.05))
speech = torch.zeros_like(speech)
......
......@@ -177,7 +177,7 @@ class SideSet(Dataset):
overlap=0.,
dataset_df=None,
min_duration=0.165,
output_format="pytorch",
output_format="pytorch"
):
"""
......
......@@ -536,12 +536,13 @@ class Xtractor(torch.nn.Module):
m = 0.20,
easy_margin = False)
elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number),
emb_dim=self.embedding_size)
self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002
self.stat_pooling_weight_decay = 0.00002
self.before_speaker_embedding_weight_decay = 0.00002
self.after_speaker_embedding_weight_decay = 0 #0.0002
self.after_speaker_embedding_weight_decay = 0.0002
elif model_archi == "rawnet2":
......@@ -1094,7 +1095,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
transform_number=dataset_opts['train']['transform_number'],
overlap=dataset_opts['train']['overlap'],
dataset_df=training_df,
output_format="pytorch",
output_format="pytorch"
)
validation_set = SideSet(dataset_opts,
......@@ -1109,15 +1110,14 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
samples_per_speaker = 1
if training_opts["multi_gpu"]:
assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
assert dataset_opts["batch_size"] % samples_per_speaker == 0
batch_size = dataset_opts["batch_size"]//(torch.cuda.device_count() * dataset_opts["train"]["sampler"]["examples_per_speaker"])
#assert dataset_opts["batch_size"] % torch.cuda.device_count() == 0
#assert dataset_opts["batch_size"] % samples_per_speaker == 0
side_sampler = SideSampler(data_source=training_set.sessions['speaker_idx'],
spk_count=model_opts["speaker_number"],
examples_per_speaker=dataset_opts["train"]["sampler"]["examples_per_speaker"],
samples_per_speaker=dataset_opts["train"]["sampler"]["samples_per_speaker"],
batch_size=batch_size,
batch_size=dataset_opts["batch_size"] * torch.cuda.device_count(),
seed=training_opts['torch_seed'],
rank=local_rank,
num_process=torch.cuda.device_count(),
......@@ -1131,7 +1131,7 @@ def get_loaders(dataset_opts, training_opts, model_opts, local_rank=0):
batch_size=dataset_opts["batch_size"],
seed=training_opts['torch_seed'],
rank=0,
num_process=torch.cuda.device_count(),
num_process=1,
num_replicas=dataset_opts["train"]["sampler"]["augmentation_replica"]
)
......@@ -1221,7 +1221,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
max_lr=train_opts["lr"],
step_size_up=training_loader.__len__() * 16,
step_size_up=training_loader.__len__() * 10,
step_size_down=None,
cycle_momentum=cycle_momentum,
mode="triangular2")
......@@ -1427,7 +1427,8 @@ def xtrain(dataset_description,
training_loader, validation_loader,\
sampler, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
model_opts)
model_opts,
local_rank)
if local_rank < 1:
monitor.logger.info(f"Start training process")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment