diff --git a/libs/community/langchain_community/document_compressors/rankllm_rerank.py b/libs/community/langchain_community/document_compressors/rankllm_rerank.py index bcf18652928e2..d0115956c3e5c 100644 --- a/libs/community/langchain_community/document_compressors/rankllm_rerank.py +++ b/libs/community/langchain_community/document_compressors/rankllm_rerank.py @@ -2,12 +2,14 @@ from copy import deepcopy from enum import Enum +from importlib.metadata import version from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.utils import get_from_dict_or_env +from packaging.version import Version from pydantic import ConfigDict, Field, PrivateAttr, model_validator if TYPE_CHECKING: @@ -49,6 +51,10 @@ def validate_environment(cls, values: Dict) -> Any: if not values.get("client"): client_name = values.get("model", "zephyr") + is_pre_rank_llm_revamp = Version( + version=version(("rank_llm").version) + ) <= Version("0.12.8") + try: model_enum = ModelType(client_name.lower()) except ValueError: @@ -58,15 +64,29 @@ def validate_environment(cls, values: Dict) -> Any: try: if model_enum == ModelType.VICUNA: - from rank_llm.rerank.vicuna_reranker import VicunaReranker + if is_pre_rank_llm_revamp: + from rank_llm.rerank.vicuna_reranker import VicunaReranker + else: + from rank_llm.rerank.listwise.vicuna_reranker import ( + VicunaReranker, + ) values["client"] = VicunaReranker() elif model_enum == ModelType.ZEPHYR: - from rank_llm.rerank.zephyr_reranker import ZephyrReranker + if is_pre_rank_llm_revamp: + from rank_llm.rerank.zephyr_reranker import ZephyrReranker + else: + from rank_llm.rerank.listwise.zephyr_reranker import ( + ZephyrReranker, + ) values["client"] = ZephyrReranker() elif model_enum == ModelType.GPT: - from rank_llm.rerank.rank_gpt import SafeOpenai + if is_pre_rank_llm_revamp: + from rank_llm.rerank.rank_gpt import SafeOpenai + else: + from rank_llm.rerank.listwise.rank_gpt import SafeOpenai + from rank_llm.rerank.reranker import Reranker openai_api_key = get_from_dict_or_env(