From 0640cbf2f126f773b7ae78b0f94c1ba0caabb2c1 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Mon, 21 Oct 2024 17:37:07 -0400 Subject: [PATCH] huggingface[patch]: hide client field in HuggingFaceEmbeddings (#27522) --- .../embeddings/huggingface.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py index a53e0e18b6e36..180a9ed3b5e79 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional # type: ignore[import-not-found] +from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from pydantic import BaseModel, ConfigDict, Field @@ -26,7 +26,6 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): ) """ - client: Any = None #: :meta private: model_name: str = DEFAULT_MODEL_NAME """Model name to use.""" cache_folder: Optional[str] = None @@ -57,7 +56,7 @@ def __init__(self, **kwargs: Any): "Please install it with `pip install sentence-transformers`." ) from exc - self.client = sentence_transformers.SentenceTransformer( + self._client = sentence_transformers.SentenceTransformer( self.model_name, cache_folder=self.cache_folder, **self.model_kwargs ) @@ -79,12 +78,20 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: texts = list(map(lambda x: x.replace("\n", " "), texts)) if self.multi_process: - pool = self.client.start_multi_process_pool() - embeddings = self.client.encode_multi_process(texts, pool) + pool = self._client.start_multi_process_pool() + embeddings = self._client.encode_multi_process(texts, pool) sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) else: - embeddings = self.client.encode( - texts, show_progress_bar=self.show_progress, **self.encode_kwargs + embeddings = self._client.encode( + texts, + show_progress_bar=self.show_progress, + **self.encode_kwargs, # type: ignore + ) + + if isinstance(embeddings, list): + raise TypeError( + "Expected embeddings to be a Tensor or a numpy array, " + "got a list instead." ) return embeddings.tolist()