Skip to content

Commit

Permalink
[Bugfix] Fix num_heads value for simple connector when tp enabled (vl…
Browse files Browse the repository at this point in the history
…lm-project#12074)

Signed-off-by: Shangming Cai <[email protected]>
  • Loading branch information
ShangmingCai authored Jan 20, 2025
1 parent bbe5f9d commit df450aa
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
):

self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size

if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
Expand Down Expand Up @@ -161,7 +162,7 @@ def send_kv_caches_and_hidden_states(
end_layer = model_executable.model.end_layer

model_config = model_executable.model.config
num_heads = model_config.num_key_value_heads
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
Expand Down

0 comments on commit df450aa

Please sign in to comment.