diff --git a/vllm/config.py b/vllm/config.py index 197f20c1ec9a5..e5fc464328b62 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -75,6 +75,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + hf_kwargs: Optional[dict] = None ) -> None: self.model = model self.tokenizer = tokenizer @@ -100,7 +101,7 @@ def __init__( self.download_dir = model_path self.tokenizer = model_path - self.hf_config = get_config(self.model, trust_remote_code, revision) + self.hf_config = get_config(self.model, trust_remote_code, revision, **hf_kwargs) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 231ce3321cdc4..c738c12847f0e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -43,6 +43,7 @@ class EngineArgs: lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None + hf_kwargs: Optional[dict] = None def __post_init__(self): if self.tokenizer is None: @@ -275,7 +276,7 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, - self.max_context_len_to_capture) + self.max_context_len_to_capture, hf_kwargs=(self.hf_kwargs or {})) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8b16e559b24f2..15faec1275c4a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -18,10 +18,11 @@ def get_config(model: str, trust_remote_code: bool, - revision: Optional[str] = None) -> PretrainedConfig: + revision: Optional[str] = None, + **hf_kwargs) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision) + model, hf_kwargs, trust_remote_code=trust_remote_code, revision=revision) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)):