Skip to content

Commit

Permalink
[TPU] Call torch._sync(param) during weight loading (#9437)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Oct 17, 2024
1 parent 5e443b5 commit 8e1cddc
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions vllm/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.platforms import current_platform
from vllm.utils import seed_everything


Expand All @@ -28,4 +29,25 @@ def set_weight_attrs(
for key, value in weight_attrs.items():
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")

# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)


def _make_synced_weight_loader(original_weight_loader):

def _synced_weight_loader(param, *args, **kwargs):
original_weight_loader(param, *args, **kwargs)
torch._sync(param)

return _synced_weight_loader

0 comments on commit 8e1cddc

Please sign in to comment.