From b3c28179aabf2d00ce8ec5d4f321e66da1e6a11b Mon Sep 17 00:00:00 2001 From: Meiyim Date: Wed, 13 Jan 2021 19:49:59 +0800 Subject: [PATCH] [fix] adamw : `exclude from weight_decay` (#614) * [fix] adamw : `exclude from weight_decay` * [fix] fix demo `finetune_classifier.py` Co-authored-by: chenxuyi --- demo/finetune_classifier.py | 9 +++++---- demo/finetune_classifier_distributed.py | 2 +- demo/finetune_ner.py | 2 +- demo/finetune_sentiment_analysis.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/demo/finetune_classifier.py b/demo/finetune_classifier.py index 4e9540d1..ce14d3ef 100644 --- a/demo/finetune_classifier.py +++ b/demo/finetune_classifier.py @@ -177,15 +177,15 @@ def map_fn(seg_a, seg_b, label): lr_scheduler, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) else: lr_scheduler = None - opt = P.optimizer.Adam( + opt = P.optimizer.AdamW( args.lr, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) @@ -209,7 +209,8 @@ def map_fn(seg_a, seg_b, label): lr_scheduler and lr_scheduler.step() if step % 10 == 0: - _lr = lr_scheduler.get_lr() + _lr = lr_scheduler.get_lr( + ) if args.use_lr_decay else args.lr if args.use_amp: _l = (loss / scaler._scale).numpy() msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % ( diff --git a/demo/finetune_classifier_distributed.py b/demo/finetune_classifier_distributed.py index d1df8675..d4b1195d 100644 --- a/demo/finetune_classifier_distributed.py +++ b/demo/finetune_classifier_distributed.py @@ -144,7 +144,7 @@ def map_fn(seg_a, seg_b, label): opt = P.optimizer.AdamW( learning_rate=lr_scheduler, parameters=model.parameters(), - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), weight_decay=args.wd, grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) diff --git a/demo/finetune_ner.py b/demo/finetune_ner.py index 7489f16d..6929afdc 100644 --- a/demo/finetune_ner.py +++ b/demo/finetune_ner.py @@ -210,7 +210,7 @@ def evaluate(model, dataset): lr_scheduler, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) diff --git a/demo/finetune_sentiment_analysis.py b/demo/finetune_sentiment_analysis.py index 015d29d4..16087fa0 100644 --- a/demo/finetune_sentiment_analysis.py +++ b/demo/finetune_sentiment_analysis.py @@ -126,7 +126,7 @@ def map_fn(seg_a, label): lr_scheduler, parameters=model.parameters(), weight_decay=args.wd, - apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), grad_clip=g_clip) scaler = P.amp.GradScaler(enable=args.use_amp) with LogWriter(logdir=str(create_if_not_exists(args.save_dir /