diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index f2f8067bd1175..bbba05de37d2c 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -32,20 +32,82 @@ def encode_batch(self, texts: List[str]) -> List[List[str]]: class MistralAIEmbeddings(BaseModel, Embeddings): - """MistralAI embedding models. + """MistralAI embedding model integration. - To use, set the environment variable `MISTRAL_API_KEY` is set with your API key or - pass it as a named parameter to the constructor. + Setup: + Install ``langchain_mistralai`` and set environment variable + ``MISTRAL_API_KEY``. - Example: + .. code-block:: bash + + pip install -U langchain_mistralai + export MISTRAL_API_KEY="your-api-key" + + Key init args — completion params: + model: str + Name of MistralAI model to use. + + Key init args — client params: + api_key: Optional[SecretStr] + The API key for the MistralAI API. If not provided, it will be read from the + environment variable `MISTRAL_API_KEY`. + max_retries: int + The number of times to retry a request if it fails. + timeout: int + The number of seconds to wait for a response before timing out. + max_concurrent_requests: int + The maximum number of concurrent requests to make to the Mistral API. + + See full list of supported init args and their descriptions in the params section. + + Instantiate: .. code-block:: python - from langchain_mistralai import MistralAIEmbeddings + from __module_name__ import MistralAIEmbeddings - mistral = MistralAIEmbeddings( + embed = MistralAIEmbeddings( model="mistral-embed", - api_key="my-api-key" + # api_key="...", + # other params... ) + + Embed single text: + .. code-block:: python + + input_text = "The meaning of life is 42" + vector = embed.embed_query(input_text) + print(vector[:3]) + + .. code-block:: python + + [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] + + Embed multiple text: + .. code-block:: python + + input_texts = ["Document 1...", "Document 2..."] + vectors = embed.embed_documents(input_texts) + print(len(vectors)) + # The first 3 coordinates for the first vector + print(vectors[0][:3]) + + .. code-block:: python + + 2 + [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] + + Async: + .. code-block:: python + + vector = await embed.aembed_query(input_text) + print(vector[:3]) + + # multiple: + # await embed.aembed_documents(input_texts) + + .. code-block:: python + + [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] """ client: httpx.Client = Field(default=None) #: :meta private: