Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer offloading through weight-only offload #867

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions axlearn/common/factorized_rms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from axlearn.common import factorized_rms
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.optimizer_base import (
NestedOptStateSpec,
Nested,
OptParam,
OptStateSpec,
PartitionedGradientTransformation,
)
from axlearn.common.optimizers import OptStateSpec, with_partition_fn
from axlearn.common.optimizers import with_partition_fn
from axlearn.common.test_utils import TestCase
from axlearn.common.utils import PartitionSpec, flatten_items

Expand Down Expand Up @@ -59,7 +60,7 @@ def testParity(self, factored, dtype):

# The 'exp' optimizer is partitioned according to the mesh_axes of parameters and
# factorization spec.
exp_partition: NestedOptStateSpec = exp.partition(param_specs)
exp_partition: Nested[OptStateSpec] = exp.partition(param_specs)
# Used for `count`.
count_spec = OptStateSpec(
dtype=jnp.int32,
Expand Down
8 changes: 3 additions & 5 deletions axlearn/common/optimizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
- weight_decay_scale: control the weight decay rate.
"""
import dataclasses
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Union

import optax
import typing_extensions

from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec
from axlearn.common.utils import Tensor, TensorSpec
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.utils import Nested, Tensor, TensorSpec


@dataclasses.dataclass
Expand Down Expand Up @@ -66,8 +65,7 @@ def __call__(

# Specification of an optimizer state array.
OptStateSpec = TensorSpec
NestedOptStateSpec = Union[OptStateSpec, dict, Sequence]
TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec]
TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]]
ruomingp marked this conversation as resolved.
Show resolved Hide resolved


class PartitionedGradientTransformation(NamedTuple):
Expand Down
134 changes: 123 additions & 11 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
import typing_extensions
from absl import logging
from jax import numpy as jnp
from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import schedule, struct
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec
from axlearn.common.base_layer import ParameterSpec, PartitionSpec
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.factorized_rms import scale_by_factored_rms
from axlearn.common.module import current_context
Expand All @@ -51,8 +52,8 @@
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
MemoryKind,
Nested,
NestedPartitionSpec,
NestedTensor,
NestedTree,
Tensor,
Expand Down Expand Up @@ -139,19 +140,40 @@ def update_fn(
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)


def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def copy_partition(
param_specs: Nested[ParameterSpec],
*,
pattern: Union[None, str, re.Pattern] = None,
memory_kind: Optional[MemoryKind] = None,
) -> Nested[OptStateSpec]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of coupling creation of OptStateSpec and setting of memory_kind, how about having a separate function for setting memory kind?

def set_memory_kind(opt_state_spec: Nested[OptStateSpec], *, pattern, memory_kind):

This allows set_memory_kind to be called multiple times, maybe for different memory kind. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see how set_memory_kind will be different from copy_partition. Signature and implementation will be the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine in the future we have many types of memory kinds, e.g., "remote_host". Then we can do:

opt_state_specs = copy_partition(...)
opt_state_specs = set_memory_kind(..., "pinned_host")
opt_state_specs = set_memory_kind(..., "remote_host")

Copy link
Member Author

@hanzhi713 hanzhi713 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be the same as

opt_state_specs = copy_partition(...)
opt_state_specs = copy_partition(..., "pinned_host")
opt_state_specs = copy_partition(..., "remote_host")

Do you mean that using a separate function is slightly better for readability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that copy_partition also performs the type conversion from Nested[ParameterSpec] to Nested[OptStateSpec].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I can change the type of param_specs in copy_partition to Nested[OptStateSpec] since ParameterSpec is a subclass of OptStateSpec and copy_partition doesn't use any new fields from ParameterSpec. Does this sound good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. SG.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""Creates OptStateSpec from ParameterSpec with possibly a different memory kind.

Args:
param_specs: Nested[ParameterSpec] to copy from.
pattern: Regex to match the full path of each spec. Matched specs will have their memory
kind replaced with `memory_kind`.
memory_kind: New memory kind. Default to None.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
Returns:

A Nested[OptStateSpec] with possibly a different memory kind.
"""
return jax.tree.map(
lambda param_spec: OptStateSpec(
dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes
lambda path, param_spec: OptStateSpec(
dtype=param_spec.dtype,
shape=param_spec.shape,
mesh_axes=param_spec.mesh_axes,
memory_kind=memory_kind
if pattern and re.fullmatch(pattern, path)
else param_spec.memory_kind,
),
tree_paths(param_specs),
param_specs,
)


def trace_partition(
base: optax.GradientTransformation,
) -> PartitionedGradientTransformation:
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.TraceState(trace=copy_partition(param_specs))

return with_partition_fn(base, partition_fn)
Expand All @@ -160,7 +182,7 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation:
state: optax.ScaleByAdamState = base.init({})

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.ScaleByAdamState(
count=OptStateSpec(
dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec()
Expand Down Expand Up @@ -950,7 +972,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
)
return updates, new_state

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec:
# Store momentum in accumulator_dtype if it is set and p is not scalar.
if param_spec.shape and accumulator_dtype is not None:
Expand Down Expand Up @@ -1412,7 +1434,7 @@ def _is_valid_step(
drop_stats=new_drop_stats,
)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
if use_adaptive_drop_norm:
one = jnp.ones([], jnp.float32)
dict_thresholds = drop_norm(count=one, mean=one, stddev=one)
Expand Down Expand Up @@ -1571,7 +1593,7 @@ def update_fn(updates, state, params):
)
return updates, ParamEmaState(count=count_inc, ema=new_ema)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return ParamEmaState(
count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()),
ema=copy_partition(param_specs),
Expand Down Expand Up @@ -1617,7 +1639,7 @@ def update_fn(updates, state, params=None):
updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu)
return updates, ScaleByLionState(count=count_inc, mu=mu)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
mu_specs = param_specs
if mu_dtype is not None:
mu_specs = jax.tree.map(
Expand Down Expand Up @@ -1993,3 +2015,93 @@ def _update2(u: Tensor, param: OptParam):
partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()),
)
return named_chain(**tx)


