Skip to content

Commit

Permalink
Start executor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 1, 2025
1 parent a7a4094 commit 9aa4231
Show file tree
Hide file tree
Showing 66 changed files with 1,686 additions and 3,052 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add create events waiting on db apply.
- Refactor secrets loading method.
- Add db.load in db wait
- Deprecate "free" queries

#### New Features & Functionality

Expand All @@ -37,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add a standalone flag in Streamlit to mark the page as independent.
- Add secrets directory mount for loading secret env vars.
- Remove components recursively
- Enforce strict and developer friendly query developer contract

#### Bug Fixes

Expand Down
3 changes: 1 addition & 2 deletions plugins/ibis/superduper_ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .data_backend import IbisDataBackend as DataBackend
from .query import IbisQuery as Query

__version__ = "0.4.7"

__all__ = ["Query", "DataBackend"]
__all__ = ["DataBackend"]
230 changes: 114 additions & 116 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os
import uuid
import typing as t
from warnings import warn

Expand All @@ -11,18 +12,16 @@
from superduper import CFG, logging
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.base.query import Query, QueryPart
from superduper.backends.local.artifacts import FileSystemArtifactStore
from superduper.base import exceptions
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.components.table import Table

from superduper_ibis.db_helper import get_db_helper
from superduper_ibis.field_types import FieldType, dtype
from superduper_ibis.query import IbisQuery
from superduper_ibis.utils import convert_schema_to_fields

BASE64_PREFIX = "base64:"
# TODO make this a global variable in main project
INPUT_KEY = "_source"


Expand Down Expand Up @@ -95,11 +94,17 @@ def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None):

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}

if uri.startswith('snowflake://'):
if uri.startswith('snowflake://') or uri.startswith('clickhouse://'):
self.bytes_encoding = 'base64'
self.datatype_presets = {
self.datatype_presets.update({
'vector': 'superduper.components.datatype.NativeVector'
}
})

def random_id(self):
return str(uuid.uuid4())

def to_id(self, id):
return str(id)

def _setup(self, conn):
self.dialect = getattr(conn, "name", "base")
Expand Down Expand Up @@ -135,116 +140,9 @@ def build_metadata(self):
logging.warn(f"Falling back to using the uri: {self.uri}.")
return MetaDataStoreProxy(SQLAlchemyMetadata(uri=self.uri))

def insert(self, table_name, raw_documents):
"""Insert data into the database.
:param table_name: The name of the table.
:param raw_documents: The data to insert.
"""
for doc in raw_documents:
for k, v in doc.items():
doc[k] = self.db_helper.convert_data_format(v)
table_name, raw_documents = self.db_helper.process_before_insert(
table_name,
raw_documents,
self.conn,
)
if not self.in_memory:
self.conn.insert(table_name, raw_documents)
else:
# CAUTION: The following is only tested with pandas.
if table_name in self.conn.tables:
t = self.conn.tables[table_name]
df = pandas.concat([t.to_pandas(), raw_documents])
self.conn.create_table(table_name, df, overwrite=True)
else:
df = pandas.DataFrame(raw_documents)
self.conn.create_table(table_name, df)

if self.conn.backend_table_type == DataFrame:
df.to_csv(os.path.join(self.name, table_name + ".csv"), index=False)

def check_ready_ids(
self, query: IbisQuery, keys: t.List[str], ids: t.Optional[t.List[t.Any]] = None
):
"""Check if all the keys are ready in the ids.
:param query: The query object.
:param keys: The keys to check.
:param ids: The ids to check.
"""
if ids:
query = query.filter(query[query.primary_id].isin(ids))
conditions = []
for key in keys:
conditions.append(query[key].notnull())

# TODO: Hotfix, will be removed by the refactor PR
try:
docs = query.filter(*conditions).select(query.primary_id).execute()
except Exception as e:
if "Table not found" in str(e) or "Can't find table" in str(e):
return []
else:
raise e
ready_ids = [doc[query.primary_id] for doc in docs]
self._log_check_ready_ids_message(ids, ready_ids)
return ready_ids

def drop_outputs(self):
def drop_table(self, table):
"""Drop the outputs."""
for table in self.conn.list_tables():
logging.info(f"Dropping table: {table}")
if CFG.output_prefix in table:
self.conn.drop_table(table)

