diff --git a/vllm/outputs.py b/vllm/outputs.py index 63df7dcf519b5..25b2265285d16 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, Generic, List, Optional +from typing import Dict, Generic, List, MutableSequence, Optional from typing import Sequence as GenericSequence from typing import Union @@ -162,6 +162,26 @@ def new( finished=finished, ) + def add(self, next_output: "RequestOutput") -> None: + """Merge subsequent RequestOutput into this one""" + + self.prompt = next_output.prompt + self.prompt_token_ids = next_output.prompt_token_ids + self.prompt_logprobs = next_output.prompt_logprobs + self.finished |= next_output.finished + + #TODO assuming n == 1 for now + completion = self.outputs[0] + next_completion = next_output.outputs[0] + completion.text += next_completion.text + if not isinstance(completion.token_ids, MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend(next_completion.logprobs) + completion.cumulative_logprob = next_completion.cumulative_logprob + @classmethod def from_seq_group( cls, seq_group: SequenceGroup, use_cache: bool, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a74699f7513e6..7a80114f49c29 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -12,7 +12,7 @@ from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext @@ -205,17 +205,22 @@ async def generate( # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. - while True: + finished = False + while not finished: # Note: drain queue without await if possible (avoids # task switching under load which helps performance). - out = q.get_nowait() if q.qsize() > 0 else await q.get() + out = q.get_nowait() if not q.empty() else await q.get() + + # Coalesce any additional queued outputs + while not q.empty(): + if sampling_params.output_kind == RequestOutputKind.DELTA: + out.add(q.get_nowait()) + else: + out = q.get_nowait() # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. - if out.finished: - yield out - break - + finished = out.finished yield out # If the request is disconnected by the client, the