def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
pattern: Union[str, re.Pattern] = ".*",
offload_src: MemoryKind = "device",
offload_dst: MemoryKind = "pinned_host",
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.

Args:
optimizer: The optimizer to offload.
pattern: Regex pattern used to match the path of optimizer states. Matched states will be
offloaded. Default to regex that matches all states.
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
offload_src: Offload-from memory kind. Default to "device".
offload_dst: Offload-to memory kind. Default to "pinned_host".

Returns:
A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.

Raises:
ValueError: when the `update` function of the returned optimizer is called outside of jit
context.

This function returns a new `PartitionedGradientTransformation` that
1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function
during state initialization in the trainer.
2. Copies the matched states to `offload_src` before `optimizer.update` is called.
3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called.

The regex pattern is matched against the full path of each optimizer state. An example full
path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
pattern should not depend on model structure, you can use ".*mu.*" to offload all `mu`.
ruomingp marked this conversation as resolved.
Show resolved Hide resolved

The .update function of the returned `PartitionedGradientTransformation` must be called within
a jit function.

Example usage:
```python
your_opt = adamw_optimizer(...)
offloaded_opt = offload_optimizer(your_opt)
```

When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
`skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
or you will get errors. Correct example:
```
offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
```
The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
`pinned_host` memory space.
"""
optimizer = maybe_instantiate(optimizer)
if offload_src is None or offload_dst is None:
raise ValueError(
"offload_src and offload_dst cannot be None when using optimizer offloading."
)

logging.info("Optimizer offloading from %s to %s enabled.", offload_src, offload_dst)

def init_fn(params: NestedOptParam):
return optimizer.init(params)

def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
# TransferToMemoryKind let us change the memory kind of tensors without specifying the full
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
# it's specified in the API signature. Reference:
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
return jax.tree.map(
lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst))
if re.fullmatch(pattern, path)
else tensor,
tree_paths(state),
state,
)

def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam):
state = _move_fn(state, offload_src)
updates, state = optimizer.update(updates, state, params)
state = _move_fn(state, offload_dst)
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
return updates, state

def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return copy_partition(
optimizer.partition(param_spec), pattern=pattern, memory_kind=offload_dst
)

return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)
59 changes: 45 additions & 14 deletions axlearn/common/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ema,
l2_regularizer,
lion_optimizer,
offload_optimizer,
opt_param_values,
param_ema,
per_param_scale_by_path,
Expand Down Expand Up @@ -379,12 +380,25 @@ def _check_dtypes(x, y, z):
jax.tree.map(_check_dtypes, init_state, partition_state, update_state)

def _test_optimizer(self, optimizer):
params = OptParam(
value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32),
factorization_spec=None,
weight_decay_scale=1.0,
)
state = optimizer.init(params)
self._test_optimizer_helper(optimizer, True)
self._test_optimizer_helper(optimizer, False)

def _test_optimizer_helper(self, optimizer, offload):
if offload:
optimizer = offload_optimizer(optimizer)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)

def create_opt_params(x):
return jax.tree.map(
lambda y: OptParam(
value=y,
factorization_spec=None,
weight_decay_scale=1.0,
),
x,
)

state = optimizer.init(create_opt_params(params))

param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None)
state_partition_spec = optimizer.partition(param_spec)
Expand All @@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree):

jax.tree.map(check_partition_spec, state_partition_spec, state)

def compute_loss(x):
return -jax.nn.log_softmax(x)[1]
@jax.jit
def jit_fn(params, state):
def compute_loss(x):
return -jax.nn.log_softmax(x)[1]

loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
new_loss = compute_loss(updated_params)
params = create_opt_params(params)
loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
return loss, compute_loss(updated_params)

if offload:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

@parameterized.product(
Expand Down Expand Up @@ -788,14 +812,17 @@ def loss_fn(x):
config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]),
config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]),
),
offload=(True, False),
)
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm):
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload):
clip = skip_and_clip_by_global_norm(
inner=_counter(),
drop_norm=drop_norm,
max_norm=max_norm,
grad_norm_ema_decay=0.99,
)
if offload:
clip = offload_optimizer(clip)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
state = clip.init(params)
init_ema = state.grad_norm_ema
Expand All @@ -821,7 +848,11 @@ def loss_fn(x):
else:
is_valid_step = drop_norm is None or g_norm < drop_norm

updates, state = clip.update(grads, state=state, params=params)
@jax.jit
def jit_fn(grads, state, params):
return clip.update(grads, state=state, params=params)

updates, state = jit_fn(grads, state, params)
if is_valid_step:
if max_norm is None or g_norm < max_norm:
np.testing.assert_allclose(updates, grads, atol=1e-6)
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
HybridMeshShape,
MeshShape,
Nested,
NestedPartitionSpec,
NestedTensor,
PartitionSpec,
Tensor,
TensorSpec,
count_model_params,
flatten_items,
match_regex_rules,
Expand All @@ -62,9 +62,9 @@


class TrainerState(NamedTuple):
prng_key: Union[Tensor, NestedPartitionSpec]
model: Union[NestedTensor, NestedPartitionSpec]
learner: Union[NestedTensor, NestedPartitionSpec]
prng_key: Union[Tensor, TensorSpec, jax.sharding.NamedSharding]
model: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]]
learner: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]]


# pylint: disable-next=too-many-instance-attributes
Expand Down Expand Up @@ -309,8 +309,8 @@ def __init__(
model=self._model_param_specs,
learner=self._learner_state_partition_specs,
)
self._trainer_state_partition_specs = jax.tree.map(
lambda spec: spec.mesh_axes, self._trainer_state_specs
self._trainer_state_partition_specs: TrainerState = jax.tree.map(
lambda spec: spec.sharding, self._trainer_state_specs
)
# Create evalers, which depend on model_param_partition_specs.
self._evalers = {}
Expand Down
Loading
Loading