Skip to content

Commit

Permalink
Intern2 habana (#489)
Browse files Browse the repository at this point in the history
Making sure the model runs on habana devices. Original code did not run
due to error in the split_qkv code as param unpacking was assuming lack
of batch dimension. Tested inference with the changes and InternLM2
works on Gaudi2 as expected.

---------

Co-authored-by: Stan Kirdey <[email protected]>
  • Loading branch information
skirdey-inflection and Stan Kirdey authored Nov 26, 2024
1 parent b099337 commit b7d75b8
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,33 @@ def __init__(
)

def split_qkv(self, qkv: torch.Tensor):
seq_len = qkv.shape[0]
# Unpack all dimensions except the last one
*batch_dims, _ = qkv.shape

if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
qkv = tensor_model_parallel_all_gather(qkv)
qkv = torch.split(qkv, qkv_map, dim=-1)
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
qkv = torch.cat(qkv, dim=-1)

qkv = qkv.view(seq_len, self.total_num_kv_heads,
qkv = qkv.contiguous()

# Dynamically reshape based on the number of batch dimensions
qkv = qkv.view(*batch_dims, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(seq_len, self.q_size * self.tp_size)
k = k.reshape(seq_len, self.kv_size * self.tp_size)
v = v.reshape(seq_len, self.kv_size * self.tp_size)
q = q.view(*batch_dims, self.q_size * self.tp_size)
k = k.view(*batch_dims, self.kv_size * self.tp_size)
v = v.view(*batch_dims, self.kv_size * self.tp_size)

if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]

return q, k, v

def forward(
Expand Down

0 comments on commit b7d75b8

Please sign in to comment.