Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Jan 3, 2025
1 parent 834a98b commit bf916cb
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
54 changes: 51 additions & 3 deletions plugins/sqlalchemy/superduper_sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import typing as t
from contextlib import contextmanager
from collections import defaultdict

import click
from sqlalchemy import (
Expand Down Expand Up @@ -51,6 +52,35 @@ def creator():
return create_engine("snowflake://not@used/db", creator=creator)


class Cache:
def __init__(self):
self._uuid2metadata: t.Dict[str, t.Dict]= {}
self._type_id_identifier2metadata = defaultdict(dict)


def add_metadata(self, metadata):
self._uuid2metadata[metadata['uuid']] = metadata
version = metadata['version']
type_id = metadata['type_id']
identifier = metadata['identifier']
self._type_id_identifier2metadata[(type_id, identifier)][version] = metadata


def get_metadata_by_uuid(self, uuid):
return self._uuid2metadata.get(uuid)

def get_metadata_by_identifier(self, type_id, identifier, version):
metadata = self._type_id_identifier2metadata[(type_id, identifier)]
if not metadata:
return None
version = version or max(metadata.keys())
return metadata[version]

def update_metadata(self, metadata):
self.add_metadata(metadata)



class SQLAlchemyMetadata(MetaDataStore):
"""
Abstraction for storing meta-data separately from primary data.
Expand Down Expand Up @@ -87,6 +117,18 @@ def __init__(

self._lock = threading.Lock()
self._connect()
self._cache = Cache()
self._init_cache()


def _init_cache(self):
with self.session_context() as session:
stmt = (
select(self.component_table)
)
res = self.query_results(self.component_table, stmt, session)
for r in res:
self._cache.add_metadata(r)

def _connect(self):

Expand Down Expand Up @@ -251,10 +293,8 @@ def session_context(self):
"""Provide a transactional scope around a series of operations."""
if not self.session.is_active:
self._connect()

try:
yield self.session
self.session.commit()
except Exception:
self.session.rollback()
raise
Expand Down Expand Up @@ -315,6 +355,7 @@ def create_component(self, info: t.Dict):
with self.session_context() as session:
stmt = insert(self.component_table).values(**new_info)
session.execute(stmt)
# self._cache.add_metadata(new_info)

def delete_parent_child(self, parent_id: str, child_id: str | None = None):
"""
Expand Down Expand Up @@ -389,6 +430,8 @@ def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False):
:param uuid: UUID of component
:param allow_hidden: whether to load hidden components
"""
if res:= self._cache.get_metadata_by_uuid(uuid):
return res
with self.session_context() as session:
stmt = (
select(self.component_table)
Expand All @@ -405,6 +448,7 @@ def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False):
dict_ = res['dict']
del res['dict']
res = {**res, **dict_}
self._cache.add_metadata(res)
return res
except IndexError:
raise NonExistentMetadataError(
Expand All @@ -425,6 +469,8 @@ def _get_component(
:param version: the version of the component
:param allow_hidden: whether to allow hidden components
"""
if res:= self._cache.get_metadata_by_identifier(type_id, identifier, version):
return res
with self.session_context() as session:
stmt = select(self.component_table).where(
self.component_table.c.type_id == type_id,
Expand All @@ -440,6 +486,7 @@ def _get_component(
dict_ = res['dict']
del res['dict']
res = {**res, **dict_}
self._cache.add_metadata(res)
return res

def get_component_version_parents(self, uuid: str):
Expand Down Expand Up @@ -482,6 +529,8 @@ def get_latest_version(
:param identifier: the identifier of the component
:param allow_hidden: whether to allow hidden components
"""
if res:= self._cache.get_metadata_by_identifier(type_id, identifier, None):
return res['version']
with self.session_context() as session:
stmt = (
select(self.component_table)
Expand All @@ -493,7 +542,6 @@ def get_latest_version(
.order_by(self.component_table.c.version.desc())
.limit(1)
)
res = session.execute(stmt)
res = self.query_results(self.component_table, stmt, session)
versions = [r['version'] for r in res]
if len(versions) == 0:
Expand Down
17 changes: 15 additions & 2 deletions superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,17 @@ def _try_execute(self, attr):
@functools.wraps(attr)
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
logging.warn(f"Executing {attr.__name__} with args: {args}, kwargs: {kwargs}")
import time
start = time.time()
result = attr(*args, **kwargs)
end = time.time()
global_time[attr.__name__] += end - start
global_count[attr.__name__] += 1
return result
except Exception as e:
error_message = str(e).lower()
if 'expire' in error_message and 'token' in error_message:
if "expire" in error_message and "token" in error_message:
logging.warn(
f"Authentication expiry detected: {e}. "
"Attempting to reconnect..."
Expand All @@ -232,3 +239,9 @@ def __getattr__(self, name):
if callable(attr):
return self._try_execute(attr)
return attr


from collections import Counter

global_time = Counter()
global_count = Counter()
7 changes: 7 additions & 0 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,16 @@ def apply(
:param wait: Wait for apply events.
:return: Tuple containing the added object(s) and the original object(s).
"""
from superduper.backends.base import data_backend
from collections import Counter
start = time.time()
result = apply.apply(db=self, object=object, force=force, wait=wait)
logging.info(f'Apply took {time.time() - start} seconds')
logging.info(data_backend.global_time)
logging.info(data_backend.global_count)
data_backend.global_time = Counter()
data_backend.global_count = Counter()

return result

def remove(
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/superduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def superduper(item: str | None = None, **kwargs) -> t.Any:
from superduper.base.build import build_datalayer

if item is None:
return build_datalayer()
return build_datalayer(**kwargs)

if item.startswith('mongomock://'):
kwargs['data_backend'] = item
Expand Down

0 comments on commit bf916cb

Please sign in to comment.