Commit 0f14371f authored by Gaëtan Caillaut's avatar Gaëtan Caillaut
Browse files

fix freeze_attention when not set

parent 295cfeb0
......@@ -60,7 +60,12 @@ def t2_run_name_from_params(args):
"nonorm" if args.dont_normalize else "norm"
])
if args.freeze_attention:
try:
freeze_attention = args.freeze_attention
except AttributeError:
freeze_attention = False
if freeze_attention:
s = f"{s}_frozen"
return s
......@@ -378,7 +383,12 @@ def finetune_t1(args):
args.model, args.d, attention_type, position_type, tokenizer, max_seq_size, mask_token, pad_token, device, checkpoint_path=args.checkpoint)
run_name = t1_run_name_from_params(args)
if args.freeze_attention:
try:
freeze_attention = args.freeze_attention
except AttributeError:
freeze_attention = False
if freeze_attention:
model.minibert.freeze()
if args.logdir is None:
......@@ -396,7 +406,7 @@ def finetune_t1(args):
print("BEGIN TRAINING", flush=True)
for epoch in range(prev_epoch + 1, prev_epoch + 1 + args.epochs):
model.train()
if args.freeze_attention:
if freeze_attention:
model.minibert.freeze()
cumloss = 0
......@@ -675,7 +685,12 @@ def finetune_t2(args):
args.model, args.d, attention_type, position_type, tokenizer, max_seq_size, mask_token, pad_token, device, checkpoint_path=args.checkpoint)
run_name = t2_run_name_from_params(args)
if args.freeze_attention:
try:
freeze_attention = args.freeze_attention
except AttributeError:
freeze_attention = False
if freeze_attention:
model.minibert.freeze()
if args.logdir is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment