Skip to content

Commit

Permalink
[V1] Add all_token_ids attribute to Request (vllm-project#10135)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored and rickyyx committed Nov 13, 2024
1 parent 4d39bd5 commit dffbff5
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 5 deletions.
2 changes: 1 addition & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def update_from_output(
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
request.append_output_token_ids(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
)
for req, num_tokens in sampled:
inputs.req_ids.append(req.request_id)
if len(req.output_token_ids) == num_tokens:
if req.num_output_tokens == num_tokens:
# The request is first detokenized.
inputs.prompt_token_ids.append(req.prompt_token_ids)
else:
Expand Down
29 changes: 26 additions & 3 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.utils import ConstantList

if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
Expand Down Expand Up @@ -40,17 +41,39 @@ def __init__(
self.prompt = inputs.get("prompt")
self.prompt_token_ids = inputs["prompt_token_ids"]
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.output_text = ""
self.num_computed_tokens = 0

@property
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)

@property
def all_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the all_token_ids since
# output_token_ids should also be updated simultaneously
return ConstantList(self._all_token_ids)

def append_output_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)

@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
return len(self._all_token_ids)

@property
def num_output_tokens(self) -> int:
return len(self.output_token_ids)
return len(self._output_token_ids)

def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
Expand Down
64 changes: 64 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Generic, List, TypeVar, overload

T = TypeVar("T")


class ConstantList(Generic[T]):

def __init__(self, x: List[T]) -> None:
self._x = x

def append(self, item):
raise Exception("Cannot append to a constant list")

def extend(self, item):
raise Exception("Cannot extend a constant list")

def insert(self, item):
raise Exception("Cannot insert into a constant list")

def pop(self, item):
raise Exception("Cannot pop from a constant list")

def remove(self, item):
raise Exception("Cannot remove from a constant list")

def clear(self):
raise Exception("Cannot clear a constant list")

def index(self, item):
return self._x.index(item)

@overload
def __getitem__(self, item) -> T:
...

@overload
def __getitem__(self, s: slice, /) -> List[T]:
...

def __getitem__(self, item):
return self._x[item]

@overload
def __setitem__(self, item, value):
...

@overload
def __setitem__(self, s: slice, value, /):
...

def __setitem__(self, item, value):
raise Exception("Cannot set item in a constant list")

def __delitem__(self, item):
raise Exception("Cannot delete item from a constant list")

def __iter__(self):
return iter(self._x)

def __contains__(self, item):
return item in self._x

def __len__(self):
return len(self._x)

0 comments on commit dffbff5

Please sign in to comment.