Commit c11e4144 authored by Anthony Larcher's avatar Anthony Larcher
Browse files

merge and sidesampler

parent acb426e0
...@@ -35,7 +35,7 @@ import sys ...@@ -35,7 +35,7 @@ import sys
# Read environment variable if it exists # Read environment variable if it exists
SIDEKIT_CONFIG={"libsvm":True, SIDEKIT_CONFIG={"libsvm":False,
"mpi":False, "mpi":False,
"cuda":True "cuda":True
} }
......
...@@ -103,10 +103,10 @@ def cosine_scoring(enroll, test, ndx, wccn=None, check_missing=True, device=None ...@@ -103,10 +103,10 @@ def cosine_scoring(enroll, test, ndx, wccn=None, check_missing=True, device=None
device = torch.device("cuda:0" if torch.cuda.is_available() and s_size_in_bytes < 3e9 else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() and s_size_in_bytes < 3e9 else "cpu")
else: else:
device = device if torch.cuda.is_available() and s_size_in_bytes < 3e9 else torch.device("cpu") device = device if torch.cuda.is_available() and s_size_in_bytes < 3e9 else torch.device("cpu")
s = torch.mm(torch.FloatTensor(enroll_copy.stat1).to(device), torch.FloatTensor(test_copy.stat1).to(device).T).cpu().numpy()
score = Scores() score = Scores()
score.scoremat = s score.scoremat = torch.einsum('ij,kj', torch.FloatTensor(enroll_copy.stat1).to(device),
torch.FloatTensor(test_copy.stat1).to(device)).cpu().numpy()
score.modelset = clean_ndx.modelset score.modelset = clean_ndx.modelset
score.segset = clean_ndx.segset score.segset = clean_ndx.segset
score.scoremask = clean_ndx.trialmask score.scoremask = clean_ndx.trialmask
......
...@@ -33,9 +33,16 @@ from .feed_forward import kaldi_to_hdf5 ...@@ -33,9 +33,16 @@ from .feed_forward import kaldi_to_hdf5
from .xsets import IdMapSetPerSpeaker from .xsets import IdMapSetPerSpeaker
from .xsets import SideSet from .xsets import SideSet
from .xsets import SideSampler from .xsets import SideSampler
from .xvector import Xtractor, xtrain, extract_embeddings, extract_sliding_embedding, MeanStdPooling from .xvector import Xtractor
from .res_net import ResBlock, PreResNet34 from .xvector import xtrain
from .rawnet import prepare_voxceleb1, Vox1Set, PreEmphasis from .xvector import extract_embeddings
from .xvector import extract_sliding_embedding
from .pooling import MeanStdPooling
from .pooling import AttentivePooling
from .pooling import GruPooling
from .res_net import ResBlock
from .res_net import PreResNet34
from .res_net import PreFastResNet34
from .sincnet import SincNet from .sincnet import SincNet
from .preprocessor import RawPreprocessor from .preprocessor import RawPreprocessor
from .preprocessor import MfccFrontEnd from .preprocessor import MfccFrontEnd
......
...@@ -164,26 +164,6 @@ def data_augmentation(speech, ...@@ -164,26 +164,6 @@ def data_augmentation(speech,
aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number) aug_idx = random.sample(range(len(transform_dict.keys())), k=transform_number)
augmentations = numpy.array(list(transform_dict.keys()))[aug_idx] augmentations = numpy.array(list(transform_dict.keys()))[aug_idx]
if "phone_filtering" in augmentations:
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
speech,
sample_rate,
effects=[
["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"],
["rate", "16000"],
])
if "filtering" in augmentations:
effects = [
["bandpass","2000","3500"],
["bandstop","200","500"]]
speech,sample_rate = torchaudio.sox_eefects.apply_effects_tensor(
speech,
sample_rate,
effects = [effects[random.randint(0,1)]],
)
if "stretch" in augmentations: if "stretch" in augmentations:
strech = torchaudio.functional.TimeStretch() strech = torchaudio.functional.TimeStretch()
rate = random.uniform(0.8,1.2) rate = random.uniform(0.8,1.2)
...@@ -242,6 +222,28 @@ def data_augmentation(speech, ...@@ -242,6 +222,28 @@ def data_augmentation(speech,
scale = snr * noise_power / speech_power scale = snr * noise_power / speech_power
speech = (scale * speech + noise) / 2 speech = (scale * speech + noise) / 2
if "phone_filtering" in augmentations:
final_shape = speech.shape[1]
speech, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
speech,
sample_rate,
effects=[
["lowpass", "4000"],
["compand", "0.02,0.05", "-60,-60,-30,-10,-20,-8,-5,-8,-2,-8", "-8", "-7", "0.05"],
["rate", "16000"],
])
speech = speech[:, :final_shape]
if "filtering" in augmentations:
effects = [
["bandpass","2000","3500"],
["bandstop","200","500"]]
speech,sample_rate = torchaudio.sox_eefects.apply_effects_tensor(
speech,
sample_rate,
effects = [effects[random.randint(0,1)]],
)
if "codec" in augmentations: if "codec" in augmentations:
final_shape = speech.shape[1] final_shape = speech.shape[1]
configs = [ configs = [
...@@ -273,7 +275,8 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path): ...@@ -273,7 +275,8 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
if noise_duration * sample_rate > speech_shape[1]: if noise_duration * sample_rate > speech_shape[1]:
# It is recommended to split noise files (especially speech noise type) in shorter subfiles # It is recommended to split noise files (especially speech noise type) in shorter subfiles
# When frame_offset is too high, loading the segment can take much longer # When frame_offset is too high, loading the segment can take much longer
frame_offset = random.randrange(noise_start * sample_rate, int((noise_start + noise_duration) * sample_rate - speech_shape[1])) frame_offset = random.randrange(noise_start * sample_rate,
int((noise_start + noise_duration) * sample_rate - speech_shape[1]))
else: else:
frame_offset = noise_start * sample_rate frame_offset = noise_start * sample_rate
...@@ -281,10 +284,10 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path): ...@@ -281,10 +284,10 @@ def load_noise_seg(noise_row, speech_shape, sample_rate, data_path):
if noise_duration * sample_rate > speech_shape[1]: if noise_duration * sample_rate > speech_shape[1]:
noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1])) noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(speech_shape[1]))
else: else:
noise_seg, noise_sr = torchaudio.load(noise_fn, frame_offset=int(frame_offset), num_frames=int(noise_duration * sample_rate)) noise_seg, noise_sr = torchaudio.load(noise_fn,
frame_offset=int(frame_offset),
num_frames=int(noise_duration * sample_rate))
assert noise_sr == sample_rate assert noise_sr == sample_rate
#if numpy.random.randint(0, 2) == 1:
# noise = torch.flip(noise, dims=[0, 1])
if noise_seg.shape[1] < speech_shape[1]: if noise_seg.shape[1] < speech_shape[1]:
noise_seg = torch.tensor(numpy.resize(noise_seg.numpy(), speech_shape)) noise_seg = torch.tensor(numpy.resize(noise_seg.numpy(), speech_shape))
......
...@@ -304,5 +304,62 @@ class SoftmaxAngularProto(torch.nn.Module): ...@@ -304,5 +304,62 @@ class SoftmaxAngularProto(torch.nn.Module):
cos_sim_matrix = torch.nn.functional.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2)) cos_sim_matrix = torch.nn.functional.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
torch.clamp(self.w, 1e-6) torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b cos_sim_matrix = cos_sim_matrix * self.w + self.b
loss = self.criterion(cos_sim_matrix, torch.arange(0, cos_sim_matrix.shape[0], device=x.device)) + self.criterion(cce_prediction, target)
return loss, cce_prediction
return cos_sim_matrix, cce_prediction
class AngularProximityMagnet(torch.nn.Module):
# from https://github.com/clovaai/voxceleb_trainer/blob/3bfd557fab5a3e6cd59d717f5029b3a20d22a281/loss/angleproto.py
def __init__(self, spk_count, emb_dim=256, batch_size=512, init_w=10.0, init_b=-5.0, **kwargs):
super(AngularProximityMagnet, self).__init__()
self.test_normalize = True
self.w = torch.nn.Parameter(torch.tensor(init_w))
self.b1 = torch.nn.Parameter(torch.tensor(init_b))
self.b2 = torch.nn.Parameter(torch.tensor(+5.54))
#last_linear = torch.nn.Linear(512, 1)
#last_linear.bias.data += 1
#self.magnitude = torch.nn.Sequential(OrderedDict([
# ("linear9", torch.nn.Linear(emb_dim, 512)),
# ("relu9", torch.nn.ReLU()),
# ("linear10", torch.nn.Linear(512, 512)),
# ("relu10", torch.nn.ReLU()),
# ("linear11", last_linear),
# ("relu11", torch.nn.ReLU())
# ]))
self.cce_backend = torch.nn.Sequential(OrderedDict([
("linear8", torch.nn.Linear(emb_dim, spk_count))
]))
self.criterion = torch.nn.CrossEntropyLoss()
self.magnet_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
def forward(self, x, target=None):
assert x.size()[1] >= 2
cce_prediction = self.cce_backend(x)
#x = self.magnitude(x) * torch.nn.functional.normalize(x)
if target==None:
return x, cce_prediction
x = x.reshape(-1,2,x.size()[-1]).squeeze(1)
out_anchor = torch.mean(x[:,1:,:],1)
out_positive = x[:,0,:]
ap_sim_matrix = torch.nn.functional.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
torch.clamp(self.w, 1e-6)
ap_sim_matrix = ap_sim_matrix * self.w + self.b1
labels = torch.arange(0, int(out_positive.shape[0]), device=torch.device("cuda:0")).unsqueeze(1)
cos_sim_matrix = torch.mm(out_positive, out_anchor.T)
cos_sim_matrix = cos_sim_matrix + self.b2
cos_sim_matrix = cos_sim_matrix + numpy.log(1/out_positive.shape[0] / (1 - 1/out_positive.shape[0]))
mask = (torch.tile(labels, (1, labels.shape[0])) == labels.T).float()
batch_loss = self.criterion(ap_sim_matrix, torch.arange(0, int(out_positive.shape[0]), device=torch.device("cuda:0"))) \
+ self.magnet_criterion(cos_sim_matrix.flatten().unsqueeze(1), mask.flatten().unsqueeze(1))
return batch_loss, cce_prediction
...@@ -162,9 +162,9 @@ class MelSpecFrontEnd(torch.nn.Module): ...@@ -162,9 +162,9 @@ class MelSpecFrontEnd(torch.nn.Module):
n_fft=1024, n_fft=1024,
f_min=90, f_min=90,
f_max=7600, f_max=7600,
win_length=400, win_length=1024,
window_fn=torch.hann_window, window_fn=torch.hann_window,
hop_length=160, hop_length=256,
power=2.0, power=2.0,
n_mels=80): n_mels=80):
...@@ -227,7 +227,6 @@ class MelSpecFrontEnd(torch.nn.Module): ...@@ -227,7 +227,6 @@ class MelSpecFrontEnd(torch.nn.Module):
return out return out
class RawPreprocessor(torch.nn.Module): class RawPreprocessor(torch.nn.Module):
""" """
......
...@@ -268,7 +268,28 @@ class ResBlock(torch.nn.Module): ...@@ -268,7 +268,28 @@ class ResBlock(torch.nn.Module):
return out return out
class SELayer(torch.nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
self.fc = torch.nn.Sequential(
torch.nn.Linear(channel, channel // reduction, bias=False),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(channel // reduction, channel, bias=False),
torch.nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class BasicBlock(torch.nn.Module): class BasicBlock(torch.nn.Module):
"""
"""
expansion = 1 expansion = 1
def __init__(self, in_planes, planes, stride=1): def __init__(self, in_planes, planes, stride=1):
...@@ -280,6 +301,8 @@ class BasicBlock(torch.nn.Module): ...@@ -280,6 +301,8 @@ class BasicBlock(torch.nn.Module):
stride=1, padding=1, bias=False) stride=1, padding=1, bias=False)
self.bn2 = torch.nn.BatchNorm2d(planes) self.bn2 = torch.nn.BatchNorm2d(planes)
self.se = SELayer(planes)
self.shortcut = torch.nn.Sequential() self.shortcut = torch.nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes: if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = torch.nn.Sequential( self.shortcut = torch.nn.Sequential(
...@@ -291,6 +314,7 @@ class BasicBlock(torch.nn.Module): ...@@ -291,6 +314,7 @@ class BasicBlock(torch.nn.Module):
def forward(self, x): def forward(self, x):
out = torch.nn.functional.relu(self.bn1(self.conv1(x))) out = torch.nn.functional.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out)) out = self.bn2(self.conv2(out))
out = self.se(out)
out += self.shortcut(x) out += self.shortcut(x)
out = torch.nn.functional.relu(out) out = torch.nn.functional.relu(out)
return out return out
...@@ -463,11 +487,13 @@ class PreFastResNet34(torch.nn.Module): ...@@ -463,11 +487,13 @@ class PreFastResNet34(torch.nn.Module):
def forward(self, x): def forward(self, x):
out = x.unsqueeze(1) out = x.unsqueeze(1)
out = out.contiguous(memory_format=torch.channels_last)
out = torch.nn.functional.relu(self.bn1(self.conv1(out))) out = torch.nn.functional.relu(self.bn1(self.conv1(out)))
out = self.layer1(out) out = self.layer1(out)
out = self.layer2(out) out = self.layer2(out)
out = self.layer3(out) out = self.layer3(out)
out = self.layer4(out) out = self.layer4(out)
out = out.contiguous(memory_format=torch.contiguous_format)
out = torch.flatten(out, start_dim=1, end_dim=2) out = torch.flatten(out, start_dim=1, end_dim=2)
return out return out
...@@ -475,119 +501,3 @@ class PreFastResNet34(torch.nn.Module): ...@@ -475,119 +501,3 @@ class PreFastResNet34(torch.nn.Module):
def ResNet34(): def ResNet34():
return ResNet(BasicBlock, [3, 1, 3, 1, 5, 1, 2]) return ResNet(BasicBlock, [3, 1, 3, 1, 5, 1, 2])
def restrain(args):
"""
Initialize and train an ResNet for Speaker Recognition
:param args:
:return:
"""
# Initialize a first model and save to disk
model = ResNet18(args.class_number,
entry_conv_kernel_size=(7,7),
entry_conv_out_channels=64,
megablock_out_channels=(64, 128, 128, 128),
megablock_size=(2, 2, 2, 2),
block_type = ResBlock)
current_model_file_name = "initial_model"
torch.save(model.state_dict(), current_model_file_name)
for epoch in range(1, args.epochs + 1):
current_model_file_name = train_resnet_epoch(epoch, args, current_model_file_name)
# Add the cross validation here
accuracy = resnet_cross_validation(args, current_model_file_name)
print("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate after every epoch
args.lr = args.lr * 0.9
print(" Decrease learning rate: {}".format(args.lr))
def train_resnet_epoch(model, epoch, train_seg_df, speaker_dict, args):
"""
:param model:
:param epoch:
:param train_seg_df:
:param args:
:return:
"""
device = torch.device("cuda:0")
torch.manual_seed(args.seed)
train_transform = []
if not args.train_transformation == '':
trans = args.train_transformation.split(',')
for t in trans:
if "CMVN" in t:
train_transform.append(CMVN())
if "FrequencyMask" in t:
a = t.split(",")[0].split("(")[1]
b = t.split(",")[1].split("(")[0]
train_transform.append(FrequencyMask(a, b))
if "TemporalMask" in t:
a = t.split(",")[0].split("(")[1]
train_transform.append(TemporalMask(a, b))
train_set = VoxDataset(train_seg_df, speaker_dict, 500, transform=transforms.Compose(train_transform),
spec_aug_ratio=args.spec_aug, temp_aug_ratio=args.temp_aug)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=15)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()
accuracy = 0.0
for batch_idx, (data, target, _, __) in enumerate(train_loader):
target = target.squeeze()
optimizer.zero_grad()
output = model(data.to(device))
loss = criterion(output, target.to(device))
loss.backward()
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % args.log_interval == 0:
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, train_loader.__len__(),
100. * batch_idx / train_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * args.batch_size)))
return model
def resnet_cross_validation(args, model, cv_seg_df, speaker_dict):
"""
:param args:
:param model:
:param cv_seg_df:
:return:
"""
cv_transform = []
if not args.cv_transformation == '':
trans = args.cv_transformation.split(',')
for t in trans:
if "CMVN" in t:
cv_transform.append(CMVN())
if "FrequencyMask" in t:
a = t.split(",")[0].split("(")[1]
b = t.split(",")[1].split("(")[0]
cv_transform.append(FrequencyMask(a, b))
if "TemporalMask" in t:
a = t.split(",")[0].split("(")[1]
cv_transform.append(TemporalMask(a, b))
cv_set = VoxDataset(cv_seg_df, speaker_dict, 500, transform=transforms.Compose(cv_transform),
spec_aug_ratio=args.spec_aug, temp_aug_ratio=args.temp_aug)
cv_loader = DataLoader(cv_set, batch_size=args.batch_size, shuffle=False, num_workers=15)
model.eval()
device = torch.device("cuda:0")
model.to(device)
accuracy = 0.0
print(cv_set.__len__())
for batch_idx, (data, target, _, __) in enumerate(cv_loader):
target = target.squeeze()
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * args.batch_size)
...@@ -389,7 +389,6 @@ class IdMapSet(Dataset): ...@@ -389,7 +389,6 @@ class IdMapSet(Dataset):
if "add_noise" in self.transformation: if "add_noise" in self.transformation:
# Load the noise dataset, filter according to the duration # Load the noise dataset, filter according to the duration
noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"]) noise_df = pandas.read_csv(self.transformation["add_noise"]["noise_db_csv"])
#tmp_df = noise_df.loc[noise_df['duration'] > self.duration]
self.noise_df = noise_df.set_index(noise_df.type) self.noise_df = noise_df.set_index(noise_df.type)
self.rir_df = None self.rir_df = None
......
...@@ -30,13 +30,10 @@ import logging ...@@ -30,13 +30,10 @@ import logging
import math import math
import os import os
import numpy import numpy
import random
import pandas import pandas
import pickle
import shutil import shutil
import tabulate
import time
import torch import torch
import torchaudio
import tqdm import tqdm
import yaml import yaml
...@@ -64,9 +61,11 @@ from ..statserver import StatServer ...@@ -64,9 +61,11 @@ from ..statserver import StatServer
from ..iv_scoring import cosine_scoring from ..iv_scoring import cosine_scoring
from .sincnet import SincNet from .sincnet import SincNet
from ..bosaris.detplot import rocch, rocch2eer from ..bosaris.detplot import rocch, rocch2eer
from .loss import SoftmaxAngularProto, ArcLinear from .loss import SoftmaxAngularProto
from .loss import l2_norm from .loss import l2_norm
from .loss import ArcMarginProduct from .loss import ArcMarginProduct
from .loss import ArcLinear
from .loss import AngularProximityMagnet
os.environ['MKL_THREADING_LAYER'] = 'GNU' os.environ['MKL_THREADING_LAYER'] = 'GNU'
...@@ -80,17 +79,25 @@ __status__ = "Production" ...@@ -80,17 +79,25 @@ __status__ = "Production"
__docformat__ = 'reS' __docformat__ = 'reS'
def eer(negatives, positives): def seed_worker():
"""Logarithmic complexity EER computation """
Args: :param worker_id:
negative_scores (numpy array): impostor scores :return:
positive_scores (numpy array): genuine scores """
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
Returns:
float: Equal Error Rate (EER) def eer(negatives, positives):
""" """
Logarithmic complexity EER computation
:param negatives: negative_scores (numpy array): impostor scores
:param positives: positive_scores (numpy array): genuine scores
:return: float: Equal Error Rate (EER)
"""
positives = numpy.sort(positives) positives = numpy.sort(positives)
negatives = numpy.sort(negatives)[::-1] negatives = numpy.sort(negatives)[::-1]
...@@ -234,7 +241,6 @@ def test_metrics(model, ...@@ -234,7 +241,6 @@ def test_metrics(model,
device=device device=device
).get_tar_non(Key(data_opts["test"]["key"])) ).get_tar_non(Key(data_opts["test"]["key"]))
#test_eer = eer(numpy.array(non).astype(numpy.double), numpy.array(tar).astype(numpy.double))
pmiss, pfa = rocch(tar, non) pmiss, pfa = rocch(tar, non)
return rocch2eer(pmiss, pfa) return rocch2eer(pmiss, pfa)
...@@ -476,7 +482,7 @@ class Xtractor(torch.nn.Module): ...@@ -476,7 +482,7 @@ class Xtractor(torch.nn.Module):
self.before_speaker_embedding = torch.nn.Linear(in_features = 2560, self.before_speaker_embedding = torch.nn.Linear(in_features = 2560,
out_features = self.embedding_size) out_features = self.embedding_size)
self.stat_pooling = MeanStdPooling() self.stat_pooling = AttentivePooling(128, 80, global_context=False)
self.stat_pooling_weight_decay = 0 self.stat_pooling_weight_decay = 0
self.loss = loss self.loss = loss
...@@ -489,6 +495,8 @@ class Xtractor(torch.nn.Module): ...@@ -489,6 +495,8 @@ class Xtractor(torch.nn.Module):
elif self.loss == 'aps': elif self.loss == 'aps':
self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number)) self.after_speaker_embedding = SoftmaxAngularProto(int(self.speaker_number))
elif self.loss == 'smn':
self.after_speaker_embedding = AngularProximityMagnet(int(self.speaker_number))
self.preprocessor_weight_decay = 0.00002 self.preprocessor_weight_decay = 0.00002
self.sequence_network_weight_decay = 0.00002 self.sequence_network_weight_decay = 0.00002
...@@ -908,7 +916,9 @@ def update_training_dictionary(dataset_description, ...@@ -908,7 +916,9 @@ def update_training_dictionary(dataset_description,
# Initialize training options # Initialize training options
training_opts["log_file"] = "sidekit.log" training_opts["log_file"] = "sidekit.log"
training_opts["seed"] = 42 training_opts["numpy_seed"] = 0
training_opts["torch_seed"] = 0
training_opts["random_seed"] = 0
training_opts["deterministic"] = False training_opts["deterministic"] = False