diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index d7eec818cbba4..c27b1cf6ac7b9 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -3,6 +3,7 @@ import torch +from vllm.platforms import current_platform from vllm.utils import seed_everything @@ -28,4 +29,25 @@ def set_weight_attrs( for key, value in weight_attrs.items(): assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") + + # NOTE(woosuk): During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. + if current_platform.is_tpu() and key == "weight_loader": + value = _make_synced_weight_loader(value) setattr(weight, key, value) + + +def _make_synced_weight_loader(original_weight_loader): + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + return _synced_weight_loader