From b0b84e424200c9847e9db5aa12ca56a9cf6aa277 Mon Sep 17 00:00:00 2001 From: TheDude Date: Sun, 5 Jan 2025 00:28:34 +0530 Subject: [PATCH] Fix version issue in metadata cache --- plugins/ibis/superduper_ibis/data_backend.py | 1 - .../sqlalchemy/plugin_test/test_metadata.py | 1 + .../superduper_sqlalchemy/metadata.py | 30 +++++++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/plugins/ibis/superduper_ibis/data_backend.py b/plugins/ibis/superduper_ibis/data_backend.py index 6b4474648..8b54b4174 100644 --- a/plugins/ibis/superduper_ibis/data_backend.py +++ b/plugins/ibis/superduper_ibis/data_backend.py @@ -104,7 +104,6 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None): 'vector': 'superduper.components.datatype.NativeVector' } - def _setup(self, conn): self.dialect = getattr(conn, "name", "base") self.db_helper = get_db_helper(self.dialect) diff --git a/plugins/sqlalchemy/plugin_test/test_metadata.py b/plugins/sqlalchemy/plugin_test/test_metadata.py index 3f49dd529..8327ae766 100644 --- a/plugins/sqlalchemy/plugin_test/test_metadata.py +++ b/plugins/sqlalchemy/plugin_test/test_metadata.py @@ -11,6 +11,7 @@ @pytest.fixture def metadata(): store = SQLAlchemyMetadata(DATABASE_URL) + store._batched = False yield store store.drop(force=True) diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py index d9c27fb96..dcdca9914 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py @@ -57,7 +57,9 @@ def __init__(self): self._uuid2metadata: t.Dict[str, t.Dict] = {} self._type_id_identifier2metadata = defaultdict(dict) - def replace_metadata(self, metadata, uuid=None, type_id=None, version=None, identifier=None): + def replace_metadata( + self, metadata, uuid=None, type_id=None, version=None, identifier=None + ): metadata = copy.deepcopy(metadata) if 'dict' in metadata: dict_ = metadata['dict'] @@ -107,7 +109,8 @@ def get_metadata_by_identifier(self, type_id, identifier, version): metadata = self._type_id_identifier2metadata[(type_id, identifier)] if not metadata: return None - version = version or max(metadata.keys()) + if version is None: + version = max(metadata.keys()) return metadata[version] def update_metadata(self, metadata): @@ -155,7 +158,7 @@ def __init__( 'parent_child': [], 'component': [], '_artifact_relations': [], - 'job': [] + 'job': [], } self._parent_relation_cache = [] self._batched = True @@ -206,7 +209,6 @@ def reconnect(self): # a reconnect. self._init_tables() - def _init_tables(self): # Get the DB config for the given dialect DBConfig = get_db_config(self.dialect) @@ -416,7 +418,10 @@ def component_version_has_parents( res = self.query_results(self.parent_child_association_table, stmt, session) return len(res) > 0 - def create_component(self, info: t.Dict, ): + def create_component( + self, + info: t.Dict, + ): """Create a component in the metadata store. :param info: the information to create the component @@ -453,7 +458,11 @@ def delete_parent_child(self, parent_id: str, child_id: str | None = None): ) session.execute(stmt) - def create_parent_child(self, parent_id: str, child_id: str, ): + def create_parent_child( + self, + parent_id: str, + child_id: str, + ): """Create a parent-child relationship between two components. :param parent_id: the parent component @@ -668,7 +677,12 @@ def _replace_object( .values(**info) ) session.execute(stmt) - self._cache.replace_metadata(type_id=type_id, identifier=identifier, version=version, metadata=info) + self._cache.replace_metadata( + type_id=type_id, + identifier=identifier, + version=version, + metadata=info, + ) else: with self.session_context() as session: stmt = ( @@ -791,8 +805,6 @@ def create_job(self, info: t.Union[t.Dict, t.List[t.Dict]]): else: self._insert_flush['job'].append(info) - - def get_job(self, job_id: str): """Get the job with the given job_id.