Skip to content

Commit

Permalink
Start executor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 2, 2025
1 parent a7a4094 commit c23916a
Show file tree
Hide file tree
Showing 83 changed files with 1,776 additions and 3,266 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add create events waiting on db apply.
- Refactor secrets loading method.
- Add db.load in db wait
- Deprecate "free" queries

#### New Features & Functionality

Expand All @@ -37,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add a standalone flag in Streamlit to mark the page as independent.
- Add secrets directory mount for loading secret env vars.
- Remove components recursively
- Enforce strict and developer friendly query developer contract

#### Bug Fixes

Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DIRECTORIES ?= superduper test
DIRECTORIES ?= superduper test plugins
SUPERDUPER_CONFIG ?= test/configs/default.yaml
PYTEST_ARGUMENTS ?=
PLUGIN_NAME ?=
Expand Down
12 changes: 3 additions & 9 deletions plugins/ibis/plugin_test/test_databackend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from superduper.misc.plugins import load_plugin
from test.utils.database import databackend as db_utils

import pytest
Expand All @@ -8,18 +9,11 @@

@pytest.fixture
def databackend():
backend = IbisDataBackend(CFG.data_backend)
plugin = load_plugin('ibis')
backend = IbisDataBackend(CFG.data_backend, plugin=plugin)
yield backend
backend.drop(True)


def test_output_dest(databackend):
db_utils.test_output_dest(databackend)


def test_query_builder(databackend):
db_utils.test_query_builder(databackend)


def test_list_tables_or_collections(databackend):
db_utils.test_list_tables_or_collections(databackend)
39 changes: 0 additions & 39 deletions plugins/ibis/plugin_test/test_end_2_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,42 +141,3 @@ def postprocess(x):
# Get the results
result = list(db.execute(q))
assert listener2.outputs in result[0].unpack()


def test_nested_query(db):
memory_table = False
if CFG.data_backend.endswith("csv"):
memory_table = True
schema = Schema(
identifier="my_table",
fields={
"id": FieldType(identifier="int64"),
"health": FieldType(identifier="int32"),
"age": FieldType(identifier="int32"),
},
)

from superduper.components.table import Table

t = Table(identifier="my_table", schema=schema)

db.apply(t)

t = db["my_table"]
q = t.filter(t.age >= 10)

expr_ = q.compile(db)

if not memory_table:
assert 'WHERE "t0"."age" >=' in str(expr_)
else:
pass
# TODO this doesn't test anything useful and
# is sensitive to version changes
# TODO refactor/ remove
# assert 'Selection[r0]\n predicates:\n r0.age >= 10' in str(expr_)
# assert (
# 'my_table\n _fold string\n id '
# 'int64\n health int32\n age '
# 'int32\n image binary' in str(expr_)
# )
66 changes: 17 additions & 49 deletions plugins/ibis/plugin_test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,17 @@ def test_renamings(db):
add_listeners(db)
t = db["documents"]
listener_uuid = [db.load('listener', k).outputs for k in db.show("listener")][0]
q = t.select("id", "x", "y").outputs(listener_uuid)
data = list(db.execute(q))
q = t.select("id", "x", "y").outputs(listener_uuid.split('__', 1)[-1])
data = q.execute()
assert isinstance(data[0].unpack()[listener_uuid], np.ndarray)


def test_serialize_query(db):
from superduper_ibis.query import IbisQuery

t = IbisQuery(db=db, table="documents", parts=[("select", ("id",), {})])
t = db['documents']
q = t.filter(t['id'] == 1).select('id', 'x')

q = t.filter(t.id == 1).select(t.id, t.x)

print(Document.decode(q.encode()).unpack())


def test_add_fold(db):
add_random_data(db, n=10)
table = db["documents"]
select_train = table.select("id", "x", "_fold").add_fold("train")
result_train = db.execute(select_train)

select_valid = table.select("id", "x", "_fold").add_fold("valid")
result_valid = db.execute(select_valid)
result_train = list(result_train)
result_valid = list(result_valid)
assert len(result_train) + len(result_valid) == 10
print(Document.decode(q.encode(), db=db).unpack())


def test_get_data(db):
Expand All @@ -88,7 +73,7 @@ def test_get_data(db):
def test_insert_select(db):
add_random_data(db, n=5)
q = db["documents"].select("id", "x", "y").limit(2)
r = list(db.execute(q))
r = q.execute()

assert len(r) == 2
assert all(all([k in ["id", "x", "y"] for k in x.unpack().keys()]) for x in r)
Expand All @@ -98,43 +83,27 @@ def test_filter(db):
add_random_data(db, n=5)
t = db["documents"]
q = t.select("id", "y")
r = list(db.execute(q))
r = q.execute()
ys = [x["y"] for x in r]
uq = np.unique(ys, return_counts=True)

q = t.select("id", "y").filter(t.y == uq[0][0])
r = list(db.execute(q))
q = t.select("id", "y").filter(t['y'] == uq[0][0])
r = q.execute()
assert len(r) == uq[1][0]


def test_execute_complex_query_sqldb_auto_schema(db):
import ibis

db.cfg.auto_schema = True

table = db["documents"]
table.insert(
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(100)]
).execute()

cur = table.select("this").order_by(ibis.desc("this")).limit(10).execute(db)
expected = [f"is a test {i}" for i in range(99, 89, -1)]
cur_this = [r["this"] for r in cur]
assert sorted(cur_this) == sorted(expected)


def test_select_using_ids(db):
db.cfg.auto_schema = True

table = db["documents"]
table.insert(
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)]
).execute()
[{"this": f"is a test {i}", "id": str(i)} for i in range(4)]
)

basic_select = db['documents'].select()

assert len(basic_select.tolist()) == 4
assert len(basic_select.select_using_ids(['1', '2']).tolist()) == 2
assert len(basic_select.execute()) == 4
assert len(basic_select.subset(['1', '2'])) == 2


def test_select_using_ids_of_outputs(db):
Expand All @@ -148,20 +117,19 @@ def my_func(x):

table = db["documents"]
table.insert(
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)]
).execute()
[{"this": f"is a test {i}", "id": str(i)} for i in range(4)]
)

listener = my_func.to_listener(key='this', select=db['documents'].select())
db.apply(listener)

q1 = db[listener.outputs].select()
r1 = q1.tolist()
r1 = q1.execute()

assert len(r1) == 4

ids = [x['id'] for x in r1]

q2 = q1.select_using_ids(ids[:2])
r2 = q2.tolist()
r2 = q1.subset(ids[:2])

assert len(r2) == 2
3 changes: 1 addition & 2 deletions plugins/ibis/superduper_ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .data_backend import IbisDataBackend as DataBackend
from .query import IbisQuery as Query

__version__ = "0.4.7"

__all__ = ["Query", "DataBackend"]
__all__ = ["DataBackend"]
Loading

0 comments on commit c23916a

Please sign in to comment.