Skip to content

Commit

Permalink
implement write batch method for elasticsearch online store
Browse files Browse the repository at this point in the history
  • Loading branch information
piket committed Oct 25, 2023
1 parent 1ddee5f commit 396fd50
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

from bidict import bidict
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, helpers
from pydantic.typing import Literal

from feast import Entity, FeatureView, RepoConfig
Expand Down Expand Up @@ -96,8 +97,22 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
with ElasticsearchConnectionManager(config):
pass
with ElasticsearchConnectionManager(config) as es:
resp = es.indices.exists(index=table.name)
if not resp.body:
self._create_index(es, table)
bulk_documents = []
for entity_key, values, timestamp, created_ts in data:
id_val = self._get_value_from_value_proto(entity_key.entity_values[0])
document = {entity_key.join_keys[0]: id_val}
for feature_name, val in values.items():
document[feature_name] = self._get_value_from_value_proto(val)
bulk_documents.append(
{"_index": table.name, "_id": id_val, "doc": document}
)

helpers.bulk(client=es, actions=bulk_documents)
es.indices.refresh(index=table.name)

def online_read(
self,
Expand Down Expand Up @@ -168,3 +183,22 @@ def _get_data_type(self, t: FeastType) -> str:
if isinstance(t, ComplexFeastType):
return "text"
return TYPE_MAPPING.get(t, "text")

def _get_value_from_value_proto(self, proto: ValueProto):
"""
Get the raw value from a value proto.
Parameters:
value (ValueProto): the value proto that contains the data.
Returns:
value (Any): the extracted value.
"""
val_type = proto.WhichOneof("val")
value = getattr(proto, val_type) # type: ignore
if val_type == "bytes_val":
value = base64.b64encode(value).decode()
if val_type == "float_list_val":
value = list(value.val)

return value
126 changes: 104 additions & 22 deletions sdk/python/tests/expediagroup/test_elasticsearch_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,6 @@ def setup_method(self, repo_config):

yield

def create_n_customer_test_samples_elasticsearch_online_read(self, n=10):
return [
(
EntityKeyProto(
join_keys=["film_id"],
entity_values=[ValueProto(int64_val=i)],
),
{
"films": ValueProto(
float_list_val=FloatList(
val=[random.random() for _ in range(2)]
)
),
"film_date": ValueProto(int64_val=n),
"film_id": ValueProto(int64_val=n),
},
datetime.utcnow(),
None,
)
for i in range(n)
]

@pytest.mark.parametrize("index_params", index_param_list)
def test_elasticsearch_update_add_index(self, repo_config, caplog, index_params):
dimensions = 16
Expand Down Expand Up @@ -307,6 +285,28 @@ def test_elasticsearch_update_delete_unavailable_index(self, repo_config, caplog
with ElasticsearchConnectionManager(repo_config.online_store) as es:
assert es.indices.exists(index=self.index_to_delete).body is False

def test_elasticsearch_online_write_batch(self, repo_config, caplog):
total_rows_to_write = 100
(
feature_view,
data,
) = self._create_n_customer_test_samples_elasticsearch_online_read(
n=total_rows_to_write
)
ElasticsearchOnlineStore().online_write_batch(
config=repo_config.online_store,
table=feature_view,
data=data,
progress=None,
)

with ElasticsearchConnectionManager(repo_config.online_store) as es:
res = es.cat.count(index=self.index_to_write, params={"format": "json"})
assert res[0]["count"] == "100"
doc = es.get(index=self.index_to_write, id="0")["_source"]["doc"]
for feature in feature_view.schema:
assert feature.name in doc

def _create_index_in_es(self, index_name, repo_config):
with ElasticsearchConnectionManager(repo_config.online_store) as es:
mapping = {
Expand All @@ -321,3 +321,85 @@ def _create_index_in_es(self, index_name, repo_config):
}
}
es.indices.create(index=index_name, mappings=mapping)

def _create_n_customer_test_samples_elasticsearch_online_read(self, n=10):
fv = FeatureView(
name=self.index_to_write,
source=SOURCE,
entities=[Entity(name="id")],
schema=[
Field(
name="vector",
dtype=Array(Float32),
tags={
"description": "float32",
"dimensions": "10",
"index_type": "HNSW",
},
),
Field(
name="id",
dtype=String,
),
Field(
name="text",
dtype=String,
),
Field(
name="int",
dtype=Int32,
),
Field(
name="long",
dtype=Int64,
),
Field(
name="float",
dtype=Float32,
),
Field(
name="double",
dtype=Float64,
),
Field(
name="binary",
dtype=Bytes,
),
Field(
name="bool",
dtype=Bool,
),
Field(
name="timestamp",
dtype=UnixTimestamp,
),
],
)
return fv, [
(
EntityKeyProto(
join_keys=["id"],
entity_values=[ValueProto(string_val=str(i))],
),
{
"vector": ValueProto(
float_list_val=FloatList(
val=[random.random() for _ in range(10)]
)
),
"text": ValueProto(string_val="text"),
"int": ValueProto(int32_val=n),
"long": ValueProto(int64_val=n),
"float": ValueProto(float_val=n * 0.3),
"double": ValueProto(double_val=n * 1.2),
"binary": ValueProto(bytes_val=b"binary"),
"bool": ValueProto(bool_val=True),
"timestamp": ValueProto(
unix_timestamp_val=int(datetime.utcnow().timestamp() * 1000)
),
},
datetime.utcnow(),
None,
)
for i in range(n)
]

0 comments on commit 396fd50

Please sign in to comment.