From c81938820f370f8b5f6e944438d9255d0d70fb20 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Wed, 1 Nov 2023 15:25:59 +0800 Subject: [PATCH] [Bug Fixed] The scaling weight is not updated in the optimizer `LBAdamW` (#112) **Description** When the parameter is stored as a ScalingTensor, it can not be updated in the optimizer `LBAdamW` when `exp_avg_dtype is torch.float32` and `exp_avg_sq_dtype is torch.float32`. --------- Co-authored-by: Yuxiang Yang --- msamp/optim/adamw.py | 13 +++++++++++-- tests/optim/test_adamw.py | 34 ++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index be82fced..711919c8 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -173,8 +173,7 @@ def adamw_fn( # noqa: C901 for i, param in enumerate(params): param, grad = param.float(), grads[i].float() if not maximize else -grads[i].float() - # Perform stepweight decay - # FP32/16 Tensor * float + # Perform step weight decay if weight_decay != 0: if self.use_adam: grad = grad.add(param, alpha=weight_decay) @@ -218,6 +217,13 @@ def adamw_fn( # noqa: C901 param, grad = param.float(), grads[i].float() if not maximize else -grads[i].float() exp_avg_value, exp_avg_sq_value = exp_avgs[i]['state'], exp_avg_sqs[i]['state'] + # Perform step weight decay + if weight_decay != 0: + if self.use_adam: + grad = grad.add(param, alpha=weight_decay) + else: + param.mul_(1 - lr * weight_decay) + if self.bias_correction: bias_correction1 = 1 - beta1**state_steps[i] bias_correction2 = 1 - beta2**state_steps[i] @@ -238,3 +244,6 @@ def adamw_fn( # noqa: C901 # param = param - step_size * (exp_avg / denom) # param.addcdiv_(exp_avg, denom, value=-step_size) param.add_(exp_avg_value / denom, alpha=-step_size) + + if isinstance(params[i], ScalingTensor): + params[i].copy_(param.cast(params[i].qtype, meta=params[i].meta)) diff --git a/tests/optim/test_adamw.py b/tests/optim/test_adamw.py index e7e8a2da..f7fbe0dc 100644 --- a/tests/optim/test_adamw.py +++ b/tests/optim/test_adamw.py @@ -8,6 +8,8 @@ import unittest import torch +from functools import partial + from msamp.common.dtype import Dtypes from msamp.common.tensor import TensorDist from msamp.optim import LBAdamW, LBAdam, LBAdamWBase, DSAdam @@ -28,10 +30,15 @@ def tearDown(self): @decorator.cuda_test def test_adamw_step(self): """Test adamw optimizer step function.""" - self.check_optimizer_step(LBAdamWBase) - self.check_optimizer_step(LBAdamW) - self.check_optimizer_step(LBAdam) - self.check_optimizer_step(DSAdam) + dtypes = [torch.uint8, torch.float16] + pairs = list(itertools.product(dtypes, dtypes)) + [[torch.float32, torch.float32]] + for exp_avg_dtype, exp_avg_sq_dtype in pairs: + with self.subTest(exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype): + kwargs = dict(exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype) + self.check_optimizer_step(torch.optim.AdamW, partial(LBAdamWBase, **kwargs)) + self.check_optimizer_step(torch.optim.AdamW, partial(LBAdamW, **kwargs)) + self.check_optimizer_step(torch.optim.AdamW, partial(DSAdam, **kwargs)) + self.check_optimizer_step(torch.optim.Adam, partial(LBAdam, **kwargs)) @decorator.cuda_test def test_state_dict(self): @@ -39,21 +46,24 @@ def test_state_dict(self): self.check_optimizer_state_dict(LBAdamW) self.check_optimizer_state_dict(LBAdam) - def check_optimizer_step(self, optimizer_class, diff=3e-4): - """Check the difference between torch.optim.AdamW and optimizer_class optimizers. + def check_optimizer_step(self, optimizer_class1, optimizer_class2, diff=3e-4): + """Check the difference between optimizer_class1 and optimizer_class2 optimizers. Args: - optimizer_class (class): LBAdamWBase, LBAdamW or LBAdam. - diff (float, optional): The difference between torch.optim.AdamW and optimizer_class optimizers. + optimizer_class1 (class): Optimizer Class + optimizer_class2 (class): Optimizer Class + diff (float, optional): The difference between optimizer_class1 and optimizer_class2 optimizers. """ input = torch.randn(4, 4, device='cuda') linear = torch.nn.Linear(4, 4).cuda() + wd = 1e-3 + steps = 4 # test torch.optim.AdamW model1 = copy.deepcopy(linear) - opt1 = torch.optim.AdamW(model1.parameters()) + opt1 = optimizer_class1(model1.parameters(), weight_decay=wd) - for _ in range(4): + for _ in range(steps): output = model1(input) output.sum().backward() opt1.step() @@ -63,9 +73,9 @@ def check_optimizer_step(self, optimizer_class, diff=3e-4): model2 = copy.deepcopy(linear) model2 = LinearReplacer.replace(model2, Dtypes.kfloat16) - opt2 = optimizer_class(model2.parameters()) + opt2 = optimizer_class2(model2.parameters(), weight_decay=wd) - for _ in range(4): + for _ in range(steps): output = model2(input) output.sum().backward() opt2.all_reduce_grads(model2)