Commit 8ccc51fd authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

show progress

parent 51eee004
from pathlib import Path from pathlib import Path
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from load import * from load import *
from evaluation import * from evaluation import *
...@@ -112,12 +113,18 @@ if __name__ == "__main__": ...@@ -112,12 +113,18 @@ if __name__ == "__main__":
if args.checkpoint is None: if args.checkpoint is None:
outdir.mkdir() outdir.mkdir()
print("BEGIN TRAINING", flush=True)
for epoch in range(prev_epoch + 1, prev_epoch + 1 + args.epochs): for epoch in range(prev_epoch + 1, prev_epoch + 1 + args.epochs):
model.train() model.train()
cumloss = 0 cumloss = 0
n_train = 0 n_train = 0
for batch_id, (x, attention_mask, wids) in enumerate(train_loader): t0_epoch = datetime.now()
batch_cumulated_time = datetime.timedelta()
for batch_id, (x, attention_mask, wids) in enumerate(train_loader, 1):
t0_batch = datetime.now()
x = x.to(device) x = x.to(device)
attention_mask = attention_mask.to(device) attention_mask = attention_mask.to(device)
wids = wids.to(device) wids = wids.to(device)
...@@ -128,8 +135,19 @@ if __name__ == "__main__": ...@@ -128,8 +135,19 @@ if __name__ == "__main__":
optimizer.step() optimizer.step()
cumloss += loss.item() cumloss += loss.item()
t1_batch = datetime.now()
batch_time = t1_batch - t0_batch
batch_cumulated_time += batch_time
if batch_id % args.progress:
print(
f"BATCH {batch_id:05}/{epoch:04} - LOSS {loss.item()} - TIME {batch_cumulated_time}", flush=True)
batch_cumulated_time = datetime.timedelta()
writer.add_scalar("Loss/train", cumloss / len(train_loader), epoch) writer.add_scalar("Loss/train", cumloss / len(train_loader), epoch)
print(f"EPOCH {epoch:04} - Loss: {cumloss / len(train_loader)}") t1_epoch = datetime.now()
print(
f"EPOCH {epoch:04} - MEAN LOSS {cumloss / len(train_loader)} - TIME {t1_epoch - t0_epoch}", flush=True)
if epoch % args.epochs_between_save == 0: if epoch % args.epochs_between_save == 0:
model.eval() model.eval()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment