Skip to content

Commit

Permalink
[Bugfix] Embedding model pooling_type equals ALL and multi input's bug (
Browse files Browse the repository at this point in the history
vllm-project#10494)

Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
BBuf authored and tlrmchlsmth committed Nov 23, 2024
1 parent 26c186e commit ca3772c
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,10 @@ def forward(
pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
offset = 0
pooled_data_lst = []
pooled_data = []
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]

pooled_data_lst.append(pooled_data_i)
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len

pooled_data = torch.stack(pooled_data_lst)
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
Expand All @@ -121,7 +117,7 @@ def forward(
step_tag_id = self.step_tag_id

offset = 0
pooled_data_lst = []
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
Expand All @@ -130,17 +126,26 @@ def forward(
pooled_data_i = pooled_data_i[token_ids == step_tag_id]

offset += prompt_len
pooled_data_lst.append(pooled_data_i)

pooled_data = torch.stack(pooled_data_lst)
pooled_data.append(pooled_data_i)
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")

if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.normalize(data, p=2, dim=1)
for data in pooled_data
]
else:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)

if self.softmax:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.softmax(data, dim=-1) for data in pooled_data
]
else:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)

pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
Expand Down

0 comments on commit ca3772c

Please sign in to comment.