diff --git a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py index fa770a1db4..15715c2ab4 100644 --- a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py +++ b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py @@ -71,29 +71,41 @@ class ElasticsearchOnlineStoreConfig(FeastConfigBaseModel): password: str """ password to connect to Elasticsearch """ + write_batch_size: Optional[int] = 40 + """ The number of rows to write in a single batch """ -class ElasticsearchConnectionManager: - def __init__(self, online_config: RepoConfig): - self.online_config = online_config - def __enter__(self): - # Connecting to Elasticsearch - logger.info( - f"Connecting to Elasticsearch with endpoint {self.online_config.endpoint}" - ) - self.client = Elasticsearch( - self.online_config.endpoint, - basic_auth=(self.online_config.username, self.online_config.password), +class ElasticsearchOnlineStore(OnlineStore): + _client: Optional[Elasticsearch] = None + + def _get_client(self, config: RepoConfig) -> Elasticsearch: + online_store_config = config.online_store + assert isinstance(online_store_config, ElasticsearchOnlineStoreConfig) + + user = online_store_config.username if online_store_config.username is not None else "" + password = ( + online_store_config.password + if online_store_config.password is not None + else "" ) - return self.client - def __exit__(self, exc_type, exc_value, traceback): - # Disconnecting from Elasticsearch - logger.info("Closing the connection to Elasticsearch") - self.client.transport.close() + if self._client: + return self._client + else: + self._client = Elasticsearch( + hosts=online_store_config.endpoint, + basic_auth=(user, password), + ) + return self._client + def _get_bulk_documents(self, index_name, data): + for entity_key, values, timestamp, created_ts in data: + id_val = self._get_value_from_value_proto(entity_key.entity_values[0]) + document = {entity_key.join_keys[0]: id_val} + for feature_name, val in values.items(): + document[feature_name] = self._get_value_from_value_proto(val) + yield {"_index": index_name, "_id": id_val, "_source": document} -class ElasticsearchOnlineStore(OnlineStore): def online_write_batch( self, config: RepoConfig, @@ -103,24 +115,27 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: - with ElasticsearchConnectionManager(config.online_store) as es: + with self._get_client(config) as es: resp = es.indices.exists(index=table.name) if not resp.body: self._create_index(es, table) - bulk_documents = [] - for entity_key, values, timestamp, created_ts in data: - id_val = self._get_value_from_value_proto(entity_key.entity_values[0]) - document = {entity_key.join_keys[0]: id_val} - for feature_name, val in values.items(): - document[feature_name] = self._get_value_from_value_proto(val) - bulk_documents.append( - {"_index": table.name, "_id": id_val, "_source": document} - ) - - successes, errors = helpers.bulk(client=es, actions=bulk_documents) + + successes = 0 + errors: List[Any] = [] + error_count = 0 + for i in range(0, len(data), config.online_store.write_batch_size): + batch = data[i : i + config.online_store.write_batch_size] + count, errs = helpers.bulk(client=es, actions=self._get_bulk_documents(table.name, batch)) + successes += count + if type(errs) is int: + error_count += errs + elif type(errs) is list: + errors.extend(errs) logger.info(f"bulk write completed with {successes} successes") + if error_count: + logger.error(f"bulk write encountered {errors} errors") if errors: - logger.error(f"bulk write return errors: {errors}") + logger.error(f"bulk write returned errors: {errors}") def online_read( self, @@ -129,7 +144,7 @@ def online_read( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - with ElasticsearchConnectionManager(config.online_store) as es: + with self._get_client(config) as es: id_list = [] for entity in entity_keys: for val in entity.entity_values: @@ -182,7 +197,7 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - with ElasticsearchConnectionManager(config.online_store) as es: + with self._get_client(config.online_store) as es: for fv in tables_to_delete: resp = es.indices.exists(index=fv.name) if resp.body: diff --git a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py index 9fe7b54780..d6f81d7e80 100644 --- a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py +++ b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py @@ -5,10 +5,10 @@ import pytest +from elasticsearch import Elasticsearch from feast import FeatureView from feast.entity import Entity from feast.expediagroup.vectordb.elasticsearch_online_store import ( - ElasticsearchConnectionManager, ElasticsearchOnlineStore, ElasticsearchOnlineStoreConfig, ) @@ -48,6 +48,21 @@ ] +class ElasticsearchConnectionManager: + def __init__(self, online_config: RepoConfig): + self.online_config = online_config + def __enter__(self): + # Connecting to Elasticsearch + self.client = Elasticsearch( + self.online_config.endpoint, + basic_auth=(self.online_config.username, self.online_config.password), + ) + return self.client + def __exit__(self, exc_type, exc_value, traceback): + # Disconnecting from Elasticsearch + self.client.transport.close() + + @pytest.fixture(scope="session") def repo_config(embedded_elasticsearch): return RepoConfig( @@ -58,6 +73,7 @@ def repo_config(embedded_elasticsearch): endpoint=f"http://{embedded_elasticsearch['host']}:{embedded_elasticsearch['port']}", username=embedded_elasticsearch["username"], password=embedded_elasticsearch["password"], + write_batch_size=5 ), offline_store=DaskOfflineStoreConfig(), entity_key_serialization_version=2,