def drop_table_or_collection(self, name: str):
"""Drop the table or collection.
Please use with caution as you will lose all data.
:param name: Table name to drop.
"""
try:
return self.db.databackend.conn.drop_table(name)
except Exception as e:
msg = "Object found is of type 'VIEW'"
if msg in str(e):
return self.db.databackend.conn.drop_view(name)
raise

def create_output_dest(
self,
predict_id: str,
datatype: t.Union[FieldType, BaseDataType],
flatten: bool = False,
):
"""Create a table for the output of the model.
:param predict_id: The identifier of the prediction.
:param datatype: The data type of the output.
:param flatten: Whether to flatten the output.
"""
# TODO: Support output schema
msg = (
"Model must have an encoder to create with the"
f" {type(self).__name__} backend."
)
assert datatype is not None, msg
if isinstance(datatype, FieldType):
output_type = dtype(datatype.identifier)
else:
output_type = datatype

fields = {
INPUT_KEY: "string",
"_source": "string",
"id": "string",
f"{CFG.output_prefix}{predict_id}": output_type,
}
return Table(
identifier=f"{CFG.output_prefix}{predict_id}",
schema=Schema(identifier=f"_schema/{predict_id}", fields=fields),
)
self.conn.drop_table(table)

def check_output_dest(self, predict_id) -> bool:
"""Check if the output destination exists.
Expand Down Expand Up @@ -310,3 +208,103 @@ def disconnect(self):
def list_tables(self):
"""List all tables or collections in the database."""
return self.conn.list_tables()

def insert(self, table_name, raw_documents):
"""Insert data into the database.
:param table_name: The name of the table.
:param raw_documents: The data to insert.
"""
for doc in raw_documents:
for k, v in doc.items():
doc[k] = self.db_helper.convert_data_format(v)

table_name, raw_documents = self.db_helper.process_before_insert(
table_name,
raw_documents,
self.conn,
)
if not self.in_memory:
self.conn.insert(table_name, raw_documents)
else:
# CAUTION: The following is only tested with pandas.
if table_name in self.conn.tables:
t = self.conn.tables[table_name]
df = pandas.concat([t.to_pandas(), raw_documents])
self.conn.create_table(table_name, df, overwrite=True)
else:
df = pandas.DataFrame(raw_documents)
self.conn.create_table(table_name, df)

if self.conn.backend_table_type == DataFrame:
df.to_csv(os.path.join(self.name, table_name + ".csv"), index=False)

def insert(self, table, documents):
primary_id = self.primary_id(self.db[table])
for r in documents:
if primary_id not in r:
r[primary_id] = str(uuid.uuid4())
ids = [r[primary_id] for r in documents]
self.conn.insert(table, documents)
return ids

def missing_outputs(self, query, predict_id: str) -> t.List[str]:
query = self._build_native_query(query)
pid = self.primary_id(query)
output_table = self.conn.table(f"{CFG.output_prefix}{predict_id}")
q = query.anti_join(output_table, output_table['_source'] == query[pid])
return q.execute().to_dict(orient='records')

def primary_id(self, query):
return self.db.load('table', query.table).primary_id

def select(self, query):
native_query = self._build_native_query(query)
return native_query.execute().to_dict(orient='records')

def _build_native_query(self, query):

q = self.conn.table(query.table)
pid = None

for part in query.parts:

if isinstance(part, QueryPart) and part.name != 'outputs':
args = []
for a in part.args:
if isinstance(a, Query):
args.append(self._build_native_query(a))
else:
args.append(a)
kwargs = {}
for k, v in part.kwargs.items():
if isinstance(v, Query):
kwargs[k] = self._build_native_query(v)
else:
kwargs[k] = v
if part.name == 'select' and len(args) == 0:
pass
else:
q = getattr(q, part.name)(*args, **kwargs)

elif isinstance(part, QueryPart) and part.name == 'outputs':
if pid is None:
pid = self.primary_id(query)

original_q = q
for predict_id in part.args:
output_t = self.conn.table(f"{CFG.output_prefix}{predict_id}")
q = q.join(
output_t, output_t['_source'] == original_q[pid]
)

elif isinstance(part, str):
if part == 'primary_id':
if pid is None:
pid = self.primary_id(query)
part = pid
q = q[part]
else:
raise ValueError(f'Unknown query part: {part}')
return q

1 change: 1 addition & 0 deletions plugins/ibis/superduper_ibis/db_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO remove, no longer relevant
import base64
import collections

Expand Down
Loading

0 comments on commit 9aa4231

Please sign in to comment.