Skip to content

Commit

Permalink
mypy type checking for vllm/worker (#11418)
Browse files Browse the repository at this point in the history
Signed-off-by: lucast2021 <[email protected]>
Co-authored-by: lucast2021 <[email protected]>
  • Loading branch information
lucas-tucker and lucast2021 authored Dec 23, 2024
1 parent f30581c commit e51719a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
3 changes: 1 addition & 2 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,8 @@ def execute_worker(
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
Expand Down
13 changes: 7 additions & 6 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,9 @@ def _async_process_outputs(self, model_input: StatefulModelInput,
if not cont:
break

def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]):
def _final_process_outputs(
self, model_input: StatefulModelInput,
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
assert model_input.frozen_model_input is not None

has_async_callback = output_proc_callback is not None
Expand Down Expand Up @@ -594,8 +595,8 @@ def execute_model(
# should be [SamplerOutput]
return output

def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
num_seqs: Optional[int], num_queries: int):

assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
Expand Down Expand Up @@ -850,13 +851,13 @@ def _pythonize_sampler_output(
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput]

if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs: List[
SequenceOutput] = completion_seq_group_output.samples
seq_outputs = completion_seq_group_output.samples
else:
seq_outputs = []

Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def init_worker(self, *args, **kwargs):
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None

def execute_method(self, method, *args, **kwargs):
def execute_method(self, method: str, *args, **kwargs):
try:
target = self if self.worker is None else self.worker
executor = getattr(target, method)
Expand Down

0 comments on commit e51719a

Please sign in to comment.