diff --git a/CHANGELOG.md b/CHANGELOG.md index b7b041d9f..d02739bea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add create events waiting on db apply. - Refactor secrets loading method. - Add db.load in db wait +- Deprecate "free" queries #### New Features & Functionality @@ -37,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add a standalone flag in Streamlit to mark the page as independent. - Add secrets directory mount for loading secret env vars. - Remove components recursively +- Enforce strict and developer friendly query developer contract #### Bug Fixes diff --git a/Makefile b/Makefile index 63a9ca1fc..d5cc2be00 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -DIRECTORIES ?= superduper test +DIRECTORIES ?= superduper test plugins SUPERDUPER_CONFIG ?= test/configs/default.yaml PYTEST_ARGUMENTS ?= PLUGIN_NAME ?= diff --git a/plugins/ibis/plugin_test/test_databackend.py b/plugins/ibis/plugin_test/test_databackend.py index 87551cdae..6a5e7dee2 100644 --- a/plugins/ibis/plugin_test/test_databackend.py +++ b/plugins/ibis/plugin_test/test_databackend.py @@ -2,24 +2,18 @@ import pytest from superduper import CFG +from superduper.misc.plugins import load_plugin from superduper_ibis.data_backend import IbisDataBackend @pytest.fixture def databackend(): - backend = IbisDataBackend(CFG.data_backend) + plugin = load_plugin('ibis') + backend = IbisDataBackend(CFG.data_backend, plugin=plugin) yield backend backend.drop(True) -def test_output_dest(databackend): - db_utils.test_output_dest(databackend) - - -def test_query_builder(databackend): - db_utils.test_query_builder(databackend) - - def test_list_tables_or_collections(databackend): db_utils.test_list_tables_or_collections(databackend) diff --git a/plugins/ibis/plugin_test/test_end_2_end.py b/plugins/ibis/plugin_test/test_end_2_end.py index 3476a37c5..5926fe70f 100644 --- a/plugins/ibis/plugin_test/test_end_2_end.py +++ b/plugins/ibis/plugin_test/test_end_2_end.py @@ -141,42 +141,3 @@ def postprocess(x): # Get the results result = list(db.execute(q)) assert listener2.outputs in result[0].unpack() - - -def test_nested_query(db): - memory_table = False - if CFG.data_backend.endswith("csv"): - memory_table = True - schema = Schema( - identifier="my_table", - fields={ - "id": FieldType(identifier="int64"), - "health": FieldType(identifier="int32"), - "age": FieldType(identifier="int32"), - }, - ) - - from superduper.components.table import Table - - t = Table(identifier="my_table", schema=schema) - - db.apply(t) - - t = db["my_table"] - q = t.filter(t.age >= 10) - - expr_ = q.compile(db) - - if not memory_table: - assert 'WHERE "t0"."age" >=' in str(expr_) - else: - pass - # TODO this doesn't test anything useful and - # is sensitive to version changes - # TODO refactor/ remove - # assert 'Selection[r0]\n predicates:\n r0.age >= 10' in str(expr_) - # assert ( - # 'my_table\n _fold string\n id ' - # 'int64\n health int32\n age ' - # 'int32\n image binary' in str(expr_) - # ) diff --git a/plugins/ibis/plugin_test/test_query.py b/plugins/ibis/plugin_test/test_query.py index 7a8c1bc92..7f7907d6d 100644 --- a/plugins/ibis/plugin_test/test_query.py +++ b/plugins/ibis/plugin_test/test_query.py @@ -51,32 +51,16 @@ def test_renamings(db): add_listeners(db) t = db["documents"] listener_uuid = [db.load('listener', k).outputs for k in db.show("listener")][0] - q = t.select("id", "x", "y").outputs(listener_uuid) - data = list(db.execute(q)) + q = t.select("id", "x", "y").outputs(listener_uuid.split('__', 1)[-1]) + data = q.execute() assert isinstance(data[0].unpack()[listener_uuid], np.ndarray) def test_serialize_query(db): - from superduper_ibis.query import IbisQuery + t = db['documents'] + q = t.filter(t['id'] == 1).select('id', 'x') - t = IbisQuery(db=db, table="documents", parts=[("select", ("id",), {})]) - - q = t.filter(t.id == 1).select(t.id, t.x) - - print(Document.decode(q.encode()).unpack()) - - -def test_add_fold(db): - add_random_data(db, n=10) - table = db["documents"] - select_train = table.select("id", "x", "_fold").add_fold("train") - result_train = db.execute(select_train) - - select_valid = table.select("id", "x", "_fold").add_fold("valid") - result_valid = db.execute(select_valid) - result_train = list(result_train) - result_valid = list(result_valid) - assert len(result_train) + len(result_valid) == 10 + print(Document.decode(q.encode(), db=db).unpack()) def test_get_data(db): @@ -88,7 +72,7 @@ def test_get_data(db): def test_insert_select(db): add_random_data(db, n=5) q = db["documents"].select("id", "x", "y").limit(2) - r = list(db.execute(q)) + r = q.execute() assert len(r) == 2 assert all(all([k in ["id", "x", "y"] for k in x.unpack().keys()]) for x in r) @@ -98,43 +82,25 @@ def test_filter(db): add_random_data(db, n=5) t = db["documents"] q = t.select("id", "y") - r = list(db.execute(q)) + r = q.execute() ys = [x["y"] for x in r] uq = np.unique(ys, return_counts=True) - q = t.select("id", "y").filter(t.y == uq[0][0]) - r = list(db.execute(q)) + q = t.select("id", "y").filter(t['y'] == uq[0][0]) + r = q.execute() assert len(r) == uq[1][0] -def test_execute_complex_query_sqldb_auto_schema(db): - import ibis - - db.cfg.auto_schema = True - - table = db["documents"] - table.insert( - [Document({"this": f"is a test {i}", "id": str(i)}) for i in range(100)] - ).execute() - - cur = table.select("this").order_by(ibis.desc("this")).limit(10).execute(db) - expected = [f"is a test {i}" for i in range(99, 89, -1)] - cur_this = [r["this"] for r in cur] - assert sorted(cur_this) == sorted(expected) - - def test_select_using_ids(db): db.cfg.auto_schema = True table = db["documents"] - table.insert( - [Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)] - ).execute() + table.insert([{"this": f"is a test {i}", "id": str(i)} for i in range(4)]) basic_select = db['documents'].select() - assert len(basic_select.tolist()) == 4 - assert len(basic_select.select_using_ids(['1', '2']).tolist()) == 2 + assert len(basic_select.execute()) == 4 + assert len(basic_select.subset(['1', '2'])) == 2 def test_select_using_ids_of_outputs(db): @@ -147,21 +113,18 @@ def my_func(x): db.cfg.auto_schema = True table = db["documents"] - table.insert( - [Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)] - ).execute() + table.insert([{"this": f"is a test {i}", "id": str(i)} for i in range(4)]) listener = my_func.to_listener(key='this', select=db['documents'].select()) db.apply(listener) q1 = db[listener.outputs].select() - r1 = q1.tolist() + r1 = q1.execute() assert len(r1) == 4 ids = [x['id'] for x in r1] - q2 = q1.select_using_ids(ids[:2]) - r2 = q2.tolist() + r2 = q1.subset(ids[:2]) assert len(r2) == 2 diff --git a/plugins/ibis/superduper_ibis/__init__.py b/plugins/ibis/superduper_ibis/__init__.py index 92a43d292..feeb35d12 100644 --- a/plugins/ibis/superduper_ibis/__init__.py +++ b/plugins/ibis/superduper_ibis/__init__.py @@ -1,6 +1,5 @@ from .data_backend import IbisDataBackend as DataBackend -from .query import IbisQuery __version__ = "0.4.7" -__all__ = ["IbisQuery", "DataBackend"] +__all__ = ["DataBackend"] diff --git a/plugins/ibis/superduper_ibis/data_backend.py b/plugins/ibis/superduper_ibis/data_backend.py index d1174291c..095d482c2 100644 --- a/plugins/ibis/superduper_ibis/data_backend.py +++ b/plugins/ibis/superduper_ibis/data_backend.py @@ -1,29 +1,26 @@ import glob import os import typing as t +import uuid from warnings import warn import click import ibis import pandas -from pandas.core.frame import DataFrame from sqlalchemy.exc import NoSuchTableError from superduper import CFG, logging from superduper.backends.base.data_backend import BaseDataBackend from superduper.backends.base.metadata import MetaDataStoreProxy +from superduper.backends.base.query import Query, QueryPart from superduper.backends.local.artifacts import FileSystemArtifactStore from superduper.base import exceptions -from superduper.base.enums import DBType -from superduper.components.datatype import BaseDataType from superduper.components.schema import Schema -from superduper.components.table import Table from superduper_ibis.db_helper import get_db_helper -from superduper_ibis.field_types import FieldType, dtype -from superduper_ibis.query import IbisQuery from superduper_ibis.utils import convert_schema_to_fields BASE64_PREFIX = "base64:" +# TODO make this a global variable in main project INPUT_KEY = "_source" @@ -84,12 +81,10 @@ class IbisDataBackend(BaseDataBackend): :param flavour: Flavour of the databackend. """ - db_type = DBType.SQL - - def __init__(self, uri: str, flavour: t.Optional[str] = None): + def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None): self.connection_callback = lambda: _connection_callback(uri, flavour) conn, name, in_memory = self.connection_callback() - super().__init__(uri=uri, flavour=flavour) + super().__init__(uri=uri, flavour=flavour, plugin=plugin) self.conn = conn self.name = name self.in_memory = in_memory @@ -98,11 +93,19 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None): self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'} - if uri.startswith('snowflake://'): + if uri.startswith('snowflake://') or uri.startswith('clickhouse://'): self.bytes_encoding = 'base64' - self.datatype_presets = { - 'vector': 'superduper.components.datatype.NativeVector' - } + self.datatype_presets.update( + {'vector': 'superduper.components.datatype.NativeVector'} + ) + + def random_id(self): + """Generate a random ID.""" + return str(uuid.uuid4()) + + def to_id(self, id): + """Convert the ID to a string.""" + return str(id) def _setup(self, conn): self.dialect = getattr(conn, "name", "base") @@ -110,18 +113,10 @@ def _setup(self, conn): def reconnect(self): """Reconnect to the database client.""" - # Reconnect to database. conn, _, _ = self.connection_callback() self.conn = conn self._setup(conn) - def get_query_builder(self, table_name): - """Get the query builder for the data backend. - - :param table_name: Which table to get the query builder for - """ - return IbisQuery(table=table_name, db=self.datalayer) - def url(self): """Get the URL of the database.""" return self.conn.con.url + self.name @@ -146,116 +141,9 @@ def build_metadata(self): logging.warn(f"Falling back to using the uri: {self.uri}.") return MetaDataStoreProxy(SQLAlchemyMetadata(uri=self.uri)) - def insert(self, table_name, raw_documents): - """Insert data into the database. - - :param table_name: The name of the table. - :param raw_documents: The data to insert. - """ - for doc in raw_documents: - for k, v in doc.items(): - doc[k] = self.db_helper.convert_data_format(v) - table_name, raw_documents = self.db_helper.process_before_insert( - table_name, - raw_documents, - self.conn, - ) - if not self.in_memory: - self.conn.insert(table_name, raw_documents) - else: - # CAUTION: The following is only tested with pandas. - if table_name in self.conn.tables: - t = self.conn.tables[table_name] - df = pandas.concat([t.to_pandas(), raw_documents]) - self.conn.create_table(table_name, df, overwrite=True) - else: - df = pandas.DataFrame(raw_documents) - self.conn.create_table(table_name, df) - - if self.conn.backend_table_type == DataFrame: - df.to_csv(os.path.join(self.name, table_name + ".csv"), index=False) - - def check_ready_ids( - self, query: IbisQuery, keys: t.List[str], ids: t.Optional[t.List[t.Any]] = None - ): - """Check if all the keys are ready in the ids. - - :param query: The query object. - :param keys: The keys to check. - :param ids: The ids to check. - """ - if ids: - query = query.filter(query[query.primary_id].isin(ids)) - conditions = [] - for key in keys: - conditions.append(query[key].notnull()) - - # TODO: Hotfix, will be removed by the refactor PR - try: - docs = query.filter(*conditions).select(query.primary_id).execute() - except Exception as e: - if "Table not found" in str(e) or "Can't find table" in str(e): - return [] - else: - raise e - ready_ids = [doc[query.primary_id] for doc in docs] - self._log_check_ready_ids_message(ids, ready_ids) - return ready_ids - - def drop_outputs(self): + def drop_table(self, table): """Drop the outputs.""" - for table in self.conn.list_tables(): - logging.info(f"Dropping table: {table}") - if CFG.output_prefix in table: - self.conn.drop_table(table) - - def drop_table_or_collection(self, name: str): - """Drop the table or collection. - - Please use with caution as you will lose all data. - :param name: Table name to drop. - """ - try: - return self.db.databackend.conn.drop_table(name) - except Exception as e: - msg = "Object found is of type 'VIEW'" - if msg in str(e): - return self.db.databackend.conn.drop_view(name) - raise - - def create_output_dest( - self, - predict_id: str, - datatype: t.Union[FieldType, BaseDataType], - flatten: bool = False, - ): - """Create a table for the output of the model. - - :param predict_id: The identifier of the prediction. - :param datatype: The data type of the output. - :param flatten: Whether to flatten the output. - """ - # TODO: Support output schema - msg = ( - "Model must have an encoder to create with the" - f" {type(self).__name__} backend." - ) - assert datatype is not None, msg - if isinstance(datatype, FieldType): - output_type = dtype(datatype.identifier) - else: - output_type = datatype - - fields = { - INPUT_KEY: "string", - "_source": "string", - "id": "string", - f"{CFG.output_prefix}{predict_id}": output_type, - } - return Table( - identifier=f"{CFG.output_prefix}{predict_id}", - schema=Schema(identifier=f"_schema/{predict_id}", fields=fields), - ) + self.conn.drop_table(table) def check_output_dest(self, predict_id) -> bool: """Check if the output destination exists. @@ -302,7 +190,7 @@ def drop(self, force: bool = False): logging.info(f"Dropping table: {table}") self.conn.drop_table(table) - def get_table_or_collection(self, identifier): + def get_table(self, identifier): """Get a table or collection from the database. :param identifier: The identifier of the table or collection. @@ -316,21 +204,98 @@ def get_table_or_collection(self, identifier): def disconnect(self): """Disconnect the client.""" - # TODO: implement me - def list_tables_or_collections(self): + def list_tables(self): """List all tables or collections in the database.""" return self.conn.list_tables() - @staticmethod - def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): - """Infer a schema from a given data object. - - :param data: The data object - :param identifier: The identifier for the schema, if None, it will be generated - :return: The inferred schema - """ - from superduper.misc.auto_schema import infer_schema + def insert(self, table, documents): + """Insert data into the database.""" + primary_id = self.primary_id(self.db[table]) + for r in documents: + if primary_id not in r: + r[primary_id] = str(uuid.uuid4()) + ids = [r[primary_id] for r in documents] + self.conn.insert(table, documents) + return ids + + def missing_outputs(self, query, predict_id: str) -> t.List[str]: + """Get missing outputs from the database.""" + query = self._build_native_query(query) + pid = self.primary_id(query) + output_table = self.conn.table(f"{CFG.output_prefix}{predict_id}") + q = query.anti_join(output_table, output_table['_source'] == query[pid]) + return q.execute().to_dict(orient='records') + + def primary_id(self, query): + """Get the primary ID of the query.""" + return self.db.load('table', query.table).primary_id + + def select(self, query): + """Select data from the database.""" + native_query = self._build_native_query(query) + return native_query.execute().to_dict(orient='records') + + def _build_native_query(self, query): + q = self.conn.table(query.table) + pid = None + predict_ids = ( + query.decomposition.outputs.args if query.decomposition.outputs else [] + ) - return infer_schema(data, identifier=identifier) + for part in query.parts: + if isinstance(part, QueryPart) and part.name != 'outputs': + args = [] + for a in part.args: + if isinstance(a, Query) and str(a).endswith('.primary_id'): + args.append(self.primary_id(query)) + elif isinstance(a, Query): + args.append(self._build_native_query(a)) + else: + args.append(a) + + kwargs = {} + for k, v in part.kwargs.items(): + if isinstance(a, Query) and str(a).endswith('.primary_id'): + args.append(self.primary_id(query)) + elif isinstance(v, Query): + kwargs[k] = self._build_native_query(v) + else: + kwargs[k] = v + + if part.name == 'select' and len(args) == 0: + pass + + else: + if part.name == 'select' and predict_ids and args: + args.extend( + [ + f'{CFG.output_prefix}{pid}' + for pid in predict_ids + if f'{CFG.output_prefix}{pid}' not in args + ] + ) + args = list(set(args)) + q = getattr(q, part.name)(*args, **kwargs) + + elif isinstance(part, QueryPart) and part.name == 'outputs': + if pid is None: + pid = self.primary_id(query) + + original_q = q + for predict_id in part.args: + output_t = self.conn.table( + f"{CFG.output_prefix}{predict_id}" + ).select(f"{CFG.output_prefix}{predict_id}", "_source") + q = q.join(output_t, output_t['_source'] == original_q[pid]) + + elif isinstance(part, str): + if part == 'primary_id': + if pid is None: + pid = self.primary_id(query) + part = pid + q = q[part] + else: + raise ValueError(f'Unknown query part: {part}') + return q diff --git a/plugins/ibis/superduper_ibis/db_helper.py b/plugins/ibis/superduper_ibis/db_helper.py index 21e7867f3..90be21e71 100644 --- a/plugins/ibis/superduper_ibis/db_helper.py +++ b/plugins/ibis/superduper_ibis/db_helper.py @@ -1,3 +1,4 @@ +# TODO remove, no longer relevant import base64 import collections diff --git a/plugins/ibis/superduper_ibis/query.py b/plugins/ibis/superduper_ibis/query.py deleted file mode 100644 index 80ef46946..000000000 --- a/plugins/ibis/superduper_ibis/query.py +++ /dev/null @@ -1,402 +0,0 @@ -import typing as t -import uuid -from collections import defaultdict - -import pandas -from superduper import CFG, Document -from superduper.backends.base.query import ( - Query, - applies_to, - parse_query as _parse_query, -) -from superduper.base.cursor import SuperDuperCursor -from superduper.base.exceptions import DatabackendException -from superduper.components.datatype import _Encodable -from superduper.components.schema import Schema -from superduper.misc.special_dicts import SuperDuperFlatEncode - -if t.TYPE_CHECKING: - from superduper.base.datalayer import Datalayer - - -def parse_query( - query, documents: t.Sequence[t.Dict] = (), db: t.Optional["Datalayer"] = None -): - """Parse a string query into a query object. - - :param query: The query to parse. - :param documents: The documents to query. - :param db: The datalayer to use to execute the query. - """ - return _parse_query( - query=query, - documents=list(documents), - builder_cls=IbisQuery, - db=db, - ) - - -def _load_keys_with_blob(output): - if isinstance(output, SuperDuperFlatEncode): - return output.load_keys_with_blob() - elif isinstance(output, dict): - return SuperDuperFlatEncode(output).load_keys_with_blob() - return output - - -def _model_update_impl_flatten( - db, - ids: t.List[t.Any], - predict_id: str, - outputs: t.Sequence[t.Any], -): - """Flatten the outputs and ids and update the model outputs in the database.""" - flattened_outputs = [] - flattened_ids = [] - for output, id in zip(outputs, ids): - assert isinstance(output, (list, tuple)), "Expected list or tuple" - for o in output: - flattened_outputs.append(o) - flattened_ids.append(id) - - return _model_update_impl( - db=db, - ids=flattened_ids, - predict_id=predict_id, - outputs=flattened_outputs, - ) - - -def _model_update_impl( - db, - ids: t.List[t.Any], - predict_id: str, - outputs: t.Sequence[t.Any], -): - if not outputs: - return - - documents = [] - for output, source_id in zip(outputs, ids): - d = { - "_source": str(source_id), - f"{CFG.output_prefix}{predict_id}": ( - output.x if isinstance(output, _Encodable) else output - ), - "id": str(uuid.uuid4()), - } - documents.append(Document(d)) - return db[f"{CFG.output_prefix}{predict_id}"].insert(documents) - - -class IbisQuery(Query): - """A query that can be executed on an Ibis database.""" - - def __post_init__(self, db=None): - super().__post_init__(db) - self._primary_id = None - self._base_table = None - - @property - def base_table(self): - """Return the base table.""" - if self._base_table is None: - self._base_table = self.db.load('table', self.table) - return self._base_table - - flavours: t.ClassVar[t.Dict[str, str]] = { - "pre_like": r"^.*\.like\(.*\)\.select", - "post_like": r"^.*\.([a-z]+)\(.*\)\.like(.*)$", - "insert": r"^[^\(]+\.insert\(.*\)$", - "filter": r"^[^\(]+\.filter\(.*\)$", - "delete": r"^[^\(]+\.delete\(.*\)$", - "select": r"^[^\(]+\.select\(.*\)$", - "join": r"^.*\.join\(.*\)$", - "anti_join": r"^[^\(]+\.anti_join\(.*\)$", - } - - # Use to control the behavior in the class construction method within LeafMeta - _dataclass_params: t.ClassVar[t.Dict[str, t.Any]] = { - "eq": False, - "order": False, - } - - @property - @applies_to("insert") - def documents(self): - """Return the documents.""" - return super().documents - - def _get_tables(self): - table = self.db.load('table', self.table) - out = {self.table: table} - - for part in self.parts: - if isinstance(part, str): - return out - args = part[1] - for a in args: - if isinstance(a, IbisQuery): - out.update(a._get_tables()) - kwargs = part[2].values() - for v in kwargs: - if isinstance(v, IbisQuery): - out.update(v._get_tables()) - return out - - def _get_schema(self): - fields = {} - tables = self._get_tables() - - table_renamings = self.renamings({}) - if len(tables) == 1 and not table_renamings: - table = self.db.load('table', self.table) - return table.schema - for identifier, c in tables.items(): - renamings = table_renamings.get(identifier, {}) - - tmp = c.schema.fields - to_update = dict( - (renamings[k], v) if k in renamings else (k, v) for k, v in tmp.items() - ) - fields.update(to_update) - - return Schema(f"_tmp:{self.table}", fields=fields, db=self.db) - - def renamings(self, r={}): - """Return the renamings. - - :param r: Renamings. - """ - for part in self.parts: - if isinstance(part, str): - continue - if part[0] == "rename": - r[self.table] = part[1][0] - if part[0] == "relabel": - r[self.table] = part[1][0] - else: - queries = list(part[1]) + list(part[2].values()) - for query in queries: - if isinstance(query, IbisQuery): - query.renamings(r) - return r - - def _execute_select(self, parent): - return self._execute(parent) - - def _execute_insert(self, parent): - documents = self._prepare_documents() - table = self.db.load('table', self.table) - for r in documents: - if table.primary_id not in r: - pid = str(uuid.uuid4()) - r[table.primary_id] = pid - ids = [r[table.primary_id] for r in documents] - self.db.databackend.insert(self.table, raw_documents=documents) - return ids - - def _create_table_if_not_exists(self): - tables = self.db.databackend.list_tables_or_collections() - if self.table in tables: - return - self.db.databackend.create_table_and_schema( - self.table, - self._get_schema(), - ) - - def _execute(self, parent, method="encode"): - q = super()._execute(parent, method=method) - try: - output = q.execute() - except Exception as e: - raise DatabackendException( - f"Error while executing ibis query {self}" - ) from e - - assert isinstance(output, pandas.DataFrame) - output = output.to_dict(orient="records") - component_table = self.db.load('table', self.table) - return SuperDuperCursor( - raw_cursor=output, - db=self.db, - id_field=component_table.primary_id, - schema=self._get_schema(), - ) - - @property - def type(self): - """Return the type of the query.""" - return defaultdict( - lambda: "select", - { - "replace": "update", - "delete": "delete", - "filter": "select", - "insert": "insert", - }, - )[self.flavour] - - @property - def primary_id(self): - """Return the primary id.""" - return self.base_table.primary_id - - def model_update( - self, - ids: t.List[t.Any], - predict_id: str, - outputs: t.Sequence[t.Any], - flatten: bool = False, - **kwargs, - ): - """Update the model outputs in the database. - - :param ids: The ids of the inputs. - :param predict_id: The predict id. - :param outputs: The outputs. - :param flatten: Whether to flatten the outputs. - :param kwargs: Additional keyword arguments. - """ - self.is_output_query = True - self.updated_key = predict_id - - if not flatten: - return _model_update_impl( - db=self.db, - ids=ids, - predict_id=predict_id, - outputs=outputs, - ) - else: - return _model_update_impl_flatten( - db=self.db, - ids=ids, - predict_id=predict_id, - outputs=outputs, - ) - - def add_fold(self, fold: str): - """Return a query that adds a fold. - - :param fold: The fold to add. - """ - return self.filter(self._fold == fold) - - def select_using_ids(self, ids: t.Sequence[str]): - """Return a query that selects using ids. - - :param ids: The ids to select. - """ - filter_query = self.filter(getattr(self, self.primary_id).isin(ids)) - return filter_query - - @property - def select_ids(self): - """Return a query that selects ids.""" - return self.select(self.primary_id) - - def drop_outputs(self, predict_id: str): - """Return a query that removes output corresponding to the predict id. - - :param predict_ids: The ids of the predictions to select. - """ - return self.db.databackend.conn.drop_table(f"{CFG.output_prefix}{predict_id}") - - @applies_to("select") - def outputs(self, *predict_ids): - """Return a query that selects outputs. - - :param predict_ids: The predict ids. - """ - for part in self.parts: - if part[0] == "select": - args = part[1] - assert ( - self.primary_id in args - ), f"Primary id: `{self.primary_id}` not in select when using outputs" - query = self - attr = getattr(query, self.primary_id) - for identifier in predict_ids: - identifier = ( - identifier - if identifier.startswith(CFG.output_prefix) - else f"{CFG.output_prefix}{identifier}" - ) - symbol_table = self.db[identifier] - - symbol_table = symbol_table.relabel( - # TODO: Check for folds - {"_fold": f"fold.{identifier}", "id": f"id.{identifier}"} - ) - query = query.join(symbol_table, symbol_table._source == attr) - return query - - @applies_to("select", "join") - def select_ids_of_missing_outputs(self, predict_id: str): - """Return a query that selects ids of missing outputs. - - :param predict_id: The predict id. - """ - from superduper.base.datalayer import Datalayer - - assert isinstance(self.db, Datalayer) - - output_table = self.db[f"{CFG.output_prefix}{predict_id}"] - return self.anti_join( - output_table, - output_table._source == getattr(self, self.primary_id), - ) - - def select_single_id(self, id: str): - """Return a query that selects a single id. - - :param id: The id to select. - """ - filter_query = eval(f"table.{self.primary_id} == {id}") - return self.filter(filter_query) - - @property - def select_table(self): - """Return a query that selects the table.""" - return self.db[self.table].select() - - def __call__(self, *args, **kwargs): - """Add a method call to the query. - - :param args: The arguments to pass to the method. - :param kwargs: The keyword arguments to pass to the method. - """ - assert isinstance(self.parts[-1], str) - # TODO: Move to _execute - if ( - self.parts[-1] == "select" - and not args - and not self.table.startswith(" 0.5) - z = np.random.rand(32) - data.append( - Document( - { - 'x': x, - 'y': y, - 'z': z, - 'update': update, - } - ) - ) - return data - - -def test_delete_many(db): - add_random_data(db, n=5) - collection = db['documents'] - old_ids = {r['_id'] for r in db.execute(collection.find({}, {'_id': 1}))} - deleted_ids = list(old_ids)[:2] - db.execute(collection.delete_many({'_id': {'$in': deleted_ids}})) - new_ids = {r['_id'] for r in db.execute(collection.find({}, {'_id': 1}))} - assert len(new_ids) == 3 - - assert old_ids - new_ids == set(deleted_ids) - - -def test_replace(db): - add_random_data(db, n=5) - collection = db['documents'] - r = next(db.execute(collection.find())) - new_x = np.random.rand(32) - r['x'] = new_x - db.execute( - collection.replace_one( - {'_id': r['_id']}, - r, - ) - ) - - new_r = db.execute(collection.find_one({'_id': r['_id']})) - assert new_r['x'].tolist() == new_x.tolist() - - -@pytest.mark.skipif(True, reason='URI not working') -def test_insert_from_uris(db, image_url): - import PIL - from superduper.ext.pillow.encoder import pil_image - - if image_url.startswith('file://'): - image_url = image_url[7:] - - db.apply(pil_image) - collection = db['documents'] - to_insert = [Document({'img': pil_image(uri=image_url)})] - - db.execute(collection.insert_many(to_insert)) - - r = db.execute(collection.find_one()) - assert isinstance(r['img'].x, PIL.Image.Image) - - -def test_update_many(db): - add_random_data(db, n=5) - collection = db['documents'] - to_update = np.random.randn(32) - db.execute(collection.update_many({}, Document({'$set': {'x': to_update}}))) - cur = db.execute(collection.find()) - r = next(cur) - - assert all(r['x'] == to_update) - - # TODO: Need to support Update result in predict_in_db - # listener = db.load('listener', 'vector-x') - # assert all( - # listener.model.predict(to_update) - # == next(db['_outputs__vector-x'].find().execute())['_outputs__vector-x'].x - # ) - - -def test_outputs_query_2(db): - import numpy - from superduper import model - - db.cfg.auto_schema = True - - @model - def test(x): - return numpy.random.randn(32) - - data = [{'x': f'test {i}', 'y': f'other {i}'} for i in range(5)] - db['example'].insert(data).execute() - - l1 = test.to_listener(key='x', select=db['example'].select(), identifier='l-x') - l2 = test.to_listener(key='y', select=db['example'].select(), identifier='l-y') - - db.apply(l1) - db.apply(l2) - - @model() - def test_flat(x): - return [numpy.random.randn(32) for _ in range(3)] - - select = db['example'].select() - l3 = test_flat.to_listener( - key='x', select=select, identifier='l-x-flat', flatten=True - ) - - db.apply(l3) - - ######## - - q = db['example'].outputs(l1.predict_id) - - results = q.execute().tolist() - - assert len(results) == 5 - - ####### - - q = db['example'].outputs(l2.predict_id) - - results = q.execute().tolist() - - assert len(results) == 5 - - ####### - - q = db['example'].outputs(l1.predict_id, l2.predict_id) - - results = q.execute().tolist() - - assert len(results) == 5 - - ####### - - q = db['example'].outputs(l3.predict_id) - - results = q.execute().tolist() - - assert len(results) == 15 - - ####### - - q = db['example'].outputs(l1.predict_id, l3.predict_id) - - results = q.execute().tolist() - - assert len(results) == 15 - - ####### - - q = db['example'].outputs(l1.predict_id, l2.predict_id, l3.predict_id) - - results = q.execute().tolist() - - assert len(results) == 15 - - -def test_outputs_query(db): - db.cfg.auto_schema = True - - add_random_data(db, n=5) - add_models(db) - - l1, l2, l1_flat = add_listeners(db) - - outputs_1 = list(db['documents'].outputs(l1.predict_id).execute()) - assert len(outputs_1) == 5 - outputs_2 = list(db['documents'].outputs(l2.predict_id).execute()) - assert len(outputs_2) == 5 - outputs_1_2 = list(db['documents'].outputs(l1.predict_id, l2.predict_id).execute()) - assert len(outputs_1_2) == 5 - - -def test_insert_many(db): - db.cfg.auto_schema = True - add_random_data(db, n=5) - add_models(db) - add_listeners(db) - collection = db['documents'] - an_update = get_new_data(10, update=True) - db.execute(collection.insert(an_update)) - assert len(list(db.execute(collection.find()))) == 5 + 10 - for lid in db.show('listener'): - if 'flat' not in lid: - component = db.load('listener', lid) - outputs = component.outputs - assert len(list(db.execute(db[outputs].find()))) == 5 + 10 - - -def test_like(db): - db.cfg.auto_schema = True - add_random_data(db, n=5) - add_models(db) - add_vector_index(db) - collection = db['documents'] - r = db.execute(collection.find_one()) - query = collection.like( - r=Document({'x': r['x']}), - vector_index='test_vector_search', - ).find() - s = next(db.execute(query)) - assert r['_id'] == s['_id'] - - -def test_insert_one(db): - db.cfg.auto_schema = True - add_random_data(db, n=5) - # MARK: empty Collection + a_single_insert - collection = db['documents'] - a_single_insert = get_new_data(1, update=False)[0] - q = collection.insert_one(a_single_insert) - out = db.execute(q) - r = db.execute(collection.find({'_id': out[0]})) - docs = list(r) - assert docs[0]['x'].tolist() == a_single_insert['x'].tolist() - - -def test_delete_one(db): - add_random_data(db, n=5) - collection = db['documents'] - r = db.execute(collection.find_one()) - db.execute(collection.delete_one({'_id': r['_id']})) - with pytest.raises(StopIteration): - next(db.execute(db['documents'].find({'_id': r['_id']}))) - - -def test_find(db): - add_random_data(db, n=10) - collection = db['documents'] - r = db.execute(collection.find().limit(1)) - assert len(list(r)) == 1 - r = db.execute(collection.find().limit(5)) - assert len(list(r)) == 5 - - -def test_find_one(db): - add_random_data(db, n=5) - r = db.execute(db['documents'].find_one()) - assert isinstance(r, Document) - - -def test_replace_one(db): - add_random_data(db, n=5) - collection = db['documents'] - # MARK: random data (change) - new_x = np.random.randn(32) - r = db.execute(collection.find_one()) - r['x'] = new_x - db.execute(collection.replace_one({'_id': r['_id']}, r)) - doc = db.execute(collection.find_one({'_id': r['_id']})) - print(doc['x']) - assert doc.unpack()['x'].tolist() == new_x.tolist() - - -def test_select_missing_ids(db): - db.cfg.auto_schema = True - add_random_data(db, n=5) - add_models(db) - add_vector_index(db) - out = db.load('listener', 'vector-x').outputs - doc = list(db[out].select().execute())[0] - source_id = doc['_source'] - db.databackend._db[out].delete_one({'_source': source_id}) - - predict_id = out.split('_outputs__')[-1] - query = db['documents'].select_ids_of_missing_outputs(predict_id) - x = list(query.execute()) - assert len(x) == 1 - assert source_id == x[0]['_id'] diff --git a/plugins/mongodb/plugin_test/test_query.py b/plugins/mongodb/plugin_test/test_query.py deleted file mode 100644 index 7726efe76..000000000 --- a/plugins/mongodb/plugin_test/test_query.py +++ /dev/null @@ -1,104 +0,0 @@ -import random - -import numpy as np -import pytest -from superduper.base.document import Document -from superduper.components.schema import Schema -from superduper.components.table import Table -from superduper.ext.numpy.encoder import Array - -from superduper_mongodb.query import MongoQuery - - -@pytest.fixture -def schema(request): - bytes_encoding = request.param if hasattr(request, 'param') else None - - array_tensor = Array(dtype="float64", shape=(32,)) - schema = Schema( - identifier=f'documents-{bytes_encoding}', - fields={ - "x": array_tensor, - "z": array_tensor, - }, - ) - return schema - - -def test_mongo_schema(db, schema): - collection_name = "documents" - data = [] - - for id_ in range(5): - x = np.random.rand(32) - y = int(random.random() > 0.5) - z = np.random.randn(32) - data.append( - Document( - { - "id": id_, - "x": x, - "y": y, - "z": z, - }, - db=db, - ) - ) - - table = Table(identifier=collection_name, schema=schema) - db.apply(table) - gt = data[0] - - db[collection_name].insert_many(data).execute() - r = db[collection_name].find_one().execute() - rs = list(db[collection_name].find().execute()) - - rs = sorted(rs, key=lambda x: x['id']) - - assert np.array_equal(r['x'], gt['x']) - assert np.array_equal(r['z'], gt['z']) - - assert np.array_equal(rs[0]['x'], gt['x']) - assert np.array_equal(rs[0]['z'], gt['z']) - - -def test_select_missing_outputs(db): - docs = list(db.execute(MongoQuery(table='documents').find({}, {'_id': 1}))) - ids = [r['_id'] for r in docs[: len(docs) // 2]] - db.execute( - MongoQuery(table='documents').update_many( - {'_id': {'$in': ids}}, - Document({'$set': {'_outputs__x::test_model_output::0::0': 'test'}}), - ) - ) - select = MongoQuery(table='documents').find({}, {'_id': 1}) - select.db = db - modified_select = select.select_ids_of_missing_outputs('x::test_model_output::0::0') - out = list(db.execute(modified_select)) - assert len(out) == (len(docs) - len(ids)) - - -def test_special_query_serialization(db): - q2 = db['docs'].find({'x': {'$lt': 9}}) - encoded_query = q2.encode() - base = encoded_query['_base'][1:] - assert encoded_query['_builds'][base]['documents'][0] == {'x': {'<$>lt': 9}} - - rq2 = Document.decode(encoded_query).unpack() - assert rq2.parts[0][1][0] == {'x': {'$lt': 9}} - - -def test_execute_complex_query_mongodb(db): - t = db['documents'] - - insert_query = t.insert_many( - [Document({'this': f'is a test {i}'}) for i in range(100)] - ) - db.execute(insert_query) - - select_query = t.find().sort('this', -1).limit(10) - cur = db.execute(select_query) - - expected = [f'is a test {i}' for i in range(99, 89, -1)] - cur_this = [r['this'] for r in cur] - assert sorted(cur_this) == sorted(expected) diff --git a/plugins/mongodb/superduper_mongodb/__init__.py b/plugins/mongodb/superduper_mongodb/__init__.py index 0825a447a..a6b7eadbe 100644 --- a/plugins/mongodb/superduper_mongodb/__init__.py +++ b/plugins/mongodb/superduper_mongodb/__init__.py @@ -1,14 +1,12 @@ -from .artifacts import MongoArtifactStore as ArtifactStore +from .artifacts import MongoDBArtifactStore as ArtifactStore from .data_backend import MongoDBDataBackend as DataBackend -from .metadata import MongoMetaDataStore as MetaDataStore -from .query import MongoQuery +from .metadata import MongoDBMetaDataStore as MetaDataStore from .vector_search import MongoAtlasVectorSearcher as VectorSearcher __version__ = "0.4.5" __all__ = [ "ArtifactStore", - "MongoQuery", "DataBackend", "MetaDataStore", "VectorSearcher", diff --git a/plugins/mongodb/superduper_mongodb/artifacts.py b/plugins/mongodb/superduper_mongodb/artifacts.py index c83f443ab..7c9259ff8 100644 --- a/plugins/mongodb/superduper_mongodb/artifacts.py +++ b/plugins/mongodb/superduper_mongodb/artifacts.py @@ -13,7 +13,7 @@ from superduper_mongodb.utils import connection_callback -class MongoArtifactStore(ArtifactStore): +class MongoDBArtifactStore(ArtifactStore): """ Artifact store for MongoDB. diff --git a/plugins/mongodb/superduper_mongodb/data_backend.py b/plugins/mongodb/superduper_mongodb/data_backend.py index 8f7e990c9..66ef8eba8 100644 --- a/plugins/mongodb/superduper_mongodb/data_backend.py +++ b/plugins/mongodb/superduper_mongodb/data_backend.py @@ -2,22 +2,27 @@ import typing as t import click -import mongomock -import pymongo -import pymongo.collection +from bson.objectid import ObjectId from superduper import CFG, logging from superduper.backends.base.data_backend import BaseDataBackend from superduper.backends.base.metadata import MetaDataStoreProxy -from superduper.base.enums import DBType -from superduper.components.datatype import BaseDataType +from superduper.backends.base.query import Query from superduper.components.schema import Schema from superduper.misc.colors import Colors -from superduper_mongodb.artifacts import MongoArtifactStore -from superduper_mongodb.metadata import MongoMetaDataStore +from superduper_mongodb.artifacts import MongoDBArtifactStore +from superduper_mongodb.metadata import MongoDBMetaDataStore from superduper_mongodb.utils import connection_callback -from .query import MongoQuery +OPS_MAP = { + '__eq__': '$eq', + '__ne__': '$ne', + '__lt__': '$lt', + '__le__': '$lte', + '__gt__': '$gt', + '__ge__': '$gte', + 'isin': '$in', +} class MongoDBDataBackend(BaseDataBackend): @@ -28,54 +33,36 @@ class MongoDBDataBackend(BaseDataBackend): :param flavour: Flavour of the databackend. """ - db_type = DBType.MONGODB - id_field = "_id" - def __init__(self, uri: str, flavour: t.Optional[str] = None): + def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None): self.connection_callback = lambda: connection_callback(uri, flavour) - self.overwrite = True - super().__init__(uri, flavour=flavour) - self.conn, self.name = connection_callback(uri, flavour) - self._db = self.conn[self.name] + super().__init__(uri, flavour=flavour, plugin=plugin) + + self.conn, self.name = connection_callback(uri, flavour) + self._database = self.conn[self.name] self.datatype_presets = { 'vector': 'superduper.components.datatype.NativeVector' } + def random_id(self): + """Generate a random ID.""" + return ObjectId() + + # TODO move to the super def reconnect(self): - """Reconnect to mongodb store.""" - # Reconnect to database. + """Reconnect to MongoDB databackend.""" conn, _ = self.connection_callback() self.conn = conn - self._db = self.conn[self.name] - - def get_query_builder(self, collection_name): - """Get the query builder for the data backend. - - :param collection_name: Which collection to get the query builder for - """ - item_gotten = self._db[collection_name] - if isinstance( - item_gotten, - (pymongo.collection.Collection, mongomock.collection.Collection), - ): - return MongoQuery(table=collection_name, db=self.datalayer) - return item_gotten - - def url(self): - """Return the data backend connection url.""" - return self.conn.HOST + ":" + str(self.conn.PORT) + "/" + self.name - - @property - def db(self): - """Return the datalayer instance.""" - return self._db + self._database = self.conn[self.name] def build_metadata(self): """Build the metadata store for the data backend.""" - return MetaDataStoreProxy(MongoMetaDataStore(callback=self.connection_callback)) + return MetaDataStoreProxy( + MongoDBMetaDataStore(callback=self.connection_callback) + ) def build_artifact_store(self): """Build the artifact store for the data backend.""" @@ -88,21 +75,15 @@ def build_artifact_store(self): os.makedirs(f"/tmp/{self.name}", exist_ok=True) return FileSystemArtifactStore(f"/tmp/{self.name}") - return MongoArtifactStore(self.conn, f"_filesystem:{self.name}") - - def drop_outputs(self): - """Drop all outputs.""" - for collection in self.db.list_collection_names(): - if collection.startswith(CFG.output_prefix): - self.db.drop_collection(collection) + return MongoDBArtifactStore(self.conn, f"_filesystem:{self.name}") - def drop_table_or_collection(self, name: str): + def drop_table(self, name: str): """Drop the table or collection. Please use with caution as you will lose all data. :param name: Collection to drop. """ - return self.db.drop_collection(name) + return self._database.drop_collection(name) def drop(self, force: bool = False): """Drop the data backend. @@ -119,137 +100,235 @@ def drop(self, force: bool = False): default=False, ): logging.warn("Aborting...") - return self.db.client.drop_database(self.db.name) + return self._database.client.drop_database(self._database.name) - def get_table_or_collection(self, identifier): + def get_table(self, identifier): """Get a table or collection from the data backend. :param identifier: table or collection identifier """ - return self._db[identifier] + return self._database[identifier] - def list_tables_or_collections(self): + def list_tables(self): """List all tables or collections in the data backend.""" - return self.db.list_collection_names() + return self._database.list_collection_names() - def disconnect(self): - """Disconnect the client.""" - - # TODO: implement me + def check_output_dest(self, predict_id) -> bool: + """Check if the output destination exists. - def create_output_dest( - self, - predict_id: str, - datatype: t.Union[str, BaseDataType], - flatten: bool = False, - ): - """Create an output collection for a component. + :param predict_id: identifier of the prediction + """ + return self._database[f"{CFG.output_prefix}{predict_id}"].find_one() is not None - That will do nothing for MongoDB. + def create_table_and_schema(self, identifier: str, schema: Schema): + """Create a table and schema in the data backend. - :param predict_id: The predict id of the output destination - :param datatype: datatype of component - :param flatten: flatten the output + :param identifier: The identifier for the table + :param mapping: The mapping for the schema """ + # If the data can be converted to JSON, + # then save it as native data in MongoDB. pass - def exists(self, table_or_collection, id, key): - """Check if a document exists in the data backend. - - :param table_or_collection: table or collection identifier - :param id: document identifier - :param key: key to check - """ - return ( - self.db[table_or_collection].find_one( - {"_id": id, f"{key}._content.bytes": {"$exists": 1}} + ################################### + # Query execution implementations # + ################################### + + def primary_id(self, query): + """Get the primary ID for the query.""" + return '_id' + + def insert(self, table, documents): + """Insert documents into the table.""" + for doc in documents: + if '_id' in doc: + doc['_id'] = ObjectId(doc['_id']) + if '_source' in doc: + doc['_source'] = ObjectId(doc['_source']) + return self._database[table].insert_many(documents).inserted_ids + + def missing_outputs(self, table, predict_id: str): + """Get the missing outputs for the prediction.""" + key = f'{CFG.output_prefix}{predict_id}' + lookup = [ + { + '$lookup': { + 'from': key, + 'localField': '_id', + 'foreignField': '_source', + 'as': key, + } + }, + {'$match': {key: {'$size': 0}}}, + ] + collection = self._database[table] + results = list(collection.aggregate(lookup)) + return [r['_id'] for r in results] + + def select(self, query: Query): + """Select data from the table.""" + if query.decomposition.outputs: + return self._outputs(query) + + collection = self._database[query.table] + + limit = self._get_limit(query) + if limit: + return list( + collection.find( + self._mongo_filter(query), self._get_project(query) + ).limit(limit) ) - is not None + + return list( + collection.find(self._mongo_filter(query), self._get_project(query)) ) - def check_output_dest(self, predict_id) -> bool: - """Check if the output destination exists. + def to_id(self, id): + """Convert the ID to the correct format.""" + return ObjectId(id) - :param predict_id: identifier of the prediction - """ - return self.db[f"{CFG.output_prefix}{predict_id}"].find_one() is not None + ######################## + # Helper methods below # + ######################## - def check_ready_ids( - self, - query: MongoQuery, - keys: t.List[str], - ids: t.Optional[t.List[t.Any]] = None, - ): - """Check if all the keys are ready in the ids. + @staticmethod + def _get_project(query): + if query.decomposition.select is None: + return {} - Use this function to check if all the keys are ready in the ids. - Because the join operation is not very efficient in MongoDB, we use the - output keys to filter the ids first and then check the base keys. + if not query.decomposition.select.args: + return {} - This process only verifies the key and does not involve reading the real data. + project = {} + for k in query.decomposition.select.args: + if isinstance(k, Query): + assert k.parts[0] == 'primary_id' + project['_id'] = 1 + else: + project[k] = 1 - :param query: The query object. - :param keys: The keys to check. - :param ids: The ids to check. - """ + if '_id' not in project: + project['_id'] = 0 - def is_output_key(key): - return key.startswith(CFG.output_prefix) and key != query.table + return project - output_keys = [key for key in keys if is_output_key(key)] - input_ids = ready_ids = ids + @staticmethod + def _mongo_filter(query): + if query.decomposition.filter is None: + return {} - # Filter the ids by the output keys first - for output_key in output_keys: - filter: dict[str, t.Any] = {} - filter[output_key] = {"$exists": 1} - if ready_ids is not None: - filter["_source"] = {"$in": ready_ids} - ready_ids = list( - self.get_table_or_collection(output_key).find(filter, {"_source": 1}) - ) - ready_ids = [doc["_source"] for doc in ready_ids] - if not ready_ids: - return [] + filters = query.decomposition.filter.args - # If we get the ready ids from the output keys, we can continue on these ids - ids = ready_ids or ids + mongo_filter = {} + for f in filters: + assert len(f) > 2, f'Invalid filter query: {f}' + key = f.parts[0] + if key == 'primary_id': + key = '_id' - base_keys = [key for key in keys if not is_output_key(key)] - base_filter: dict[str, t.Any] = {} - base_filter.update({key: {"$exists": 1} for key in base_keys}) - if ready_ids is not None: - base_filter["_id"] = {"$in": ready_ids} + op = f.parts[1] - ready_ids = list( - self.get_table_or_collection(query.table).find(base_filter, {"_id": 1}) - ) - ready_ids = [doc["_id"] for doc in ready_ids] + if op.name not in OPS_MAP: + raise ValueError( + f'Operation {op} not supported, ' + f'supported operations are: {OPS_MAP.keys()}' + ) + + if not op.args: + raise ValueError(f'No arguments found for operation {op}') + + value = op.args[0] - if ids is not None: - ready_ids = [id for id in ids if id in ready_ids] + if f.decomposition.col == 'primary_id': + if isinstance(value, str): + value = ObjectId(value) + elif isinstance(value, list): + value = [ObjectId(x) for x in value] - self._log_check_ready_ids_message(input_ids, ready_ids) - return ready_ids + mongo_filter[key] = {OPS_MAP[op.name]: value} + return mongo_filter @staticmethod - def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): - """Infer a schema from a given data object. + def _get_limit(query): + try: + out = query.decomposition.limit.args[0] + assert out > 0 + return out + except AttributeError: + return + + def _outputs(self, query): + pipeline = [] + + project = self._get_project(query).copy() + + filter_mapping_base = { + k: v + for k, v in self._mongo_filter(query).items() + if not k.startswith(CFG.output_prefix) + } + filter_mapping_outputs = { + k: v + for k, v in self._mongo_filter(query).items() + if k.startswith(CFG.output_prefix) + } - :param data: The data object - :param identifier: The identifier for the schema, if None, it will be generated - :return: The inferred schema - """ - from superduper.misc.auto_schema import infer_schema + if filter_mapping_base: + pipeline.append({"$match": filter_mapping_base}) + if project: + project.update({k: 1 for k in filter_mapping_base.keys()}) + + predict_ids = query.decomposition.predict_ids + + if filter_mapping_outputs: + predict_ids = [ + pid + for pid in predict_ids + if f'{CFG.output_prefix}{pid}' in filter_mapping_outputs + ] + + for predict_id in predict_ids: + key = f'{CFG.output_prefix}{predict_id}' + lookup = { + "$lookup": { + "from": key, + "localField": "_id", + "foreignField": "_source", + "as": key, + } + } + + if project: + project[key] = 1 + + pipeline.append(lookup) + + if key in filter_mapping_outputs: + pipeline.append({"$match": {key: filter_mapping_outputs[key]}}) + + pipeline.append( + {"$unwind": {"path": f"${key}", "preserveNullAndEmptyArrays": True}} + ) - return infer_schema(data, identifier) + if project: + pipeline.append({"$project": project}) - def create_table_and_schema(self, identifier: str, schema: Schema): - """Create a table and schema in the data backend. + if self._get_limit(query): + pipeline.append({"$limit": self._get_limit(query)}) - :param identifier: The identifier for the table - :param mapping: The mapping for the schema - """ - # If the data can be converted to JSON, - # then save it as native data in MongoDB. - pass + try: + import json + + logging.debug(f'Executing pipeline: {json.dumps(pipeline, indent=2)}') + except TypeError: + pass + + collection = self._database[query.table] + result = list(collection.aggregate(pipeline)) + + for pid in predict_ids: + k = f'{CFG.output_prefix}{pid}' + for r in result: + r[k] = r[k][k] + return result diff --git a/plugins/mongodb/superduper_mongodb/metadata.py b/plugins/mongodb/superduper_mongodb/metadata.py index 9864c297c..f2a8fdc48 100644 --- a/plugins/mongodb/superduper_mongodb/metadata.py +++ b/plugins/mongodb/superduper_mongodb/metadata.py @@ -8,7 +8,7 @@ from superduper.misc.colors import Colors -class MongoMetaDataStore(MetaDataStore): +class MongoDBMetaDataStore(MetaDataStore): """ Metadata store for MongoDB. diff --git a/plugins/mongodb/superduper_mongodb/query.py b/plugins/mongodb/superduper_mongodb/query.py deleted file mode 100644 index cf4a6f63d..000000000 --- a/plugins/mongodb/superduper_mongodb/query.py +++ /dev/null @@ -1,847 +0,0 @@ -import copy -import dataclasses as dc -import functools -import re -import typing as t -from collections import defaultdict - -import pymongo -from bson import ObjectId -from superduper import CFG, logging -from superduper.backends.base.query import ( - Query, - applies_to, - parse_query as _parse_query, -) -from superduper.base.cursor import SuperDuperCursor -from superduper.base.document import Document, QueryUpdateDocument -from superduper.base.leaf import Leaf - -if t.TYPE_CHECKING: - from superduper.base.datalayer import Datalayer - -_SPECIAL_CHRS: list = ['$', '.'] - -OPS_MAP = { - '__eq__': '$eq', - '__ne__': '$ne', - '__lt__': '$lt', - '__le__': '$lte', - '__gt__': '$gt', - '__ge__': '$gte', - 'isin': '$in', -} - - -def _serialize_special_character(d, to='encode'): - def extract_character(s): - pattern = r'<(.)>' - match = re.search(pattern, s) - if match: - return match.group(1) - return None - - if not isinstance(d, dict): - return d - - new_dict = {} - for key, value in d.items(): - new_key = key - if isinstance(key, str): - if to == 'encode': - if key[0] in _SPECIAL_CHRS: - new_key = f'<{key[0]}>' + key[1:] - elif to == 'decode': - k = extract_character(key[:3]) - if k in _SPECIAL_CHRS: - new_key = k + key[3:] - - if isinstance(value, dict): - new_dict[new_key] = _serialize_special_character(value, to=to) - elif isinstance(value, list): - new_dict[new_key] = [ - ( - _serialize_special_character(item, to=to) - if isinstance(item, dict) - else item - ) - for item in value - ] - else: - new_dict[new_key] = value - - return new_dict - - -def parse_query( - query, documents: t.Sequence[t.Dict] = (), db: t.Optional['Datalayer'] = None -): - """Parse a string query into a query object. - - :param query: The query to parse. - :param documents: The documents to query. - :param db: The datalayer to use to execute the query. - """ - _decode = functools.partial(_serialize_special_character, to='decode') - documents = list(map(_decode, documents)) - return _parse_query( - query=query, - builder_cls=MongoQuery, - documents=list(documents), - db=db, - ) - - -class MongoQuery(Query): - """A query class for MongoDB. - - This class is used to build and execute queries on a MongoDB database. - """ - - flavours: t.ClassVar[t.Dict[str, str]] = { - 'pre_like': r'^.*\.like\(.*\)\.(find|select)', - 'post_like': r'^.*\.(find|select)\(.*\)\.like(.*)$', - 'bulk_write': r'^.*\.bulk_write\(.*\)$', - 'outputs': r'^.*\.outputs\(.*\)', - 'missing_outputs': r'^.*\.missing_outputs\(.*\)$', - 'find_one': r'^.*\.find_one\(.*\)', - 'find': r'^.*\.find\(.*\)', - 'select': r'^.*\.select\(.*\)$', - 'insert_one': r'^.*\.insert_one\(.*\)$', - 'insert_many': r'^.*\.insert_many\(.*\)$', - 'insert': r'^.*\.insert\(.*\)$', - 'replace_one': r'^.*\.replace_one\(.*\)$', - 'update_many': r'^.*\.update_many\(.*\)$', - 'update_one': r'^.*\.update_one\(.*\)$', - 'delete_many': r'^.*\.delete_many\(.*\)$', - 'delete_one': r'^.*\.delete_one\(.*\)$', - 'other': '.*', - } - - # Use to control the behavior in the class construction method within LeafMeta - _dataclass_params: t.ClassVar[t.Dict[str, t.Any]] = { - 'eq': False, - 'order': False, - } - - def _create_table_if_not_exists(self): - return - - def dict(self, metadata: bool = True, defaults: bool = True, uuid: bool = True): - """Return the query as a dictionary.""" - r = super().dict() - r['documents'] = list(map(_serialize_special_character, r['documents'])) - return r - - @property - def type(self): - """Return the type of the query.""" - return defaultdict( - lambda: 'select', - { - 'update_many': 'update', - 'update_one': 'update', - 'delete_many': 'delete', - 'delete_one': 'delete', - 'bulk_write': 'write', - 'insert_many': 'insert', - 'insert_one': 'insert', - 'insert': 'insert', - 'outputs': 'select', - }, - )[self.flavour] - - @property - def _is_select_find(self): - return self.parts and any([p[0] == 'select' for p in self.parts]) - - def _prepare_inputs(self, inputs): - if isinstance(inputs, BulkOp): - return inputs.op - if isinstance(inputs, (list, tuple)): - return [self._prepare_inputs(i) for i in inputs] - if isinstance(inputs, dict): - return {k: self._prepare_inputs(v) for k, v in inputs.items()} - return inputs - - def _execute_delete_one(self, parent): - r = next(self.select_ids.limit(1)._execute(parent)) - self.table_or_collection.delete_one({'_id': r['_id']})._execute(parent) - return [str(r['_id'])] - - def _execute_delete_many(self, parent): - id_cursor = self.select_ids._execute(parent) - ids = [r['_id'] for r in id_cursor] - if not ids: - return {} - self.table_or_collection.delete_many({'_id': {'$in': ids}})._execute(parent) - return [str(id) for id in ids] - - def _execute(self, parent, method='encode'): - c = super()._execute(parent, method=method) - import mongomock - import pymongo - - if isinstance(c, (pymongo.cursor.Cursor, mongomock.collection.Cursor)): - return SuperDuperCursor( - raw_cursor=c, - db=self.db, - id_field='_id', - ) - return c - - def _execute_bulk_write(self, parent): - """Execute the query. - - :param db: The datalayer instance - """ - assert self.parts[0][0] == 'bulk_write' - operations = self.parts[0][1][0] - for query in operations: - assert isinstance(query, (BulkOp)) - - query.is_output_query = self.is_output_query - if not query.kwargs.get('arg_ids', None): - raise ValueError( - 'Please provided update/delete id in args', - r'all ids selection e.g `\{\}` is not supported', - ) - - collection = self.db.databackend.get_table_or_collection(self.table) - bulk_operations = [] - bulk_update_ids = [] - bulk_delete_ids = [] - bulk_result = {'delete': [], 'update': []} - for query in operations: - operation = query.op - - bulk_operations.append(operation) - ids = query.kwargs['arg_ids'] - if query.identifier == 'DeleteOne': - bulk_result['delete'].append({'query': query, 'ids': ids}) - bulk_delete_ids += ids - else: - bulk_update_ids += ids - bulk_result['update'].append({'query': query, 'ids': ids}) - - result = collection.bulk_write(bulk_operations) - if result.deleted_count != bulk_delete_ids: - logging.warn( - 'Some delete ids are not executed', - ', hence halting execution', - 'Please note the partially executed operations', - 'wont trigger any `model/listeners` unless CDC is active.', - ) - elif (result.modified_count + result.upserted_count) != bulk_update_ids: - logging.warn( - 'Some update ids are not executed', - ', hence halting execution', - 'Please note the partially executed operations', - 'wont trigger any `model/listeners` unless CDC is active.', - ) - return bulk_result, bulk_update_ids, bulk_delete_ids - - def filter(self, *args, **kwargs): - """Return a query that filters the documents. - - :param args: The arguments to filter by. - :param kwargs: Additional keyword arguments. - """ - filters = {} - for arg in args: - if isinstance(arg, dict): - filters.update(arg) - continue - if not isinstance(arg, Query): - raise ValueError(f'Filter arguments must be queries, but got: {arg}') - assert len(arg.parts) > 1, f'Invalid filter query: {arg}' - key, (op, op_args, _) = arg.parts[-2:] - - if op not in OPS_MAP: - raise ValueError( - f'Operation {op} not supported, ' - f'supported operations are: {OPS_MAP.keys()}' - ) - - if not op_args: - raise ValueError(f'No arguments found for operation {op}') - - value = op_args[0] - - filters[key] = {OPS_MAP[op]: value} - - query = self if self.parts else self.select() - - return type(self)( - db=self.db, - table=self.table, - parts=[*query.parts, ('filter', (filters,), kwargs)], - ) - - def _execute_find(self, parent): - return self._execute_select(parent) - - def _execute_replace_one(self, parent): - documents = self.parts[0][1][0] - trailing_args = list(self.parts[0][1][1:]) - kwargs = self.parts[0][2] - - schema = kwargs.pop('schema', None) - - replacement = trailing_args[0] - if isinstance(replacement, Document): - replacement = replacement.encode(schema) - trailing_args[0] = replacement - - q = self.table_or_collection.replace_one(documents, *trailing_args, **kwargs) - q._execute(parent) - - def _execute_find_one(self, parent): - r = self._execute_select(parent) - if r is None: - return - return Document.decode(r, db=self.db, schema=self._get_schema()) - - def _execute_insert_one(self, parent): - insert_part = self.parts[0] - parts = self.parts[1:] - insert_part = ('insert_many', [insert_part[1]], insert_part[2]) - - self.parts = [insert_part] + parts - return self._execute_insert_many(parent) - - def _execute_insert_many(self, parent): - trailing_args = self.parts[0][1][1:] - kwargs = self.parts[0][2] - documents = self._prepare_documents() - q = self.table_or_collection.insert_many(documents, *trailing_args, **kwargs) - result = q._execute(parent) - return result.inserted_ids - - def _execute_insert(self, parent): - """Provide a unified insertion interface.""" - return self._execute_insert_many(parent) - - def _execute_update_many(self, parent): - ids = [r['_id'] for r in self.select_ids._execute(parent)] - filter = self.parts[0][1][0] - trailing_args = list(self.parts[0][1][1:]) - update = {} - kwargs = self.parts[0][2] - - # Encode update document - for ix, arg in enumerate(trailing_args): - if '$set' in arg: - if isinstance(arg, Document): - update = QueryUpdateDocument.from_document(arg) - else: - update = arg - del trailing_args[ix] - break - - filter['_id'] = {'$in': ids} - - try: - table = self.db.load('table', self.table) - schema = table.schema - except FileNotFoundError: - schema = None - - trailing_args.insert( - 0, update.encode(schema=schema) if isinstance(update, Document) else update - ) - - parent.update_many(filter, *trailing_args, **kwargs) - return ids - - def drop_outputs(self, predict_id: str): - """Return a query that removes output corresponding to the predict id. - - :param predict_ids: The ids of the predictions to select. - """ - return self.db.databackend.drop_table_or_collection( - f'{CFG.output_prefix}{predict_id}' - ) - - @applies_to('find', 'find_one', 'select', 'outputs') - def add_fold(self, fold: str): - """Return a query that adds a fold to the query. - - :param fold: The fold to add. - """ - return self.filter(self['_fold'] == fold) - - @property - @applies_to('insert_many', 'insert_one', 'insert') - def documents(self): - """Return the documents.""" - return super().documents - - @property - def primary_id(self): - """Return the primary id of the documents.""" - return '_id' - - def select_using_ids(self, ids: t.Sequence[str]): - """Return a query that selects using the given ids. - - :param ids: The ids to select. - """ - ids = [ObjectId(id) for id in ids] - q = self.filter(self['_id'].isin(ids)) - return q - - @property - def select_ids(self): - """Select the ids of the documents.""" - if self._is_select_find: - part = self.parts[0] - new_part = ('select', ('_id',), part[2]) - return type(self)( - db=self.db, - table=self.table, - parts=[new_part, *self.parts[1:]], - ) - - filter_ = {} - if self.parts and self.parts[0][1]: - filter_ = self.parts[0][1][0] - if isinstance(filter_, str): - filter_ = {} - projection = {'_id': 1} - coll = type(self)(table=self.table, db=self.db) - return coll.find(filter_, projection) - - def select_ids_of_missing_outputs(self, predict_id: str): - """Select the ids of documents that are missing the given output. - - :param predict_id: The id of the prediction. - """ - return self.missing_outputs(predict_id=predict_id, ids_only=1) - - def _execute_missing_outputs(self, parent): - """Select the documents that are missing the given output.""" - if len(self.parts[-1][2]) == 0: - raise ValueError("Predict id is required") - predict_id = self.parts[-1][2]["predict_id"] - ids_only = self.parts[-1][2].get('ids_only', False) - - key = f'{CFG.output_prefix}{predict_id}' - - lookup = [ - { - '$lookup': { - 'from': key, - 'localField': '_id', - 'foreignField': '_source', - 'as': key, - } - }, - {'$match': {key: {'$size': 0}}}, - ] - - raw_cursor = getattr(parent, 'aggregate')(lookup) - - def get_ids(result): - return {"_id": result["_id"]} - - return SuperDuperCursor( - raw_cursor=raw_cursor, - db=self.db, - id_field='_id', - process_func=get_ids if ids_only else self._postprocess_result, - schema=self._get_schema(), - ) - - @property - @applies_to('find') - def select_single_id(self, id: str): - """Return a query that selects a single id. - - :param id: The id to select. - """ - args, kwargs = self.parts[0][1:] - args = list(self.args)[:] - if not args: - args[0] = {} - args[0]['_id'] = ObjectId(id) - return type(self)( - db=self.db, table=self.table, parts=[('find_one', args, kwargs)] - ) - - @property - def select_table(self): - """Return the table or collection to select from.""" - return self.table_or_collection.find() - - def model_update( - self, - ids: t.List[t.Any], - predict_id: str, - outputs: t.Sequence[t.Any], - flatten: bool = False, - **kwargs, - ): - """Update the model outputs in the database. - - :param ids: The ids of the documents to update. - :param predict_id: The id of the prediction. - :param outputs: The outputs to store. - :param flatten: Whether to flatten the outputs. - :param kwargs: Additional keyword arguments. - """ - if flatten: - flattened_outputs = [] - flattened_ids = [] - for output, id in zip(outputs, ids): - assert isinstance(output, (list, tuple)), 'Expected list or tuple' - for o in output: - flattened_outputs.append(o) - flattened_ids.append(id) - return self.model_update( - ids=flattened_ids, - predict_id=predict_id, - outputs=flattened_outputs, - flatten=False, - **kwargs, - ) - - documents = [] - for output, id in zip(outputs, ids): - documents.append( - { - f'{CFG.output_prefix}{predict_id}': output, - '_source': ObjectId(id), - } - ) - - from superduper.base.datalayer import Datalayer - - assert isinstance(self.db, Datalayer) - output_query = self.db[f'{CFG.output_prefix}{predict_id}'].insert_many( - documents - ) - output_query.is_output_query = True - output_query.updated_key = predict_id - return output_query - - def _replace_part(self, part_name, replace_function): - parts = copy.deepcopy(self.parts) - - for i, part in enumerate(parts): - if part[0] == part_name: - parts[i] = replace_function(part) - break - - return type(self)( - db=self.db, - table=self.table, - parts=parts, - ) - - def _execute_select(self, parent): - parts = self._get_select_parts() - output = self._get_chain_native_query(parent, parts, method='unpack') - import mongomock - import pymongo - - if isinstance(output, (pymongo.cursor.Cursor, mongomock.collection.Cursor)): - return SuperDuperCursor( - raw_cursor=output, - db=self.db, - id_field='_id', - schema=self._get_schema(), - ) - return output - - def _get_select_parts(self): - def process_select_part(part): - # Convert the select part to a find part - _, args, _ = part - projection = {key: 1 for key in args} - if projection and '_id' not in projection: - projection['_id'] = 0 - return ('find', ({}, projection), {}) - - def process_find_part(part): - method, args, kwargs = part - # args: (filter, projection, *args) - filter = copy.deepcopy(args[0]) if len(args) > 0 else {} - filter = dict(filter) - filter.update(self._get_filter_conditions()) - args = tuple((filter, *args[1:])) - - return (method, args, kwargs) - - parts = [] - for part in self.parts: - if part[0] == 'select': - part = process_select_part(part) - - if part[0] in {'find', 'find_one'}: - part = process_find_part(part) - - if part[0] == 'filter': - continue - parts.append(part) - - return parts - - def _get_filter_conditions(self): - filters = {} - for part in self.parts: - if part[0] == 'filter': - sub_filters = part[1][0] - assert isinstance(sub_filters, dict) - filters.update(sub_filters) - return filters - - def isin(self, other): - """Create an isin query. - - :param other: The value to check against. - """ - other = [ObjectId(o) for o in other] - return self._ops('isin', other) - - def _get_method_parameters(self, method): - args, kwargs = (), {} - for part in self.parts: - if part[0] == method: - assert not args, 'Multiple find operations found' - assert not kwargs, 'Multiple find operations found' - args, kwargs = part[1:] - - return args, kwargs - - def _get_predict_ids(self): - outputs_parts = [p for p in self.parts if p[0] == 'outputs'] - predict_ids = sum([p[1] for p in outputs_parts], ()) - return predict_ids - - def _execute_outputs(self, parent, method='encode'): - project = self._get_project() - - limit_args, _ = self._get_method_parameters('limit') - limit = {"$limit": limit_args[0]} if limit_args else None - - pipeline = [] - filter_mapping_base, filter_mapping_outputs = self._get_filter_mapping() - if filter_mapping_base: - pipeline.append({"$match": filter_mapping_base}) - project.update({k: 1 for k in filter_mapping_base.keys()}) - - predict_ids_in_filter = list(filter_mapping_outputs.keys()) - - predict_ids = self._get_predict_ids() - predict_ids = list(set(predict_ids).union(predict_ids_in_filter)) - # After the join, the complete outputs data can be queried as - # {CFG.output_prefix}{predict_id}._outputs.{predict_id} : result. - for predict_id in predict_ids: - key = f'{CFG.output_prefix}{predict_id}' - lookup = { - "$lookup": { - "from": key, - "localField": "_id", - "foreignField": "_source", - "as": key, - } - } - - project[key] = 1 - pipeline.append(lookup) - - if predict_id in filter_mapping_outputs: - filter_key, filter_value = list( - filter_mapping_outputs[predict_id].items() - )[0] - pipeline.append({"$match": {f'{key}.{filter_key}': filter_value}}) - - pipeline.append( - {"$unwind": {"path": f"${key}", "preserveNullAndEmptyArrays": True}} - ) - - if project: - pipeline.append({"$project": project}) - - if limit: - pipeline.append(limit) - - try: - import json - - logging.debug(f'Executing pipeline: {json.dumps(pipeline, indent=2)}') - except TypeError: - pass - - raw_cursor = getattr(parent, 'aggregate')(pipeline) - - return SuperDuperCursor( - raw_cursor=raw_cursor, - db=self.db, - id_field='_id', - process_func=self._postprocess_result, - schema=self._get_schema(), - ) - - def _get_schema(self): - outputs_parts = [p for p in self.parts if p[0] == 'outputs'] - predict_ids = sum([p[1] for p in outputs_parts], ()) - - try: - table = self.db.load('table', self.table) - if not predict_ids: - return table.schema - fields = table.schema.fields - except FileNotFoundError: - fields = {} - - for predict_id in predict_ids: - key = f'{CFG.output_prefix}{predict_id}' - try: - output_table = self.db.load('table', key) - except FileNotFoundError: - logging.warn( - f'No schema found for table {key}. Using default projection' - ) - continue - fields[key] = output_table.schema.fields[key] - - from superduper.components.datatype import BaseDataType - from superduper.components.schema import Schema - - fields = {k: v for k, v in fields.items() if isinstance(v, BaseDataType)} - - return Schema(f"_tmp:{self.table}", fields=fields, db=self.db) - - def _get_project(self): - find_params, _ = self._get_method_parameters('find') - select_params, _ = self._get_method_parameters('select') - if self._is_select_find: - project = {key: 1 for key in select_params} - else: - project = copy.deepcopy(find_params[1]) if len(find_params) > 1 else {} - - if not project: - try: - table = self.db.load('table', self.table) - project = {key: 1 for key in table.schema.fields.keys()} - except FileNotFoundError: - logging.warn( - 'No schema found for table', - f'{self.table}. Using default projection', - ) - - return project - - def _get_filter_mapping(self): - find_params, _ = self._get_method_parameters('find') - filter = find_params[0] if find_params else {} - filter.update(self._get_filter_conditions()) - - if not filter: - return {}, {} - - filter_mapping_base = {} - filter_mapping_outputs = defaultdict(dict) - - for key, value in filter.items(): - if '{CFG.output_prefix}' not in key: - filter_mapping_base[key] = value - continue - - if key.startswith('{CFG.output_prefix}'): - predict_id = key.split('__')[1] - filter_mapping_outputs[predict_id] = {key: value} - - return filter_mapping_base, filter_mapping_outputs - - def _postprocess_result(self, result): - """Postprocess the result of the query. - - Merge the outputs from the output keys to the result - - :param result: The result to postprocess. - """ - merge_outputs = {} - predict_ids = self._get_predict_ids() - output_keys = [f"{CFG.output_prefix}{predict_id}" for predict_id in predict_ids] - for output_key in output_keys: - output_data = result[output_key] - output_result = output_data[output_key] - merge_outputs[output_key] = output_result - - result.update(merge_outputs) - return result - - -def InsertOne(**kwargs): - """InsertOne operation for MongoDB. - - :param kwargs: The arguments to pass to the operation. - """ - return BulkOp(identifier='InsertOne', kwargs=kwargs) - - -def UpdateOne(**kwargs): - """UpdateOne operation for MongoDB. - - :param kwargs: The arguments to pass to the operation. - """ - try: - filter = kwargs['filter'] - except Exception as e: - raise KeyError('Filter not found in `UpdateOne`') from e - - id = filter['_id'] - if isinstance(id, ObjectId): - ids = [id] - else: - ids = id['$in'] - kwargs['arg_ids'] = ids - return BulkOp(identifier='UpdateOne', kwargs=kwargs) - - -def DeleteOne(**kwargs): - """DeleteOne operation for MongoDB. - - :param kwargs: The arguments to pass to the operation. - """ - return BulkOp(identifier='DeleteOne', kwargs=kwargs) - - -def ReplaceOne(**kwargs): - """ReplaceOne operation for MongoDB. - - :param kwargs: The arguments to pass to the operation. - """ - return BulkOp(identifier='ReplaceOne', kwargs=kwargs) - - -class BulkOp(Leaf): - """A bulk operation for MongoDB. - - :param kwargs: The arguments to pass to the operation. - """ - - ops: t.ClassVar[t.Sequence[str]] = [ - 'InsertOne', - 'UpdateOne', - 'DeleteOne', - 'ReplaceOne', - ] - kwargs: t.Dict = dc.field(default_factory=dict) - - def __post_init__(self, db): - super().__post_init__(db) - assert self.identifier in self.ops - - @property - def op(self): - """Return the operation.""" - kwargs = copy.deepcopy(self.kwargs) - kwargs.pop('arg_ids') - for k, v in kwargs.items(): - if isinstance(v, Document): - kwargs[k] = v.unpack() - return getattr(pymongo, self.identifier)(**kwargs) diff --git a/superduper/backends/base/data_backend.py b/superduper/backends/base/data_backend.py index 82605e3d2..98868078c 100644 --- a/superduper/backends/base/data_backend.py +++ b/superduper/backends/base/data_backend.py @@ -2,9 +2,10 @@ import typing as t from abc import ABC, abstractmethod -from superduper import logging +from superduper import CFG, logging from superduper.backends.base.query import Query -from superduper.components.datatype import BaseDataType +from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES, KEY_SCHEMA +from superduper.base.document import Document if t.TYPE_CHECKING: from superduper.components.schema import Schema @@ -14,20 +15,26 @@ class BaseDataBackend(ABC): """Base data backend for the database. :param uri: URI to the databackend database. + :param plugin: Plugin implementing the databackend. :param flavour: Flavour of the databackend. """ - db_type = None + id_field: str = 'id' - def __init__(self, uri: str, flavour: t.Optional[str] = None): + def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None): self.conn = None self.flavour = flavour self.in_memory: bool = False self.in_memory_tables: t.Dict = {} - self._datalayer = None + self.plugin = plugin + self._db = None self.uri = uri self.bytes_encoding = 'bytes' + @property + def database(self): + raise NotImplementedError + @property def backend_name(self): return self.__class__.__name__.split('DataBackend')[0].lower() @@ -37,32 +44,27 @@ def type(self): """Return databackend.""" raise NotImplementedError - @property - def db(self): - """Return the datalayer.""" - raise NotImplementedError - @abstractmethod - def drop_outputs(self): + def drop_table(self, table: str): """Drop all outputs.""" + @abstractmethod + def random_id(self): + """Generate random-id.""" + pass + @property - def datalayer(self): + def db(self): """Return the datalayer.""" - return self._datalayer + return self._db - @datalayer.setter - def datalayer(self, value): + @db.setter + def db(self, value): """Set the datalayer. :param value: The datalayer. """ - self._datalayer = value - - @abstractmethod - def url(self): - """Databackend connection url.""" - pass + self._db = value @abstractmethod def build_metadata(self): @@ -74,21 +76,6 @@ def build_artifact_store(self): """Build a default artifact store based on current connection.""" pass - @abstractmethod - def create_output_dest( - self, - predict_id: str, - datatype: t.Union[str, BaseDataType], - flatten: bool = False, - ): - """Create an output destination for the database. - - :param predict_id: The predict id of the output destination. - :param datatype: The datatype of the output destination. - :param flatten: Whether to flatten the output destination. - """ - pass - @abstractmethod def create_table_and_schema(self, identifier: str, schema: "Schema"): """Create a schema in the data-backend. @@ -106,16 +93,7 @@ def check_output_dest(self, predict_id) -> bool: pass @abstractmethod - def get_query_builder(self, key): - """Get a query builder for the database. - - :param key: The key of the query builder, - typically the table or collection name. - """ - pass - - @abstractmethod - def get_table_or_collection(self, identifier): + def get_table(self, identifier): """Get a table or collection from the database. :param identifier: The identifier of the table or collection. @@ -130,53 +108,154 @@ def drop(self, force: bool = False): """ @abstractmethod - def disconnect(self): - """Disconnect the client.""" + def list_tables(self): + """List all tables or collections in the database.""" @abstractmethod - def list_tables_or_collections(self): - """List all tables or collections in the database.""" + def reconnect(self): + """Reconnect to the databackend.""" - @staticmethod - def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): - """Infer a schema from a given data object. + ######################################################## + # Abstract methods/ optional methods to be implemented # + ######################################################## - :param data: The data object - :param identifier: The identifier for the schema, if None, it will be generated - :return: The inferred schema - """ + @abstractmethod + def insert(self, table: str, documents: t.Sequence[t.Dict]) -> t.List[str]: + pass - def check_ready_ids( - self, query: Query, keys: t.List[str], ids: t.Optional[t.List[t.Any]] = None - ): - """Check if all the keys are ready in the ids. + @abstractmethod + def missing_outputs(self, query: Query, predict_id: str) -> t.List[str]: + pass - :param query: The query object. - :param keys: The keys to check. - :param ids: The ids to check. - """ - if ids: - query = query.select_using_ids(ids) - data = query.execute() - ready_ids = [] - for select in data: - notfound = 0 - for k in keys: - try: - select[k] - except KeyError: - notfound += 1 - if notfound == 0: - ready_ids.append(select[query.primary_id]) - self._log_check_ready_ids_message(ids, ready_ids) - return ready_ids - - def _log_check_ready_ids_message(self, input_ids, ready_ids): - if input_ids and len(ready_ids) != len(input_ids): - not_ready_ids = set(input_ids) - set(ready_ids) - logging.info(f"IDs {not_ready_ids} do not ready.") - logging.debug(f"Ready IDs: {ready_ids}") - logging.debug(f"Not ready IDs: {not_ready_ids}") + @abstractmethod + def primary_id(self, query: Query) -> str: + pass + + @abstractmethod + def select(self, query: Query) -> t.List[t.Dict]: + pass + + def to_id(self, id: t.Any) -> str: + return id + + ########################################## + # Methods which leverage implementations # + ########################################## + + def get(self, query: Query): + assert query.type == 'select' + + if query.decomposition.pre_like: + return list(self.pre_like(query, n=1))[0] + + elif query.decomposition.post_like: + return list(self.post_like(query, n=1))[0] + + return query.limit(1).execute()[0] + + def _wrap_results(self, query: Query, result, schema): + pid = self.primary_id(query) + for r in result: + if pid in r: + r[pid] = str(r[pid]) + if '_source' in r: + r['_source'] = str(r['_source']) + return [Document.decode(r, schema=schema, db=self.db) for r in result] + + def execute(self, query: Query): + query = query if '.outputs' not in str(query) else query.complete_uuids(self.db) + + schema = self.get_schema(query) + + if query.decomposition.pre_like: + return self._wrap_results(query, self.pre_like(query), schema=schema) + + if query.decomposition.post_like: + return self._wrap_results(query, self.post_like(query), schema=schema) + + return self._wrap_results(query, self.select(query), schema=schema) + + def get_schema(self, query) -> 'Schema': + base_schema = self.db.load('table', query.table).schema + + if query.decomposition.outputs: + for predict_id in query.decomposition.outputs.args: + base_schema += self.db.load( + 'table', f'{CFG.output_prefix}{predict_id}' + ).schema + + return base_schema + + def _do_insert(self, table, documents): + schema = self.get_schema(self.db[table]) + + if not schema.trivial: + for i, r in enumerate(documents): + r = Document(r).encode(schema=self.get_schema(self.db[table])) + self.db.artifact_store.save_artifact(r) + r.pop(KEY_BUILDS) + r.pop(KEY_BLOBS) + r.pop(KEY_FILES) + r.pop(KEY_SCHEMA, None) + documents[i] = r + + out = self.insert(table, documents) + return [str(x) for x in out] + + def pre_like(self, query: Query): + assert query.decomposition.pre_like is not None + + ids, scores = self.db.select_nearest( + like=query.decomposition.pre_like.args[0], + vector_index=query.decomposition.pre_like.args[1], + n=query.decomposition.pre_like.kwargs.get('n', 10), + ) + + lookup = {id: score for id, score in zip(ids, scores)} + + t = self.db[query.decomposition.table] + new_filter = t.primary_id.isin(ids) + + copy = query.decomposition.copy() + copy.pre_like = None + + new = copy.to_query() + new = new.filter(new_filter) + + results = new.execute() + + pid = self.primary_id(query) + for r in results: + r['score'] = lookup[r[pid]] + + results = sorted(results, key=lambda x: x['score'], reverse=True) + return results + + def post_like(self, query: Query): + like_part = query[-1] + prepare_query = query[:-1] + relevant_ids = prepare_query.ids() + + ids, scores = self.db.select_nearest( + like=like_part.args[0], + vector_index=like_part.args[1], + n=like_part.kwargs['n'], + ids=relevant_ids, + ) + + lookup = {id: score for id, score in zip(ids, scores)} + + t = self.db[query.table] + + results = prepare_query.filter(t.primary_id.isin(ids)).execute() + + pid = self.primary_id(query) + + for r in results: + r['score'] = lookup[r[pid]] + + results = sorted(results, key=lambda x: x['score'], reverse=True) + return results class DataBackendProxy: @@ -190,17 +269,17 @@ def __init__(self, backend): self._backend = backend @property - def datalayer(self): + def db(self): """Return the datalayer.""" return self._backend._datalayer - @datalayer.setter - def datalayer(self, value): + @db.setter + def db(self, value): """Set the datalayer. :param value: The datalayer. """ - self._backend._datalayer = value + self._backend._db = value @property def type(self): diff --git a/superduper/backends/base/query.py b/superduper/backends/base/query.py index 1fce727af..290e0ffc2 100644 --- a/superduper/backends/base/query.py +++ b/superduper/backends/base/query.py @@ -1,19 +1,21 @@ +""" +Permitted patterns. + +type_1: table.like()[.filter(...)][.select(...)][.get() | .limit(...)]' +type_2: table[.filter(...)][.select(...)][.like()][.get() | .limit(...)]' + +Select always comes last, unless with `.get`, `.limit`. + +""" import dataclasses as dc -import importlib +import functools import json import re import typing as t import uuid -from abc import abstractmethod -from functools import cached_property, wraps +from types import MethodType from superduper import CFG, logging -from superduper.base.constant import ( - KEY_BLOBS, - KEY_BUILDS, - KEY_FILES, - KEY_SCHEMA, -) from superduper.base.document import Document, _unpack from superduper.base.leaf import Leaf @@ -21,39 +23,147 @@ from superduper.base.datalayer import Datalayer -def applies_to(*flavours): - """Decorator to check if the query matches the accepted flavours. +@dc.dataclass +class QueryPart: + """A method part of a query. - :param flavours: The flavours to check against. + :param name: The name of the method. + :param args: The arguments of the method. + :param kwargs: The keyword arguments of the method. """ - def decorator(f): - @wraps(f) - def decorated(self, *args, **kwargs): - msg = ( - f'Query {self} does not match any of accepted patterns {flavours},' - f' for the {f.__name__} method to which this method applies.' - ) + name: str + args: t.Sequence + kwargs: t.Dict - try: - flavour = self.flavour - except TypeError: - raise TypeError(msg) - assert flavour in flavours, msg - return f(self, *args, **kwargs) - return decorated +@dc.dataclass +class Op(QueryPart): + """An operation part of a query. + + :param name: The name of the method. + :param args: The arguments of the method. + :param kwargs: The keyword arguments of the method. + :param symbol: The symbol of the operation. + """ + + symbol: str - return decorator +@dc.dataclass +class Decomposition: + """ + Decompose a query into its parts. + :param table: The table to use. + :param db: The datalayer to use. + :param col: The column to use. + :param insert: The insert part of the query. + :param pre_like: The pre-like part of the query. + :param post_like: The post-like part of the query. + :param filter: The filter part of the query. + :param select: The select part of the query. + :param get: The get part of the query. + :param limit: The limit part of the query. + :param outputs: The outputs part of the query. + :param op: The operation part of the query. + """ + + table: str + db: 'Datalayer' + col: str | None = None + insert: QueryPart | None = None + pre_like: QueryPart | None = None + post_like: QueryPart | None = None + filter: QueryPart | None = None + select: QueryPart | None = None + get: QueryPart | None = None + limit: QueryPart | None = None + outputs: QueryPart | None = None + op: Op | None = None + + @property + def predict_ids(self): + if self.outputs: + return self.outputs.args + return [] + + def to_query(self): + """Convert decomposition back to a ``Query``.""" + q = self.db[self.table] + + if self.pre_like: + q = q + self.pre_like + + if self.filter: + q = q + self.filter + + if self.outputs: + q = q + self.outputs + + if self.select: + q = q + self.select + + if self.post_like: + q = q + self.post_like + + if self.get: + assert not self.limit + q = q + self.get + + if self.limit: + q = q + self.limit + + return q + + def copy(self): + return self.to_query().copy().decomposition + + +def _stringify(item, documents, queries): + if isinstance(item, dict): + documents.append(item) + out = f'documents[{len(documents) - 1}]' + elif isinstance(item, list): + old_len = len(documents) + documents.extend(item) + out = f'documents[{old_len}:{len(documents)}]' + elif isinstance(item, Query): + out = f'query[{len(queries)}]' + queries.append(item.stringify(documents, queries)) + elif isinstance(item, Op): + arg = _stringify(item.args[0], documents, queries) + return f' {item.symbol} {arg}' + elif isinstance(item, QueryPart): + args = [_stringify(a, documents, queries) for a in item.args] + kwargs = {k: _stringify(v, documents, queries) for k, v in item.kwargs.items()} + parameters = '' + if args and kwargs: + parameters = ( + ', '.join(args) + + ', ' + + ', '.join([f'{k}={v}' for k, v in kwargs.items()]) + ) + elif args: + parameters = ', '.join(args) + elif kwargs: + parameters = ', '.join([f'{k}={v}' for k, v in kwargs.items()]) + return f'.{item.name}({parameters})' + else: + try: + out = json.dumps(item) + except Exception: + documents.append(item) + out = f'documents[{len(documents) - 1}]' + return out + + +# TODO add to regular Query class class _BaseQuery(Leaf): - parts: t.Sequence[t.Union[t.Tuple, str]] = dc.field(default_factory=list) + parts: t.Sequence[t.Union[QueryPart, str]] = dc.field(default_factory=list) def __post_init__(self, db: t.Optional['Datalayer'] = None): super().__post_init__(db) - self._is_output_query = False - self._updated_key = None if not self.identifier: self.identifier = self._build_hr_identifier() self.identifier = re.sub('[^a-zA-Z0-9\-]', '-', self.identifier) @@ -61,11 +171,8 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None): def unpack(self): parts = _unpack(self.parts) - return type(self)( - db=self.db, - table=self.table, - parts=parts, - identifier=self.identifier, + return _from_parts( + impl=self.__class__, table=self.table, parts=parts, db=self.db ) def _build_hr_identifier(self): @@ -82,50 +189,34 @@ def _build_hr_identifier(self): identifier = identifier.replace(f'#{i}', v) return identifier - def __getattr__(self, item): - return type(self)( - db=self.db, - table=self.table, - parts=[*self.parts, item], - ) - - def __call__(self, *args, **kwargs): - """Add a method call to the query. - - :param args: The arguments to pass to the method. - :param kwargs: The keyword arguments to pass to the method. - """ - assert isinstance(self.parts[-1], str) - return type(self)( - db=self.db, - table=self.table, - parts=[*self.parts[:-1], (self.parts[-1], args, kwargs)], - ) - def _to_str(self): - documents = {} + documents = [] queries = {} out = str(self.table) for part in self.parts: if isinstance(part, str): - out += f'.{part}' - continue + if isinstance(getattr(self.__class__, part, None), property): + out += f'.{part}' + continue + else: + out += f'["{part}"]' + continue args = [] - for a in part[1]: + for a in part.args: args.append(self._update_item(a, documents, queries)) args = ', '.join(args) kwargs = {} - for k, v in part[2].items(): + for k, v in part.kwargs.items(): kwargs[k] = self._update_item(v, documents, queries) kwargs = ', '.join([f'{k}={v}' for k, v in kwargs.items()]) - if part[1] and part[2]: - out += f'.{part[0]}({args}, {kwargs})' - if not part[1] and part[2]: - out += f'.{part[0]}({kwargs})' - if part[1] and not part[2]: - out += f'.{part[0]}({args})' - if not part[1] and not part[2]: - out += f'.{part[0]}()' + if part.args and part.kwargs: + out += f'.{part.name}({args}, {kwargs})' + if not part.args and part.kwargs: + out += f'.{part.name}({kwargs})' + if part.args and not part.kwargs: + out += f'.{part.name}({args})' + if not part.args and not part.kwargs: + out += f'.{part.name}()' return out, documents, queries def _dump_query(self): @@ -134,37 +225,233 @@ def _dump_query(self): output = '\n'.join(list(queries.values())) + '\n' + output for i, k in enumerate(queries): output = output.replace(k, str(i)) - for i, k in enumerate(documents): - output = output.replace(k, str(i)) - documents = list(documents.values()) return output, documents @staticmethod def _update_item(a, documents, queries): if isinstance(a, Query): a, sub_documents, sub_queries = a._to_str() - documents.update(sub_documents) + if documents: + for i in range(len(sub_documents)): + a = a.replace(f'documents[{i}]', f'documents[{i + len(documents)}]') + documents.extend(sub_documents) queries.update(sub_queries) id_ = uuid.uuid4().hex[:5].upper() queries[id_] = a arg = f'query[{id_}]' else: - id_ = uuid.uuid4().hex[:5].upper() if isinstance(a, dict): - documents[id_] = a - arg = f'documents[{id_}]' + documents.append(a) + arg = f'documents[{len(documents) - 1}]' elif isinstance(a, list): - documents[id_] = {'_base': a} - arg = f'documents[{id_}]' + old_len = len(documents) + documents.extend(a) + arg = f'documents[{old_len}:{len(documents)}]' else: try: arg = json.dumps(a) except Exception: - documents[id_] = {'_base': a} - arg = id_ + documents.append(a) + arg = f'documents[{len(documents) - 1}]' return arg +def bind(f): + """Bind a method to a query object. + + :param f: The method to bind. + """ + + @functools.wraps(f) + def decorated(self, *args, **kwargs): + out = f(self, *args, **kwargs) + children = self.mapping[f.__name__] + for method in children: + out._bind_base_method(method, eval(method)) + return out + + decorated.__name__ = f.__name__ + return decorated + + +@bind +def limit(self, n: int): + """Limit the number of results returned by the query. + + # noqa + + :param n: The number of results to return. + """ + # always the last one + assert not self.decomposition.limit + assert not self.decomposition.get + return self + QueryPart('limit', (n,), {}) + + +def insert(self, documents): + """Insert documents into the table. + + # noqa + """ + self.db.pre_insert(self.table, documents) + out = self.db.databackend._do_insert(self.table, documents) + self.db.post_insert(self.table, ids=out) + return out + + +@bind +def outputs(self, *predict_ids): + """Add outputs to the query. + + # noqa + + :param predict_ids: The predict_ids to add. # noqa + """ + d = self.decomposition + + assert not d.outputs + + d.outputs = QueryPart('outputs', predict_ids, {}) + + return d.to_query() + + +def get(self, eager_mode: bool = False, **kwargs): + """Get a single row of data. + + # noqa + """ + query = self + if kwargs: + filters = [] + t = self.db[self.table] + for k, v in kwargs.items(): + filters.append(t[k] == v) + query = self.filter(*filters) + result = query.db.databackend.get(query) + if eager_mode: + return self._convert_eager_mode_results(result) + return result + + +def ids(self): + """Get the primary ids of the query. + + # noqa + """ + msg = '.ids only applicable to select queries' + assert self.type == 'select', msg + q = self.select(self.primary_id) + pid = self.primary_id.execute() + results = q.execute() + return [str(r[pid]) for r in results] + + +# TODO use this/ test this +def missing_outputs(self, predict_id): + """Get missing outputs for a given predict_id. + + # noqa + + :param predict_id: The predict_id to check. + """ + return self.db.databackend.missing_outputs(self, predict_id) + + +# TODO use this in the code to split jobs in parts +def chunks(self, n: int): + """Split a query into chunks of size n. + + # noqa + + :param n: The size of the chunks. + """ + assert self.type == 'select' + t = self.db[self.table] + ids = self.select(t.primary_id).execute() + for i in range(0, len(ids), n): + yield self.subset(ids[i : i + n]) + + +@bind +def select(self, *cols): + """Create a select query selecting certain fields/ cols. + + # noqa + + :param cols: The columns to select. + + >>> from superduper import superduper + >>> db = superduper() + >>> db['table'].insert({'col': 1, 'other': 2}) + >>> results = db['table'].select('col').execute() + >>> list(results[0].keys()) + ['col'] + """ + d = self.decomposition + + if d.select: + d.select = QueryPart( + 'select', + (*d.select.args, *cols), + {}, + ) + else: + d.select = QueryPart('select', cols, {}) + + return d.to_query() + + +@bind +def filter(self, *filters): + """Create a filter query. + + # noqa + + :param filters: The filters to apply. + + >>> from superduper import superduper + >>> db = superduper() + >>> t = db['table'] + >>> t.insert({'col': 1}) + >>> results = t.filter(t['col'] == 1, t['col'] > 0).execute() + >>> len(results) + 1 + """ + d = self.decomposition + + if d.filter: + d.filter = QueryPart('filter', args=(*d.filter.args, *filters), kwargs={}) + else: + d.filter = QueryPart('filter', args=filters, kwargs={}) + + return d.to_query() + + +@bind +def like(self, r: t.Dict, vector_index: str, n: int = 10): + """Create a similarity query with a vector_index. + + # noqa + + :param r: The vector to compare against. + :param vector_index: The index of the vector. + :param n: The number of results to return. + """ + return self + QueryPart('like', args=(r, vector_index), kwargs={'n': n}) + + +SYMBOLS = { + '__eq__': '==', + '__ne__': '!=', + '__le__': '<=', + '__ge__': '>=', + '__lt__': '<', + '__gt__': '>', + 'isin': 'in', +} + + class Query(_BaseQuery): """A query object. @@ -175,189 +462,182 @@ class Query(_BaseQuery): :param parts: The parts of the query. """ - flavours: t.ClassVar[t.Dict[str, str]] = {} + # mapping between methods and allowed downstream methods + # base methods are at the key level + mapping: t.ClassVar[t.Dict] = { + 'insert': [], + 'select': [ + 'filter', + 'outputs', + 'like', + 'limit', + 'select', + 'ids', + 'missing_outputs', + 'chunks', + 'get', + ], + 'filter': [ + 'filter', + 'outputs', + 'like', + 'limit', + 'select', + 'ids', + 'missing_outputs', + 'chunks', + 'get', + ], + 'like': ['select', 'filter', 'ids', 'missing_outputs', 'get', 'limit'], + 'outputs': [ + 'filter', + 'limit', + 'ids', + 'missing_outputs', + 'chunks', + 'get', + 'select', + ], + 'limit': [], + 'ids': [], + 'get': [], + } + flavours: t.ClassVar[t.Dict[str, str]] = {} table: str identifier: str = '' - @property - def tables(self): - """Tables contained in the ``Query`` object.""" - out = [] - for part in self.parts: - if part[0] == 'outputs': - out.extend([f'{CFG.output_prefix}{x}' for x in part[1]]) - out.append(self.table) - return list(set(out)) - - def __getitem__(self, item): - if isinstance(item, str): - return getattr(self, item) - if not isinstance(item, slice): - raise TypeError('Query index must be a string or a slice') - assert isinstance(item, slice) - parts = self.parts[item] - return type(self)(db=self.db, table=self.table, parts=parts) + def __post_init__(self, db=None): + out = super().__post_init__(db) + if not self.parts: + for method in self.mapping: + self._bind_base_method(method, eval(method)) + elif self.parts: + if isinstance(self.parts[-1], str): + name = self.parts[-1] + self._bind_base_method('filter', filter) + else: + name = self.parts[-1].name - # TODO - not necessary: either `Document.decode(r, db=db)` - # or `db['table'].select...` + try: + for method in self.mapping[name]: + self._bind_base_method(method, eval(method)) + except KeyError: + pass - # TODO why necessary? - def set_db(self, value: 'Datalayer'): - """Set the datalayer to use to execute the query. + if self.type == 'insert': + self._add_fold_to_insert() - :param db: The datalayer to use to execute the query. - """ + return out - def _set_the_db(r, db): - if isinstance(r, (tuple, list)): - out = [_set_the_db(x, db) for x in r] - return out - if isinstance(r, Document): - return Document({k: _set_the_db(v, db) for k, v in r.items()}) - if isinstance(r, dict): - return {k: _set_the_db(v, db) for k, v in r.items()} - if isinstance(r, Query): - r.db = db - return r + def _add_fold_to_insert(self): + assert self.type == 'insert' + documents = self[-1].args[0] + import random - return r + for r in documents: + r.setdefault( + '_fold', + 'train' if random.random() >= CFG.fold_probability else 'valid', + ) - self._db = value + @property + def decomposition(self): + out = Decomposition(table=self.table, db=self.db) - # Recursively set db - parts: t.List[t.Union[str, tuple]] = [] - for part in self.parts: + for i, part in enumerate(self.parts): if isinstance(part, str): - parts.append(part) + out.col = part continue - part_args = tuple(_set_the_db(part[1], value)) - part_kwargs = _set_the_db(part[2], value) - part = part[0] - parts.append((part, part_args, part_kwargs)) - self.parts = parts - # TODO need this? - @property - def is_output_query(self): - """Check if query is of output type.""" - return self._is_output_query + if i == 0 and part.name == 'like': + out.pre_like = part + continue - @is_output_query.setter - def is_output_query(self, b): - """Property setter.""" - self._is_output_query = b + if part.name == 'like': + out.post_like = part + continue - # TODO necessary? - @property - def updated_key(self): - """Return query updated key.""" - return self._updated_key + if isinstance(part, Op): + out.op = Op + continue - @updated_key.setter - def updated_key(self, update): - """Property setter.""" - self._updated_key = update + msg = f'Found unexpected query part "{part.name}"' + assert part.name in [f.name for f in dc.fields(out)], msg + setattr(out, part.name, part) - def _get_flavour(self): - _query_str = self._to_str() - repr_ = _query_str[0] + return out - if repr_ == self.table and not (_query_str[0] and _query_str[-1]): - # Table selection query. - return 'select' + def _bind_base_method(self, name, method): + method = MethodType(method, self) + setattr(self, name, method) - try: - return next(k for k, v in self.flavours.items() if re.match(v, repr_)) - except StopIteration: - raise TypeError( - f'Query flavour {repr_} did not match existing {type(self)} flavours' - ) + def stringify(self, documents, queries): + parts = [] + for part in self.parts: + if isinstance(part, str): + if part == 'primary_id': + parts.append('.primary_id') + else: + parts.append(f'["{part}"]') + continue + parts.append(_stringify(part, documents, queries)) + parts = ''.join(parts) + return f'{self.table}{parts}' - def _get_parent(self): - return self.db.databackend.get_table_or_collection(self.table) + @property + def type(self): + if self.parts and isinstance(self[-1], QueryPart) and self[-1].name == 'insert': + return 'insert' + if 'delete' in str(self): + return 'delete' + return 'select' - def _execute_select(self, parent): - raise NotImplementedError + @property + def tables(self): + """Tables contained in the ``Query`` object.""" + out = [] + for part in self.parts: + if part.name == 'outputs': + out.extend([f'{CFG.output_prefix}{x}' for x in part.args]) + out.append(self.table) + return list(set(out)) - def _prepare_pre_like(self, parent): - like_args, like_kwargs = self.parts[0][1:] - like_args = list(like_args) - if not like_args: - like_args = [{}] - like = like_args[0] or like_kwargs.pop('r', {}) - if isinstance(like, Document): - like = like.unpack() + def __len__(self): + return len(self.parts) + 1 - ids = like_kwargs.pop('within_ids', []) + def __getitem__(self, item): + # supports queries which use strings to index + if isinstance(item, str): + return self + item - n = like_kwargs.pop('n', 100) + if isinstance(item, int): + return self.parts[item] - vector_index = like_kwargs.get('vector_index') + if not isinstance(item, slice): + raise TypeError('Query index must be a string or a slice') - similar_ids, similar_scores = self.db.select_nearest( - like, - vector_index=vector_index, - ids=ids, - n=n, - ) - similar_scores = dict(zip(similar_ids, similar_scores)) - return similar_ids, similar_scores + assert isinstance(item, slice) - @property - def flavour(self): - """Return the flavour of the query.""" - return self._get_flavour() + parts = self.parts[item] - @cached_property - def documents(self): - """Return the documents.""" - - def _update_part(documents): - nonlocal self - doc_args = (documents, *self.parts[0][1][1:]) - insert_part = (self.parts[0][0], doc_args, self.parts[0][2]) - return [insert_part] + self.parts[1:] - - documents = self.parts[0][1][0] - one_document = isinstance(documents, (dict, Document)) - if one_document: - documents = [documents] - wrapped_documents = [] - - for document in documents: - document = Document(document) - wrapped_documents.append(document) - - if one_document: - self.parts = _update_part(wrapped_documents[0]) - else: - self.parts = _update_part(wrapped_documents) - return wrapped_documents + return self.__class__(db=self.db, table=self.table, parts=parts) - @property - @abstractmethod - def type(self): - """Return the type of the query. + def copy(self): + r = self.dict() + del r['_path'] + del r['identifier'] + return parse_query(**r, db=self.db) - The type is used to route the correct method to execute the query in the - datalayer. - """ - pass - - def dict( - self, - metadata: bool = True, - defaults: bool = True, - uuid: bool = True, - refs: bool = False, - ): + def dict(self, *args, **kwargs): """Return the query as a dictionary.""" - query, documents = self._dump_query() - documents = [Document(r) for r in documents] + documents = [] + queries = [] + _stringify(self, documents=documents, queries=queries) + query = '\n'.join(queries) return Document( { - '_path': f'{self.__module__}.parse_query', + '_path': 'superduper.backends.base.query.parse_query', 'documents': documents, 'identifier': self.identifier, 'query': query, @@ -365,24 +645,40 @@ def dict( ) def __repr__(self): - output, docs = self._dump_query() - for i, doc in enumerate(docs): - doc_string = str(doc) - if isinstance(doc, Document): - r = doc.unpack() - if '_base' in r: - r = r['_base'] - doc_string = str(r) - output = output.replace(f'documents[{i}]', doc_string) - return output + r = self.dict() + query = r['query'].split('\n')[-1] + queries = r['query'].split('\n')[:-1] + for i, q in enumerate(queries): + query = query.replace(f'query[{i}]', q) + + doc_refs = re.findall('documents\[([0-9]+)\]', query) + if doc_refs: + for numeral in doc_refs: + query = query.replace( + f'documents[{numeral}]', str(r['documents'][int(numeral)]) + ) - def _ops(self, op, other): - return type(self)( - db=self.db, + doc_segs = re.findall('documents\[([0-9]+):([0-9]+)\]', query) + if doc_segs: + for n1, n2 in doc_segs: + query = query.replace( + f'documents[{n1}:{n2}]', str(r['documents'][int(n1) : int(n2)]) + ) + + return query + + def __add__(self, other: QueryPart | str): + return Query( table=self.table, - parts=self.parts + [(op, (other,), {})], + parts=[*self.parts, other], + db=self.db, ) + def _ops(self, op, other): + msg = 'Can only compare based on a column' + assert isinstance(self.parts[-1], str), msg + return self + Op(op, args=(other,), kwargs={}, symbol=SYMBOLS[op]) + def __eq__(self, other): return self._ops('__eq__', other) @@ -442,36 +738,10 @@ def _encode_or_unpack_args(self, r, db, method='encode', parent=None): return r - def _execute(self, parent, method='encode'): - return self._get_chain_native_query(parent, self.parts, method) - - def _get_chain_native_query(self, parent, parts, method='encode'): - try: - for part in parts: - if isinstance(part, str): - parent = getattr(parent, part) - continue - args = self._encode_or_unpack_args( - part[1], self.db, method=method, parent=parent - ) - kwargs = self._encode_or_unpack_args( - part[2], self.db, method=method, parent=parent - ) - parent = getattr(parent, part[0])(*args, **kwargs) - except Exception as e: - logging.error(f'Error in executing query, parts: {parts}') - raise e - - return parent - - @abstractmethod - def _create_table_if_not_exists(self): - pass - def complete_uuids( self, db: 'Datalayer', listener_uuids: t.Optional[t.Dict] = None ) -> 'Query': - """Complete the UUIDs with have been omitted from output-tables. + """Complete the UUIDs which have been omitted from output-tables. :param db: ``db`` instance. :param listener_uuids: identifier to UUIDs of listeners lookup @@ -481,7 +751,7 @@ def complete_uuids( r = copy.deepcopy(self.dict()) lines = r['query'].split('\n') - parser = importlib.import_module(self.__module__).parse_query + parser = parse_query def _get_uuid(identifier): if '.' in identifier: @@ -515,7 +785,7 @@ def _get_uuid(identifier): predict_ids = [eval(x.strip()) for x in group.split(',')] replace_ids = [] for predict_id in predict_ids: - if re.match(r'^.*__([0-9a-z]{15,})$', predict_id): + if re.match(r'^.*__([0-9a-z]{8,})$', predict_id): replace_ids.append(f'"{predict_id}"') continue listener_uuid = _get_uuid(predict_id) @@ -526,7 +796,7 @@ def _get_uuid(identifier): output_table_groups = re.findall(f'^{CFG.output_prefix}.*?\.', line) for group in output_table_groups: - if re.match(f'^{CFG.output_prefix}[^\.]+__([0-9a-z]{{15,}})\.$', group): + if re.match(f'^{CFG.output_prefix}[^\.]+__([0-9a-z]{{8,}})\.$', group): continue identifier = group[len(CFG.output_prefix) : -1] listener_uuid = _get_uuid(identifier) @@ -559,258 +829,125 @@ def swap_keys(r: str | list | dict): out = parser(**r, db=db) return out - def tolist(self, db=None, eager_mode=False, **kwargs): - """Execute and convert to list.""" - return self.execute(db=db, eager_mode=eager_mode, **kwargs).tolist() + @property + def primary_id(self): + return Query(table=self.table, parts=(), db=self.db) + 'primary_id' - def execute(self, db=None, eager_mode=False, handle_outputs=True, **kwargs): - """ - Execute the query. + @property + def documents(self): + return self.dict()['documents'] - :param db: Datalayer instance. - """ - if self.type == 'select' and handle_outputs and 'outputs' in str(self): - query = self.complete_uuids(db=db or self.db) - return query.execute( - db=db, eager_mode=eager_mode, **kwargs, handle_outputs=False - ) - self.db = db or self.db - results = self.db.execute(self, **kwargs) - if eager_mode and self.type == 'select': - results = self._convert_eager_mode_results(results) + def subset(self, ids: t.Sequence[str]): + # TODO broken + + assert self.type == 'select' + + # mypy nonsense + from superduper.base.datalayer import Datalayer + + assert isinstance(self.db, Datalayer) + + t = self.db[self.table] + modified_query = self.filter(t.primary_id.isin(ids)) + + return modified_query.execute() + + def execute(self, eager_mode=False): + if self.parts and self.parts[0] == 'primary_id': + return self.db.databackend.primary_id(self) + results = self.db.databackend.execute(self) + if eager_mode: + return self._convert_eager_mode_results(results) return results def _convert_eager_mode_results(self, results): - from superduper.base.cursor import SuperDuperCursor from superduper.misc.eager import SuperDuperData, SuperDuperDataType new_results = [] - query = self - if not len(query.parts): - query = query.select() - if isinstance(results, (SuperDuperCursor, list)): + if isinstance(results, list): for r in results: r = Document(r.unpack()) - sdd = SuperDuperData(r, type=SuperDuperDataType.DATA, query=query) + sdd = SuperDuperData(r, type=SuperDuperDataType.DATA, query=self) new_results.append(sdd) - return new_results elif isinstance(results, dict): - return SuperDuperData(results, type=SuperDuperDataType.DATA, query=query) + return SuperDuperData(results, type=SuperDuperDataType.DATA, query=self) raise ValueError(f'Cannot convert {results} to eager mode results') - def do_execute(self, db=None): - """ - Execute the query. - - This methold will first create the table if it does not exist and then - execute the query. - All the methods matching the pattern `_execute_{flavour}` will be - called if they exist. +def _parse_op_part(table, col, symbol, operand, db, documents=()): + operand = eval(operand, {'documents': documents}) - If no such method exists, the `_execute` method will be called. - - :param db: The datalayer to use to execute the query. - """ - self.db = db or self.db - assert self.db is not None, 'No datalayer (db) provided' - self._create_table_if_not_exists() - parent = self._get_parent() - try: - flavour = self._get_flavour() - handler = f'_execute_{flavour}' in dir(self) - if handler is False: - raise AssertionError - handler = getattr(self, f'_execute_{flavour}') - return handler(parent=parent) - except TypeError as e: - logging.error(f'Error in executing query: {self}') - if 'did not match' in str(e): - return self._execute(parent=parent) - else: - raise e - except AssertionError: - return self._execute(parent=parent) - - @property - @abstractmethod - def primary_id(self): - """Return the primary id of the table.""" - pass - - @abstractmethod - def model_update( - self, - ids: t.List[t.Any], - predict_id: str, - outputs: t.Sequence[t.Any], - flatten: bool = False, - **kwargs, - ): - """Update the model outputs in the database. - - :param ids: The ids of the documents to update. - :param predict_id: The id of the prediction. - :param outputs: The outputs to store. - :param flatten: Whether to flatten the outputs. - :param kwargs: Additional keyword arguments. - """ - pass + reverse = dict(zip(SYMBOLS.values(), SYMBOLS.keys())) - @abstractmethod - def add_fold(self, fold: str): - """Add a fold to the query. - - :param fold: The fold to add. - """ - pass - - @abstractmethod - def select_using_ids(self, ids: t.Sequence[str]): - """Return a query that selects ids. - - :param ids: The ids to select. - """ - pass - - @property - @abstractmethod - def select_ids(self, ids: t.Sequence[str]): - """Return a query that selects ids. - - :param ids: The ids to select. - """ - pass - - @abstractmethod - def select_ids_of_missing_outputs(self, predict_id: str): - """Return the ids of missing outputs. - - :param predict_id: The id of the prediction. - """ - pass - - @abstractmethod - def select_single_id(self, id: str): - """Return a single document by id. - - :param id: The id of the document. - """ - pass - - @property - @abstractmethod - def select_table(self): - """Return the table to select from.""" - pass - - def _prepare_documents(self): - documents = self.documents - kwargs = self.parts[0][2] - schema = kwargs.pop('schema', None) + if col != 'primary_id': + out = getattr(db[table][col], reverse[symbol])(operand) + else: + out = getattr(db[table].primary_id, reverse[symbol])(operand) - if schema is None: - try: - table = self.db.load('table', self.table) - schema = table.schema - except FileNotFoundError: - pass + return out - documents = [ - r.encode(schema) if isinstance(r, Document) else r for r in documents - ] - for r in documents: - r = self.db.artifact_store.save_artifact(r) - r.pop(KEY_BUILDS) - r.pop(KEY_BLOBS) - r.pop(KEY_FILES) - r.pop(KEY_SCHEMA, None) - return documents - - # TODO deprecate (self.table) - @property - def table_or_collection(self): - """Return the table or collection to select from.""" - return type(self)(table=self.table, db=self.db) - def _execute_pre_like(self, parent): - assert self.parts[0][0] == 'like' - assert self.parts[1][0] in ['find', 'find_one', 'select'] +def _parse_query_part(part, documents, query, db): + pattern = ( + '^([a-zA-Z0-9_]+)\["([a-zA-Z0-9_]+)"\][ ]{0,}' + '([!=><]=|==|!=|<=|>=|<|>|in)[ ]{0,}(.*)[ ]{0,}$' + ) - similar_ids, similar_scores = self._prepare_pre_like(parent) + if match := re.match(pattern, part): + return _parse_op_part(*match.groups(), db, documents=documents) - query = self[1:] - query = query.filter(query[self.primary_id].isin(similar_ids)) - result = query.execute() - result.scores = similar_scores - return result + pattern = ( + '^([a-zA-Z0-9_]+)\.primary_id[ ]{0,}' + '([!=><]=|==|!=|<=|>=|<|>|in)[ ]{0,}(.*)[ ]{0,}$' + ) - def _execute_post_like(self, parent): - assert self.parts[0][0] in { - 'find', - 'select', - }, "Post like query must start with find/select" - if self.parts[-1][0] != 'like': - raise ValueError('Post like query must end with like') - like_kwargs = self.parts[-1][2] - like_args = self.parts[-1][1] - assert 'vector_index' in like_kwargs - - if not like_args and 'r' in like_kwargs: - like_args = (like_kwargs['r'],) - - assert like_args - - query = self[:-1] - result = list(query.execute()) - ids = [str(r[self.primary_id]) for r in query.execute()] - - similar_ids, scores = self.db.select_nearest( - like=like_args[0], - ids=ids, - vector_index=like_kwargs.get('vector_index'), - n=like_kwargs.get('n', 100), + if match := re.match(pattern, part): + return _parse_op_part( + match.groups()[0], + 'primary_id', + *match.groups()[1:], + db, + documents=documents, ) - scores = dict(zip(similar_ids, scores)) - - result = [r for r in result if str(r[self.primary_id]) in similar_ids] - from superduper.base.cursor import SuperDuperCursor + table = part.split('.', 1)[0] - cursor = SuperDuperCursor( - raw_cursor=result, - db=self.db, - id_field=self.primary_id, - ) - cursor.scores = scores - return cursor + rest_part = part[len(table) + 1 :] + col_match = re.match('^([a-zA-Z0-9]+)\["[a-zA-Z0-9]+"\]$', table) + if col_match: + table = col_match.groups()[0] -def _parse_query_part(part, documents, query, builder_cls, db=None): - if part.startswith(CFG.output_prefix): - predict_id = part[len(CFG.output_prefix) :].split('.')[0] - table = f'{CFG.output_prefix}{predict_id}' - rest_part = part[len(table) + 1 :] - else: - table = part.split('.', 1)[0] - rest_part = part[len(table) + 1 :] - - # The format of the rest part should be a chain of '.method(args, kwargs)' parts = re.findall(r'\.([a-zA-Z0-9_]+)(\(.*?\))?', "." + rest_part) + + # TODO what's this clause? recheck_part = ".".join(p[0] + p[1] for p in parts) if recheck_part != rest_part: raise ValueError(f'Invalid query part: {part} != {recheck_part}') - current = builder_cls(table=table, parts=(), db=db) + new_parts = [] + for part in parts: + if ( + isinstance(part, str) + and re.match('^[a-zA-Z0-9]+\["[a-zA-Z0-9]+"\]$', part) is not None + ): + new_parts.extend(part.split('[')[0], part.split(']').strip()[:-1]) + continue + new_parts.append(part) + + current = Query(table=table, parts=(), db=db) + for part in parts: comp = part[0] + part[1] match = re.match(r'^([a-zA-Z0-9_]+)\((.*)\)$', comp) + if match is None: current = getattr(current, comp) continue + if not match.groups()[1].strip(): current = getattr(current, match.groups()[0])() continue @@ -826,78 +963,37 @@ def _parse_query_part(part, documents, query, builder_cls, db=None): else: args.append(eval(x, {'documents': documents, 'query': query})) current = comp(*args, **kwargs) + + return current + + +def _from_parts(impl, table, parts, db): + current = impl(table=table, parts=(), db=db) + for part in parts: + if isinstance(part, str): + try: + current = getattr(current, part) + except AttributeError: + current = current[part] + continue + current = getattr(current, part.name)(*part.args, **part.kwargs) return current def parse_query( query: t.Union[str, list], - builder_cls: t.Optional[t.Type[Query]] = None, documents: t.Sequence[t.Any] = (), db: t.Optional['Datalayer'] = None, ): """Parse a string query into a query object. :param query: The query to parse. - :param builder_cls: The class to use to build the query. :param documents: The documents to query. :param db: The datalayer to use to execute the query. """ - if ( - isinstance(query, str) - and 'predict' in query - and query.split('\n')[-1].strip().split('.')[1].startswith('predict') - ): - builder_cls = Model - return _parse_query_part(query, documents, [], builder_cls, db=db) - - builder_cls = builder_cls or Query - documents = [Document(r, db=db) for r in documents] if isinstance(query, str): query = [x.strip() for x in query.split('\n') if x.strip()] for i, q in enumerate(query): - query[i] = _parse_query_part(q, documents, query[:i], builder_cls, db=db) - return query[-1] - - -class Model(_BaseQuery): - """ - A model helper class for create a query to predict. - - :param table: The table to use. - :param parts: The parts of the query. - """ - - table: str - identifier: str = '' - type: t.ClassVar[str] = 'predict' - - def execute(self): - """Execute the model as a query.""" - return self.db.execute(self) + query[i] = _parse_query_part(q, documents, query[:i], db=db) - def do_execute(self, db=None): - """Execute the query. - - :param db: Datalayer instance. - """ - self.db = db - m = self.db.load('model', self.table) - method = getattr(m, self.parts[-1][0]) - r = method(*self.parts[-1][1], **self.parts[-1][2]) - if isinstance(r, dict): - return Document(r) - else: - return Document({'_base': r}) - - def dict(self, metadata: bool = True, defaults: bool = True): - """Return the query as a dictionary.""" - query, documents = self._dump_query() - documents = [Document(r) for r in documents] - return Document( - { - '_path': f'{self.__module__}.parse_query', - 'documents': documents, - 'identifier': self.identifier, - 'query': query, - } - ) + return query[-1] diff --git a/superduper/backends/query_dataset.py b/superduper/backends/query_dataset.py index d98c6a3af..3cda9ca47 100644 --- a/superduper/backends/query_dataset.py +++ b/superduper/backends/query_dataset.py @@ -51,28 +51,32 @@ def __init__( self._db = db self.transform = transform + if fold is not None: - self.select = select.add_fold(fold) + assert db is not None + fold_filter = db[select.table]['_fold'] == fold + self.select = select.filter(fold_filter) else: self.select = select self.in_memory = in_memory if self.in_memory: if ids is None: - self._documents = list(self.db.execute(self.select)) + self._documents = self.select.execute() else: - self._documents = list( - self.db.execute(self.select.select_using_ids(ids)) - ) + self._documents = self.select.subset(ids) else: if ids is None: - self._ids = [ - r[self.select.id_field] - for r in self.db.execute(self.select.select_ids) - ] + self._ids = self.select.ids() else: self._ids = ids - self.select_one = self.select.select_single_id + + # TODO replace by adding parameters to `.get` + assert db is not None + t = db[self.select.table] + self.select_one = lambda x: next( + self.select.filter(t[t.primary_id] == x).execute() + ) self.mapping = mapping diff --git a/superduper/base/apply.py b/superduper/base/apply.py index b3550a154..7089e9576 100644 --- a/superduper/base/apply.py +++ b/superduper/base/apply.py @@ -239,7 +239,7 @@ def replace_existing(x): replace_k = replace_k.replace(uuid, non_breaking_changes[uuid]) replace[replace_k] = doc[k] r['documents'][i] = replace - x = Document.decode(r).unpack() + x = Document.decode(r, db=db).unpack() else: raise TypeError("Unexpected target of substitution in db.apply") diff --git a/superduper/base/build.py b/superduper/base/build.py index c7785d202..9fbb481f3 100644 --- a/superduper/base/build.py +++ b/superduper/base/build.py @@ -42,8 +42,9 @@ def create(cls, uri): f"{plugin} with flavour {flavour} not supported " "to create metadata store." ) - impl = getattr(load_plugin(plugin), cls.impl) - return impl(uri, flavour=flavour) + plugin = load_plugin(plugin) + impl = getattr(plugin, cls.impl) + return impl(uri, flavour=flavour, plugin=plugin) class _MetaDataLoader(_Loader): diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index c2aba72dc..55b557671 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -17,13 +17,11 @@ from superduper.base.cursor import SuperDuperCursor from superduper.base.document import Document from superduper.components.component import Component -from superduper.components.datatype import BaseDataType from superduper.components.schema import Schema from superduper.components.table import Table from superduper.misc.annotations import deprecated from superduper.misc.colors import Colors from superduper.misc.importing import import_object -from superduper.misc.retry import db_retry DBResult = t.Any TaskGraph = t.Any @@ -69,7 +67,7 @@ def __init__( self.artifact_store.db = self self.databackend = databackend - self.databackend.datalayer = self + self.databackend.db = self self.cluster = cluster self.cluster.db = self @@ -78,7 +76,7 @@ def __init__( self.startup_cache: t.Dict[str, t.Any] = {} def __getitem__(self, item): - return self.databackend.get_query_builder(item) + return Query(table=item, parts=(), db=self) @property def cdc(self): @@ -138,7 +136,7 @@ def drop(self, force: bool = False, data: bool = False): self.databackend.drop(force=True) self.artifact_store.drop(force=True) else: - self.databackend.drop_outputs() + self.databackend.drop_table() self.metadata.drop(force=True) @@ -193,89 +191,17 @@ def show( type_id=type_id, identifier=identifier, version=version ) - @db_retry(connector='databackend') - def execute(self, query: Query, *args, **kwargs) -> ExecuteResult: - """Execute a query on the database. - - :param query: The SQL query to execute, such as select, insert, - delete, or update. - :param args: Positional arguments to pass to the execute call. - :param kwargs: Keyword arguments to pass to the execute call. - """ - if query.type == 'delete': - return self._delete(query, *args, **kwargs) - if query.type == 'insert': - return self._insert(query, *args, **kwargs) - if query.type == 'select': - return self._select(query, *args, **kwargs) - if query.type == 'write': - return self._write(query, *args, **kwargs) - if query.type == 'update': - return self._update(query, *args, **kwargs) - if query.type == 'predict': - return self._predict(query, *args, **kwargs) - - raise TypeError( - f'Wrong type of {query}; ' - f'Expected object of type "delete", "insert", "select", "update"' - f'Got {type(query)};' - ) - - def _predict(self, prediction: t.Any) -> PredictResult: - return prediction.do_execute(self) - - def _delete(self, delete: Query, refresh: bool = True) -> DeleteResult: - """ - Delete data from the database. - - :param delete: The delete query object specifying the data to be deleted. - """ - result = delete.do_execute(self) - # TODO - do we need this refresh? - # If the handle-event works well, then we should not need this - - if not refresh: - return - - call_cdc = ( - delete.query.table in self.metadata.show_cdc_tables() - and delete.query.table.startswith(CFG.output_prefix) - ) - if call_cdc: - self.cluster.cdc.handle_event( - event_type='delete', table=delete.table, ids=result - ) - return result - - def _insert( - self, - insert: Query, - refresh: bool = True, - datatypes: t.Sequence[BaseDataType] = (), - auto_schema: bool = True, - ) -> InsertResult: - """ - Insert data into the database. - - :param insert: The insert query object specifying the data to be inserted. - :param refresh: Boolean indicating whether to refresh the task group on insert. - :param datatypes: List of datatypes in the insert documents. - :param auto_schema: Toggle to False to switch off automatic schema creation. - """ - for e in datatypes: - self.add(e) - - if not insert.documents: - logging.info(f'No documents to insert into {insert.table}') - return [] - - for r in insert.documents: + def pre_insert( + self, table: str, documents: t.Sequence[t.Dict], auto_schema: bool = True + ): + for r in documents: r.setdefault( '_fold', 'train' if random.random() >= s.CFG.fold_probability else 'valid', ) + if auto_schema and self.cfg.auto_schema: - schema = self._auto_create_table(insert.table, insert.documents).schema + self._auto_create_table(table, documents).schema timeout = 60 @@ -285,9 +211,9 @@ def _insert( exists = False while time.time() - start < timeout: try: - assert insert.table in self.show( + assert table in self.show( 'table' - ), f'{insert.table} not found, retrying...' + ), f'{table} not found, retrying...' exists = True except AssertionError as e: logging.warn(str(e)) @@ -297,29 +223,16 @@ def _insert( if not exists: raise TimeoutError( - f'{insert.table} not found after {timeout} seconds' + f'{table} not found after {timeout} seconds' ' table auto creation likely has failed or is stalling...' ) - for r in insert.documents: - r.schema = schema - - inserted_ids = insert.do_execute(self) - - logging.info(f'Inserted {len(inserted_ids)} documents into {insert.table}') - logging.debug(f'Inserted IDs: {inserted_ids}') - if not refresh: - return [] - - if ( - insert.table in self.metadata.show_cdc_tables() - and not insert.table.startswith(CFG.output_prefix) + def post_insert(self, table: str, ids: t.Sequence[str]): + if table in self.metadata.show_cdc_tables() and not table.startswith( + CFG.output_prefix ): - self.cluster.cdc.handle_event( - event_type='insert', table=insert.table, ids=inserted_ids - ) - - return inserted_ids + logging.info(f'CDC for {table} is enabled') + self.cluster.cdc.handle_event(event_type='insert', table=table, ids=ids) def _auto_create_table(self, table_name, documents): try: @@ -330,20 +243,12 @@ def _auto_create_table(self, table_name, documents): # Should we need to check all the documents? document = documents[0] - schema = document.schema or self.infer_schema(document) + schema = self.infer_schema(document) table = Table(identifier=table_name, schema=schema) logging.info(f"Creating table {table_name} with schema {schema.fields_set}") self.apply(table, force=True) return table - def _select(self, select: Query, reference: bool = True) -> SelectResult: - """ - Select data from the database. - - :param select: The select query object specifying the data to be retrieved. - """ - return select.do_execute(db=self) - def on_event(self, table: str, ids: t.List[str], event_type: 'str'): """ Trigger computation jobs after data insertion. @@ -795,15 +700,6 @@ def select_nearest( def disconnect(self): """Gracefully shutdown the Datalayer.""" - logging.info("Disconnect from Data Store") - self.databackend.disconnect() - - logging.info("Disconnect from Metadata Store") - self.metadata.disconnect() - - logging.info("Disconnect from Artifact Store") - self.artifact_store.disconnect() - logging.info("Disconnect from Cluster") self.cluster.disconnect() @@ -816,8 +712,10 @@ def infer_schema( :param identifier: The identifier for the schema, if None, it will be generated :return: The inferred schema """ - out = self.databackend.infer_schema(data, identifier) - return out + # TODO have a slightly more user-friendly schema + from superduper.misc.auto_schema import infer_schema + + return infer_schema(data, identifier=identifier) @property def cfg(self) -> Config: diff --git a/superduper/base/document.py b/superduper/base/document.py index 1dd98dce9..2f9f0f8df 100644 --- a/superduper/base/document.py +++ b/superduper/base/document.py @@ -46,10 +46,6 @@ def __init__(self, getters=None): def add_getter(self, name: str, getter: t.Callable): """Add a getter for a reference type.""" self._getters[name].append(getter) - # if name == 'blob': - # self._getters[name].append(_build_blob_getter(getter)) - # else: - # self._getters[name].append(getter) def run(self, name, data): """Run the getters one by one until one returns a value.""" @@ -179,7 +175,7 @@ def update(self, other: t.Union['Document', dict]): if isinstance(other, Document) and other.schema: assert other.schema is not None - schema = schema.update(other.schema) + schema += other.schema return Document(_update(dict(self), dict(other)), schema=schema) def encode( diff --git a/superduper/base/enums.py b/superduper/base/enums.py deleted file mode 100644 index 4399a1e2a..000000000 --- a/superduper/base/enums.py +++ /dev/null @@ -1,13 +0,0 @@ -# TODO not needed -from enum import Enum - - -class DBType(str, Enum): - """ - DBType is an enumeration of the supported database types. - - # noqa - """ - - SQL = "SQL" - MONGODB = "MONGODB" diff --git a/superduper/base/leaf.py b/superduper/base/leaf.py index 52b570c94..2a1c277b9 100644 --- a/superduper/base/leaf.py +++ b/superduper/base/leaf.py @@ -230,7 +230,7 @@ def _replace_uuids_with_keys(record): } ) - def set_variables(self, **kwargs) -> 'Leaf': + def set_variables(self, db: t.Union['Datalayer', None] = None, **kwargs) -> 'Leaf': """Set free variables of self. :param db: Datalayer instance. @@ -241,7 +241,7 @@ def set_variables(self, **kwargs) -> 'Leaf': r = self.encode() rr = _replace_variables(r, **kwargs) - return Document.decode(rr).unpack() + return Document.decode(rr, db=db).unpack() @property def variables(self) -> t.List[str]: @@ -258,9 +258,13 @@ def defaults(self): fields = dc.fields(self) for f in fields: value = getattr(self, f.name) - if f.default is not dc.MISSING and value == f.default: + if f.default is not dc.MISSING and f.default and value == f.default: out[f.name] = value - elif f.default_factory is not dc.MISSING and value == f.default_factory(): + elif ( + f.default_factory is not dc.MISSING + and f.default + and value == f.default_factory() + ): out[f.name] = value return out diff --git a/superduper/components/component.py b/superduper/components/component.py index ddedc6007..69570836e 100644 --- a/superduper/components/component.py +++ b/superduper/components/component.py @@ -260,7 +260,7 @@ def _find_refs(r): def get_children(self, deep: bool = False) -> t.List["Component"]: """Get all the children of the component.""" - r = self.dict().encode(leaves_to_keep=Component) + r = self.dict().encode(leaves_to_keep=(Component,)) out = [v for v in r['_builds'].values() if isinstance(v, Component)] lookup = {} for v in out: diff --git a/superduper/components/dataset.py b/superduper/components/dataset.py index 07fb6600f..2c224c89d 100644 --- a/superduper/components/dataset.py +++ b/superduper/components/dataset.py @@ -72,7 +72,7 @@ def _pre_create(self, db: 'Datalayer', startup_cache: t.Dict = {}) -> None: def _load_data(self, db: 'Datalayer'): assert db is not None, 'Database must be set' assert self.select is not None, 'Select must be set' - data = list(db.execute(self.select)) + data = self.select.execute() if self.sample_size is not None and self.sample_size < len(data): perm = self.random.permutation(len(data)).tolist() data = [data[perm[i]] for i in range(self.sample_size)] diff --git a/superduper/components/datatype.py b/superduper/components/datatype.py index fdc6467c0..5a0352df3 100644 --- a/superduper/components/datatype.py +++ b/superduper/components/datatype.py @@ -19,6 +19,10 @@ Encode = t.Callable[[t.Any], bytes] +if t.TYPE_CHECKING: + from superduper.base.datalayer import Datalayer + + class DataTypeFactory: """Abstract class for creating a DataType # noqa.""" @@ -262,7 +266,7 @@ class File(Saveable): path: str = '' - def __post_init__(self, db=None): + def __post_init__(self, db: 'Datalayer' = None): if not self.identifier: self.identifier = get_hash(self.path) return super().__post_init__(db) diff --git a/superduper/components/listener.py b/superduper/components/listener.py index 1829a7c74..e05486c70 100644 --- a/superduper/components/listener.py +++ b/superduper/components/listener.py @@ -189,14 +189,12 @@ def _get_sample_input(self, db: Datalayer): else: if self.dependencies: try: - r = next(self.select.limit(1).execute()) + r = self.select.get() except errors: try: if not self.cdc_table.startswith(CFG.output_prefix): try: - r = next( - db[self.select.table].select().limit(1).execute() - ) + r = db[self.select.table].get() except errors: # Note: This is added for sql databases, # since they return error if key not found @@ -209,7 +207,7 @@ def _get_sample_input(self, db: Datalayer): raise Exception(msg.format(table=self.cdc_table)) from e else: try: - r = next(self.select.limit(1).execute()) + r = self.select.get() except (StopIteration, KeyError, FileNotFoundError) as e: raise Exception(msg.format(table=self.select)) from e mapping = Mapping(self.key, self.model.signature) @@ -313,4 +311,5 @@ def cleanup(self, db: "Datalayer") -> None: :param db: Data layer instance to process. """ if self.select is not None: - db[self.select.table].drop_outputs(self.predict_id) + # db[self.select.table].drop_outputs(self.predict_id) + db.databackend.drop_table(self.outputs) diff --git a/superduper/components/model.py b/superduper/components/model.py index ca9633f62..811970c77 100644 --- a/superduper/components/model.py +++ b/superduper/components/model.py @@ -12,7 +12,6 @@ from functools import wraps import requests -import tqdm from overrides import override from superduper import CFG, logging @@ -20,7 +19,6 @@ from superduper.backends.query_dataset import CachedQueryDataset, QueryDataset from superduper.base.annotations import trigger from superduper.base.document import Document -from superduper.base.exceptions import DatabackendException from superduper.base.leaf import Leaf from superduper.components.component import Component, ComponentMeta, ensure_initialized from superduper.components.datatype import BaseDataType @@ -451,6 +449,7 @@ def _prepare_select_for_predict(self, select, db): select.db = db return select + # TODO use query chunking not id chunking def _get_ids_from_select( self, *, @@ -460,38 +459,24 @@ def _get_ids_from_select( predict_id: str, overwrite: bool = False, ): + # TODO why all this complex logic just to get ids if not self.db.databackend.check_output_dest(predict_id): overwrite = True try: if not overwrite: if ids: - select = select.select_using_ids(ids) + select = select.select(select.primary_id) # TODO - this is broken - query = select.select_ids_of_missing_outputs(predict_id=predict_id) - + # query = select.select_ids_of_missing_outputs(predict_id=predict_id) + predict_ids = select.missing_outputs(predict_id) else: if ids: return ids - query = select.select_ids + predict_ids = select.ids() except FileNotFoundError: # This is case for sql where Table is not created yet # and we try to access `db.load('table', name)`. return [] - try: - id_field = self.db.databackend.id_field - except AttributeError: - id_field = query.table_or_collection.primary_id - - # TODO: Find better solution to support in-memory (pandas) - # Since pandas has a bug, it cannot join on empty table. - try: - id_curr = self.db.execute(query) - except DatabackendException: - id_curr = self.db.execute(select.select(id_field)) - - predict_ids = [] - for r in tqdm.tqdm(id_curr): - predict_ids.append(str(r[id_field])) if ids and len(predict_ids) > len(ids): raise Exception( @@ -544,6 +529,7 @@ def predict_in_db( overwrite=overwrite, predict_id=predict_id, ) + out = self._predict_with_select_and_ids( X=X, predict_id=predict_id, @@ -564,8 +550,9 @@ def _prepare_inputs_from_select( ): X_data: t.Any mapping = Mapping(X, self.signature) + if in_memory: - docs = list(self.db.execute(select.select_using_ids(ids))) + docs = select.subset(ids) X_data = list(map(lambda x: mapping(x), docs)) else: assert isinstance(self.db, Datalayer) @@ -580,7 +567,7 @@ def _prepare_inputs_from_select( flat = False if 'outputs' in str(select): - sample = next(select.limit(1).execute()) + sample = select.get() upstream_predict_ids = [ k for k in sample if k.startswith(CFG.output_prefix) ] @@ -596,8 +583,8 @@ def _prepare_inputs_from_select( logging.error(f"select: {select}") raise Exception( 'You\'ve specified more documents than unique ids;' - f' Is it possible that {select.table_or_collection.primary_id}' - f' isn\'t uniquely identifying?' + ' Is it possible that the primary_id' + ' isn\'t uniquely identifying?' ) return X_data, mapping @@ -654,6 +641,7 @@ def _predict_with_select_and_ids( ) it += 1 return output_ids + dataset, _ = self._prepare_inputs_from_select( X=X, select=select, @@ -668,22 +656,30 @@ def _predict_with_select_and_ids( self.version, int ), 'Version has not been set, can\'t save outputs...' - update = select.model_update( - db=self.db, - predict_id=predict_id, - outputs=outputs, - ids=ids, - flatten=flatten, - **self.model_update_kwargs, - ) - output_ids = [] - if update: - # Don't use auto_schema for inserting model outputs - if update.type == 'insert': - output_ids = update.execute(db=self.db, auto_schema=True) - else: - output_ids = update.execute(db=self.db) - return output_ids + if flatten: + documents = [ + { + self.db.databackend.id_field: self.db.databackend.random_id(), + '_source': self.db.databackend.to_id(id), + f'{CFG.output_prefix}{predict_id}': sub_output, + } + for id, output in zip(ids, outputs) + for sub_output in output + ] + else: + documents = [ + { + self.db.databackend.id_field: self.db.databackend.random_id(), + '_source': self.db.databackend.to_id(id), + f'{CFG.output_prefix}{predict_id}': output, + } + for id, output in zip(ids, outputs) + ] + + from superduper.base.datalayer import Datalayer + + assert isinstance(self.db, Datalayer) + return self.db[f'{CFG.output_prefix}{predict_id}'].insert(documents) def __call__(self, *args, **kwargs): """Connect the models to build a graph. @@ -1372,7 +1368,6 @@ def _build_prompt(self, query, docs): def predict(self, query: str): assert self.db, 'db cannot be None' select = self.select.set_variables(db=self.db, query=query) - self.db.execute(select) - results = [r.unpack() for r in select.tolist()] + results = [r.unpack() for r in select.execute()] prompt = self._build_prompt(query, results) return self.llm.predict(prompt) diff --git a/superduper/components/schema.py b/superduper/components/schema.py index 6b313e06d..95c460bff 100644 --- a/superduper/components/schema.py +++ b/superduper/components/schema.py @@ -78,10 +78,11 @@ def __post_init__(self, db): self.fields[k] = v - def update(self, other: 'Schema'): + def __add__(self, other: 'Schema'): new_fields = self.fields.copy() new_fields.update(other.fields) - return Schema(self.identifier, fields=new_fields) + id = self.identifier + '+' + other.identifier + return Schema(id, fields=new_fields, db=self.db) # TODO why do we need this? @cached_property @@ -152,6 +153,7 @@ def encode_data(self, out, builds, blobs, files, leaves_to_keep=()): :param blobs: Blobs. :param files: Files. """ + result = {k: v for k, v in out.items()} for k, field in self.fields.items(): if not isinstance(field, BaseDataType): continue @@ -176,7 +178,7 @@ def encode_data(self, out, builds, blobs, files, leaves_to_keep=()): ): assert isinstance(data, bytes) data = _convert_bytes_to_base64(data) - out[k] = data + result[k] = data elif isinstance(data, Saveable): ref_obj = parse_reference(data.reference) @@ -189,13 +191,12 @@ def encode_data(self, out, builds, blobs, files, leaves_to_keep=()): else: assert False, f'Unknown reference type {ref_obj.name}' - out[k] = data.reference + result[k] = data.reference else: - out[k] = data + result[k] = data - out['_schema'] = self.identifier - - return out + result['_schema'] = self.identifier + return result def __call__(self, data: dict[str, t.Any]) -> dict[str, t.Any]: """Encode data using the schema's encoders. diff --git a/superduper/components/table.py b/superduper/components/table.py index 943c337eb..b5a8d78d3 100644 --- a/superduper/components/table.py +++ b/superduper/components/table.py @@ -33,6 +33,7 @@ def __post_init__(self, db): fields.update(self.schema.fields) schema_version = self.schema.version + # TODO make globally configurable if '_fold' not in self.schema.fields: fields.update({'_fold': 'str'}) @@ -43,6 +44,10 @@ def __post_init__(self, db): ) self.schema.version = schema_version + def cleanup(self, db): + if self.identifier.startswith(CFG.output_prefix): + db.databackend.drop_table(self.identifier) + def on_create(self, db: 'Datalayer'): """Create the table, on creation of the component. @@ -70,4 +75,4 @@ def add_data(self): else: data = self.data if data: - self.db[self.identifier].insert(data).execute() + self.db[self.identifier].insert(data) diff --git a/superduper/components/vector_index.py b/superduper/components/vector_index.py index 265167bd3..31788746e 100644 --- a/superduper/components/vector_index.py +++ b/superduper/components/vector_index.py @@ -173,11 +173,10 @@ def copy_vectors(self, ids: t.Sequence[str] | None = None): # TODO do this using the backfill_vector_search functionality here if ids is None: assert self.indexing_listener.select is not None - cur = select.select_ids.execute() - ids = [r[select.primary_id] for r in cur] + ids = select.ids() docs = [r.unpack() for r in select.execute()] else: - docs = [r.unpack() for r in select.select_using_ids(ids).execute()] + docs = [r.unpack() for r in select.subset(ids)] vectors = [] nokeys = 0 diff --git a/superduper/misc/eager.py b/superduper/misc/eager.py index ca3d3ea03..491646ab0 100644 --- a/superduper/misc/eager.py +++ b/superduper/misc/eager.py @@ -5,7 +5,6 @@ import networkx as nx from superduper import CFG, logging -from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES from superduper.base.leaf import build_uuid if t.TYPE_CHECKING: @@ -397,24 +396,9 @@ def _get_select(self, node: SuperDuperData): else: raise ValueError(f"Unknown node type: {upstream_node.type}") - if main_table != root_table: - select = self.db[main_table].select() - - else: - from superduper.base.enums import DBType - - if self.db.databackend.db_type == DBType.MONGODB: - if main_table_keys: - main_table_keys.extend( - [KEY_BUILDS, KEY_FILES, KEY_BLOBS, "_schema"] - ) - select = self.db[main_table].find({}, {k: 1 for k in main_table_keys}) - - else: - if "id" not in main_table_keys: - main_table_keys.insert(0, "id") - select = self.db[main_table].select(*main_table_keys) + select = self.db[main_table] + if main_table == root_table: if node.filter: for key, value in node.filter.items(): if value[0] == "ne": diff --git a/superduper/misc/tree.py b/superduper/misc/tree.py index 10c88d7dc..bb57f42ce 100644 --- a/superduper/misc/tree.py +++ b/superduper/misc/tree.py @@ -17,7 +17,7 @@ def dict_to_tree(dictionary, root: str = 'root', tree=None): # If the value is another dictionary, create a subtree subtree = tree.add(f"[bold yellow]{key}") dict_to_tree(value, root=root, tree=subtree) - else: + elif key == 'status': # Add the key and value as a leaf node if value == 'breaking': tree.add(f"[bold cyan]{key}: [red]{value}") diff --git a/superduper/rest/build.py b/superduper/rest/build.py index dbc731945..9f56a76bf 100644 --- a/superduper/rest/build.py +++ b/superduper/rest/build.py @@ -281,21 +281,33 @@ def db_execute(query: t.Dict, db: 'Datalayer' = DatalayerDependency()): logging.info(output) return [{'_base': output}], [] - if '_path' not in query: - plugin = db.databackend.backend_name - query['_path'] = f'superduper_{plugin}.query.parse_query' + import re - q = Document.decode(query, db=db).unpack() + predict_match = re.match( + r'^([a-zA-Z0-9_]+)\.predict\(.*\)$', query['query'].strip() + ) - logging.info('processing this query:') - logging.info(q) + if predict_match: + model = predict_match.groups()[0] + m = db.load('model', model) + result = eval( + query['query'], {model: m, 'documents': query.get('documents', [])} + ) - result = q.execute() + if isinstance(result, dict): + result = Document(result) + else: + result = Document({'_base': result}) - if q.type in {'insert', 'delete', 'update'}: - return {'_base': [str(x) for x in result]}, [] + q = None + else: + query['_path'] = 'superduper.backends.base.query.parse_query' + q = Document.decode(query, db=db).unpack() - logging.warn(str(q)) + logging.info('processing this query:') + logging.info(q) + result = q.execute() + logging.warn(str(q)) if isinstance(result, Document): result = [result] @@ -308,8 +320,9 @@ def db_execute(query: t.Dict, db: 'Datalayer' = DatalayerDependency()): if isinstance(q, Query): for i, r in enumerate(result): r = list(r) - if q.primary_id in r[0]: - r[0][q.primary_id] = str(r[0][q.primary_id]) + pid = q.primary_id.execute() + if pid in r[0]: + r[0][pid] = str(r[0][pid]) result[i] = tuple(r) if 'score' in result[0][0]: result = sorted(result, key=lambda x: -x[0]['score']) diff --git a/templates/simple_rag/blobs/2a5cef4ab2af1d1006597e512755d36e953956b6 b/templates/simple_rag/blobs/2a5cef4ab2af1d1006597e512755d36e953956b6 deleted file mode 100644 index 93df19d6a..000000000 Binary files a/templates/simple_rag/blobs/2a5cef4ab2af1d1006597e512755d36e953956b6 and /dev/null differ diff --git a/templates/simple_rag/blobs/ef8655ee5f19aa0d2a7fb09d4c46aeedceb2bd24 b/templates/simple_rag/blobs/2b1758f910e08d9bb2eaa836d00c725d22ee9616 similarity index 95% rename from templates/simple_rag/blobs/ef8655ee5f19aa0d2a7fb09d4c46aeedceb2bd24 rename to templates/simple_rag/blobs/2b1758f910e08d9bb2eaa836d00c725d22ee9616 index 341869b03..c71d9ab3d 100644 Binary files a/templates/simple_rag/blobs/ef8655ee5f19aa0d2a7fb09d4c46aeedceb2bd24 and b/templates/simple_rag/blobs/2b1758f910e08d9bb2eaa836d00c725d22ee9616 differ diff --git a/templates/simple_rag/blobs/34c27f9f9368613917c80f0b66d7ae6e144f0794 b/templates/simple_rag/blobs/34c27f9f9368613917c80f0b66d7ae6e144f0794 deleted file mode 100644 index 4429b0736..000000000 Binary files a/templates/simple_rag/blobs/34c27f9f9368613917c80f0b66d7ae6e144f0794 and /dev/null differ diff --git a/templates/simple_rag/blobs/44c0285e7e53a1004226cc4c110a442909214bed b/templates/simple_rag/blobs/44c0285e7e53a1004226cc4c110a442909214bed deleted file mode 100644 index d754c0998..000000000 Binary files a/templates/simple_rag/blobs/44c0285e7e53a1004226cc4c110a442909214bed and /dev/null differ diff --git a/templates/simple_rag/blobs/558862283097265020c65fb73179764194a1f5e7 b/templates/simple_rag/blobs/558862283097265020c65fb73179764194a1f5e7 deleted file mode 100644 index 01aa725a5..000000000 Binary files a/templates/simple_rag/blobs/558862283097265020c65fb73179764194a1f5e7 and /dev/null differ diff --git a/templates/simple_rag/blobs/6a536b4ec925b94103a04c3083f940fe07ed75e1 b/templates/simple_rag/blobs/6a536b4ec925b94103a04c3083f940fe07ed75e1 deleted file mode 100644 index 2cf26d60a..000000000 Binary files a/templates/simple_rag/blobs/6a536b4ec925b94103a04c3083f940fe07ed75e1 and /dev/null differ diff --git a/templates/simple_rag/blobs/6bc208de57d306fe2a803f3a308933c36565fc19 b/templates/simple_rag/blobs/6bc208de57d306fe2a803f3a308933c36565fc19 deleted file mode 100644 index a207c774c..000000000 Binary files a/templates/simple_rag/blobs/6bc208de57d306fe2a803f3a308933c36565fc19 and /dev/null differ diff --git a/templates/simple_rag/blobs/79ce3dfb3c6c84bb7337688e86681e4848847bbb b/templates/simple_rag/blobs/ae6c79e267e9c8689461504861898e273522f09d similarity index 81% rename from templates/simple_rag/blobs/79ce3dfb3c6c84bb7337688e86681e4848847bbb rename to templates/simple_rag/blobs/ae6c79e267e9c8689461504861898e273522f09d index 6c1e16029..c592ba09c 100644 Binary files a/templates/simple_rag/blobs/79ce3dfb3c6c84bb7337688e86681e4848847bbb and b/templates/simple_rag/blobs/ae6c79e267e9c8689461504861898e273522f09d differ diff --git a/templates/simple_rag/blobs/dc07a0790fa145a227e8c24ca7d84bd438522762 b/templates/simple_rag/blobs/dc07a0790fa145a227e8c24ca7d84bd438522762 deleted file mode 100644 index 39d2b5ab1..000000000 Binary files a/templates/simple_rag/blobs/dc07a0790fa145a227e8c24ca7d84bd438522762 and /dev/null differ diff --git a/templates/simple_rag/blobs/e8e60a69a01b49adc788e151e61426c579d8f935 b/templates/simple_rag/blobs/e8e60a69a01b49adc788e151e61426c579d8f935 deleted file mode 100644 index 415b5eaad..000000000 Binary files a/templates/simple_rag/blobs/e8e60a69a01b49adc788e151e61426c579d8f935 and /dev/null differ diff --git a/templates/simple_rag/blobs/e9b257239a8150546960866580334037088fe6c1 b/templates/simple_rag/blobs/e9b257239a8150546960866580334037088fe6c1 deleted file mode 100644 index 332315e7a..000000000 Binary files a/templates/simple_rag/blobs/e9b257239a8150546960866580334037088fe6c1 and /dev/null differ diff --git a/templates/simple_rag/build.ipynb b/templates/simple_rag/build.ipynb index 3017b819e..0d6ac5dcf 100644 --- a/templates/simple_rag/build.ipynb +++ b/templates/simple_rag/build.ipynb @@ -55,17 +55,17 @@ }, "outputs": [], "source": [ - "APPLY = True\n", + "APPLY = False\n", "SAMPLE_COLLECTION_NAME = 'sample_simple_rag'\n", "COLLECTION_NAME = '' if not APPLY else 'docs'\n", "ID_FIELD = '' if not APPLY else 'id'\n", "OUTPUT_PREFIX = 'outputs__'\n", - "EAGER = False" + "EAGER = True" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "cb029a5e-fedf-4f07-8a31-d220cfbfbb3d", "metadata": { "editable": true, @@ -79,40 +79,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-Dec-18 14:42:09.35\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mkartiks-MacBook-Air.local\u001b[0m| \u001b[36msuperduper.misc.plugins\u001b[0m:\u001b[36m13 \u001b[0m | \u001b[1mLoading plugin: mongodb\u001b[0m\n" - ] - }, - { - "ename": "ServerSelectionTimeoutError", - "evalue": "localhost:27017: [Errno 61] Connection refused (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 5.0s, Topology Description: ]>", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mServerSelectionTimeoutError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m CFG\u001b[38;5;241m.\u001b[39mbytes_encoding \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstr\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 7\u001b[0m os\u001b[38;5;241m.\u001b[39menviron[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSUPERDUPER_DATA_BACKEND\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msnowflake://softwareuser:SU4yv6DfUPUL0CPDdsCDDSLttVc@ngkjqqn-superduperdbeu\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m----> 8\u001b[0m db \u001b[38;5;241m=\u001b[39m \u001b[43msuperduper\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/superduper/base/superduper.py:18\u001b[0m, in \u001b[0;36msuperduper\u001b[0;34m(item, **kwargs)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msuperduper\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbase\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbuild\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m build_datalayer\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m item \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbuild_datalayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmongomock://\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m 21\u001b[0m kwargs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata_backend\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m item\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/superduper/base/build.py:135\u001b[0m, in \u001b[0;36mbuild_datalayer\u001b[0;34m(cfg, **kwargs)\u001b[0m\n\u001b[1;32m 132\u001b[0m cfg \u001b[38;5;241m=\u001b[39m (cfg \u001b[38;5;129;01mor\u001b[39;00m CFG)(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 134\u001b[0m cfg \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mcast(Config, cfg)\n\u001b[0;32m--> 135\u001b[0m databackend_obj \u001b[38;5;241m=\u001b[39m \u001b[43m_build_databackend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata_backend\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cfg\u001b[38;5;241m.\u001b[39mmetadata_store:\n\u001b[1;32m 137\u001b[0m metadata_obj \u001b[38;5;241m=\u001b[39m _build_metadata(cfg\u001b[38;5;241m.\u001b[39mmetadata_store)\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/superduper/base/build.py:98\u001b[0m, in \u001b[0;36m_build_databackend\u001b[0;34m(uri)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_build_databackend\u001b[39m(uri):\n\u001b[0;32m---> 98\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m DataBackendProxy(\u001b[43m_DataBackendLoader\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/superduper/base/build.py:46\u001b[0m, in \u001b[0;36m_Loader.create\u001b[0;34m(cls, uri)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 42\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mplugin\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m with flavour \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mflavour\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not supported \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto create metadata store.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 45\u001b[0m impl \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(load_plugin(plugin), \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mimpl)\n\u001b[0;32m---> 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflavour\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mflavour\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/plugins/mongodb/superduper_mongodb/data_backend.py:39\u001b[0m, in \u001b[0;36mMongoDBDataBackend.__init__\u001b[0;34m(self, uri, flavour)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moverwrite \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(uri, flavour\u001b[38;5;241m=\u001b[39mflavour)\n\u001b[0;32m---> 39\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconn, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m=\u001b[39m \u001b[43mconnection_callback\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflavour\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_db \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconn[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname]\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdatatype_presets \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 44\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvector\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msuperduper.components.datatype.NativeVector\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 45\u001b[0m }\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/plugins/mongodb/superduper_mongodb/utils.py:70\u001b[0m, in \u001b[0;36mconnection_callback\u001b[0;34m(uri, flavour)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flavour \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmongodb\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 69\u001b[0m name \u001b[38;5;241m=\u001b[39m uri\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m---> 70\u001b[0m conn \u001b[38;5;241m=\u001b[39m \u001b[43m_get_avaliable_conn\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mserverSelectionTimeoutMS\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5000\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m flavour \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124matlas\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 73\u001b[0m name \u001b[38;5;241m=\u001b[39m uri\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/plugins/mongodb/superduper_mongodb/utils.py:30\u001b[0m, in \u001b[0;36m_get_avaliable_conn\u001b[0;34m(uri, **kwargs)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m client\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ServerSelectionTimeoutError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 29\u001b[0m \u001b[38;5;66;03m# If the server is not available, raise the exception\u001b[39;00m\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 32\u001b[0m uri_mask \u001b[38;5;241m=\u001b[39m anonymize_url(uri)\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/plugins/mongodb/superduper_mongodb/utils.py:26\u001b[0m, in \u001b[0;36m_get_avaliable_conn\u001b[0;34m(uri, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m client: pymongo\u001b[38;5;241m.\u001b[39mMongoClient \u001b[38;5;241m=\u001b[39m pymongo\u001b[38;5;241m.\u001b[39mMongoClient(uri, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 26\u001b[0m \u001b[43mclient\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdb_name\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlist_collection_names\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m client\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ServerSelectionTimeoutError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 29\u001b[0m \u001b[38;5;66;03m# If the server is not available, raise the exception\u001b[39;00m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/database.py:1226\u001b[0m, in \u001b[0;36mDatabase.list_collection_names\u001b[0;34m(self, session, filter, comment, **kwargs)\u001b[0m\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlist_collection_names\u001b[39m(\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1196\u001b[0m session: Optional[ClientSession] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1199\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any,\n\u001b[1;32m 1200\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m]:\n\u001b[1;32m 1201\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get a list of all the collection names in this database.\u001b[39;00m\n\u001b[1;32m 1202\u001b[0m \n\u001b[1;32m 1203\u001b[0m \u001b[38;5;124;03m For example, to list all non-system collections::\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1224\u001b[0m \u001b[38;5;124;03m .. versionadded:: 3.6\u001b[39;00m\n\u001b[1;32m 1225\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1226\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_list_collection_names\u001b[49m\u001b[43m(\u001b[49m\u001b[43msession\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mfilter\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomment\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/database.py:1191\u001b[0m, in \u001b[0;36mDatabase._list_collection_names\u001b[0;34m(self, session, filter, comment, **kwargs)\u001b[0m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mfilter\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mfilter\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mfilter\u001b[39m):\n\u001b[1;32m 1188\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnameOnly\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[0;32m-> 1191\u001b[0m result[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m result \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_list_collections_helper\u001b[49m\u001b[43m(\u001b[49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msession\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1192\u001b[0m ]\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/database.py:1138\u001b[0m, in \u001b[0;36mDatabase._list_collections_helper\u001b[0;34m(self, session, filter, comment, **kwargs)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_cmd\u001b[39m(\n\u001b[1;32m 1131\u001b[0m session: Optional[ClientSession],\n\u001b[1;32m 1132\u001b[0m _server: Server,\n\u001b[1;32m 1133\u001b[0m conn: Connection,\n\u001b[1;32m 1134\u001b[0m read_preference: _ServerMode,\n\u001b[1;32m 1135\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m CommandCursor[MutableMapping[\u001b[38;5;28mstr\u001b[39m, Any]]:\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_list_collections(conn, session, read_preference\u001b[38;5;241m=\u001b[39mread_preference, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m-> 1138\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_retryable_read\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1139\u001b[0m \u001b[43m \u001b[49m\u001b[43m_cmd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mread_pref\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msession\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_Op\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLIST_COLLECTIONS\u001b[49m\n\u001b[1;32m 1140\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/mongo_client.py:1863\u001b[0m, in \u001b[0;36mMongoClient._retryable_read\u001b[0;34m(self, func, read_pref, session, operation, address, retryable, operation_id)\u001b[0m\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;66;03m# Ensure that the client supports retrying on reads and there is no session in\u001b[39;00m\n\u001b[1;32m 1859\u001b[0m \u001b[38;5;66;03m# transaction, otherwise, we will not support retry behavior for this call.\u001b[39;00m\n\u001b[1;32m 1860\u001b[0m retryable \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mbool\u001b[39m(\n\u001b[1;32m 1861\u001b[0m retryable \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptions\u001b[38;5;241m.\u001b[39mretry_reads \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (session \u001b[38;5;129;01mand\u001b[39;00m session\u001b[38;5;241m.\u001b[39min_transaction)\n\u001b[1;32m 1862\u001b[0m )\n\u001b[0;32m-> 1863\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_retry_internal\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1865\u001b[0m \u001b[43m \u001b[49m\u001b[43msession\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1866\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1867\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1868\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_read\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1869\u001b[0m \u001b[43m \u001b[49m\u001b[43maddress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maddress\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1870\u001b[0m \u001b[43m \u001b[49m\u001b[43mread_pref\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mread_pref\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1871\u001b[0m \u001b[43m \u001b[49m\u001b[43mretryable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mretryable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1872\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moperation_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1873\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/_csot.py:119\u001b[0m, in \u001b[0;36mapply..csot_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _TimeoutContext(timeout):\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/mongo_client.py:1830\u001b[0m, in \u001b[0;36mMongoClient._retry_internal\u001b[0;34m(self, func, session, bulk, operation, is_read, address, read_pref, retryable, operation_id)\u001b[0m\n\u001b[1;32m 1793\u001b[0m \u001b[38;5;129m@_csot\u001b[39m\u001b[38;5;241m.\u001b[39mapply\n\u001b[1;32m 1794\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_retry_internal\u001b[39m(\n\u001b[1;32m 1795\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1804\u001b[0m operation_id: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1805\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 1806\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Internal retryable helper for all client transactions.\u001b[39;00m\n\u001b[1;32m 1807\u001b[0m \n\u001b[1;32m 1808\u001b[0m \u001b[38;5;124;03m :param func: Callback function we want to retry\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1817\u001b[0m \u001b[38;5;124;03m :return: Output of the calling func()\u001b[39;00m\n\u001b[1;32m 1818\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 1819\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_ClientConnectionRetryable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1820\u001b[0m \u001b[43m \u001b[49m\u001b[43mmongo_client\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1821\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1822\u001b[0m \u001b[43m \u001b[49m\u001b[43mbulk\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbulk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1823\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moperation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1824\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_read\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_read\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1825\u001b[0m \u001b[43m \u001b[49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msession\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1826\u001b[0m \u001b[43m \u001b[49m\u001b[43mread_pref\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mread_pref\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1827\u001b[0m \u001b[43m \u001b[49m\u001b[43maddress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maddress\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1828\u001b[0m \u001b[43m \u001b[49m\u001b[43mretryable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mretryable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1829\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moperation_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m-> 1830\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/mongo_client.py:2554\u001b[0m, in \u001b[0;36m_ClientConnectionRetryable.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_last_error(check_csot\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 2554\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_read \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_write()\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ServerSelectionTimeoutError:\n\u001b[1;32m 2556\u001b[0m \u001b[38;5;66;03m# The application may think the write was never attempted\u001b[39;00m\n\u001b[1;32m 2557\u001b[0m \u001b[38;5;66;03m# if we raise ServerSelectionTimeoutError on the retry\u001b[39;00m\n\u001b[1;32m 2558\u001b[0m \u001b[38;5;66;03m# attempt. Raise the original exception instead.\u001b[39;00m\n\u001b[1;32m 2559\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_last_error()\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/mongo_client.py:2689\u001b[0m, in \u001b[0;36m_ClientConnectionRetryable._read\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2684\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_read\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 2685\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Wrapper method for read-type retryable client executions\u001b[39;00m\n\u001b[1;32m 2686\u001b[0m \n\u001b[1;32m 2687\u001b[0m \u001b[38;5;124;03m :return: Output for func()'s call\u001b[39;00m\n\u001b[1;32m 2688\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2689\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_server \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_server\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2690\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_pref \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRead Preference required on read calls\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2691\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_client\u001b[38;5;241m.\u001b[39m_conn_from_server(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_pref, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_server, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_session) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m 2692\u001b[0m conn,\n\u001b[1;32m 2693\u001b[0m read_pref,\n\u001b[1;32m 2694\u001b[0m ):\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/mongo_client.py:2645\u001b[0m, in \u001b[0;36m_ClientConnectionRetryable._get_server\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2640\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_server\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Server:\n\u001b[1;32m 2641\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Retrieves a server object based on provided object context\u001b[39;00m\n\u001b[1;32m 2642\u001b[0m \n\u001b[1;32m 2643\u001b[0m \u001b[38;5;124;03m :return: Abstraction to connect to server\u001b[39;00m\n\u001b[1;32m 2644\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2645\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_select_server\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2646\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_server_selector\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2647\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_session\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2648\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_operation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2649\u001b[0m \u001b[43m \u001b[49m\u001b[43maddress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_address\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2650\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeprioritized_servers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_deprioritized_servers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2651\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_operation_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2652\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/mongo_client.py:1649\u001b[0m, in \u001b[0;36mMongoClient._select_server\u001b[0;34m(self, server_selector, session, operation, address, deprioritized_servers, operation_id)\u001b[0m\n\u001b[1;32m 1647\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m AutoReconnect(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mserver \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m no longer available\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m address) \u001b[38;5;66;03m# noqa: UP031\u001b[39;00m\n\u001b[1;32m 1648\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1649\u001b[0m server \u001b[38;5;241m=\u001b[39m \u001b[43mtopology\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect_server\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mserver_selector\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeprioritized_servers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdeprioritized_servers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moperation_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1654\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1655\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m server\n\u001b[1;32m 1656\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m PyMongoError \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 1657\u001b[0m \u001b[38;5;66;03m# Server selection errors in a transaction are transient.\u001b[39;00m\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/topology.py:398\u001b[0m, in \u001b[0;36mTopology.select_server\u001b[0;34m(self, selector, operation, server_selection_timeout, address, deprioritized_servers, operation_id)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mselect_server\u001b[39m(\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 390\u001b[0m selector: Callable[[Selection], Selection],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 395\u001b[0m operation_id: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 396\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Server:\n\u001b[1;32m 397\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Like select_servers, but choose a random server if several match.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 398\u001b[0m server \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_select_server\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 399\u001b[0m \u001b[43m \u001b[49m\u001b[43mselector\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 401\u001b[0m \u001b[43m \u001b[49m\u001b[43mserver_selection_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43maddress\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeprioritized_servers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moperation_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _csot\u001b[38;5;241m.\u001b[39mget_timeout():\n\u001b[1;32m 407\u001b[0m _csot\u001b[38;5;241m.\u001b[39mset_rtt(server\u001b[38;5;241m.\u001b[39mdescription\u001b[38;5;241m.\u001b[39mmin_round_trip_time)\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/topology.py:376\u001b[0m, in \u001b[0;36mTopology._select_server\u001b[0;34m(self, selector, operation, server_selection_timeout, address, deprioritized_servers, operation_id)\u001b[0m\n\u001b[1;32m 367\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_select_server\u001b[39m(\n\u001b[1;32m 368\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 369\u001b[0m selector: Callable[[Selection], Selection],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 374\u001b[0m operation_id: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 375\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Server:\n\u001b[0;32m--> 376\u001b[0m servers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect_servers\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 377\u001b[0m \u001b[43m \u001b[49m\u001b[43mselector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mserver_selection_timeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maddress\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\n\u001b[1;32m 378\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 379\u001b[0m servers \u001b[38;5;241m=\u001b[39m _filter_servers(servers, deprioritized_servers)\n\u001b[1;32m 380\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(servers) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/topology.py:283\u001b[0m, in \u001b[0;36mTopology.select_servers\u001b[0;34m(self, selector, operation, server_selection_timeout, address, operation_id)\u001b[0m\n\u001b[1;32m 280\u001b[0m server_timeout \u001b[38;5;241m=\u001b[39m server_selection_timeout\n\u001b[1;32m 282\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[0;32m--> 283\u001b[0m server_descriptions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_select_servers_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 284\u001b[0m \u001b[43m \u001b[49m\u001b[43mselector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mserver_timeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperation_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maddress\u001b[49m\n\u001b[1;32m 285\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 287\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 288\u001b[0m cast(Server, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_server_by_address(sd\u001b[38;5;241m.\u001b[39maddress)) \u001b[38;5;28;01mfor\u001b[39;00m sd \u001b[38;5;129;01min\u001b[39;00m server_descriptions\n\u001b[1;32m 289\u001b[0m ]\n", - "File \u001b[0;32m~/Work/superduperdb/code/superduper/.venv/lib/python3.10/site-packages/pymongo/synchronous/topology.py:333\u001b[0m, in \u001b[0;36mTopology._select_servers_loop\u001b[0;34m(self, selector, timeout, operation, operation_id, address)\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _SERVER_SELECTION_LOGGER\u001b[38;5;241m.\u001b[39misEnabledFor(logging\u001b[38;5;241m.\u001b[39mDEBUG):\n\u001b[1;32m 323\u001b[0m _debug_log(\n\u001b[1;32m 324\u001b[0m _SERVER_SELECTION_LOGGER,\n\u001b[1;32m 325\u001b[0m message\u001b[38;5;241m=\u001b[39m_ServerSelectionStatusMessage\u001b[38;5;241m.\u001b[39mFAILED,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 331\u001b[0m failure\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_error_message(selector),\n\u001b[1;32m 332\u001b[0m )\n\u001b[0;32m--> 333\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ServerSelectionTimeoutError(\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_error_message(selector)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Timeout: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimeout\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124ms, Topology Description: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdescription\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 335\u001b[0m )\n\u001b[1;32m 337\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m logged_waiting:\n\u001b[1;32m 338\u001b[0m _debug_log(\n\u001b[1;32m 339\u001b[0m _SERVER_SELECTION_LOGGER,\n\u001b[1;32m 340\u001b[0m message\u001b[38;5;241m=\u001b[39m_ServerSelectionStatusMessage\u001b[38;5;241m.\u001b[39mWAITING,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 346\u001b[0m remainingTimeMS\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mint\u001b[39m(end_time \u001b[38;5;241m-\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic()),\n\u001b[1;32m 347\u001b[0m )\n", - "\u001b[0;31mServerSelectionTimeoutError\u001b[0m: localhost:27017: [Errno 61] Connection refused (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 5.0s, Topology Description: ]>" + "\u001b[32m2025-Jan-01 18:39:35.72\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.misc.plugins\u001b[0m:\u001b[36m13 \u001b[0m | \u001b[1mLoading plugin: mongodb\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:35.80\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.datalayer\u001b[0m:\u001b[36m62 \u001b[0m | \u001b[1mBuilding Data Layer\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:35.80\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.build\u001b[0m:\u001b[36m185 \u001b[0m | \u001b[1mConfiguration: \n", + " +---------------+-----------------------------------+\n", + "| Configuration | Value |\n", + "+---------------+-----------------------------------+\n", + "| Data Backend | mongodb://localhost:27017/test_db |\n", + "+---------------+-----------------------------------+\u001b[0m\n" ] } ], @@ -123,7 +97,6 @@ "CFG.output_prefix = OUTPUT_PREFIX\n", "CFG.bytes_encoding = 'str'\n", "\n", - "os.environ['SUPERDUPER_DATA_BACKEND'] = 'snowflake://softwareuser:SU4yv6DfUPUL0CPDdsCDDSLttVc@ngkjqqn-superduperdbeu'\n", "db = superduper()" ] }, @@ -185,7 +158,7 @@ "source": [ "if APPLY:\n", " from superduper import Document\n", - " ids = db.execute(db[COLLECTION_NAME].insert([Document(r) for r in data]))" + " ids = db[COLLECTION_NAME].insert([Document(r) for r in data])" ] }, { @@ -468,8 +441,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-Dec-06 13:29:54.37\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.application\u001b[0m:\u001b[36m39 \u001b[0m | \u001b[1mResorting components based on topological order.\u001b[0m\n", - "\u001b[32m2024-Dec-06 13:29:54.37\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.application\u001b[0m:\u001b[36m56 \u001b[0m | \u001b[1mNew order of components: ['listener:chunker:8fa67b7633974833', 'vector_index:vectorindex:88db3297b36847e8', 'model:simple_rag:cbfd63906d6d42bc', 'streamlit:simple-rag-demo:7c642ffb09474afc']\u001b[0m\n" + "\u001b[32m2025-Jan-01 18:39:36.65\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.application\u001b[0m:\u001b[36m39 \u001b[0m | \u001b[1mResorting components based on topological order.\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:36.66\u001b[0m| \u001b[1mINFO \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.application\u001b[0m:\u001b[36m56 \u001b[0m | \u001b[1mNew order of components: ['listener:chunker:3296a0335d6a4c59', 'vector_index:vectorindex:a0f0611c9d9c4617', 'model:simple_rag:ae6565708aae43de', 'streamlit:simple-rag-demo:38c123369e2c46dd']\u001b[0m\n" ] } ], @@ -536,12 +509,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-Dec-06 13:29:54.38\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m74 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", - "\u001b[32m2024-Dec-06 13:29:54.39\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m74 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", - "\u001b[32m2024-Dec-06 13:29:54.39\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m558 \u001b[0m | \u001b[33m\u001b[1mLeaf listener:chunker already exists\u001b[0m\n", - "\u001b[32m2024-Dec-06 13:29:54.39\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m74 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", - "\u001b[32m2024-Dec-06 13:29:54.39\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m558 \u001b[0m | \u001b[33m\u001b[1mLeaf model:chunker already exists\u001b[0m\n", - "\u001b[32m2024-Dec-06 13:29:54.39\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m558 \u001b[0m | \u001b[33m\u001b[1mLeaf var-table-name-select-var-id-field-x already exists\u001b[0m\n" + "\u001b[32m2025-Jan-01 18:39:36.67\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m76 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:36.68\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m76 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:36.68\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m554 \u001b[0m | \u001b[33m\u001b[1mLeaf listener:chunker already exists\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:36.68\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.components.listener\u001b[0m:\u001b[36m76 \u001b[0m | \u001b[33m\u001b[1moutput_table not found in listener.dict()\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:36.68\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m554 \u001b[0m | \u001b[33m\u001b[1mLeaf model:chunker already exists\u001b[0m\n", + "\u001b[32m2025-Jan-01 18:39:36.69\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m554 \u001b[0m | \u001b[33m\u001b[1mLeaf var-table-name-select-var-id-field-x already exists\u001b[0m\n" ] } ], @@ -608,7 +581,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m2024-Dec-06 13:29:54.40\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m558 \u001b[0m | \u001b[33m\u001b[1mLeaf str already exists\u001b[0m\n" + "\u001b[32m2025-Jan-01 18:39:36.69\u001b[0m| \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mDuncans-MBP.fritz.box\u001b[0m| \u001b[36msuperduper.base.document\u001b[0m:\u001b[36m554 \u001b[0m | \u001b[33m\u001b[1mLeaf str already exists\u001b[0m\n" ] } ], @@ -633,7 +606,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/templates/simple_rag/component.json b/templates/simple_rag/component.json index 05d54f849..62dea8bf2 100644 --- a/templates/simple_rag/component.json +++ b/templates/simple_rag/component.json @@ -10,7 +10,7 @@ "_base": "?simple-rag-app", "_builds": { "model:chunker": { - "_object": "&:blob:6a536b4ec925b94103a04c3083f940fe07ed75e1", + "_object": "&:blob:2b1758f910e08d9bb2eaa836d00c725d22ee9616", "upstream": null, "plugins": null, "cache": true, @@ -31,7 +31,7 @@ "chunk_size": 200 }, "var-table-name-select-var-id-field-x": { - "_path": "superduper_.query.parse_query", + "_path": "superduper.backends.base.query.parse_query", "documents": [], "query": ".select(\"\", \"x\")" }, @@ -48,18 +48,6 @@ "select": "?var-table-name-select-var-id-field-x", "flatten": true }, - "datatype:vector[1536]": { - "_path": "superduper.components.datatype.Vector", - "upstream": null, - "plugins": null, - "cache": true, - "build_variables": null, - "build_template": null, - "shape": [ - 1536 - ], - "dtype": "float64" - }, "model:text-embedding": { "_path": "superduper_openai.model.OpenAIEmbedding", "upstream": null, @@ -68,7 +56,7 @@ "build_variables": null, "build_template": null, "signature": "singleton", - "datatype": "?datatype:vector[1536]", + "datatype": null, "output_schema": null, "model_update_kwargs": {}, "predict_kwargs": {}, @@ -81,14 +69,14 @@ "deploy": false, "model": "", "max_batch_size": 8, - "postprocess": "&:blob:e8e60a69a01b49adc788e151e61426c579d8f935", + "postprocess": null, "openai_api_key": null, "openai_api_base": null, "client_kwargs": {}, "batch_size": 100 }, "outputs-chunker-?(listener:chunker.uuid)-select-id-source-outputs-chunker-?(listener:chunker.uuid)": { - "_path": "superduper_.query.parse_query", + "_path": "superduper.backends.base.query.parse_query", "documents": [], "query": "chunker__?(listener:chunker.uuid).select(\"id\", \"_source\", \"chunker__?(listener:chunker.uuid)\")" }, @@ -116,17 +104,17 @@ "build_template": null, "indexing_listener": "?listener:embeddinglistener", "compatible_listener": null, - "measure": "l2", + "measure": "cosine", "metric_values": {} }, - "outputs-chunker-?(listener:chunker.uuid)-select-query-0-like-outputs-chunker-?(listener:chunker.uuid)-var-query-vector-index-vectorindex-n-5": { - "_path": "superduper_.query.parse_query", + "outputs-chunker-?(listener:chunker.uuid)-select-like-outputs-chunker-?(listener:chunker.uuid)-var-query-vectorindex-n-5": { + "_path": "superduper.backends.base.query.parse_query", "documents": [ { "chunker__?(listener:chunker.uuid)": "" } ], - "query": "chunker__?(listener:chunker.uuid)\nchunker__?(listener:chunker.uuid).select(query[0]).like(documents[0], vector_index=\"vectorindex\", n=5)" + "query": "chunker__?(listener:chunker.uuid).select().like(documents[0], \"vectorindex\", n=5)" }, "model:llm-model": { "_path": "superduper_openai.model.OpenAIChatCompletion", @@ -176,7 +164,7 @@ "trainer": null, "deploy": false, "prompt_template": "Use the following context snippets, these snippets are not ordered!, Answer the question based on this context.\nThese snippets are samples from our internal data-repositories, and should be used exclusively and as a matter of priority to answer the question. Please answer in 20 words or less.\n\n{context}\n\nHere's the question: {query}", - "select": "?outputs-chunker-?(listener:chunker.uuid)-select-query-0-like-outputs-chunker-?(listener:chunker.uuid)-var-query-vector-index-vectorindex-n-5", + "select": "?outputs-chunker-?(listener:chunker.uuid)-select-like-outputs-chunker-?(listener:chunker.uuid)-var-query-vectorindex-n-5", "key": "chunker__?(listener:chunker.uuid)", "llm": "?model:llm-model" }, @@ -199,9 +187,10 @@ "cache": true, "build_variables": null, "build_template": null, - "demo_func": "&:blob:34c27f9f9368613917c80f0b66d7ae6e144f0794", + "demo_func": "&:blob:c663645459821c0f6a085c0f21d2ae498d037bdd", "demo_kwargs": {}, - "default": false + "default": false, + "is_standalone": false }, "simple-rag-app": { "_path": "superduper.components.application.Application", @@ -277,11 +266,11 @@ "_builds": { "str": { "_path": "superduper.components.schema.FieldType", - "uuid": "c3a936826e214795" + "uuid": "49a7db083e1c406c" }, "schema:sample_simple_rag/schema": { "_path": "superduper.components.schema.Schema", - "uuid": "ba88363e986a4262", + "uuid": "dbdca687d3f64472", "upstream": null, "plugins": null, "cache": true, @@ -294,17 +283,17 @@ }, "dataset:superduper-docs": { "_path": "superduper.components.dataset.RemoteData", - "uuid": "eaef7d2313774641", + "uuid": "ca1a5e9e171a492e", "upstream": null, "plugins": null, "cache": true, "build_variables": null, "build_template": null, - "getter": "&:blob:558862283097265020c65fb73179764194a1f5e7" + "getter": "&:blob:ae6c79e267e9c8689461504861898e273522f09d" }, "table:sample_simple_rag": { "_path": "superduper.components.table.Table", - "uuid": "5ea14bc1c1dc4c66", + "uuid": "59d30c168633457e", "upstream": null, "plugins": null, "cache": true, diff --git a/test/integration/template/test_template.py b/test/integration/template/test_template.py index 9c3cb0434..17b7c42c5 100644 --- a/test/integration/template/test_template.py +++ b/test/integration/template/test_template.py @@ -22,12 +22,11 @@ def test_template(): assert f'sample_{template_name}' in db.show('table') - sample = db[f'sample_{template_name}'].select().limit(2).tolist() + sample = db[f'sample_{template_name}'].limit(2).execute() assert sample print('Got sample:', sample) - print(f'Got {len(sample)} samples') app = t() diff --git a/test/integration/usecase/test_eager.py b/test/integration/usecase/test_eager.py index e7779aa48..124a141fb 100644 --- a/test/integration/usecase/test_eager.py +++ b/test/integration/usecase/test_eager.py @@ -24,9 +24,9 @@ def test_graph_deps(db: "Datalayer"): {"x": 3, "y": "4", "z": Tuple(7, 8, 9)}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) - data = list(db["documents"].select().execute(eager_mode=True))[0] + data = db["documents"].select().execute(eager_mode=True)[0] def func_a(x): return Tuple(x, "a") @@ -74,7 +74,7 @@ def func_c(x, y, z, o_a, o_b): {"x": 6, "y": "7", "z": Tuple(16, 17, 18)}, ] - db["documents"].insert(new_data).execute() + db["documents"].insert(new_data) new_data = Document( list( @@ -102,9 +102,9 @@ def test_flatten(db: "Datalayer"): {"n": 3, "x": "c"}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) - data = list(db["documents"].select().execute(eager_mode=True))[0] + data = db["documents"].select().execute(eager_mode=True)[0] def func_a(n, x): return [x] * n @@ -115,9 +115,7 @@ def func_a(n, x): output_a.apply() pprint(output_a) - outputs = list( - db[f'_outputs__{output_a.predict_id}'].select().execute(eager_mode=True) - ) + outputs = db[f'_outputs__{output_a.predict_id}'].select().execute(eager_mode=True) assert len(outputs) == 6 results = [o[f'_outputs__{output_a.predict_id}'].data for o in outputs] @@ -133,9 +131,9 @@ def test_predict_id(db: "Datalayer"): {"x": 1}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) - data = list(db["documents"].select().execute(eager_mode=True))[0] + data = db["documents"].select().execute(eager_mode=True)[0] model_a = ObjectModel(identifier="a", object=lambda x: f"{x}->a") @@ -170,9 +168,9 @@ def test_condition(db: "Datalayer"): {"x": 3}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) - data = list(db["documents"].select().execute(eager_mode=True))[0] + data = db["documents"].select().execute(eager_mode=True)[0] model = ObjectModel(identifier="a", object=lambda x: f"{x}->a") diff --git a/test/integration/usecase/test_graceful_update.py b/test/integration/usecase/test_graceful_update.py index d6a3e6f03..7aec8bc7d 100644 --- a/test/integration/usecase/test_graceful_update.py +++ b/test/integration/usecase/test_graceful_update.py @@ -35,7 +35,7 @@ def validate(self, model): def test(db: Datalayer): db.cfg.auto_schema = True - db['docs'].insert([{'x': random.randrange(10)} for _ in range(10)]).execute() + db['docs'].insert([{'x': random.randrange(10)} for _ in range(10)]) def build_vi(**kwargs): model = MyModel('my-model', example=1, **kwargs) diff --git a/test/integration/usecase/test_output_prefix.py b/test/integration/usecase/test_output_prefix.py index 1e6179972..77ac76c0e 100644 --- a/test/integration/usecase/test_output_prefix.py +++ b/test/integration/usecase/test_output_prefix.py @@ -23,25 +23,21 @@ def test_output_prefix(db): "sddb_outputs_c", ] - tables = [ - x - for x in db.databackend.list_tables_or_collections() - if not x.startswith("_") - ] + tables = [x for x in db.databackend.list_tables() if not x.startswith("_")] for t in expect_tables: assert any(k.startswith(t) for k in tables) - outputs_a = db[listener_a.outputs].select().tolist() + outputs_a = db[listener_a.outputs].select().execute() assert len(outputs_a) == 6 for r in outputs_a: assert any(k.startswith("sddb_outputs_a") for k in r) - outputs_b = db[listener_b.outputs].select().tolist() + outputs_b = db[listener_b.outputs].select().execute() assert len(outputs_b) == 6 for r in outputs_b: assert any(k.startswith("sddb_outputs_b") for k in r) - outputs_c = db[listener_c.outputs].select().tolist() + outputs_c = db[listener_c.outputs].select().execute() assert len(outputs_c) == 6 for r in outputs_c: assert any(k.startswith("sddb_outputs_c") for k in r) diff --git a/test/integration/usecase/test_reapply.py b/test/integration/usecase/test_reapply.py index 76ba47311..42e89851b 100644 --- a/test/integration/usecase/test_reapply.py +++ b/test/integration/usecase/test_reapply.py @@ -14,7 +14,7 @@ def predict(self, x): def test_reapply(db): db.cfg.auto_schema = True - db['docs'].insert([{'x': i} for i in range(10)]).execute() + db['docs'].insert([{'x': i} for i in range(10)]) def build(name, data): model = MyModel('test', a=name, b=data) @@ -34,7 +34,7 @@ def build(name, data): db.apply(listener_2) - outputs = db[listener_2.outputs].select().tolist() + outputs = db[listener_2.outputs].select().execute() import pprint diff --git a/test/integration/usecase/test_training.py b/test/integration/usecase/test_training.py index e1d25019c..38e49fe33 100644 --- a/test/integration/usecase/test_training.py +++ b/test/integration/usecase/test_training.py @@ -48,7 +48,7 @@ def test_training(db: "Datalayer"): # Need to reload to get the fitted model reloaded = db.load('model', 'my-model') - r = next(db['documents'].select().limit(1).execute()) + r = db['documents'].get() # This only works if the model was trained prediction = reloaded.predict(r['x']) diff --git a/test/integration/usecase/test_vector_index.py b/test/integration/usecase/test_vector_index.py index 332224e4a..cf4df3552 100644 --- a/test/integration/usecase/test_vector_index.py +++ b/test/integration/usecase/test_vector_index.py @@ -9,20 +9,18 @@ # @pytest.mark.skip def test_vector_index(db: "Datalayer"): def check_result(out, sample_data): - scores = out.scores - ids = [o[primary_id] for o in list(out)] assert len(ids) == 10 assert sample_data[primary_id] in ids - assert scores[str(sample_data[primary_id])] > 0.999999 + assert out[0]['score'] > 0.999999 build_vector_index(db, n=100) vector_index = "vector_index" table = db["documents"] - primary_id = table.primary_id - sample_data = next(table.select().filter(table['x'] == 50).execute()) + primary_id = table.primary_id.execute() + sample_data = table.select().filter(table['x'] == 50).execute()[0] # test indexing vector search out = ( @@ -44,7 +42,8 @@ def check_result(out, sample_data): # test adding new data out = table.like({"x": 150}, vector_index=vector_index, n=10).select().execute() - assert sum(out.scores.values()) == 0 + scores = [r['score'] for r in out] + assert sum(scores) == 0 # TODO - this is not triggering the update of the component add_data(db, 100, 200) @@ -52,6 +51,6 @@ def check_result(out, sample_data): assert len(db.cluster.vector_search[vector_index]) == 200 out = table.like({"x": 150}, vector_index=vector_index, n=1).select().execute() - result = next(out) + result = out[0] assert result['x'] == 150 assert result['score'] > 0.999999 diff --git a/test/rest/mock_client.py b/test/rest/mock_client.py index a66818f1d..7083016ef 100644 --- a/test/rest/mock_client.py +++ b/test/rest/mock_client.py @@ -11,11 +11,6 @@ def make_params(params): return '?' + urlencode(params) -def insert(client, data): - query = {'query': 'coll.insert_many(documents)', 'documents': data} - return client.post('/db/execute', json=query) - - def apply(client, component): return client.post('/db/apply', json=component) @@ -38,13 +33,14 @@ def setup(client): {"x": [1, 2, 3, 4, 5], "y": 'test'}, {"x": [6, 7, 8, 9, 10], "y": 'test'}, ] - insert(client, data) + db['coll'].insert(data) return client def teardown(client): - delete(client) - remove(client, 'datatype', 'image') + # delete(client) + # remove(client, 'datatype', 'image') + ... if __name__ == '__main__': diff --git a/test/rest/test_rest.py b/test/rest/test_rest.py index fc1d0c1a8..a5b8364d5 100644 --- a/test/rest/test_rest.py +++ b/test/rest/test_rest.py @@ -4,7 +4,6 @@ from fastapi.testclient import TestClient from superduper import CFG -from superduper.base.document import Document CFG.auto_schema = True CFG.rest.uri = 'localhost:8000' @@ -28,7 +27,7 @@ def test_health(setup): def test_select_data(setup): - result = setup.post('/db/execute', json={'query': 'coll.find({}, {"_id": 0})'}) + result = setup.post('/db/execute', json={'query': 'coll.select()'}) result = json.loads(result.content) if 'error' in result: raise Exception(result['messages'] + result['traceback']) @@ -70,65 +69,3 @@ def test_apply(setup): models = json.loads(models.content) assert models == ['my_function'] - - -@pytest.mark.skip -def test_insert_image(setup): - result = setup.put( - '/db/artifact_store/put', files={"raw": ("test/material/data/test.png")} - ) - result = json.loads(result.content) - - file_id = result['file_id'] - - query = { - '_path': 'superduper_mongodb.query.parse_query', - 'query': 'coll.insert_one(documents[0])', - '_builds': { - 'image_type': { - '_path': 'superduper.ext.pillow.encoder.image_type', - 'encodable': 'artifact', - }, - 'my_artifact': { - '_path': 'superduper.components.datatype.LazyArtifact', - 'blob': f'&:blob:{file_id}', - 'datatype': "?image_type", - }, - }, - 'documents': [ - { - 'img': '?my_artifact', - } - ], - } - - result = setup.post( - '/db/execute', - json=query, - ) - - query = { - '_path': 'superduper_mongodb.query.parse_query', - 'query': 'coll.find(documents[0], documents[1])', - 'documents': [{}, {'_id': 0}], - } - - result = setup.post( - '/db/execute', - json=query, - ) - - result = json.loads(result.content) - from superduper import superduper - - db = superduper() - - result = [Document.decode(r[0], db=db).unpack() for r in result] - - assert len(result) == 3 - - image_record = next(r for r in result if 'img' in r) - - from PIL.PngImagePlugin import PngImageFile - - assert isinstance(image_record['img'], PngImageFile) diff --git a/test/unittest/backends/base/test_base_query.py b/test/unittest/backends/base/test_base_query.py new file mode 100644 index 000000000..dfd03cae5 --- /dev/null +++ b/test/unittest/backends/base/test_base_query.py @@ -0,0 +1,56 @@ +from superduper.base.document import Document + + +def test(db): + q = db['docs'].select() + + assert hasattr(q, 'filter') + + ### + + q = db['docs'].select('a').select('b') + + assert str(q) == 'docs.select("a", "b")' + + ### + + t = db['docs'] + q = t.select('a').filter(t['a'] == 2) + + assert str(q) == 'docs.filter(docs["a"] == 2).select("a")' + + +def test_decomposition(db): + t = db['docs'] + + q = t.filter(t['a'] == 2).select('a') + + d = q.decomposition + + assert d.filter is not None and d.select is not None + + assert str(d.to_query()) == str(q) + + +def test_stringify(db): + t = db['docs'] + + q = t.filter(t['a'] == 2).select('a') + + assert str(q) == 'docs.filter(docs["a"] == 2).select("a")' + + q = t.like({'a': 2}, vector_index='test').select() + + assert str(q) == "docs.like({'a': 2}, \"test\", n=10).select()" + + +def test_serialize_deserialize(db): + t = db['docs'] + + q = t.like({'a': 2}, vector_index='test').select() + + r = q.dict() + + de_q = Document.decode(r, db=db).unpack() + + assert str(q) == str(de_q) diff --git a/test/unittest/backends/base/test_query.py b/test/unittest/backends/base/test_query.py index 2bd96e976..d34d743b7 100644 --- a/test/unittest/backends/base/test_query.py +++ b/test/unittest/backends/base/test_query.py @@ -1,12 +1,52 @@ from test.utils.database import query as query_utils -from superduper import Document +from superduper.backends.base.query import parse_query def test_insert(db): query_utils.test_insert(db) +def test_atomic_parse(db): + query_utils.test_atomic_parse(db) + + +def test_encode_decode_data(db): + query_utils.test_encode_decode_data(db) + + +def test_filter_select(db): + query_utils.test_filter_select(db) + + +def test_filter(db): + query_utils.test_filter(db) + + +def test_select_one_col(db): + query_utils.test_select_one_col(db) + + +def test_select_all_cols(db): + query_utils.test_select_all_cols(db) + + +def test_select_table(db): + query_utils.test_select_table(db) + + +def test_ids(db): + query_utils.test_ids(db) + + +def test_subset(db): + query_utils.test_subset(db) + + +def test_outputs(db): + query_utils.test_outputs(db) + + def test_read(db): query_utils.test_read(db) @@ -23,30 +63,11 @@ def test_insert_with_diff_schemas(db): query_utils.test_insert_with_diff_schemas(db) -def test_auto_document_wrapping(db): - query_utils.test_auto_document_wrapping(db) - - -def test_model(db): - query_utils.test_model(db) - - -def test_model_query(): - query_utils.test_model_query() - - -def test_model_query_serialization(): - query = { - "query": 'modela.predict("", documents[0], ' - 'condition={"uri": "123.PDF"})', - '_variables': {'abc': "ABC"}, - "documents": [{"a": "a"}], - "_builds": {}, - "_files": {}, - "_path": "superduper.backends.base.query.parse_query", - } +def test_parse_outputs_query(db): + q = parse_query( + query='_outputs__listener1__9bc4a01366f24603.select()', + documents=[], + db=db, + ) - decode_query = Document.decode(query) - assert decode_query.parts[0][0] == 'predict' - assert decode_query.parts[0][1] == ("ABC", {'a': 'a'}) - assert decode_query.parts[0][2] == {'condition': {'uri': '123.PDF'}} + assert len(q) == 2 diff --git a/test/unittest/backends/test_query_dataset.py b/test/unittest/backends/test_query_dataset.py index cff131bc9..0ee3a5bd6 100644 --- a/test/unittest/backends/test_query_dataset.py +++ b/test/unittest/backends/test_query_dataset.py @@ -1,5 +1,6 @@ import pytest +from superduper import CFG from superduper.backends.query_dataset import QueryDataset from superduper.components.model import Mapping @@ -16,7 +17,7 @@ def test_query_dataset(db): add_random_data(db) add_models(db) add_vector_index(db) - primary_id = db["documents"].primary_id + primary_id = db["documents"].primary_id.execute() listener_uuid = db.show('listener', 'vector-x', -1)['uuid'] @@ -24,18 +25,19 @@ def test_query_dataset(db): db=db, mapping=Mapping("_base", signature="singleton"), select=db["documents"] - .select(primary_id, 'x', '_fold') - .outputs("vector-x__" + listener_uuid), + .outputs("vector-x__" + listener_uuid) + .select( + primary_id, 'x', '_fold', f"{CFG.output_prefix}vector-x__" + listener_uuid + ), fold="train", ) + r = train_data[0] + assert r["_fold"] == "train" assert "y" not in r assert "x" in r - db["documents"].select(primary_id, 'x', '_fold').outputs( - "vector-x__" + listener_uuid - ) assert r['_outputs__vector-x__' + listener_uuid].shape[0] == 16 train_data = QueryDataset( diff --git a/test/unittest/base/test_apply.py b/test/unittest/base/test_apply.py index 30d3d8bfb..dc99ef4bf 100644 --- a/test/unittest/base/test_apply.py +++ b/test/unittest/base/test_apply.py @@ -226,7 +226,7 @@ def my_func(x): ], ) - db['docs'].insert([{'x': i} for i in range(10)]).execute() + db['docs'].insert([{'x': i} for i in range(10)]) db.apply(c) diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index 9d934402e..b79330b9a 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -405,9 +405,7 @@ def test_load(db): def test_insert(db): db.cfg.auto_schema = True add_fake_model(db) - inserted_ids = ( - db['documents'].insert([{'x': i, 'update': True} for i in range(5)]).execute() - ) + inserted_ids = db['documents'].insert([{'x': i} for i in range(5)]) assert len(inserted_ids) == 5 uuid = db.show('listener', 'listener-x', 0)['uuid'] @@ -420,39 +418,16 @@ def test_insert(db): def test_insert_artifacts(db): db.cfg.auto_schema = True - db._insert( - db['documents'].insert( - [Document({'x': numpy.random.randn(100)}) for _ in range(1)] - ) - ) - r = list(db.execute(db['documents'].select()))[0] + db['documents'].insert([{'x': numpy.random.randn(100)} for _ in range(1)]) + r = db['documents'].get() assert isinstance(r['x'], numpy.ndarray) -def test_insert_sql_db(db): - db.cfg.auto_schema = True - listener = add_fake_model(db) - table = db['documents'] - inserted_ids = db.execute( - table.insert([Document({'id': str(i), 'x': i}) for i in range(5)]) - ) - assert len(inserted_ids) == 5 - - q = table.select().outputs(listener.predict_id) - - new_docs = db.execute(q) - new_docs = list(new_docs) - - result = [Document(doc.unpack())[listener.outputs] for doc in new_docs] - assert sorted(result) == ['0', '1', '2', '3', '4'] - - @pytest.mark.skipif(not mongodb_config, reason='MongoDB not configured') def test_update_db(db): # TODO: test update sql db after the update method is implemented add_fake_model(db) - q = db['documents'].insert([Document({'x': i, 'update': True}) for i in range(5)]) - db._insert(q) + db['documents'].insert([Document({'x': i, 'update': True}) for i in range(5)]) updated_ids, _ = db._update( db['documents'].update_many({}, Document({'$set': {'x': 100}})) ) @@ -605,7 +580,7 @@ def test_dataset(db): db.apply(d) assert db.show('dataset') == ['test_dataset'] dataset = db.load('dataset', 'test_dataset') - assert len(dataset.data) == len(list(db.execute(dataset.select))) + assert len(dataset.data) == len(dataset.select.execute()) def test_delete_component_with_same_artifact(db): diff --git a/test/unittest/base/test_document.py b/test/unittest/base/test_document.py index cf059a60d..5bb267b6f 100644 --- a/test/unittest/base/test_document.py +++ b/test/unittest/base/test_document.py @@ -5,7 +5,6 @@ import numpy as np import pytest -from superduper.backends.base.query import Query from superduper.base.constant import KEY_BLOBS, KEY_BUILDS from superduper.base.document import Document from superduper.components.datatype import ( @@ -30,16 +29,18 @@ def test_document_encoding(db): assert (new_document['x'] - document['x']).sum() == 0 -def test_flat_query_encoding(): - q = Query(table='docs').find({'a': 1}).limit(2) +def test_flat_query_encoding(db): + # TODO what is being tested here?? - r = q._deep_flat_encode({}, {}, {}) + t = db['docs'] - doc = Document({'x': 1}) + q = t.filter(t['a'] == 1).limit(2) - q = Query(table='docs').like(doc, vector_index='test').find({'a': 1}).limit(2) + r = q.encode() - r = q._deep_flat_encode({}, {}, {}) + q = t.like({'x': 1}, vector_index='test').filter(t['a'] == 1).limit(2) + + r = q.encode() print(r) @@ -207,7 +208,7 @@ def test_column_encoding(db): schema = Schema( 'test', fields={ - 'id': int, + 'id': str, 'x': int, 'y': int, 'data': pickle_serializer, @@ -218,10 +219,10 @@ def test_column_encoding(db): data = np.random.rand(20) db['test'].insert( [ - Document({'id': 1, 'x': 1, 'y': 2, 'data': data}), - Document({'id': 2, 'x': 3, 'y': 4, 'data': data}), + {'id': '1', 'x': 1, 'y': 2, 'data': data}, + {'id': '2', 'x': 3, 'y': 4, 'data': data}, ] - ).execute() + ) db['test'].select().execute() diff --git a/test/unittest/base/test_leaf.py b/test/unittest/base/test_leaf.py index 3d8e22f06..b044c2158 100644 --- a/test/unittest/base/test_leaf.py +++ b/test/unittest/base/test_leaf.py @@ -3,7 +3,6 @@ from pprint import pprint from superduper import ObjectModel -from superduper.backends.base.query import Query from superduper.base.constant import KEY_BLOBS, KEY_BUILDS from superduper.base.document import Document from superduper.base.leaf import Leaf @@ -75,20 +74,16 @@ def test_encode_leaf_with_children(): } -def test_save_variables_2(): - query = ( - Query(table='documents') - .like({'x': ''}, vector_index='test') - .find({'x': {'$regex': '^test/1'}}) - ) +def test_save_variables_2(db): + t = db['documents'] + query = t.like({'x': ''}, vector_index='test').filter(t['x'] == 1) assert [x for x in query.variables] == ['X'] -def test_save_non_string_variables(): - query = Query(table='documents').find().limit('') - - assert str(query) == 'documents.find().limit("")' +def test_save_non_string_variables(db): + query = db['documents'].select().limit('') + assert str(query) == 'documents.select().limit("")' assert [x for x in query.variables] == ['limit'] @@ -121,23 +116,20 @@ def test_component_with_document(): print(type(builds[leaf])) -def test_find_variables(): +def test_find_variables(db): from superduper import Document r = Document({'txt': ''}) assert r.variables == ['test'] - q = Query(table='test').find_one(Document({'txt': ''})) + t = db['test'] + + q = t.filter(t['txt'] == '') assert q.variables == ['test'] - q = ( - Query(table='test') - .like(Document({'txt': ''}), vector_index='test') - .find() - .limit(5) - ) + q = db['test'].like({'txt': ''}, vector_index='test').limit(5) q_set = q.set_variables(test='my-value') diff --git a/test/unittest/component/datatype/test_pickle.py b/test/unittest/component/datatype/test_pickle.py index 1e42dc987..0050a80b0 100644 --- a/test/unittest/component/datatype/test_pickle.py +++ b/test/unittest/component/datatype/test_pickle.py @@ -4,7 +4,6 @@ import pandas as pd import pytest -from superduper.base.enums import DBType from superduper.components.datatype import ( BaseDataType, pickle_encoder, @@ -43,6 +42,6 @@ def test_component(random_data, datatype): def test_component_with_db(db, random_data, datatype): # TODO: Need to fix the encodable in component when db is SQL # Some bytes are not serializable, then can't be stored in SQL - if datatype.encodable == "encodable" and db.databackend.db_type == DBType.SQL: - return + # if datatype.encodable == "encodable" and db.databackend.db_type == DBType.SQL: + # return datatype_utils.check_component_with_db(random_data, datatype, db) diff --git a/test/unittest/component/datatype/test_vector.py b/test/unittest/component/datatype/test_vector.py index 23c527ea0..cd4674a3c 100644 --- a/test/unittest/component/datatype/test_vector.py +++ b/test/unittest/component/datatype/test_vector.py @@ -3,7 +3,7 @@ def test_auto_detect_vector(db): db.cfg.auto_schema = True - db['vectors'].insert([{'x': numpy.random.randn(7)} for _ in range(3)]).execute() + db['vectors'].insert([{'x': numpy.random.randn(7)} for _ in range(3)]) assert 'vector[7]' in db.show('datatype') diff --git a/test/unittest/component/test_component.py b/test/unittest/component/test_component.py index 27abe0468..5bc213e14 100644 --- a/test/unittest/component/test_component.py +++ b/test/unittest/component/test_component.py @@ -120,7 +120,7 @@ def test_set_variables(db): object=lambda x: x + 2, ), key="", - select=db["docs"].find(), + select=db["docs"], ) from superduper import Document @@ -186,7 +186,7 @@ def test_set_db_deep(db): object=lambda x: x + 2, ), key="x", - select=db["docs"].find(), + select=db["docs"], ) assert m.upstream[0].db is None diff --git a/test/unittest/component/test_dataset.py b/test/unittest/component/test_dataset.py index 300196a27..1310ec7ed 100644 --- a/test/unittest/component/test_dataset.py +++ b/test/unittest/component/test_dataset.py @@ -9,7 +9,7 @@ def test_dataset_pin(db, pin): datas = [{"x": i, "y": [1, 2, 3]} for i in range(10)] - db["documents"].insert(datas).execute() + db["documents"].insert(datas) select = db["documents"].select() @@ -22,7 +22,7 @@ def test_dataset_pin(db, pin): assert db.show("dataset") == ["test_dataset"] new_datas = [{"x": i, "y": [1, 2, 3]} for i in range(10, 20)] - db["documents"].insert(new_datas).execute() + db["documents"].insert(new_datas) dataset: Dataset = db.load("dataset", "test_dataset") dataset.init(db) if pin: diff --git a/test/unittest/component/test_graph.py b/test/unittest/component/test_graph.py index eccca41e9..dbe5e2ba3 100644 --- a/test/unittest/component/test_graph.py +++ b/test/unittest/component/test_graph.py @@ -135,10 +135,7 @@ def test_complex_graph_with_select(db): ) db.apply(listener) assert all( - [ - '_outputs__test__test' in x - for x in list(db.execute(db['_outputs__test__test'].select())) - ] + ['_outputs__test__test' in x for x in db['_outputs__test__test'].execute()] ) diff --git a/test/unittest/component/test_listener.py b/test/unittest/component/test_listener.py index c34d6ff38..e386b0f14 100644 --- a/test/unittest/component/test_listener.py +++ b/test/unittest/component/test_listener.py @@ -5,7 +5,6 @@ import pytest from superduper import Application, Document -from superduper.backends.base.query import Query from superduper.base.constant import KEY_BLOBS from superduper.components.listener import Listener from superduper.components.model import ObjectModel, Trainer @@ -24,8 +23,8 @@ class _Tmp(ObjectModel): ... -def test_listener_serializes_properly(): - q = Query(table='test').find({}, {}) +def test_listener_serializes_properly(db): + q = db['test'].select() listener = Listener( identifier="listener", model=ObjectModel("test", object=lambda x: x), @@ -58,7 +57,7 @@ def insert_random(start=0): ) ) - db.execute(table.insert(data)) + table.insert(data) # Insert data insert_random() @@ -85,7 +84,7 @@ def insert_random(start=0): db.apply(listener2) def check_listener_output(listener, output_n): - docs = db[listener.outputs].select().tolist() + docs = db[listener.outputs].select().execute() assert len(docs) == output_n assert all([listener.outputs in r for r in docs]) @@ -120,9 +119,7 @@ def test_create_output_dest(db, data, flatten): "m1", object=lambda x: data if not flatten else [data] * 10, ) - q = table.insert([{"x": 1}]) - - db.execute(q) + table.insert([{"x": 1}]) listener1 = Listener( model=m1, @@ -134,7 +131,7 @@ def test_create_output_dest(db, data, flatten): db.apply(listener1) - doc = db[listener1.outputs].select().tolist()[0] + doc = db[listener1.outputs].select().get() result = Document(doc.unpack())[listener1.outputs] assert isinstance(result, type(data)) if isinstance(data, np.ndarray): @@ -164,9 +161,7 @@ def test_listener_cleanup(db, data): "m1", object=lambda x: data, ) - q = table.insert([{"x": 1}]) - - db.execute(q) + table.insert([{"x": 1}]) listener1 = Listener( model=m1, @@ -176,7 +171,7 @@ def test_listener_cleanup(db, data): ) db.add(listener1) - doc = db[listener1.outputs].select().tolist()[0] + doc = db[listener1.outputs].select().execute()[0] result = Document(doc.unpack())[listener1.outputs] assert isinstance(result, type(data)) if isinstance(data, np.ndarray): @@ -216,7 +211,7 @@ def insert_random(start=0): ) ) - db.execute(table.insert(data)) + table.insert(data) # Insert data insert_random() @@ -294,7 +289,7 @@ def test_predict_id_utils(db): ] ) - db.execute(q) + q.execute() listener1 = Listener( model=m1, @@ -309,19 +304,18 @@ def test_predict_id_utils(db): outputs = listener1.outputs # Listener identifier is set as the table name select = db[outputs].select() - docs = select.tolist() - # docs = list(db.execute(select)) + docs = select.execute() assert [doc[listener1.outputs] for doc in docs] == [1, 2, 3] # Listener identifier is set as the table name and filter is applied table = db[outputs].select() select = table.filter(table[outputs] > 1) - docs = select.tolist() + docs = select.execute() assert [doc[listener1.outputs] for doc in docs] == [2, 3] # Listener identifier is set as the predict_id in outputs() select = db["test"].select().outputs('listener1') - docs = select.tolist() + docs = select.execute() assert [doc[listener1.outputs] for doc in docs] == [1, 2, 3] @@ -341,11 +335,9 @@ def test_complete_uuids(db): ] ) - db.execute(q) - l1 = Listener( model=m1, - select=db['test'].select(), + select=db['test'], key="x", identifier="l1", ) @@ -358,7 +350,7 @@ def test_complete_uuids(db): assert f'"{l1.predict_id}"' in str(qq) or l1.predict_id in str(qq) - results = q.tolist() + results = q.execute() assert results[0]['_outputs__l1'] == results[0][l1.outputs] @@ -377,7 +369,7 @@ def test_autofill_data_listener(db): {"x": 2}, {"x": 3}, ] - ).execute() + ) l1 = m.to_listener(select=db['test'].select(), key='x', identifier='l1') l2 = m.to_listener(select=db[l1.outputs].select(), key=l1.outputs, identifier='l2') diff --git a/test/unittest/component/test_model.py b/test/unittest/component/test_model.py index c674d8986..0a3d0817e 100644 --- a/test/unittest/component/test_model.py +++ b/test/unittest/component/test_model.py @@ -108,7 +108,7 @@ def test_pm_predict_batches(predict_mixin): db.compute = MagicMock(spec=LocalComputeBackend) db.metadata = MagicMock() db.databackend = MagicMock() - select = MagicMock(spec=Query) + select = MagicMock() predict_mixin.db = db with patch.object(predict_mixin, 'predict_batches') as predict_func, patch.object( @@ -119,6 +119,7 @@ def test_pm_predict_batches(predict_mixin): predict_func.assert_called_once() +@pytest.mark.skip def test_pm_predict_with_select_ids(monkeypatch, predict_mixin): xs = [np.random.randn(4) for _ in range(10)] @@ -137,21 +138,15 @@ def test_pm_predict_with_select_ids(monkeypatch, predict_mixin): my_object.return_value = 2 # Check the base predict function predict_mixin.db = db - with patch.object(select, 'select_using_ids') as select_using_ids, patch.object( - select, 'model_update' - ) as model_update: + with patch.object(select, 'select_using_ids') as select_using_ids: predict_mixin._predict_with_select_and_ids( X=X, select=select, ids=ids, predict_id='test' ) select_using_ids.assert_called_once_with(ids) - _, kwargs = model_update.call_args - # make sure the outputs are set - assert kwargs.get('outputs') == [2] * 10 with ( patch.object(predict_mixin, 'object') as my_object, patch.object(select, 'select_using_ids') as select_using_id, - patch.object(select, 'model_update') as model_update, ): my_object.return_value = 2 @@ -164,9 +159,6 @@ def test_pm_predict_with_select_ids(monkeypatch, predict_mixin): X=X, select=select, ids=ids, predict_id='test' ) select_using_id.assert_called_once_with(ids) - _, kwargs = model_update.call_args - kwargs_output_ids = [o for o in kwargs.get('outputs')] - assert kwargs_output_ids == [2] * 10 with patch.object(predict_mixin, 'object') as my_object: my_object.return_value = {'out': 2} @@ -177,15 +169,11 @@ def test_pm_predict_with_select_ids(monkeypatch, predict_mixin): predict_mixin.output_schema = schema = MagicMock(spec=Schema) predict_mixin.db = db schema.side_effect = str - with patch.object(select, 'select_using_ids') as select_using_ids, patch.object( - select, 'model_update' - ) as model_update: + with patch.object(select, 'select_using_ids') as select_using_ids: predict_mixin._predict_with_select_and_ids( X=X, select=select, ids=ids, predict_id='test' ) select_using_ids.assert_called_once_with(ids) - _, kwargs = model_update.call_args - assert kwargs.get('outputs') == [{'out': 2} for _ in range(10)] def test_model_append_metrics(): @@ -326,6 +314,7 @@ def fit(self, *args, **kwargs): model.db.replace.assert_called_once() +@pytest.mark.skip def test_query_model(db): from test.utils.setup.fake_data import add_models, add_random_data, add_vector_index @@ -402,27 +391,22 @@ def func(x, y): def _test(X, docs): ids = [i for i in range(10)] - select = MagicMock(spec=Query) + select = MagicMock() db = MagicMock(spec=Datalayer) db.databackend = MagicMock(spec=BaseDataBackend) - db.execute.return_value = docs + select.execute.return_value = docs # Check the base predict function predict_mixin_multikey.db = db - with patch.object(select, 'select_using_ids') as select_using_ids, patch.object( - select, 'model_update' - ) as model_update: + with patch.object(select, 'subset') as subset: predict_mixin_multikey._predict_with_select_and_ids( X=X, predict_id='test', select=select, ids=ids ) - select_using_ids.assert_called_once_with(ids) - _, kwargs = model_update.call_args - # make sure the outputs are set - assert kwargs.get('outputs') == [2] * 10 + subset.assert_called_once_with(ids) # TODO - I don't know how this works given that the `_outputs` field # should break... - docs = [Document({'x': x, 'y': x}) for x in xs] + docs = [{'x': x, 'y': x} for x in xs] X = ('x', 'y') _test(X, docs) diff --git a/test/unittest/component/test_schema.py b/test/unittest/component/test_schema.py index 54d735f63..58c4d4ef8 100644 --- a/test/unittest/component/test_schema.py +++ b/test/unittest/component/test_schema.py @@ -39,7 +39,7 @@ def test_schema_with_bytes_encoding(db): db.databackend.bytes_encoding = 'base64' - db['documents'].insert([{'txt': 'testing 123'}]).execute() + db['documents'].insert([{'txt': 'testing 123'}]) try: r = db.databackend.db['documents'].find_one() @@ -61,9 +61,9 @@ def test_schema_with_blobs(db): ) ) - db['documents'].insert([{'txt': 'testing 123'}]).execute() + db['documents'].insert([{'txt': 'testing 123'}]) - r = db['documents'].select().tolist()[0] + r = db['documents'].get() assert isinstance(r['txt'], Blob) @@ -98,10 +98,10 @@ def test_schema_with_file(db, tmp_file): schema=Schema('_schema/documents', fields={'my_file': file}), ) ) - db['documents'].insert([{'my_file': tmp_file}]).execute() + db['documents'].insert([{'my_file': tmp_file}]) # only the references are loaded when data is selected - r = db['documents'].select().tolist()[0] + r = db['documents'].get() # loaded document contains a pointer to the file assert isinstance(r['my_file'], File) diff --git a/test/unittest/component/test_template.py b/test/unittest/component/test_template.py index be5e5ad57..46402027b 100644 --- a/test/unittest/component/test_template.py +++ b/test/unittest/component/test_template.py @@ -55,8 +55,7 @@ def model(x): assert listener.model.object(3) == 5 # Check listener outputs with key and model_id - primary_id = db['documents'].primary_id - r = db['documents'].select(primary_id, 'y').outputs(listener.predict_id).execute() + r = db['documents'].outputs(listener.predict_id).execute() r = Document(list(r)[0].unpack()) assert r[listener.outputs] == r['y'] + 2 @@ -101,13 +100,7 @@ def test_template_export(db): db.apply(listener) # Check listener outputs with key and model_id - primary_id = db['documents'].primary_id - r = ( - db['documents'] - .select(primary_id, 'y') - .outputs(listener.predict_id) - .execute() - ) + r = db['documents'].outputs(listener.predict_id).execute() r = Document(list(r)[0].unpack()) assert r[listener.outputs] == r['y'] + 2 @@ -140,11 +133,15 @@ def test_from_template(db): def test_query_template(db): add_random_data(db) - q = db['documents'].find({'this': 'is a '}).limit('') + table = db['documents'] + q = table.filter(table['this'] == 'is a ').limit('') t = QueryTemplate('select_lim', template=q) assert set(t.template_variables).issuperset({'limit', 'test'}) - assert t.template['query'] == 'documents.find(documents[0]).limit("")' + assert ( + t.template['query'].split('\n')[-1] + == 'documents.filter(query[0]).limit("")' + ) def test_cross_reference(db): diff --git a/test/unittest/component/test_vector_index.py b/test/unittest/component/test_vector_index.py index 51b3dc20b..d1160105a 100644 --- a/test/unittest/component/test_vector_index.py +++ b/test/unittest/component/test_vector_index.py @@ -9,9 +9,9 @@ def test_vector_index_recovery(db): build_vector_index(db) table = db["documents"] - primary_id = table.primary_id + primary_id = table.primary_id.execute() vector_index = "vector_index" - sample_data = list(table.select().execute())[50] + sample_data = table.select().execute()[50] # Simulate restart del db.cluster.vector_search[vector_index] diff --git a/test/unittest/ext/test_vanilla.py b/test/unittest/ext/test_vanilla.py index ac529d301..0dc4ba76d 100644 --- a/test/unittest/ext/test_vanilla.py +++ b/test/unittest/ext/test_vanilla.py @@ -9,9 +9,7 @@ def data_in_db(db): db.cfg.auto_schema = True X = [1, 2, 3, 4, 5] y = [1, 2, 3, 4, 5] - db.execute( - db['documents'].insert([Document({'X': x, 'y': yy}) for x, yy in zip(X, y)]) - ) + db['documents'].insert([Document({'X': x, 'y': yy}) for x, yy in zip(X, y)]) yield db @@ -31,10 +29,10 @@ def test_function_predict_in_db(data_in_db): data_in_db.apply(function) function.predict_in_db( X='X', - select=data_in_db['documents'].select(), + select=data_in_db['documents'], predict_id='test', ) - out = list(data_in_db.execute(data_in_db['_outputs__test'].select())) + out = data_in_db['_outputs__test'].select().execute() assert [Document(o)['_outputs__test'] for o in out] == [1, 2, 3, 4, 5] @@ -51,11 +49,9 @@ def test_function_predict_with_flatten_outputs(data_in_db): predict_id='test', flatten=True, ) - out = list(data_in_db.execute(data_in_db['_outputs__test'].select())) - primary_id = data_in_db['documents'].primary_id - input_ids = [ - c[primary_id] for c in data_in_db.execute(data_in_db['documents'].select()) - ] + out = data_in_db['_outputs__test'].select().execute() + primary_id = data_in_db['documents'].primary_id.execute() + input_ids = [c[primary_id] for c in data_in_db['documents'].select().execute()] source_ids = [] for i, id in enumerate(input_ids): ix = 3 if i + 1 > 2 else 2 @@ -95,11 +91,9 @@ def test_function_predict_with_mix_flatten_outputs(data_in_db): flatten=True, ) - out = list(data_in_db.execute(data_in_db['_outputs__test'].select())) - primary_id = data_in_db['documents'].primary_id - input_ids = [ - c[primary_id] for c in data_in_db.execute(data_in_db['documents'].select()) - ] + out = data_in_db['_outputs__test'].select().execute() + primary_id = data_in_db['documents'].primary_id.execute() + input_ids = [c[primary_id] for c in data_in_db['documents'].select().execute()] source_ids = [] for i, id in enumerate(input_ids): source_ids.append(id if i + 1 < 2 else [id] * 3) diff --git a/test/unittest/jobs/test_task_workflow.py b/test/unittest/jobs/test_task_workflow.py index 137b68c50..b1cf2f562 100644 --- a/test/unittest/jobs/test_task_workflow.py +++ b/test/unittest/jobs/test_task_workflow.py @@ -23,7 +23,7 @@ def assert_output_is_correct(data, output): def test_downstream_task_workflows_are_triggered(db, data, flatten): db.cfg.auto_schema = True - db.execute(db["test"].insert([{"x": 10}])) + db["test"].insert([{"x": 10}]) upstream_model = ObjectModel( "m1", @@ -52,10 +52,10 @@ def test_downstream_task_workflows_are_triggered(db, data, flatten): db.apply(downstream_listener) - outputs1 = db[upstream_listener.outputs].select().tolist() + outputs1 = db[upstream_listener.outputs].select().execute() outputs1 = [r[upstream_listener.outputs] for r in outputs1] - outputs2 = db[downstream_listener.outputs].select().tolist() + outputs2 = db[downstream_listener.outputs].select().execute() outputs2 = [r[downstream_listener.outputs] for r in outputs2] assert len(outputs1) == 1 if not flatten else 10 @@ -64,7 +64,7 @@ def test_downstream_task_workflows_are_triggered(db, data, flatten): assert_output_is_correct(data * 10, outputs1[0]) assert_output_is_correct(data * 10 / 2, outputs2[0]) - db["test"].insert([{"x": 20}]).execute() + db["test"].insert([{"x": 20}]) # Check that the listeners are triggered when data is inserted later outputs1 = [ diff --git a/test/unittest/misc/test_auto_schema.py b/test/unittest/misc/test_auto_schema.py index 196f9cc1a..11326d839 100644 --- a/test/unittest/misc/test_auto_schema.py +++ b/test/unittest/misc/test_auto_schema.py @@ -56,10 +56,10 @@ def test_schema(db, data): db.apply(t) - db.execute(db["my_table"].insert([Document(data)])) + db["my_table"].insert([Document(data)]) - select = db["my_table"].select().limit(1) - decode_data = db.execute(select).next().unpack() + select = db["my_table"].select() + decode_data = select.get().unpack() for key in data: assert isinstance(data[key], type(decode_data[key])) assert str(data[key]) == str(decode_data[key]) diff --git a/test/utils/component/datatype.py b/test/utils/component/datatype.py index 652e75c35..bb600700c 100644 --- a/test/utils/component/datatype.py +++ b/test/utils/component/datatype.py @@ -7,7 +7,6 @@ from superduper.base.datalayer import Datalayer from superduper.base.document import Document -from superduper.base.enums import DBType from superduper.components.component import Component from superduper.components.datatype import BaseDataType, pickle_serializer from superduper.components.schema import Schema @@ -65,30 +64,22 @@ def check_data_with_schema_and_db(data, datatype: BaseDataType, db: Datalayer): table = Table("documents", schema=schema) db.apply(table) - document = Document({"x": data, "y": 1}) + document = {"x": data, "y": 1} print(document) print_sep() - db["documents"].insert([document]).execute() - if db.databackend.db_type == DBType.MONGODB: - encoded = db.databackend.conn["test_db"]["documents"].find_one() - else: - t = db.databackend.conn.table("documents") - encoded = dict(t.select(t).execute().iloc[0]) + db["documents"].insert([document]) - pprint(encoded) - print_sep() + decoded = db["documents"].select().execute()[0] - decoded = list(db["documents"].select().execute())[0] decoded = decoded.unpack() pprint(decoded) print_sep() - assert_equal(document["x"], decoded["x"]) assert_equal(document["y"], decoded["y"]) - return document, encoded, decoded + return document, decoded @dc.dataclass(kw_only=True) diff --git a/test/utils/component/model.py b/test/utils/component/model.py index 7a7bcd8a0..a68cd652b 100644 --- a/test/utils/component/model.py +++ b/test/utils/component/model.py @@ -27,7 +27,7 @@ def test_predict_in_db(model: Model, sample_data: t.Any, db: "Datalayer"): db.cfg.auto_schema = True - db["datas"].insert([{"data": sample_data, "i": i} for i in range(10)]).execute() + db["datas"].insert([{"data": sample_data, "i": i} for i in range(10)]) listener = Listener( key="data", @@ -47,7 +47,7 @@ def test_predict_in_db(model: Model, sample_data: t.Any, db: "Datalayer"): def test_model_as_a_listener(model: Model, sample_data: t.Any, db: "Datalayer"): db.cfg.auto_schema = True - db["datas"].insert([{"data": sample_data, "i": i} for i in range(10)]).execute() + db["datas"].insert([{"data": sample_data, "i": i} for i in range(10)]) model.identifier = f'test-{random_id()}' diff --git a/test/utils/database/databackend.py b/test/utils/database/databackend.py index 1de05c562..cd002baca 100644 --- a/test/utils/database/databackend.py +++ b/test/utils/database/databackend.py @@ -1,39 +1,7 @@ -from superduper import CFG from superduper.backends.base.data_backend import BaseDataBackend -from superduper.backends.base.query import Query -from superduper.components.datatype import pickle_serializer from superduper.components.schema import Schema -def test_output_dest(databackend: BaseDataBackend): - assert isinstance(databackend, BaseDataBackend) - # Create an output destination for the database - predict_id = "predict_id" - - assert not databackend.check_output_dest(predict_id) - - table = databackend.create_output_dest(predict_id, pickle_serializer) - - assert table.identifier.startswith(CFG.output_prefix) - - databackend.create_table_and_schema(table.identifier, table.schema) - - assert databackend.check_output_dest(predict_id) - - # Drop the output destination - # - databackend.drop_outputs() - - assert not databackend.check_output_dest(predict_id) - - -def test_query_builder(databackend: BaseDataBackend): - query = databackend.get_query_builder("datas") - - assert isinstance(query, Query) - assert query.table == "datas" - - def test_list_tables_or_collections(databackend: BaseDataBackend): fields = { 'a': int, @@ -45,6 +13,6 @@ def test_list_tables_or_collections(databackend: BaseDataBackend): table_name, schema=Schema(identifier="schema", fields=fields) ) - tables = databackend.list_tables_or_collections() + tables = databackend.list_tables() assert len(tables) == 10 assert [f"table_{i}" for i in range(10)] == sorted(tables) diff --git a/test/utils/database/query.py b/test/utils/database/query.py index 6cfaf64a3..1c0ad7e07 100644 --- a/test/utils/database/query.py +++ b/test/utils/database/query.py @@ -1,23 +1,167 @@ import numpy as np import pytest -from superduper.base.document import Document +from superduper.backends.base.query import parse_query + + +def test_atomic_parse(db): + q = db['docs']['x'] == 2 + r = q.dict() + + parsed = parse_query(query=r['query'], documents=r['documents'], db=db) + + assert len(parsed) == 3 + + q = db['docs']['x'] == 'a' + r = q.dict() + + parsed = parse_query(query=r['query'], documents=r['documents'], db=db) + + assert len(parsed) == 3 def test_insert(db): db.cfg.auto_schema = True # Test insert one - db["documents"].insert([{"this": "is a test"}]).execute() - result = list(db["documents"].select("this").execute())[0] + db["documents"].insert([{"this": "is a test"}]) + result = db["documents"].select("this").execute()[0] assert result["this"] == "is a test" # Test insert multiple - db["documents"].insert([{"this": "is a test"}, {"this": "is a test"}]).execute() - results = list(db["documents"].select("this").execute()) + db["documents"].insert([{"this": "is a test"}, {"this": "is a test"}]) + results = db["documents"].select("this").execute() assert len(results) == 3 +def test_outputs(db): + db.cfg.auto_schema = True + + ids = db['documents'].insert([{'x': i} for i in range(10)]) + db['_outputs__a__123456789'].insert( + [{'_outputs__a__123456789': i + 2, '_source': id} for i, id in enumerate(ids)] + ) + outputs = db['documents'].outputs('a__123456789').execute() + print(outputs) + + for r in outputs: + assert r['x'] + 2 == r['_outputs__a__123'] + + +def test_subset(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + ids = db['documents'].ids() + results = db['documents'].subset(ids[:5]) + + pid = db['documents'].primary_id.execute() + + assert set([r[pid] for r in results]) == set(ids[:5]) + + db['_outputs__a__123456789'].insert( + [{'_outputs__a__123456789': i + 2, '_source': id} for i, id in enumerate(ids)] + ) + + results = db['documents'].outputs('a__123456789').subset(ids[:5]) + + assert set([r[pid] for r in results]) == set(ids[:5]) + + assert 'x' in results[0] + + +def test_ids(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + results = db['documents'].ids() + + assert len(results) == 10 + + assert all(isinstance(x, str) for x in results) + + +def test_select_table(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + results = db['documents'].execute() + assert len(results) == 10 + + +def test_select_all_cols(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + q = db['documents'].select() + + results = q.execute() + + assert len(results) == 10 + + +def test_select_one_col(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + q = db['documents'].select('x') + + results = q.execute() + + assert set(results[0].keys()) == {'x'} + + +def test_filter(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + t = db['documents'] + + q = t.filter(t['x'] == 1) + + results = q.execute() + + assert len(results) == 1 + + assert results[0]['x'] == 1 + + +def test_filter_select(db): + db.cfg.auto_schema = True + + db['documents'].insert([{'x': i} for i in range(10)]) + + t = db['documents'] + + pid = t.primary_id.execute() + + q = t.filter(t['x'] == 2).select(pid) + + r = q.execute()[0] + + assert set(r.keys()) == {pid} + + +class ToSave: + def __init__(self, x): + self.x = x + + +def test_encode_decode_data(db): + db.cfg.auto_schema = True + db['docs'].insert([{'x': ToSave(i)} for i in range(10)]) + + results = db['docs'].execute() + + assert isinstance(results[0]['x'], ToSave) + + def test_read(db): def check_keys(data, keys): for k in keys: @@ -39,13 +183,14 @@ def check_keys(data, keys): } ) - db["documents"].insert(datas).execute() + db["documents"].insert(datas) # Test base select - results = list(db["documents"].select().execute()) + results = db["documents"].select().execute() assert len(results) == 10 - primary_id = db["documents"].primary_id + primary_id = db["documents"].primary_id.execute() + for r in results: check_keys(r, ["x", "y", "z", primary_id, "_fold", "n"]) @@ -59,7 +204,7 @@ def check_keys(data, keys): # Test filter select table = db["documents"] primary_id = table.primary_id - select = table.select("x", "y", "n").filter(table.y == 1, table.n > 5) + select = table.select("x", "y", "n").filter(table['y'] == 1, table['n'] > 5) results = list(select.execute()) assert len(results) == 3 assert [6, 8, 10] == [r["n"] for r in results] @@ -71,27 +216,17 @@ def check_keys(data, keys): select = table.select("x", "special-y", "special-n").filter( table["special-y"] == 1, table["special-n"] > 5 ) - results = list(select.execute()) + results = select.execute() assert len(results) == 3 assert [6, 8, 10] == [r["special-n"] for r in results] -# TODO:Add delete common function -def test_delete(db): - pass - - -# TODO Add update common function -def test_update(db): - pass - - def test_like(db): from test.utils.usecase.vector_search import build_vector_index build_vector_index(db) table = db["documents"] - primary_id = table.primary_id + # primary_id = table.primary_id.execute() vector_index = "vector_index" sample_data = list(table.select().execute())[50] @@ -102,13 +237,16 @@ def test_like(db): .execute() ) - scores = out.scores + scores = [r['score'] for r in out] + + primary_id = table.primary_id.execute() ids = [o[primary_id] for o in list(out)] + assert len(ids) == 10 - assert sample_data[primary_id] in ids - assert scores[str(sample_data[primary_id])] > 0.999999 + assert ids[0] == sample_data[primary_id] + assert scores[0] > 0.999999 # Pre-like out = ( @@ -118,12 +256,9 @@ def test_like(db): .execute() ) - scores = out.scores - results = list(out) - - assert len(results) == 2 + assert len(out) == 2 - assert [r["x"] for r in results] == [49, 51] + assert set(r["x"] for r in out) == {49, 51} # Post-like out = ( @@ -133,34 +268,28 @@ def test_like(db): .execute() ) - scores = out.scores - results = list(out) + scores = [r['score'] for r in out] - assert len(results) == 4 + assert len(out) == 4 - assert [r["x"] for r in results] == [47, 49, 51, 53] + assert set(r["x"] for r in out) == {47, 49, 51, 53} def test_insert_with_auto_schema(db): db.cfg.auto_schema = True import numpy as np - # Doesn't work with the Vector datatype together data = { - # "df": pd.DataFrame(np.random.randn(10, 10)), "array": np.array([1, 2, 3]), } table_or_collection = db["documents"] - datas = [Document(data)] - table_or_collection.insert(datas).execute() - # Make sure multiple insert works - table_or_collection.insert(datas).execute() + table_or_collection.insert([data]) + datas_from_db = list(table_or_collection.select().execute()) - for d, d_db in zip(datas, datas_from_db): - # assert d["df"].values.sum() == d_db["df"].values.sum() + for d, d_db in zip([data], datas_from_db): assert np.all(d["array"] == d_db["array"]) @@ -173,70 +302,16 @@ def test_insert_with_diff_schemas(db): data = { "array": np.array([1, 2, 3]), } - datas = [Document(data)] - table_or_collection.insert(datas).execute() + table_or_collection.insert([data]) - datas_from_db = list(table_or_collection.select().execute()) + datas_from_db = table_or_collection.select().execute() - assert np.all(datas[0]["array"] == datas_from_db[0]["array"]) + assert np.all(data["array"] == datas_from_db[0]["array"]) data = { "df": pd.DataFrame(np.random.randn(10, 10)), } - datas = [Document(data)] # Do not support different schema with pytest.raises(Exception): - table_or_collection.insert(datas).execute() - - -def test_auto_document_wrapping(db): - db.cfg.auto_schema = True - import numpy as np - - table_or_collection = db["my_table"] - data = {"x": np.zeros((1))} - datas = [Document(data)] - table_or_collection.insert(datas).execute() - - def _check(n): - c = list(table_or_collection.select().execute()) - assert len(c) == n - return c - - _check(1) - - # Without `Document` dict data - table_or_collection.insert([data]).execute() - _check(2) - - -def test_model(db): - from test.utils.setup.fake_data import add_models - - add_models(db) - t = np.random.rand(32) - - m = db.load("model", "linear_a") - - out = m.predict(t) - assert isinstance(out, np.ndarray) - - from superduper.backends.base.query import Model - - out = m.predict(t) - assert isinstance(out, np.ndarray) - - q = Model(table="linear_a").predict(t) - - out = db.execute(q).unpack() - assert isinstance(out, np.ndarray) - - -def test_model_query(): - from superduper.backends.base.query import Model - - q = Model(table="my-model").predict("This is a test") - - r = q.dict() - assert r["query"] == 'my-model.predict("This is a test")' + table_or_collection.insert(data) diff --git a/test/utils/setup/fake_data.py b/test/utils/setup/fake_data.py index 9dac29ee7..22ac0b40c 100644 --- a/test/utils/setup/fake_data.py +++ b/test/utils/setup/fake_data.py @@ -52,7 +52,7 @@ def add_random_data( fold = int(random.random() > 0.5) fold = "valid" if fold else "train" data.append({"id": str(i), "x": x, "y": y, "z": z, "_fold": fold}) - db[table_name].insert(data).execute() + db[table_name].insert(data) def add_datatypes(db: Datalayer): diff --git a/test/utils/usecase/chain_listener.py b/test/utils/usecase/chain_listener.py index cf8c891bd..cae6c1bce 100644 --- a/test/utils/usecase/chain_listener.py +++ b/test/utils/usecase/chain_listener.py @@ -14,7 +14,7 @@ def build_chain_listener(db: "Datalayer"): {"x": 3}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) data = list(db["documents"].select().execute(eager_mode=True))[0] @@ -51,7 +51,7 @@ def build_chain_listener(db: "Datalayer"): {"x": 6}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) assert db.databackend.check_output_dest(listener_a.predict_id) assert db.databackend.check_output_dest(listener_b.predict_id) diff --git a/test/utils/usecase/graph_listener.py b/test/utils/usecase/graph_listener.py index 5c358a20e..a3aa92054 100644 --- a/test/utils/usecase/graph_listener.py +++ b/test/utils/usecase/graph_listener.py @@ -36,9 +36,9 @@ def build_graph_listener(db: "Datalayer"): {"x": 3, "y": "4", "z": np.array([7, 8, 9])}, ] - db["documents"].insert(data).execute() + db["documents"].insert(data) - data = db['documents'].select().tolist() + data = db['documents'].select().execute() assert isinstance(data[0]['z'], np.ndarray) @@ -121,7 +121,7 @@ def func_c(x, y, z, o_a, o_b): {"x": 6, "y": "7", "z": np.array([16, 17, 18])}, ] - db["documents"].insert(new_data).execute() + db["documents"].insert(new_data) new_data = Document( list( diff --git a/test/utils/usecase/vector_search.py b/test/utils/usecase/vector_search.py index 34d899c67..4e4b63a19 100644 --- a/test/utils/usecase/vector_search.py +++ b/test/utils/usecase/vector_search.py @@ -23,7 +23,7 @@ def add_data(db: "Datalayer", start: int, end: int): "label": int(i % 2 == 0), } ) - db["documents"].insert(data).execute() + db["documents"].insert(data) def build_vector_index(