Commit e0cfb4ec authored by Anthony Larcher's avatar Anthony Larcher
Browse files

Merge branch 'dev_al' of https://git-lium.univ-lemans.fr/Larcher/sidekit into dev_al

parents 7ce25f50 bca82ed4
...@@ -282,7 +282,7 @@ class ArcMarginProduct(torch.nn.Module): ...@@ -282,7 +282,7 @@ class ArcMarginProduct(torch.nn.Module):
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output = output * self.s output = output * self.s
return output return output, cosine * self.s
class SoftmaxAngularProto(torch.nn.Module): class SoftmaxAngularProto(torch.nn.Module):
......
...@@ -1340,6 +1340,7 @@ def xtrain(dataset_description, ...@@ -1340,6 +1340,7 @@ def xtrain(dataset_description,
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# Set all the seeds # Set all the seeds
random.seed(training_opts["random_seed"])
numpy.random.seed(training_opts["numpy_seed"]) # Set the random seed of numpy for the data split. numpy.random.seed(training_opts["numpy_seed"]) # Set the random seed of numpy for the data split.
torch.manual_seed(training_opts["torch_seed"]) torch.manual_seed(training_opts["torch_seed"])
torch.cuda.manual_seed(training_opts["torch_seed"]) torch.cuda.manual_seed(training_opts["torch_seed"])
...@@ -1519,6 +1520,7 @@ def train_epoch(model, ...@@ -1519,6 +1520,7 @@ def train_epoch(model,
accuracy = 0.0 accuracy = 0.0
running_loss = 0.0 running_loss = 0.0
batch_count = 0
for batch_idx, (data, target) in enumerate(training_loader): for batch_idx, (data, target) in enumerate(training_loader):
data = data.squeeze().to(device) data = data.squeeze().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment