diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index a2f95c71e3dc5..08bb679a552b3 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any from langchain._api import create_importer +from langchain.embeddings.base import init_embeddings from langchain.embeddings.cache import CacheBackedEmbeddings if TYPE_CHECKING: @@ -221,4 +222,5 @@ def __getattr__(name: str) -> Any: "VertexAIEmbeddings", "VoyageEmbeddings", "XinferenceEmbeddings", + "init_embeddings", ] diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py index 9e648a342eab1..a8c8a97939676 100644 --- a/libs/langchain/langchain/embeddings/base.py +++ b/libs/langchain/langchain/embeddings/base.py @@ -1,4 +1,224 @@ +import functools +from importlib import util +from typing import Any, List, Optional, Tuple, Union + +from langchain_core._api import beta from langchain_core.embeddings import Embeddings +from langchain_core.runnables import Runnable + +_SUPPORTED_PROVIDERS = { + "azure_openai": "langchain_openai", + "bedrock": "langchain_aws", + "cohere": "langchain_cohere", + "google_vertexai": "langchain_google_vertexai", + "huggingface": "langchain_huggingface", + "mistralai": "langchain_mistralai", + "openai": "langchain_openai", +} + + +def _get_provider_list() -> str: + """Get formatted list of providers and their packages.""" + return "\n".join( + f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items() + ) + + +def _parse_model_string(model_name: str) -> Tuple[str, str]: + """Parse a model string into provider and model name components. + + The model string should be in the format 'provider:model-name', where provider + is one of the supported providers. + + Args: + model_name: A model string in the format 'provider:model-name' + + Returns: + A tuple of (provider, model_name) + + .. code-block:: python + + _parse_model_string("openai:text-embedding-3-small") + # Returns: ("openai", "text-embedding-3-small") + + _parse_model_string("bedrock:amazon.titan-embed-text-v1") + # Returns: ("bedrock", "amazon.titan-embed-text-v1") + + Raises: + ValueError: If the model string is not in the correct format or + the provider is unsupported + """ + if ":" not in model_name: + providers = _SUPPORTED_PROVIDERS + raise ValueError( + f"Invalid model format '{model_name}'.\n" + f"Model name must be in format 'provider:model-name'\n" + f"Example valid model strings:\n" + f" - openai:text-embedding-3-small\n" + f" - bedrock:amazon.titan-embed-text-v1\n" + f" - cohere:embed-english-v3.0\n" + f"Supported providers: {providers}" + ) + + provider, model = model_name.split(":", 1) + provider = provider.lower().strip() + model = model.strip() + + if provider not in _SUPPORTED_PROVIDERS: + raise ValueError( + f"Provider '{provider}' is not supported.\n" + f"Supported providers and their required packages:\n" + f"{_get_provider_list()}" + ) + if not model: + raise ValueError("Model name cannot be empty") + return provider, model + + +def _infer_model_and_provider( + model: str, *, provider: Optional[str] = None +) -> Tuple[str, str]: + if not model.strip(): + raise ValueError("Model name cannot be empty") + if provider is None and ":" in model: + provider, model_name = _parse_model_string(model) + else: + provider = provider + model_name = model + + if not provider: + providers = _SUPPORTED_PROVIDERS + raise ValueError( + "Must specify either:\n" + "1. A model string in format 'provider:model-name'\n" + " Example: 'openai:text-embedding-3-small'\n" + "2. Or explicitly set provider from: " + f"{providers}" + ) + + if provider not in _SUPPORTED_PROVIDERS: + raise ValueError( + f"Provider '{provider}' is not supported.\n" + f"Supported providers and their required packages:\n" + f"{_get_provider_list()}" + ) + return provider, model_name + + +@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS)) +def _check_pkg(pkg: str) -> None: + """Check if a package is installed.""" + if not util.find_spec(pkg): + raise ImportError( + f"Could not import {pkg} python package. " + f"Please install it with `pip install {pkg}`" + ) + + +@beta() +def init_embeddings( + model: str, + *, + provider: Optional[str] = None, + **kwargs: Any, +) -> Union[Embeddings, Runnable[Any, List[float]]]: + """Initialize an embeddings model from a model name and optional provider. + + **Note:** Must have the integration package corresponding to the model provider + installed. + + Args: + model: Name of the model to use. Can be either: + - A model string like "openai:text-embedding-3-small" + - Just the model name if provider is specified + provider: Optional explicit provider name. If not specified, + will attempt to parse from the model string. Supported providers + and their required packages: + + {_get_provider_list()} + + **kwargs: Additional model-specific parameters passed to the embedding model. + These vary by provider, see the provider-specific documentation for details. + + Returns: + An Embeddings instance that can generate embeddings for text. + + Raises: + ValueError: If the model provider is not supported or cannot be determined + ImportError: If the required provider package is not installed + + .. dropdown:: Example Usage + :open: + + .. code-block:: python + + # Using a model string + model = init_embeddings("openai:text-embedding-3-small") + model.embed_query("Hello, world!") + + # Using explicit provider + model = init_embeddings( + model="text-embedding-3-small", + provider="openai" + ) + model.embed_documents(["Hello, world!", "Goodbye, world!"]) + + # With additional parameters + model = init_embeddings( + "openai:text-embedding-3-small", + api_key="sk-..." + ) + + .. versionadded:: 0.3.9 + """ + if not model: + providers = _SUPPORTED_PROVIDERS.keys() + raise ValueError( + "Must specify model name. " + f"Supported providers are: {', '.join(providers)}" + ) + + provider, model_name = _infer_model_and_provider(model, provider=provider) + pkg = _SUPPORTED_PROVIDERS[provider] + _check_pkg(pkg) + + if provider == "openai": + from langchain_openai import OpenAIEmbeddings + + return OpenAIEmbeddings(model=model_name, **kwargs) + elif provider == "azure_openai": + from langchain_openai import AzureOpenAIEmbeddings + + return AzureOpenAIEmbeddings(model=model_name, **kwargs) + elif provider == "google_vertexai": + from langchain_google_vertexai import VertexAIEmbeddings + + return VertexAIEmbeddings(model=model_name, **kwargs) + elif provider == "bedrock": + from langchain_aws import BedrockEmbeddings + + return BedrockEmbeddings(model_id=model_name, **kwargs) + elif provider == "cohere": + from langchain_cohere import CohereEmbeddings + + return CohereEmbeddings(model=model_name, **kwargs) + elif provider == "mistralai": + from langchain_mistralai import MistralAIEmbeddings + + return MistralAIEmbeddings(model=model_name, **kwargs) + elif provider == "huggingface": + from langchain_huggingface import HuggingFaceEmbeddings + + return HuggingFaceEmbeddings(model_name=model_name, **kwargs) + else: + raise ValueError( + f"Provider '{provider}' is not supported.\n" + f"Supported providers and their required packages:\n" + f"{_get_provider_list()}" + ) + -# This is for backwards compatibility -__all__ = ["Embeddings"] +__all__ = [ + "init_embeddings", + "Embeddings", # This one is for backwards compatibility +] diff --git a/libs/langchain/tests/integration_tests/embeddings/__init__.py b/libs/langchain/tests/integration_tests/embeddings/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain/tests/integration_tests/embeddings/test_base.py b/libs/langchain/tests/integration_tests/embeddings/test_base.py new file mode 100644 index 0000000000000..204754642fdf5 --- /dev/null +++ b/libs/langchain/tests/integration_tests/embeddings/test_base.py @@ -0,0 +1,44 @@ +"""Test embeddings base module.""" + +import importlib + +import pytest +from langchain_core.embeddings import Embeddings + +from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings + + +@pytest.mark.parametrize( + "provider, model", + [ + ("openai", "text-embedding-3-large"), + ("google_vertexai", "text-embedding-gecko@003"), + ("bedrock", "amazon.titan-embed-text-v1"), + ("cohere", "embed-english-v2.0"), + ], +) +async def test_init_embedding_model(provider: str, model: str) -> None: + package = _SUPPORTED_PROVIDERS[provider] + try: + importlib.import_module(package) + except ImportError: + pytest.skip(f"Package {package} is not installed") + + model_colon = init_embeddings(f"{provider}:{model}") + assert isinstance(model_colon, Embeddings) + + model_explicit = init_embeddings( + model=model, + provider=provider, + ) + assert isinstance(model_explicit, Embeddings) + + text = "Hello world" + + embedding_colon = await model_colon.aembed_query(text) + assert isinstance(embedding_colon, list) + assert all(isinstance(x, float) for x in embedding_colon) + + embedding_explicit = await model_explicit.aembed_query(text) + assert isinstance(embedding_explicit, list) + assert all(isinstance(x, float) for x in embedding_explicit) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_base.py b/libs/langchain/tests/unit_tests/embeddings/test_base.py new file mode 100644 index 0000000000000..5ca919497458b --- /dev/null +++ b/libs/langchain/tests/unit_tests/embeddings/test_base.py @@ -0,0 +1,111 @@ +"""Test embeddings base module.""" + +import pytest + +from langchain.embeddings.base import ( + _SUPPORTED_PROVIDERS, + _infer_model_and_provider, + _parse_model_string, +) + + +def test_parse_model_string() -> None: + """Test parsing model strings into provider and model components.""" + assert _parse_model_string("openai:text-embedding-3-small") == ( + "openai", + "text-embedding-3-small", + ) + assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == ( + "bedrock", + "amazon.titan-embed-text-v1", + ) + assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == ( + "huggingface", + "BAAI/bge-base-en:v1.5", + ) + + +def test_parse_model_string_errors() -> None: + """Test error cases for model string parsing.""" + with pytest.raises(ValueError, match="Model name must be"): + _parse_model_string("just-a-model-name") + + with pytest.raises(ValueError, match="Invalid model format "): + _parse_model_string("") + + with pytest.raises(ValueError, match="is not supported"): + _parse_model_string(":model-name") + + with pytest.raises(ValueError, match="Model name cannot be empty"): + _parse_model_string("openai:") + + with pytest.raises( + ValueError, match="Provider 'invalid-provider' is not supported" + ): + _parse_model_string("invalid-provider:model-name") + + for provider in _SUPPORTED_PROVIDERS: + with pytest.raises(ValueError, match=f"{provider}"): + _parse_model_string("invalid-provider:model-name") + + +def test_infer_model_and_provider() -> None: + """Test model and provider inference from different input formats.""" + assert _infer_model_and_provider("openai:text-embedding-3-small") == ( + "openai", + "text-embedding-3-small", + ) + + assert _infer_model_and_provider( + model="text-embedding-3-small", provider="openai" + ) == ("openai", "text-embedding-3-small") + + assert _infer_model_and_provider( + model="ft:text-embedding-3-small", provider="openai" + ) == ("openai", "ft:text-embedding-3-small") + + assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == ( + "openai", + "ft:text-embedding-3-small", + ) + + +def test_infer_model_and_provider_errors() -> None: + """Test error cases for model and provider inference.""" + # Test missing provider + with pytest.raises(ValueError, match="Must specify either"): + _infer_model_and_provider("text-embedding-3-small") + + # Test empty model + with pytest.raises(ValueError, match="Model name cannot be empty"): + _infer_model_and_provider("") + + # Test empty provider with model + with pytest.raises(ValueError, match="Must specify either"): + _infer_model_and_provider("model", provider="") + + # Test invalid provider + with pytest.raises(ValueError, match="is not supported"): + _infer_model_and_provider("model", provider="invalid") + + # Test provider list is in error + with pytest.raises(ValueError) as exc: + _infer_model_and_provider("model", provider="invalid") + for provider in _SUPPORTED_PROVIDERS: + assert provider in str(exc.value) + + +@pytest.mark.parametrize( + "provider", + sorted(_SUPPORTED_PROVIDERS.keys()), +) +def test_supported_providers_package_names(provider: str) -> None: + """Test that all supported providers have valid package names.""" + package = _SUPPORTED_PROVIDERS[provider] + assert "-" not in package + assert package.startswith("langchain_") + assert package.islower() + + +def test_is_sorted() -> None: + assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys()) diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index c6d7a8207d1c5..b44acf1a6032d 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -55,6 +55,7 @@ "JohnSnowLabsEmbeddings", "VoyageEmbeddings", "BookendEmbeddings", + "init_embeddings", ]