Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Add all_token_ids attribute to Request #10135

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def update_from_output(
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
request.append_output_token_ids(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
)
for req, num_tokens in sampled:
inputs.req_ids.append(req.request_id)
if len(req.output_token_ids) == num_tokens:
if req.num_output_tokens == num_tokens:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len is supported by ConstantList, but I fixed this for clarity.

# The request is first detokenized.
inputs.prompt_token_ids.append(req.prompt_token_ids)
else:
Expand Down
29 changes: 26 additions & 3 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.utils import ConstantList

if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
Expand Down Expand Up @@ -40,17 +41,39 @@ def __init__(
self.prompt = inputs.get("prompt")
self.prompt_token_ids = inputs["prompt_token_ids"]
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.output_text = ""
self.num_computed_tokens = 0

@property
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)

@property
def all_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the all_token_ids since
# output_token_ids should also be updated simultaneously
return ConstantList(self._all_token_ids)

def append_output_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)

@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
return len(self._all_token_ids)

@property
def num_output_tokens(self) -> int:
return len(self.output_token_ids)
return len(self._output_token_ids)

def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
Expand Down
64 changes: 64 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Generic, List, TypeVar, overload

T = TypeVar("T")


class ConstantList(Generic[T]):

def __init__(self, x: List[T]) -> None:
self._x = x

def append(self, item):
raise Exception("Cannot append to a constant list")

def extend(self, item):
raise Exception("Cannot extend a constant list")

def insert(self, item):
raise Exception("Cannot insert into a constant list")

def pop(self, item):
raise Exception("Cannot pop from a constant list")

def remove(self, item):
raise Exception("Cannot remove from a constant list")

def clear(self):
raise Exception("Cannot clear a constant list")

def index(self, item):
return self._x.index(item)

@overload
def __getitem__(self, item) -> T:
...

@overload
def __getitem__(self, s: slice, /) -> List[T]:
...

def __getitem__(self, item):
return self._x[item]

@overload
def __setitem__(self, item, value):
...

@overload
def __setitem__(self, s: slice, value, /):
...

def __setitem__(self, item, value):
raise Exception("Cannot set item in a constant list")

def __delitem__(self, item):
raise Exception("Cannot delete item from a constant list")

def __iter__(self):
return iter(self._x)

def __contains__(self, item):
return item in self._x

def __len__(self):
return len(self._x)