Skip to content

Commit

Permalink
community: Fix rank-llm import paths for new 0.20.3 version
Browse files Browse the repository at this point in the history
  • Loading branch information
tymzar committed Jan 11, 2025
1 parent bbc3e3b commit fcdd05b
Showing 1 changed file with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from copy import deepcopy
from enum import Enum
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.utils import get_from_dict_or_env
from packaging.version import Version
from pydantic import ConfigDict, Field, PrivateAttr, model_validator

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +51,10 @@ def validate_environment(cls, values: Dict) -> Any:
if not values.get("client"):
client_name = values.get("model", "zephyr")

is_pre_rank_llm_revamp = Version(version=version("rank_llm")) <= Version(
"0.12.8"
)

try:
model_enum = ModelType(client_name.lower())
except ValueError:
Expand All @@ -58,15 +64,29 @@ def validate_environment(cls, values: Dict) -> Any:

try:
if model_enum == ModelType.VICUNA:
from rank_llm.rerank.vicuna_reranker import VicunaReranker
if is_pre_rank_llm_revamp:
from rank_llm.rerank.vicuna_reranker import VicunaReranker
else:
from rank_llm.rerank.listwise.vicuna_reranker import (
VicunaReranker,
)

values["client"] = VicunaReranker()
elif model_enum == ModelType.ZEPHYR:
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
if is_pre_rank_llm_revamp:
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
else:
from rank_llm.rerank.listwise.zephyr_reranker import (
ZephyrReranker,
)

values["client"] = ZephyrReranker()
elif model_enum == ModelType.GPT:
from rank_llm.rerank.rank_gpt import SafeOpenai
if is_pre_rank_llm_revamp:
from rank_llm.rerank.rank_gpt import SafeOpenai
else:
from rank_llm.rerank.listwise.rank_gpt import SafeOpenai

from rank_llm.rerank.reranker import Reranker

openai_api_key = get_from_dict_or_env(
Expand Down

0 comments on commit fcdd05b

Please sign in to comment.