Skip to content

Commit

Permalink
[fix] adamw : exclude from weight_decay (#614)
Browse files Browse the repository at this point in the history
* [fix] adamw : `exclude from weight_decay`

* [fix] fix demo `finetune_classifier.py`

Co-authored-by: chenxuyi <[email protected]>
  • Loading branch information
Meiyim and chenxuyi authored Jan 13, 2021
1 parent 738e368 commit b3c2817
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions demo/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ def map_fn(seg_a, seg_b, label):
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
else:
lr_scheduler = None
opt = P.optimizer.Adam(
opt = P.optimizer.AdamW(
args.lr,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)

scaler = P.amp.GradScaler(enable=args.use_amp)
Expand All @@ -209,7 +209,8 @@ def map_fn(seg_a, seg_b, label):
lr_scheduler and lr_scheduler.step()

if step % 10 == 0:
_lr = lr_scheduler.get_lr()
_lr = lr_scheduler.get_lr(
) if args.use_lr_decay else args.lr
if args.use_amp:
_l = (loss / scaler._scale).numpy()
msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % (
Expand Down
2 changes: 1 addition & 1 deletion demo/finetune_classifier_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def map_fn(seg_a, seg_b, label):
opt = P.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
weight_decay=args.wd,
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
Expand Down
2 changes: 1 addition & 1 deletion demo/finetune_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def evaluate(model, dataset):
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)

scaler = P.amp.GradScaler(enable=args.use_amp)
Expand Down
2 changes: 1 addition & 1 deletion demo/finetune_sentiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def map_fn(seg_a, label):
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
with LogWriter(logdir=str(create_if_not_exists(args.save_dir /
Expand Down

0 comments on commit b3c2817

Please sign in to comment.