diff --git a/libs/community/langchain_community/embeddings/baichuan.py b/libs/community/langchain_community/embeddings/baichuan.py index a346fa2b57d30..d0f54fff0d36b 100644 --- a/libs/community/langchain_community/embeddings/baichuan.py +++ b/libs/community/langchain_community/embeddings/baichuan.py @@ -4,6 +4,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from requests import RequestException BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings" @@ -22,11 +23,23 @@ # NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding. # Multi-language support is coming soon. class BaichuanTextEmbeddings(BaseModel, Embeddings): - """Baichuan Text Embedding models.""" + """Baichuan Text Embedding models. + + To use, you should set the environment variable ``BAICHUAN_API_KEY`` to + your API key or pass it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain_community.embeddings import BaichuanTextEmbeddings + + baichuan = BaichuanTextEmbeddings(baichuan_api_key="my-api-key") + """ session: Any #: :meta private: model_name: str = "Baichuan-Text-Embedding" baichuan_api_key: Optional[SecretStr] = None + """Automatically inferred from env var `BAICHUAN_API_KEY` if not provided.""" @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: @@ -65,29 +78,26 @@ def _embed(self, texts: List[str]) -> Optional[List[List[float]]]: A list of list of floats representing the embeddings, or None if an error occurs. """ - try: - response = self.session.post( - BAICHUAN_API_URL, json={"input": texts, "model": self.model_name} + response = self.session.post( + BAICHUAN_API_URL, json={"input": texts, "model": self.model_name} + ) + # Raise exception if response status code from 400 to 600 + response.raise_for_status() + # Check if the response status code indicates success + if response.status_code == 200: + resp = response.json() + embeddings = resp.get("data", []) + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) + # Return just the embeddings + return [result.get("embedding", []) for result in sorted_embeddings] + else: + # Log error or handle unsuccessful response appropriately + # Handle 100 <= status_code < 400, not include 200 + raise RequestException( + f"Error: Received status code {response.status_code} from " + "`BaichuanEmbedding` API" ) - # Check if the response status code indicates success - if response.status_code == 200: - resp = response.json() - embeddings = resp.get("data", []) - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) - # Return just the embeddings - return [result.get("embedding", []) for result in sorted_embeddings] - else: - # Log error or handle unsuccessful response appropriately - print( # noqa: T201 - f"Error: Received status code {response.status_code} from " - "embedding API" - ) - return None - except Exception as e: - # Log the exception or handle it as needed - print(f"Exception occurred while trying to get embeddings: {str(e)}") # noqa: T201 - return None def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override] """Public method to get embeddings for a list of documents.