Commit 9bf349fb authored by Anthony Larcher's avatar Anthony Larcher
Browse files

refactoring

parent 3a559ff9
......@@ -1226,7 +1226,11 @@ def new_xtrain(dataset_description,
"""
REFACTORING
- affiner les loggings
- en cas de redemarrage à partir d'un modele existant, recharger l'optimize et le scheduler
"""
# Test to optimize
torch.autograd.profiler.emit_nvtx(enabled=False)
dataset_opts, model_opts, training_opts = update_training_dictionary(dataset_description,
model_description,
training_description,
......@@ -1258,30 +1262,27 @@ def new_xtrain(dataset_description,
monitor.logger.info("\n*********************************\nTraining options\n*********************************\n")
monitor.logger.info(yaml.dump(training_opts, default_flow_style=False))
# Test to optimize
torch.autograd.profiler.emit_nvtx(enabled=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the model
model = get_network(model_opts)
speaker_number = model.speaker_number
embedding_size = model.embedding_size
# Set the device and manage parallel processing
if torch.cuda.device_count() > 1 and training_opts["multi_gpu"]:
model = torch.nn.DataParallel(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
monitor.logger.info(f"Start process at {time.strftime('%H:%M:%S', time.localtime())}")
monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
# Initialise data loaders
training_loader, validation_loader, validation_tar_indices, validation_non_indices = get_loaders(dataset_opts,
training_opts,
speaker_number)
monitor.logger.info(f"Start training process")
monitor.logger.info(f"Use \t{torch.cuda.device_count()} \tgpus")
monitor.logger.info(f"Use \t{training_opts['num_cpu']} \tcpus")
monitor.logger.info(f"Validation EER will be measured using")
monitor.logger.info(f"\t {numpy.sum(validation_tar_indices)} target trials and")
monitor.logger.info(f"\t {numpy.sum(validation_non_indices)} non-target trials")
......@@ -1303,13 +1304,13 @@ def new_xtrain(dataset_description,
break
model = new_train_epoch(model,
monitor,
training_loader,
optimizer,
scheduler,
device,
scaler=scaler
)
training_opts,
monitor,
training_loader,
optimizer,
scheduler,
device,
scaler=scaler)
# Cross validation
if math.fmod(epoch, training_opts["validation_frequency"]) == 0:
......@@ -1760,7 +1761,8 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
optimizer.step()
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if batch_idx % log_interval == 0:
if math.fmod(batch_idx, log_interval) == 0:
batch_size = target.shape[0]
logging.critical('{}, Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
time.strftime('%H:%M:%S', time.localtime()),
......@@ -1786,6 +1788,7 @@ def train_epoch(model, epoch, training_loader, optimizer, scheduler, log_interva
def new_train_epoch(model,
training_opts,
training_monitor,
training_loader,
optimizer,
......@@ -1796,6 +1799,7 @@ def new_train_epoch(model,
"""
:param model:
:param training_opts:
:param training_monitor:
:param training_loader:
:param optimizer:
......@@ -1860,7 +1864,7 @@ def new_train_epoch(model,
running_loss += loss.item()
accuracy += (torch.argmax(output.data, 1) == target).sum()
if batch_idx % training_monitor.log_interval == 0:
if math.fmod(batch_idx, training_opts["log_interval"]) == 0:
batch_size = target.shape[0]
training_monitor.update(training_loss=loss.item(),
training_acc=100.0 * accuracy.item() / ((batch_idx + 1) * batch_size))
......@@ -2127,34 +2131,6 @@ def extract_embeddings_per_speaker(idmap_name,
return embeddings
def second_logger():
# create logger with 'spam_application'
logger = logging.getLogger('monitoring')
logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler('monitor.log')
fh.setLevel(logging.INFO)
logger.addHandler(fh)
logger.info("EER = 100")
logger.info("EER = 90")
logger.info("EER = 80")
logger.info("EER = 70")
logger.info("EER = 60")
logger.info("EER = 50")
logger.info("EER = 40")
def test_logger():
init_logging(filename="./test.log")
logging.info('test de l info.')
logging.debug('test du debug.')
logging.critical('test du critical')
logging.warning("test du warning")
logging.error('test error')
second_logger()
def extract_sliding_embedding(idmap_name,
window_len,
window_shift,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment