Skip to content

Commit

Permalink
[Model] Support for fairseq2 Llama (vllm-project#11442)
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gleize <[email protected]>
Co-authored-by: mgleize user <[email protected]>
  • Loading branch information
MartinGleize and mgleize user authored Jan 19, 2025
1 parent 81763c5 commit bbe5f9d
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 21 deletions.
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class _HfExamplesInfo:
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
Expand Down
3 changes: 2 additions & 1 deletion tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main
13 changes: 7 additions & 6 deletions tests/weight_loading/test_weight_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
with vllm_runner(
model_name=MODEL_NAME,
revision=REVISION,
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=None if QUANTIZATION == "None" else QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:

output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
Expand Down
34 changes: 22 additions & 12 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

param_data = param.data
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
Expand Down Expand Up @@ -546,6 +548,11 @@ def weight_loader(self,

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
Expand All @@ -554,9 +561,7 @@ def weight_loader(self,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
Expand Down Expand Up @@ -941,6 +946,11 @@ def weight_loader(self,

use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
Expand All @@ -964,9 +974,7 @@ def weight_loader(self,
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

Expand Down Expand Up @@ -1070,6 +1078,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand All @@ -1085,9 +1097,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
Expand Down
15 changes: 13 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ class Source:
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""

allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns."""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
Expand Down Expand Up @@ -218,6 +221,7 @@ def _prepare_weights(
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool,
allow_patterns_overrides: Optional[list[str]],
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
Expand Down Expand Up @@ -249,6 +253,9 @@ def _prepare_weights(
if fall_back_to_pt:
allow_patterns += ["*.pt"]

if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides

if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
Expand Down Expand Up @@ -298,7 +305,8 @@ def _get_weights_iterator(
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt)
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
Expand Down Expand Up @@ -340,6 +348,8 @@ def _get_all_weights(
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
None),
)
yield from self._get_weights_iterator(primary_weights)

Expand All @@ -353,7 +363,8 @@ def _get_all_weights(
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True)
fall_back_to_pt=True,
allow_patterns_overrides=None)

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
Expand Down
151 changes: 151 additions & 0 deletions vllm/model_executor/models/fairseq2_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2024 The vLLM team.
# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Llama model for fairseq2 weights."""

from typing import Iterable, Set, Tuple

import torch
from torch.nn import Parameter

from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import set_weight_attrs
from vllm.model_executor.models.llama import LlamaForCausalLM

from .utils import AutoWeightsLoader, WeightsMapper


class Fairseq2LlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
# For the model loader to read only the relevant checkpoint files
self.allow_patterns_overrides = [
# either the full checkpoint
"model.pt",
# or the tp-sharded checkpoint of the current rank
f"model.{self.tp_rank}.pt",
]

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
# { "model_key": my_model_name, "my_model_name": state_dict }
# which we first need to unpack
weights_wrapped = dict(weights)
weights = weights_wrapped[
weights_wrapped["model_key"]].items() # type: ignore

# remap keys
fs2_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder_frontend.embed.": "model.embed_tokens.",
"decoder.": "model.",
"final_proj.": "lm_head.",
},
orig_to_new_substr={
".self_attn_layer_norm.": ".input_layernorm.",
".ffn_layer_norm.": ".post_attention_layernorm.",
".self_attn.output_proj.": ".self_attn.o_proj.",
".ffn.gate_proj.": ".mlp.gate_proj.",
".ffn.inner_proj.": ".mlp.up_proj.",
".ffn.output_proj.": ".mlp.down_proj.",
".layer_norm.": ".norm.",
},
)
weights = fs2_to_vllm_mapper.apply(weights)

params = dict(self.named_parameters())

loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
(self.reshape_fairseq2_weights(name, loaded_weight, params)
for name, loaded_weight in weights))

def flag_sharded_weights(self, params: dict[str, Parameter]):
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
for name, param in params.items():
modules = name.split(".")
if "norm" in name and len(param.size()) < 2:
# layer norms are not sharded
continue
elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
# for now we repeat embedding layers for compatibility
continue
else:
# all other layers are sharded
set_weight_attrs(param, {"is_sharded_weight": True})

def reshape_fairseq2_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params: dict[str, Parameter],
) -> Tuple[str, torch.Tensor]:
"""Reshape fairseq2's weights."""

def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
attn_in = self.config.head_dim * n_heads
# check for a sharded weight on dim 0
if attn_in // self.tp_size == w.size()[0]:
attn_in //= self.tp_size
n_heads //= self.tp_size
attn_out = self.config.hidden_size
return (w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1,
2).reshape(attn_in, attn_out))

modules = name.split(".")

# rotary embeds should be sliced
if "k_proj" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)

elif "q_proj" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)

# We make the loaded weights compatible with both
# full checkpoints and tp sharded checkpoints.
# Embeddings are repeated to fit the vocab size.
# Other weights are flagged for the weight_loader calls.
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
# Embeddings are sharded on dim 0
dim = 0
# In fairseq2, vocab size has to be divisible by tp_size
# so we don't worry about padding
if self.tp_size > 1 and loaded_weight.shape[
dim] < self.config.vocab_size:
assert loaded_weight.shape[
dim] * self.tp_size == self.config.vocab_size, \
"vocab_size should be divisible by tp_size."
repeats = [1] * len(loaded_weight.size())
repeats[dim] = self.tp_size
# repeat to match vocab size and to be easily 'narrow'able
loaded_weight = loaded_weight.repeat(repeats)
set_weight_attrs(params[name], {"is_sharded_weight": False})
# if embeddings are sharded, the rest is too
if "embed_tokens" in modules:
self.flag_sharded_weights(params)

return name, loaded_weight
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
Expand Down

0 comments on commit bbe5f9d

Please sign in to comment.