Skip to content

Commit

Permalink
[Bug Fixed] The scaling weight is not updated in the optimizer `LBAda…
Browse files Browse the repository at this point in the history
…mW` (#112)

**Description**
When the parameter is stored as a ScalingTensor, it can not be updated
in the optimizer `LBAdamW` when `exp_avg_dtype is torch.float32` and
`exp_avg_sq_dtype is torch.float32`.

---------

Co-authored-by: Yuxiang Yang <[email protected]>
  • Loading branch information
wkcn and tocean authored Nov 1, 2023
1 parent 914c0e0 commit c819388
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
13 changes: 11 additions & 2 deletions msamp/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def adamw_fn( # noqa: C901
for i, param in enumerate(params):
param, grad = param.float(), grads[i].float() if not maximize else -grads[i].float()

# Perform stepweight decay
# FP32/16 Tensor * float
# Perform step weight decay
if weight_decay != 0:
if self.use_adam:
grad = grad.add(param, alpha=weight_decay)
Expand Down Expand Up @@ -218,6 +217,13 @@ def adamw_fn( # noqa: C901
param, grad = param.float(), grads[i].float() if not maximize else -grads[i].float()
exp_avg_value, exp_avg_sq_value = exp_avgs[i]['state'], exp_avg_sqs[i]['state']

# Perform step weight decay
if weight_decay != 0:
if self.use_adam:
grad = grad.add(param, alpha=weight_decay)
else:
param.mul_(1 - lr * weight_decay)

if self.bias_correction:
bias_correction1 = 1 - beta1**state_steps[i]
bias_correction2 = 1 - beta2**state_steps[i]
Expand All @@ -238,3 +244,6 @@ def adamw_fn( # noqa: C901
# param = param - step_size * (exp_avg / denom)
# param.addcdiv_(exp_avg, denom, value=-step_size)
param.add_(exp_avg_value / denom, alpha=-step_size)

if isinstance(params[i], ScalingTensor):
params[i].copy_(param.cast(params[i].qtype, meta=params[i].meta))
34 changes: 22 additions & 12 deletions tests/optim/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import unittest
import torch

from functools import partial

from msamp.common.dtype import Dtypes
from msamp.common.tensor import TensorDist
from msamp.optim import LBAdamW, LBAdam, LBAdamWBase, DSAdam
Expand All @@ -28,32 +30,40 @@ def tearDown(self):
@decorator.cuda_test
def test_adamw_step(self):
"""Test adamw optimizer step function."""
self.check_optimizer_step(LBAdamWBase)
self.check_optimizer_step(LBAdamW)
self.check_optimizer_step(LBAdam)
self.check_optimizer_step(DSAdam)
dtypes = [torch.uint8, torch.float16]
pairs = list(itertools.product(dtypes, dtypes)) + [[torch.float32, torch.float32]]
for exp_avg_dtype, exp_avg_sq_dtype in pairs:
with self.subTest(exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype):
kwargs = dict(exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype)
self.check_optimizer_step(torch.optim.AdamW, partial(LBAdamWBase, **kwargs))
self.check_optimizer_step(torch.optim.AdamW, partial(LBAdamW, **kwargs))
self.check_optimizer_step(torch.optim.AdamW, partial(DSAdam, **kwargs))
self.check_optimizer_step(torch.optim.Adam, partial(LBAdam, **kwargs))

@decorator.cuda_test
def test_state_dict(self):
"""Test state dict of LBAdamW and LBAdam."""
self.check_optimizer_state_dict(LBAdamW)
self.check_optimizer_state_dict(LBAdam)

def check_optimizer_step(self, optimizer_class, diff=3e-4):
"""Check the difference between torch.optim.AdamW and optimizer_class optimizers.
def check_optimizer_step(self, optimizer_class1, optimizer_class2, diff=3e-4):
"""Check the difference between optimizer_class1 and optimizer_class2 optimizers.
Args:
optimizer_class (class): LBAdamWBase, LBAdamW or LBAdam.
diff (float, optional): The difference between torch.optim.AdamW and optimizer_class optimizers.
optimizer_class1 (class): Optimizer Class
optimizer_class2 (class): Optimizer Class
diff (float, optional): The difference between optimizer_class1 and optimizer_class2 optimizers.
"""
input = torch.randn(4, 4, device='cuda')
linear = torch.nn.Linear(4, 4).cuda()
wd = 1e-3
steps = 4

# test torch.optim.AdamW
model1 = copy.deepcopy(linear)
opt1 = torch.optim.AdamW(model1.parameters())
opt1 = optimizer_class1(model1.parameters(), weight_decay=wd)

for _ in range(4):
for _ in range(steps):
output = model1(input)
output.sum().backward()
opt1.step()
Expand All @@ -63,9 +73,9 @@ def check_optimizer_step(self, optimizer_class, diff=3e-4):
model2 = copy.deepcopy(linear)
model2 = LinearReplacer.replace(model2, Dtypes.kfloat16)

opt2 = optimizer_class(model2.parameters())
opt2 = optimizer_class2(model2.parameters(), weight_decay=wd)

for _ in range(4):
for _ in range(steps):
output = model2(input)
output.sum().backward()
opt2.all_reduce_grads(model2)
Expand Down

0 comments on commit c819388

Please sign in to comment.