Skip to content

Commit

Permalink
Add batch/microbatch transforms (#3703)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Nov 11, 2024
1 parent 93678e8 commit 18da725
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 15 deletions.
46 changes: 37 additions & 9 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler

from composer.utils import dist, ensure_tuple
from composer.utils import VersionedDeprecationWarning, dist, ensure_tuple

if TYPE_CHECKING:
from composer.core.types import Batch
Expand Down Expand Up @@ -126,16 +126,16 @@ def _default_split_batch(batch: Any, microbatch_size: Union[int, float]) -> Sequ
class DataSpec:
"""Specifications for operating and training on data.
An example of constructing a :class:`DataSpec` object with a ``device_transforms``
An example of constructing a :class:`DataSpec` object with a ``batch_transforms``
callable and then using it with :class:`~.Trainer`:
.. doctest::
>>> # Construct DataSpec and subtract mean from the batch
>>> device_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
>>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn)
>>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
>>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn)
>>> # The same function can be used for eval dataloader as well
>>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn)
>>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn)
>>> # Use this DataSpec object to construct trainer
>>> trainer = Trainer(
... model=model,
Expand All @@ -155,11 +155,20 @@ class DataSpec:
num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the
:class:`.Timestamp` (training progress tracker).
device_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
batch once it has been moved onto the device. For example, this function can be used for GPU-based
device_transforms ((Batch) -> Batch, optional): Deprecated argument. Please use ``batch_transforms`` for batch
level transformations on CPU and ``microbatch_transforms`` for microbatch level transformations on target
device.
batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
batch before it is moved onto the device. For example, this function can be used for CPU-based
normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
the batch is not modified.
microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
microbatch before it is moved onto the device. For example, this function can be used for GPU-based
normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not
specified, the microbatch is not modified.
split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
split a batch (the first parameter) into microbatches of a given size (the second parameter). If
the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then
Expand All @@ -186,13 +195,32 @@ def __init__(
num_samples: Optional[int] = None,
num_tokens: Optional[int] = None,
device_transforms: Optional[Callable[[Batch], Batch]] = None,
batch_transforms: Optional[Callable[[Batch], Batch]] = None,
microbatch_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None,
) -> None:
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
self.num_tokens = num_tokens
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
if device_transforms is not None:
if batch_transforms is not None:
raise ValueError(
'Cannot specify both `device_transforms` and `batch_transforms`. Please use `batch_transforms` for '
'batch level transformations on CPU and `microbatch_transforms` for microbatch level transformations '
'on target device.',
)
warnings.warn(
VersionedDeprecationWarning(
'The `device_transforms` argument is deprecated. Please use `batch_transforms` for batch level '
'transformations on CPU and `microbatch_transforms` for microbatch level transformations on target '
'device.',
'v0.29.0',
),
)
self.batch_transforms = device_transforms
self.batch_transforms = self._default_transforms if batch_transforms is None else batch_transforms
self.microbatch_transforms = self._default_transforms if microbatch_transforms is None else microbatch_transforms
self.split_batch = default_split_batch if split_batch is None else split_batch
self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
Expand Down Expand Up @@ -242,7 +270,7 @@ def __init__(
'For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.',
)

def _default_device_transforms(self, batch: Batch):
def _default_transforms(self, batch: Batch):
return batch

def _default_get_num_samples_in_batch(self, batch: Batch) -> int:
Expand Down
12 changes: 7 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,7 +2622,7 @@ def _train_loop(self) -> None:
self._rng_state = None
continue

self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
self.state.batch = self._train_data_spec.batch_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

Expand Down Expand Up @@ -3034,6 +3034,7 @@ def _train_microbatches(

for microbatch_idx, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.microbatch_transforms(self.state.batch)
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)

Expand Down Expand Up @@ -3306,11 +3307,11 @@ def predict_batch_end(self, state: State, logger: Logger) -> None:
self.engine.run_event(Event.PREDICT_START)

for self.state.batch in self._iter_dataloader(TrainerMode.PREDICT):

# Move the batch onto the device
self.state.batch = data_spec.batch_transforms(self.state.batch)
self.state.batch = self.state.device.batch_to_device(self.state.batch)

# Perform any device transforms
self.state.batch = data_spec.device_transforms(self.state.batch)
self.state.batch = data_spec.microbatch_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
Expand Down Expand Up @@ -3586,7 +3587,7 @@ def _eval_loop(
)

for self.state.batch in self._iter_dataloader(TrainerMode.EVAL):
self.state.batch = data_spec.device_transforms(self.state.batch)
self.state.batch = data_spec.batch_transforms(self.state.batch)

# Count the batch size and num tokens before any events run
rank_num_samples = data_spec.get_num_samples_in_batch(self.state.batch)
Expand Down Expand Up @@ -3616,6 +3617,7 @@ def _eval_loop(
microbatches = data_spec.split_batch(device_batch, evaluator.device_eval_microbatch_size)
for i, self.state.batch in enumerate(microbatches):
self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = data_spec.microbatch_transforms(self.state.batch)
last_microbatch = i == len(microbatches) - 1
skip_metric_update = False
# Distributed samplers pad batches to be the same size. If using a
Expand Down
25 changes: 24 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from composer import Callback, Evaluator, Trainer
from composer.algorithms import CutOut, LabelSmoothing
from composer.core import Event, Precision, State, Time, TimeUnit
from composer.core import DataSpec, Event, Precision, State, Time, TimeUnit
from composer.devices import Device
from composer.loggers import InMemoryLogger, Logger, RemoteUploaderDownloader
from composer.loss import soft_cross_entropy
Expand Down Expand Up @@ -1733,3 +1733,26 @@ def test_empty_eval_dataloader(self):
max_duration='1ba',
)
trainer.fit()


@device('cpu', 'gpu')
def test_transforms(device: str):

def get_transform(device: str):

def transform(batch: list[torch.Tensor]):
batch_device = 'gpu' if batch[0].device.type == 'cuda' else 'cpu'
assert batch_device == device
return batch

return transform

dataloader = _get_classification_dataloader()
data_spec = DataSpec(
dataloader,
batch_transforms=get_transform('cpu'),
microbatch_transforms=get_transform(device),
)
model = SimpleModel()
trainer = Trainer(model=model, train_dataloader=data_spec, max_duration='1ba')
trainer.fit()

0 comments on commit 18da725

Please sign in to comment.