From 1191485f675ebdb03c71a6720eebed5bd2f271c4 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 20 Jan 2025 17:05:55 +0000 Subject: [PATCH] fix vocab_parallel_embedding sharding Signed-off-by: NickLucche --- .../decoder_only/language/test_models.py | 42 +++++++++++++++++++ vllm/config.py | 2 +- .../layers/vocab_parallel_embedding.py | 4 +- vllm/transformers_utils/config.py | 4 +- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index c7efa4edbbc0a..1ab9a82275d2e 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -87,3 +87,45 @@ def print_model(model): name_0="hf", name_1="vllm", ) + +@pytest.mark.parametrize( + "model", + [ + pytest.param("cognitivecomputations/TinyDolphin-2.8-1.1b"), # testing VocabParallelEmbedding crash + ]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp", [2]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_tp_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp: int, + max_tokens: int, + num_logprobs: int, +) -> None: + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=tp) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + def print_model(model): + print(model) + + vllm_model.apply_model(print_model) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/config.py b/vllm/config.py index 4698a05020332..ac1cee66d73ed 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -277,7 +277,7 @@ def __init__(self, self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init - + # breakpoint() hf_config = get_config(self.model, trust_remote_code, revision, code_revision, config_format) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 65920aa61ba15..3eb5c39ccf580 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -355,7 +355,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): elif isinstance(param, UninitializedParameter): shape = list(loaded_weight.shape) if output_dim is not None: - shape[output_dim] = shape[output_dim] // self.tp_size + shape[output_dim] = self.num_embeddings_per_partition param.materialize(tuple(shape), dtype=loaded_weight.dtype) # If parameter does not have output dim, then it should @@ -381,7 +381,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): else: assert loaded_weight.shape[output_dim] == self.org_vocab_size - # Copy the data. + # Copy the data. Select chunk corresponding to current shard. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) if current_platform.is_hpu(): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f57dfded0a62f..88ba098d73927 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -206,7 +206,7 @@ def get_config( token=HF_TOKEN, **kwargs, ) - + # config_dict["model_type"] = "granite" # Use custom model class if it's in our registry model_type = config_dict.get("model_type") if model_type in _CONFIG_REGISTRY: @@ -228,6 +228,7 @@ def get_config( token=HF_TOKEN, **kwargs, ) + # config.model_type = 'granite' except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" @@ -252,6 +253,7 @@ def get_config( if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: raise RuntimeError( f"Can't get gguf config for {config.model_type}.") + # model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES['granite'] model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]})