Commit 00ec1954 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

add halffastresnet

parent b6038a2b
......@@ -85,8 +85,7 @@ class SideSampler(torch.utils.data.Sampler):
self.num_process = num_process
self.num_replicas = num_replicas
assert batch_size % (examples_per_speaker * self.num_replicas) == 0
#assert batch_size % examples_per_speaker == 0
assert batch_size % examples_per_speaker == 0
assert (self.samples_per_speaker * self.spk_count * self.examples_per_speaker) % self.num_process == 0
self.batch_size = batch_size // (self.examples_per_speaker * self.num_replicas)
......
......@@ -995,7 +995,7 @@ def get_network(model_opts, local_rank):
:return:
"""
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34"]:
if model_opts["model_type"] in ["xvector", "rawnet2", "resnet34", "fastresnet34", "halfresnet34"]:
model = Xtractor(model_opts["speaker_number"], model_opts["model_type"], loss=model_opts["loss"]["type"])
else:
# Custom type of model
......@@ -1191,7 +1191,7 @@ def get_optimizer(model, model_opts, train_opts, training_loader):
if train_opts["scheduler"]["type"] == 'CyclicLR':
cycle_momentum = True
if train_opts["optimizer"]["type"] == "aam":
if train_opts["optimizer"]["type"] in ["aam", "aps"]:
cycle_momentum = False
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizer,
base_lr=1e-8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment