From ca3772c4b0ff4ce86320621e725576ad38b98cfe Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 21 Nov 2024 22:40:02 +0800 Subject: [PATCH] [Bugfix] Embedding model pooling_type equals ALL and multi input's bug (#10494) Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/pooler.py | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index bfe2d7d0f382e..df1978241340b 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -94,14 +94,10 @@ def forward( pooled_data = hidden_states[last_token_flat_indices] elif self.pooling_type == PoolingType.ALL: offset = 0 - pooled_data_lst = [] + pooled_data = [] for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - - pooled_data_lst.append(pooled_data_i) + pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len - - pooled_data = torch.stack(pooled_data_lst) elif self.pooling_type == PoolingType.MEAN: # Calculate mean pooling cumsum = torch.cumsum(hidden_states, dim=0) @@ -121,7 +117,7 @@ def forward( step_tag_id = self.step_tag_id offset = 0 - pooled_data_lst = [] + pooled_data = [] for prompt_len, seq_data_i in zip( prompt_lens, pooling_metadata.seq_data.values()): pooled_data_i = hidden_states[offset:offset + prompt_len] @@ -130,17 +126,26 @@ def forward( pooled_data_i = pooled_data_i[token_ids == step_tag_id] offset += prompt_len - pooled_data_lst.append(pooled_data_i) - - pooled_data = torch.stack(pooled_data_lst) + pooled_data.append(pooled_data_i) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") if self.normalize: - pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if isinstance(pooled_data, list): + pooled_data = [ + nn.functional.normalize(data, p=2, dim=1) + for data in pooled_data + ] + else: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) if self.softmax: - pooled_data = nn.functional.softmax(pooled_data, dim=-1) + if isinstance(pooled_data, list): + pooled_data = [ + nn.functional.softmax(data, dim=-1) for data in pooled_data + ] + else: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data