diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py index 7e1230530..9bad0c956 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py @@ -2,6 +2,7 @@ import threading import typing as t from contextlib import contextmanager +from collections import defaultdict import click from sqlalchemy import ( @@ -51,6 +52,35 @@ def creator(): return create_engine("snowflake://not@used/db", creator=creator) +class Cache: + def __init__(self): + self._uuid2metadata: t.Dict[str, t.Dict]= {} + self._type_id_identifier2metadata = defaultdict(dict) + + + def add_metadata(self, metadata): + self._uuid2metadata[metadata['uuid']] = metadata + version = metadata['version'] + type_id = metadata['type_id'] + identifier = metadata['identifier'] + self._type_id_identifier2metadata[(type_id, identifier)][version] = metadata + + + def get_metadata_by_uuid(self, uuid): + return self._uuid2metadata.get(uuid) + + def get_metadata_by_identifier(self, type_id, identifier, version): + metadata = self._type_id_identifier2metadata[(type_id, identifier)] + if not metadata: + return None + version = version or max(metadata.keys()) + return metadata[version] + + def update_metadata(self, metadata): + self.add_metadata(metadata) + + + class SQLAlchemyMetadata(MetaDataStore): """ Abstraction for storing meta-data separately from primary data. @@ -87,6 +117,18 @@ def __init__( self._lock = threading.Lock() self._connect() + self._cache = Cache() + self._init_cache() + + + def _init_cache(self): + with self.session_context() as session: + stmt = ( + select(self.component_table) + ) + res = self.query_results(self.component_table, stmt, session) + for r in res: + self._cache.add_metadata(r) def _connect(self): @@ -251,10 +293,8 @@ def session_context(self): """Provide a transactional scope around a series of operations.""" if not self.session.is_active: self._connect() - try: yield self.session - self.session.commit() except Exception: self.session.rollback() raise @@ -315,6 +355,7 @@ def create_component(self, info: t.Dict): with self.session_context() as session: stmt = insert(self.component_table).values(**new_info) session.execute(stmt) + # self._cache.add_metadata(new_info) def delete_parent_child(self, parent_id: str, child_id: str | None = None): """ @@ -389,6 +430,8 @@ def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False): :param uuid: UUID of component :param allow_hidden: whether to load hidden components """ + if res:= self._cache.get_metadata_by_uuid(uuid): + return res with self.session_context() as session: stmt = ( select(self.component_table) @@ -405,6 +448,7 @@ def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False): dict_ = res['dict'] del res['dict'] res = {**res, **dict_} + self._cache.add_metadata(res) return res except IndexError: raise NonExistentMetadataError( @@ -425,6 +469,8 @@ def _get_component( :param version: the version of the component :param allow_hidden: whether to allow hidden components """ + if res:= self._cache.get_metadata_by_identifier(type_id, identifier, version): + return res with self.session_context() as session: stmt = select(self.component_table).where( self.component_table.c.type_id == type_id, @@ -440,6 +486,7 @@ def _get_component( dict_ = res['dict'] del res['dict'] res = {**res, **dict_} + self._cache.add_metadata(res) return res def get_component_version_parents(self, uuid: str): @@ -482,6 +529,8 @@ def get_latest_version( :param identifier: the identifier of the component :param allow_hidden: whether to allow hidden components """ + if res:= self._cache.get_metadata_by_identifier(type_id, identifier, None): + return res['version'] with self.session_context() as session: stmt = ( select(self.component_table) @@ -493,7 +542,6 @@ def get_latest_version( .order_by(self.component_table.c.version.desc()) .limit(1) ) - res = session.execute(stmt) res = self.query_results(self.component_table, stmt, session) versions = [r['version'] for r in res] if len(versions) == 0: diff --git a/superduper/backends/base/data_backend.py b/superduper/backends/base/data_backend.py index 82605e3d2..8a622ed13 100644 --- a/superduper/backends/base/data_backend.py +++ b/superduper/backends/base/data_backend.py @@ -211,10 +211,17 @@ def _try_execute(self, attr): @functools.wraps(attr) def wrapper(*args, **kwargs): try: - return attr(*args, **kwargs) + logging.warn(f"Executing {attr.__name__} with args: {args}, kwargs: {kwargs}") + import time + start = time.time() + result = attr(*args, **kwargs) + end = time.time() + global_time[attr.__name__] += end - start + global_count[attr.__name__] += 1 + return result except Exception as e: error_message = str(e).lower() - if 'expire' in error_message and 'token' in error_message: + if "expire" in error_message and "token" in error_message: logging.warn( f"Authentication expiry detected: {e}. " "Attempting to reconnect..." @@ -232,3 +239,9 @@ def __getattr__(self, name): if callable(attr): return self._try_execute(attr) return attr + + +from collections import Counter + +global_time = Counter() +global_count = Counter() diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index 64df47337..5466d9f06 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -454,9 +454,16 @@ def apply( :param wait: Wait for apply events. :return: Tuple containing the added object(s) and the original object(s). """ + from superduper.backends.base import data_backend + from collections import Counter start = time.time() result = apply.apply(db=self, object=object, force=force, wait=wait) logging.info(f'Apply took {time.time() - start} seconds') + logging.info(data_backend.global_time) + logging.info(data_backend.global_count) + data_backend.global_time = Counter() + data_backend.global_count = Counter() + return result def remove( diff --git a/superduper/base/superduper.py b/superduper/base/superduper.py index e2a3d774d..4c0152653 100644 --- a/superduper/base/superduper.py +++ b/superduper/base/superduper.py @@ -15,7 +15,7 @@ def superduper(item: str | None = None, **kwargs) -> t.Any: from superduper.base.build import build_datalayer if item is None: - return build_datalayer() + return build_datalayer(**kwargs) if item.startswith('mongomock://'): kwargs['data_backend'] = item