diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4c353ae6ffc13..37b91a803d71e 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -from transformers import PretrainedConfig from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger @@ -19,9 +18,6 @@ logger = init_logger(__name__) -# The type of HF config -C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True) - # The type of hidden states # Currently, T = torch.Tensor for all models except for Medusa # which has T = List[torch.Tensor] @@ -34,7 +30,7 @@ @runtime_checkable -class VllmModel(Protocol[C_co, T_co]): +class VllmModel(Protocol[T_co]): """The interface required for all models in vLLM.""" def __init__( @@ -97,7 +93,7 @@ def is_vllm_model( @runtime_checkable -class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]): +class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): """The interface required for all generative models in vLLM.""" def compute_logits( @@ -143,7 +139,7 @@ def is_text_generation_model( @runtime_checkable -class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]): +class VllmModelForPooling(VllmModel[T], Protocol[T]): """The interface required for all pooling models in vLLM.""" def pooler(