diff --git a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py new file mode 100644 index 0000000000..730e303a72 --- /dev/null +++ b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py @@ -0,0 +1,170 @@ +import json +import logging +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +from bidict import bidict +from elasticsearch import Elasticsearch +from pydantic.typing import Literal + +from feast import Entity, FeatureView, RepoConfig +from feast.infra.online_stores.online_store import OnlineStore +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import FeastConfigBaseModel +from feast.types import ( + Bool, + Bytes, + ComplexFeastType, + FeastType, + Float32, + Float64, + Int32, + Int64, + String, + UnixTimestamp, +) + +logger = logging.getLogger(__name__) + +TYPE_MAPPING = bidict( + { + Bytes: "binary", + Int32: "integer", + Int64: "long", + Float32: "float", + Float64: "double", + Bool: "boolean", + String: "text", + UnixTimestamp: "date_nanos", + } +) + + +class ElasticsearchOnlineStoreConfig(FeastConfigBaseModel): + """Online store config for the Elasticsearch online store""" + + type: Literal["elasticsearch"] = "elasticsearch" + """Online store type selector""" + + endpoint: str + """ the http endpoint URL """ + + username: str + """ username to connect to Elasticsearch """ + + password: str + """ password to connect to Elasticsearch """ + + token: str + """ bearer token for authentication """ + + +class ElasticsearchConnectionManager: + def __init__(self, online_config: RepoConfig): + self.online_config = online_config + + def __enter__(self): + # Connecting to Elasticsearch + logger.info( + f"Connecting to Elasticsearch with endpoint {self.online_config.endpoint}" + ) + if len(self.online_config.token) > 0: + self.client = Elasticsearch( + self.online_config.endpoint, bearer_auth=self.online_config.token + ) + else: + self.client = Elasticsearch( + self.online_config.endpoint, + basic_auth=(self.online_config.username, self.online_config.password), + ) + return self.client + + def __exit__(self, exc_type, exc_value, traceback): + # Disconnecting from Elasticsearch + logger.info("Closing the connection to Elasticsearch") + self.client.transport.close() + + +class ElasticsearchOnlineStore(OnlineStore): + def online_write_batch( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + with ElasticsearchConnectionManager(config): + pass + + def online_read( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + pass + + def update( + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + with ElasticsearchConnectionManager(config) as es: + for fv in tables_to_delete: + resp = es.indices.exists(index=fv.name) + if resp.body: + es.indices.delete(index=fv.name) + for fv in tables_to_keep: + resp = es.indices.exists(index=fv.name) + if not resp.body: + self._create_index(es, fv) + + def teardown( + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], + ): + pass + + def _create_index(self, es, fv): + index_mapping = {"properties": {}} + for feature in fv.schema: + is_primary = True if feature.name in fv.join_keys else False + if "index_type" in feature.tags: + dimensions = int(feature.tags.get("dimensions", "0")) + metric_type = feature.tags.get("metric_type", "l2_norm") + index_mapping["properties"][feature.name] = { + "type": "dense_vector", + "dims": dimensions, + "index": True, + "similarity": metric_type, + } + index_params = json.loads(feature.tags.get("index_params", "{}")) + if len(index_params) > 0: + index_params["type"] = feature.tags.get( + "index_type", "hnsw" + ).lower() + index_mapping["properties"][feature.name][ + "index_options" + ] = index_params + else: + t = self._get_data_type(feature.dtype) + t = "keyword" if is_primary and t == "text" else t + index_mapping["properties"][feature.name] = {"type": t} + if is_primary: + index_mapping["properties"][feature.name]["index"] = True + es.indices.create(index=fv.name, mappings=index_mapping) + + def _get_data_type(self, t: FeastType) -> str: + if isinstance(t, ComplexFeastType): + return "text" + return TYPE_MAPPING.get(t, "text") diff --git a/sdk/python/requirements/py3.10-ci-requirements.txt b/sdk/python/requirements/py3.10-ci-requirements.txt index 267e27b6c9..223e76b2e6 100644 --- a/sdk/python/requirements/py3.10-ci-requirements.txt +++ b/sdk/python/requirements/py3.10-ci-requirements.txt @@ -197,6 +197,10 @@ docker==6.1.3 # testcontainers docutils==0.19 # via sphinx +elastic-transport==8.4.1 + # via elasticsearch +elasticsearch==8.8.0 + # via eg-feast (setup.py) entrypoints==0.4 # via altair environs==9.5.0 diff --git a/sdk/python/requirements/py3.8-ci-requirements.txt b/sdk/python/requirements/py3.8-ci-requirements.txt index ec8875c268..0a1e7d74de 100644 --- a/sdk/python/requirements/py3.8-ci-requirements.txt +++ b/sdk/python/requirements/py3.8-ci-requirements.txt @@ -200,6 +200,10 @@ docker==6.1.3 # testcontainers docutils==0.19 # via sphinx +elastic-transport==8.4.1 + # via elasticsearch +elasticsearch==8.8.0 + # via eg-feast (setup.py) entrypoints==0.4 # via altair exceptiongroup==1.1.1 diff --git a/sdk/python/requirements/py3.9-ci-requirements.txt b/sdk/python/requirements/py3.9-ci-requirements.txt index b9d053b2ae..31eb4496c6 100644 --- a/sdk/python/requirements/py3.9-ci-requirements.txt +++ b/sdk/python/requirements/py3.9-ci-requirements.txt @@ -197,6 +197,10 @@ docker==6.1.3 # testcontainers docutils==0.19 # via sphinx +elastic-transport==8.4.1 + # via elasticsearch +elasticsearch==8.8.0 + # via eg-feast (setup.py) entrypoints==0.4 # via altair environs==9.5.0 diff --git a/sdk/python/tests/expediagroup/elasticsearch_online_store_creator.py b/sdk/python/tests/expediagroup/elasticsearch_online_store_creator.py new file mode 100644 index 0000000000..6bda8ac0ff --- /dev/null +++ b/sdk/python/tests/expediagroup/elasticsearch_online_store_creator.py @@ -0,0 +1,36 @@ +import logging + +from testcontainers.elasticsearch import ElasticSearchContainer + +from tests.integration.feature_repos.universal.online_store_creator import ( + OnlineStoreCreator, +) + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ElasticsearchOnlineCreator(OnlineStoreCreator): + def __init__(self, project_name: str, es_port: int): + super().__init__(project_name) + self.elasticsearch_container = ElasticSearchContainer( + image="docker.elastic.co/elasticsearch/elasticsearch:8.8.2", + port_to_expose=es_port, + ) + + def create_online_store(self): + # Start the container + self.elasticsearch_container.start() + elasticsearch_host = self.elasticsearch_container.get_container_host_ip() + elasticsearch_http_port = self.elasticsearch_container.get_exposed_port(9200) + return { + "host": elasticsearch_host, + "port": elasticsearch_http_port, + "username": "", + "password": "", + "token": "", + } + + def teardown(self): + self.elasticsearch_container.stop() diff --git a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py new file mode 100644 index 0000000000..ab33aef499 --- /dev/null +++ b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py @@ -0,0 +1,323 @@ +import json +import logging +import random +from datetime import datetime + +import pytest + +from feast import FeatureView +from feast.entity import Entity +from feast.expediagroup.vectordb.elasticsearch_online_store import ( + ElasticsearchConnectionManager, + ElasticsearchOnlineStore, + ElasticsearchOnlineStoreConfig, +) +from feast.field import Field +from feast.infra.offline_stores.file import FileOfflineStoreConfig +from feast.infra.offline_stores.file_source import FileSource +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import FloatList +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import RepoConfig +from feast.types import ( + Array, + Bool, + Bytes, + Float32, + Float64, + Int32, + Int64, + String, + UnixTimestamp, +) +from tests.expediagroup.elasticsearch_online_store_creator import ( + ElasticsearchOnlineCreator, +) + +logging.basicConfig(level=logging.INFO) + +REGISTRY = "s3://test_registry/registry.db" +PROJECT = "test_aws" +PROVIDER = "aws" +REGION = "us-west-2" +SOURCE = FileSource(path="some path") + +index_param_list = [ + {"index_type": "HNSW", "index_params": {"m": 16, "ef_construction": 100}}, + {"index_type": "HNSW"}, +] + + +@pytest.fixture(scope="session") +def repo_config(embedded_elasticsearch): + return RepoConfig( + registry=REGISTRY, + project=PROJECT, + provider=PROVIDER, + online_store=ElasticsearchOnlineStoreConfig( + endpoint=f"http://{embedded_elasticsearch['host']}:{embedded_elasticsearch['port']}", + username=embedded_elasticsearch["username"], + password=embedded_elasticsearch["password"], + token=embedded_elasticsearch["token"], + ), + offline_store=FileOfflineStoreConfig(), + entity_key_serialization_version=2, + ) + + +@pytest.fixture(scope="session") +def embedded_elasticsearch(): + online_store_creator = ElasticsearchOnlineCreator(PROJECT, 9200) + online_store_config = online_store_creator.create_online_store() + + yield online_store_config + + online_store_creator.teardown() + + +class TestElasticsearchOnlineStore: + index_to_write = "index_write" + index_to_delete = "index_delete" + unavailable_index = "abc" + + @pytest.fixture(autouse=True) + def setup_method(self, repo_config): + # Ensuring that the indexes created are dropped before the tests are run + with ElasticsearchConnectionManager(repo_config.online_store) as es: + # Dropping indexes if they exist + if es.indices.exists(index=self.index_to_delete): + es.indices.delete(index=self.index_to_delete) + if es.indices.exists(index=self.index_to_write): + es.indices.delete(index=self.index_to_write) + if es.indices.exists(index=self.unavailable_index): + es.indices.delete(index=self.unavailable_index) + + yield + + def create_n_customer_test_samples_elasticsearch_online_read(self, n=10): + return [ + ( + EntityKeyProto( + join_keys=["film_id"], + entity_values=[ValueProto(int64_val=i)], + ), + { + "films": ValueProto( + float_list_val=FloatList( + val=[random.random() for _ in range(2)] + ) + ), + "film_date": ValueProto(int64_val=n), + "film_id": ValueProto(int64_val=n), + }, + datetime.utcnow(), + None, + ) + for i in range(n) + ] + + @pytest.mark.parametrize("index_params", index_param_list) + def test_elasticsearch_update_add_index(self, repo_config, caplog, index_params): + dimensions = 16 + vector_type = Float32 + vector_tags = { + "is_primary": "False", + "description": vector_type.name, + "dimensions": dimensions, + "index_type": index_params["index_type"], + } + if "index_params" in index_params: + vector_tags["index_params"] = json.dumps( + index_params.get("index_params", {}) + ) + entity = Entity(name="feature2") + feast_schema = [ + Field( + name="feature1", + dtype=Array(vector_type), + tags=vector_tags, + ), + Field( + name="feature2", + dtype=String, + ), + Field(name="feature3", dtype=String), + Field(name="feature4", dtype=Bytes), + Field(name="feature5", dtype=Int32), + Field(name="feature6", dtype=Int64), + Field(name="feature7", dtype=Float32), + Field(name="feature8", dtype=Float64), + Field(name="feature9", dtype=Bool), + Field(name="feature10", dtype=UnixTimestamp), + ] + ElasticsearchOnlineStore().update( + config=repo_config.online_store, + tables_to_delete=[], + tables_to_keep=[ + FeatureView( + name=self.index_to_write, + entities=[entity], + schema=feast_schema, + source=SOURCE, + ) + ], + entities_to_delete=[], + entities_to_keep=[], + partial=False, + ) + + mapping = { + "properties": { + "feature1": { + "type": "dense_vector", + "dims": 16, + "index": True, + "similarity": "l2_norm", + }, + "feature2": {"type": "keyword"}, + "feature3": {"type": "text"}, + "feature4": {"type": "binary"}, + "feature5": {"type": "integer"}, + "feature6": {"type": "long"}, + "feature7": {"type": "float"}, + "feature8": {"type": "double"}, + "feature9": {"type": "boolean"}, + "feature10": {"type": "date_nanos"}, + } + } + if "index_params" in index_params: + mapping["properties"]["feature1"]["index_options"] = { + "type": index_params["index_type"].lower(), + **index_params["index_params"], + } + with ElasticsearchConnectionManager(repo_config.online_store) as es: + created_index = es.indices.get(index=self.index_to_write) + assert created_index.body[self.index_to_write]["mappings"] == mapping + + def test_elasticsearch_update_add_existing_index(self, repo_config, caplog): + entity = Entity(name="id") + feast_schema = [ + Field( + name="vector", + dtype=Array(Float32), + tags={ + "description": "float32", + "dimensions": "10", + "index_type": "HNSW", + }, + ), + Field( + name="id", + dtype=String, + ), + ] + self._create_index_in_es(self.index_to_write, repo_config) + ElasticsearchOnlineStore().update( + config=repo_config.online_store, + tables_to_delete=[], + tables_to_keep=[ + FeatureView( + name=self.index_to_write, + entities=[entity], + schema=feast_schema, + source=SOURCE, + ) + ], + entities_to_delete=[], + entities_to_keep=[], + partial=False, + ) + with ElasticsearchConnectionManager(repo_config.online_store) as es: + assert es.indices.exists(index=self.index_to_write).body is True + + def test_elasticsearch_update_delete_index(self, repo_config, caplog): + entity = Entity(name="id") + feast_schema = [ + Field( + name="vector", + dtype=Array(Float32), + tags={ + "description": "float32", + "dimensions": "10", + "index_type": "HNSW", + }, + ), + Field( + name="id", + dtype=String, + ), + ] + self._create_index_in_es(self.index_to_delete, repo_config) + with ElasticsearchConnectionManager(repo_config.online_store) as es: + assert es.indices.exists(index=self.index_to_delete).body is True + + ElasticsearchOnlineStore().update( + config=repo_config.online_store, + tables_to_delete=[ + FeatureView( + name=self.index_to_delete, + entities=[entity], + schema=feast_schema, + source=SOURCE, + ) + ], + tables_to_keep=[], + entities_to_delete=[], + entities_to_keep=[], + partial=False, + ) + with ElasticsearchConnectionManager(repo_config.online_store) as es: + assert es.indices.exists(index=self.index_to_delete).body is False + + def test_elasticsearch_update_delete_unavailable_index(self, repo_config, caplog): + entity = Entity(name="id") + feast_schema = [ + Field( + name="vector", + dtype=Array(Float32), + tags={ + "description": "float32", + "dimensions": "10", + "index_type": "HNSW", + }, + ), + Field( + name="id", + dtype=String, + ), + ] + with ElasticsearchConnectionManager(repo_config.online_store) as es: + assert es.indices.exists(index=self.index_to_delete).body is False + + ElasticsearchOnlineStore().update( + config=repo_config.online_store, + tables_to_delete=[ + FeatureView( + name=self.index_to_delete, + entities=[entity], + schema=feast_schema, + source=SOURCE, + ) + ], + tables_to_keep=[], + entities_to_delete=[], + entities_to_keep=[], + partial=False, + ) + with ElasticsearchConnectionManager(repo_config.online_store) as es: + assert es.indices.exists(index=self.index_to_delete).body is False + + def _create_index_in_es(self, index_name, repo_config): + with ElasticsearchConnectionManager(repo_config.online_store) as es: + mapping = { + "properties": { + "vector": { + "type": "dense_vector", + "dims": 10, + "index": True, + "similarity": "l2_norm", + }, + "id": {"type": "keyword"}, + } + } + es.indices.create(index=index_name, mappings=mapping) diff --git a/setup.py b/setup.py index 40df4e1394..96273be713 100644 --- a/setup.py +++ b/setup.py @@ -153,6 +153,11 @@ "bidict==0.22.1" ] +ELASTICSEARCH_REQUIRED = [ + "elasticsearch==8.8", + "bidict==0.22.1", +] + CI_REQUIRED = ( [ "build", @@ -213,6 +218,7 @@ + ROCKSET_REQUIRED + HAZELCAST_REQUIRED + MILVUS_REQUIRED + + ELASTICSEARCH_REQUIRED ) @@ -546,6 +552,7 @@ def copy_extensions_to_source(self): "rockset": ROCKSET_REQUIRED, "milvus": MILVUS_REQUIRED, "go": GO_REQUIRED, + "elasticsearch": ELASTICSEARCH_REQUIRED, }, include_package_data=True, license="Apache",