Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy type checking for vllm/worker #11418

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,8 @@ def execute_worker(
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
Expand Down
13 changes: 7 additions & 6 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,9 @@ def _async_process_outputs(self, model_input: StatefulModelInput,
if not cont:
break

def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
def _final_process_outputs(
self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
assert model_input.frozen_model_input is not None

has_async_callback = output_proc_callback is not None
Expand Down Expand Up @@ -594,8 +595,8 @@ def execute_model(
# should be [SamplerOutput]
return output

def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
num_seqs: Optional[int], num_queries: int):

assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
Expand Down Expand Up @@ -850,13 +851,13 @@ def _pythonize_sampler_output(
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput]

if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
seq_outputs = completion_seq_group_output.samples
else:
seq_outputs = []

Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def init_worker(self, *args, **kwargs):
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None

def execute_method(self, method, *args, **kwargs):
def execute_method(self, method: str, *args, **kwargs):
try:
target = self if self.worker is None else self.worker
executor = getattr(target, method)
Expand Down
Loading