diff --git a/CHANGELOG.md b/CHANGELOG.md index a13d4cd61..015ddc171 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add a standalone flag in Streamlit to mark the page as independent. - Add secrets directory mount for loading secret env vars. - Remove components recursively +- Add metadata batched db updates #### Bug Fixes diff --git a/plugins/ibis/superduper_ibis/__init__.py b/plugins/ibis/superduper_ibis/__init__.py index 92a43d292..ea1310ffb 100644 --- a/plugins/ibis/superduper_ibis/__init__.py +++ b/plugins/ibis/superduper_ibis/__init__.py @@ -1,6 +1,6 @@ from .data_backend import IbisDataBackend as DataBackend from .query import IbisQuery -__version__ = "0.4.7" +__version__ = "0.4.8" __all__ = ["IbisQuery", "DataBackend"] diff --git a/plugins/ibis/superduper_ibis/data_backend.py b/plugins/ibis/superduper_ibis/data_backend.py index d1174291c..8b54b4174 100644 --- a/plugins/ibis/superduper_ibis/data_backend.py +++ b/plugins/ibis/superduper_ibis/data_backend.py @@ -216,11 +216,11 @@ def drop_table_or_collection(self, name: str): :param name: Table name to drop. """ try: - return self.db.databackend.conn.drop_table(name) + return self.conn.drop_table(name) except Exception as e: msg = "Object found is of type 'VIEW'" if msg in str(e): - return self.db.databackend.conn.drop_view(name) + return self.conn.drop_view(name) raise def create_output_dest( @@ -300,7 +300,7 @@ def drop(self, force: bool = False): for table in self.conn.list_tables(): logging.info(f"Dropping table: {table}") - self.conn.drop_table(table) + self.drop_table_or_collection(table) def get_table_or_collection(self, identifier): """Get a table or collection from the database. diff --git a/plugins/sqlalchemy/plugin_test/test_metadata.py b/plugins/sqlalchemy/plugin_test/test_metadata.py index 3f49dd529..8327ae766 100644 --- a/plugins/sqlalchemy/plugin_test/test_metadata.py +++ b/plugins/sqlalchemy/plugin_test/test_metadata.py @@ -11,6 +11,7 @@ @pytest.fixture def metadata(): store = SQLAlchemyMetadata(DATABASE_URL) + store._batched = False yield store store.drop(force=True) diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py b/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py index adaa6208c..4830ede8e 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ from .metadata import SQLAlchemyMetadata as MetaDataStore -__version__ = "0.4.5" +__version__ = "0.4.6" __all__ = ['MetaDataStore'] diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py index 8834205b3..f04cd8a02 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py @@ -1,5 +1,8 @@ +import copy +import json import threading import typing as t +from collections import defaultdict from contextlib import contextmanager import click @@ -49,6 +52,71 @@ def creator(): return create_engine("snowflake://not@used/db", creator=creator) +class _Cache: + def __init__(self): + self._uuid2metadata: t.Dict[str, t.Dict] = {} + self._type_id_identifier2metadata = defaultdict(dict) + + def replace_metadata( + self, metadata, uuid=None, type_id=None, version=None, identifier=None + ): + metadata = copy.deepcopy(metadata) + if 'dict' in metadata: + dict_ = metadata['dict'] + del metadata['dict'] + metadata = {**metadata, **dict_} + if uuid: + self._uuid2metadata[uuid] = metadata + + version = metadata['version'] + type_id = metadata['type_id'] + identifier = metadata['identifier'] + self._type_id_identifier2metadata[(type_id, identifier)][version] = metadata + return metadata + + def expire(self, uuid): + if uuid in self._uuid2metadata: + metadata = self._uuid2metadata[uuid] + del self._uuid2metadata[uuid] + type_id = metadata['type_id'] + identifier = metadata['identifier'] + if (type_id, identifier) in self._type_id_identifier2metadata: + del self._type_id_identifier2metadata[(type_id, identifier)] + + def expire_identifier(self, type_id, identifier): + if (type_id, identifier) in self._type_id_identifier2metadata: + del self._type_id_identifier2metadata[(type_id, identifier)] + + def add_metadata(self, metadata): + metadata = copy.deepcopy(metadata) + if 'dict' in metadata: + dict_ = metadata['dict'] + del metadata['dict'] + metadata = {**metadata, **dict_} + + self._uuid2metadata[metadata['uuid']] = metadata + + version = metadata['version'] + type_id = metadata['type_id'] + identifier = metadata['identifier'] + self._type_id_identifier2metadata[(type_id, identifier)][version] = metadata + return metadata + + def get_metadata_by_uuid(self, uuid): + return self._uuid2metadata.get(uuid) + + def get_metadata_by_identifier(self, type_id, identifier, version): + metadata = self._type_id_identifier2metadata[(type_id, identifier)] + if not metadata: + return None + if version is None: + version = max(metadata.keys()) + return metadata[version] + + def update_metadata(self, metadata): + self.add_metadata(metadata) + + class SQLAlchemyMetadata(MetaDataStore): """ Abstraction for storing meta-data separately from primary data. @@ -84,6 +152,54 @@ def __init__( self._init_tables() self._lock = threading.Lock() + self._cache = _Cache() + self._init_cache() + self._insert_flush = { + 'parent_child': [], + 'component': [], + '_artifact_relations': [], + 'job': [], + } + self._parent_relation_cache = [] + self._batched = True + + def expire(self, uuid): + """Expire metadata cache.""" + self._cache.expire(uuid) + + @property + def batched(self): + """Batched metadata updates.""" + return self._batched + + def _init_cache(self): + with self.session_context() as session: + stmt = select(self.component_table) + res = self.query_results(self.component_table, stmt, session) + for r in res: + self._cache.add_metadata(r) + + def _get_db_table(self, table): + if table == 'component': + return self.component_table + elif table == 'parent_child': + return self.parent_child_association_table + elif table == 'job': + return self.job_table + else: + return self._table_mapping[table] + + def commit(self): + """Commit execute.""" + if self._insert_flush: + for table, flush in self._insert_flush.items(): + if flush: + with self.session_context() as session: + session.execute(insert(self._get_db_table(table)), flush) + self._insert_flush[table] = [] + with self.session_context() as session: + session.commit() + self._batched = False def reconnect(self): """Reconnect to sqlalchmey metadatastore.""" @@ -176,10 +292,16 @@ def _init_tables(self): def _create_data(self, table_name, datas): table = self._table_mapping[table_name] - with self.session_context() as session: - for data in datas: - stmt = insert(table).values(**data) - session.execute(stmt) + with self.session_context(commit=not self.batched) as session: + if not self.batched: + for data in datas: + stmt = insert(table).values(**data) + session.execute(stmt) + else: + if table_name not in self._insert_flush: + self._insert_flush[table_name] = datas + else: + self._insert_flush[table_name] += datas def _delete_data(self, table_name, filter): table = self._table_mapping[table_name] @@ -237,13 +359,14 @@ def drop(self, force: bool = False): logging.warn(f'Error dropping artifact table {e}') @contextmanager - def session_context(self): + def session_context(self, commit=True): """Provide a transactional scope around a series of operations.""" sm = sessionmaker(bind=self.conn) session = sm() try: yield session - session.commit() + if commit: + session.commit() except Exception: session.rollback() raise @@ -296,15 +419,23 @@ def component_version_has_parents( res = self.query_results(self.parent_child_association_table, stmt, session) return len(res) > 0 - def create_component(self, info: t.Dict): + def create_component( + self, + info: t.Dict, + ): """Create a component in the metadata store. :param info: the information to create the component """ new_info = self._refactor_component_info(info) - with self.session_context() as session: - stmt = insert(self.component_table).values(**new_info) - session.execute(stmt) + with self.session_context(commit=not self.batched) as session: + if not self.batched: + stmt = insert(self.component_table).values(new_info) + session.execute(stmt) + else: + self._insert_flush['component'].append(copy.deepcopy(new_info)) + + self._cache.add_metadata(new_info) def delete_parent_child(self, parent_id: str, child_id: str | None = None): """ @@ -328,7 +459,11 @@ def delete_parent_child(self, parent_id: str, child_id: str | None = None): ) session.execute(stmt) - def create_parent_child(self, parent_id: str, child_id: str): + def create_parent_child( + self, + parent_id: str, + child_id: str, + ): """Create a parent-child relationship between two components. :param parent_id: the parent component @@ -337,11 +472,19 @@ def create_parent_child(self, parent_id: str, child_id: str): import sqlalchemy try: - with self.session_context() as session: - stmt = insert(self.parent_child_association_table).values( - parent_id=parent_id, child_id=child_id - ) - session.execute(stmt) + self._parent_relation_cache.append((parent_id, child_id)) + with self.session_context(commit=not self.batched) as session: + if not self.batched: + stmt = insert(self.parent_child_association_table).values( + parent_id=parent_id, child_id=child_id + ) + session.execute(stmt) + else: + if (parent_id, child_id) not in self._parent_relation_cache: + self._insert_flush['parent_child'].append( + {'parent_id': parent_id, 'child_id': child_id} + ) + except sqlalchemy.exc.IntegrityError: logging.warn(f'Skipping {parent_id} {child_id} since they already exists') @@ -369,6 +512,7 @@ def delete_component_version(self, type_id: str, identifier: str, version: int): self.component_table.c.id == cv['id'] ) session.execute(stmt_delete) + self._cache.expire_identifier(type_id, identifier) if cv: self.delete_parent_child(cv['id']) @@ -379,6 +523,8 @@ def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False): :param uuid: UUID of component :param allow_hidden: whether to load hidden components """ + if res := self._cache.get_metadata_by_uuid(uuid): + return res with self.session_context() as session: stmt = ( select(self.component_table) @@ -387,21 +533,18 @@ def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False): ) .limit(1) ) + if not allow_hidden: + stmt = stmt.where(self.component_table.c.hidden == allow_hidden) res = self.query_results(self.component_table, stmt, session) try: - r = res[0] + res = res[0] + res = self._cache.add_metadata(res) + return res except IndexError: raise NonExistentMetadataError( f'Table with uuid: {uuid} does not exist' ) - return self._get_component( - type_id=r['type_id'], - identifier=r['identifier'], - version=r['version'], - allow_hidden=allow_hidden, - ) - def _get_component( self, type_id: str, @@ -416,6 +559,8 @@ def _get_component( :param version: the version of the component :param allow_hidden: whether to allow hidden components """ + if res := self._cache.get_metadata_by_identifier(type_id, identifier, version): + return res with self.session_context() as session: stmt = select(self.component_table).where( self.component_table.c.type_id == type_id, @@ -428,9 +573,7 @@ def _get_component( res = self.query_results(self.component_table, stmt, session) if res: res = res[0] - dict_ = res['dict'] - del res['dict'] - res = {**res, **dict_} + res = self._cache.add_metadata(res) return res def get_component_version_parents(self, uuid: str): @@ -473,6 +616,8 @@ def get_latest_version( :param identifier: the identifier of the component :param allow_hidden: whether to allow hidden components """ + if res := self._cache.get_metadata_by_identifier(type_id, identifier, None): + return res['version'] with self.session_context() as session: stmt = ( select(self.component_table) @@ -484,7 +629,6 @@ def get_latest_version( .order_by(self.component_table.c.version.desc()) .limit(1) ) - res = session.execute(stmt) res = self.query_results(self.component_table, stmt, session) versions = [r['version'] for r in res] if len(versions) == 0: @@ -534,6 +678,12 @@ def _replace_object( .values(**info) ) session.execute(stmt) + self._cache.replace_metadata( + type_id=type_id, + identifier=identifier, + version=version, + metadata=info, + ) else: with self.session_context() as session: stmt = ( @@ -542,6 +692,7 @@ def _replace_object( .values(**info) ) session.execute(stmt) + self._cache.replace_metadata(uuid=uuid, metadata=info) def show_cdc_tables(self): """Show tables to be consumed with cdc.""" @@ -593,6 +744,7 @@ def _show_components(self, type_id: t.Optional[str] = None): :param type_id: the type of the component """ + # TODO: cache it. with self.session_context() as session: stmt = select(self.component_table) if type_id is not None: @@ -638,18 +790,20 @@ def _update_object( # --------------- JOBS ----------------- - def create_job(self, info: t.Dict): + def create_job(self, info: t.Union[t.Dict, t.List[t.Dict]]): """Create a job with the given info. :param info: The information used to create the job """ if 'dependencies' in info: - import json - info['dependencies'] = json.dumps(info['dependencies']) - with self.session_context() as session: - stmt = insert(self.job_table).values(**info) - session.execute(stmt) + + with self.session_context(commit=not self.batched) as session: + if not self.batched: + stmt = insert(self.job_table).values(**info) + session.execute(stmt) + else: + self._insert_flush['job'].append(info) def get_job(self, job_id: str): """Get the job with the given job_id. diff --git a/superduper/backends/base/cluster.py b/superduper/backends/base/cluster.py index 1e11c2777..339566803 100644 --- a/superduper/backends/base/cluster.py +++ b/superduper/backends/base/cluster.py @@ -1,4 +1,5 @@ import dataclasses as dc +import time from abc import ABC, abstractmethod from superduper.backends.base.cache import Cache @@ -82,6 +83,9 @@ def initialize(self, with_compute: bool = False): :param with_compute: Boolean to init compute. """ + from superduper import logging + + start = time.time() assert self.db if with_compute: self.compute.initialize() @@ -91,3 +95,5 @@ def initialize(self, with_compute: bool = False): self.vector_search.initialize() self.crontab.initialize() self.cdc.initialize() + + logging.info(f"Cluster initialized in {time.time() - start:.2f} seconds.") diff --git a/superduper/backends/base/data_backend.py b/superduper/backends/base/data_backend.py index 82605e3d2..80bb10f96 100644 --- a/superduper/backends/base/data_backend.py +++ b/superduper/backends/base/data_backend.py @@ -211,10 +211,11 @@ def _try_execute(self, attr): @functools.wraps(attr) def wrapper(*args, **kwargs): try: - return attr(*args, **kwargs) + result = attr(*args, **kwargs) + return result except Exception as e: error_message = str(e).lower() - if 'expire' in error_message and 'token' in error_message: + if "expire" in error_message and "token" in error_message: logging.warn( f"Authentication expiry detected: {e}. " "Attempting to reconnect..." diff --git a/superduper/backends/base/metadata.py b/superduper/backends/base/metadata.py index 3c6915533..635a6db85 100644 --- a/superduper/backends/base/metadata.py +++ b/superduper/backends/base/metadata.py @@ -34,6 +34,17 @@ def __init__( self.uri = uri self.flavour = flavour + @property + def batched(self): + """Batched metadata updates.""" + return False + + def expire(self, uuid: str): + """Expire metadata batch cache if any. + + :param uuid: uuid to expire. + """ + @abstractmethod def delete_parent_child(self, parent: str, child: str): """ @@ -425,6 +436,7 @@ def _replace_object( type_id: str | None = None, version: int | None = None, uuid: str | None = None, + batch: bool = False, ): pass diff --git a/superduper/backends/base/queue.py b/superduper/backends/base/queue.py index ec6d2ce8f..518988f9d 100644 --- a/superduper/backends/base/queue.py +++ b/superduper/backends/base/queue.py @@ -15,6 +15,15 @@ if t.TYPE_CHECKING: from superduper.base.datalayer import Datalayer +BATCH_SIZE = 100 + + +def _chunked_list(lst, batch_size=BATCH_SIZE): + if len(lst) <= batch_size: + return [lst] + + return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)] + class BaseQueueConsumer(ABC): """ @@ -189,11 +198,21 @@ def _consume_event_type(event_type, ids, table, db: 'Datalayer'): jobs += sub_jobs logging.info(f'Streaming with {component.type_id}:{component.identifier}') - for job in jobs: - job.execute(db) + if db.metadata.batched: + for chunk in _chunked_list(jobs): + for job in chunk: + job.execute(db) + db.metadata.commit() + else: + for job in jobs: + job.execute(db) + db.cluster.compute.release_futures(context) +table_type_ids = {'table', 'schema', 'data', 'datatype'} + + def consume_events(events, table: str, db=None): """ Consume events from table queue. @@ -205,5 +224,29 @@ def consume_events(events, table: str, db=None): if table != '_apply': consume_streaming_events(events=events, table=table, db=db) else: - for event in events: + if not db.metadata.batched: + for event in events: + event.execute(db) + return + + # table events + table_events = [] + non_table_events = [] + + for ix, event in enumerate(events): + if event.genus == 'create' and event.component['type_id'] in table_type_ids: + table_events.append(event) + else: + non_table_events.append(event) + + for event in table_events: event.execute(db) + db.metadata.commit() + + # non table events + for event in non_table_events: + event.execute(db) + if event.genus == 'update': + db.metadata.commit() + + db.metadata.commit() diff --git a/superduper/backends/local/queue.py b/superduper/backends/local/queue.py index db0e02e9f..fdf13f315 100644 --- a/superduper/backends/local/queue.py +++ b/superduper/backends/local/queue.py @@ -5,7 +5,7 @@ from superduper.backends.base.queue import ( BaseQueueConsumer, BaseQueuePublisher, - consume_streaming_events, + consume_events, ) from superduper.base.event import Event from superduper.components.cdc import CDC @@ -101,13 +101,8 @@ def consume(self, db: 'Datalayer', queue: t.Dict[str, t.List[Event]]): """Consume the current queue and run jobs.""" keys = list(queue.keys())[:] for k in keys: - if k != '_apply': - consume_streaming_events(events=queue[k], table=k, db=db) - queue[k] = [] - else: - while queue['_apply']: - event = queue['_apply'].pop(0) - event.execute(db) + consume_events(events=queue[k], table=k, db=db) + queue[k] = [] logging.info('Consumed all events') diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index 516f27593..d9029d436 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -76,6 +76,7 @@ def __init__( self._cfg = s.CFG self.startup_cache: t.Dict[str, t.Any] = {} + logging.info("Data Layer built") def __getitem__(self, item): return self.databackend.get_query_builder(item) @@ -452,7 +453,8 @@ def apply( :param wait: Wait for apply events. :return: Tuple containing the added object(s) and the original object(s). """ - return apply.apply(db=self, object=object, force=force, wait=wait) + result = apply.apply(db=self, object=object, force=force, wait=wait) + return result def remove( self, @@ -706,6 +708,8 @@ def _replace_fn(component): serialized = serialized.encode(keep_schema=False) self._delete_artifacts(object.uuid, info) + artifact_ids, _ = self._find_artifacts(info) + self.metadata.create_artifact_relation(object.uuid, artifact_ids) serialized = self._save_artifact(object.uuid, serialized) self.metadata.replace_object( @@ -722,10 +726,12 @@ def _replace_fn(component): def expire(self, uuid): """Expire a component from the cache.""" self.cluster.cache.expire(uuid) + self.metadata.expire(uuid) parents = self.metadata.get_component_version_parents(uuid) while parents: for uuid in parents: self.cluster.cache.expire(uuid) + self.metadata.expire(uuid) parents = sum( [self.metadata.get_component_version_parents(uuid) for uuid in parents], [], @@ -737,8 +743,6 @@ def _save_artifact(self, uuid, info: t.Dict): :param artifact: The artifact to save. """ - artifact_ids, _ = self._find_artifacts(info) - self.metadata.create_artifact_relation(uuid, artifact_ids) return self.artifact_store.save_artifact(info) def _delete_artifacts(self, uuid, info: t.Dict): diff --git a/superduper/base/event.py b/superduper/base/event.py index 3e484344f..08424957e 100644 --- a/superduper/base/event.py +++ b/superduper/base/event.py @@ -35,7 +35,10 @@ def create(cls, kwargs): return cls(**kwargs) @abstractmethod - def execute(self, db: 'Datalayer'): + def execute( + self, + db: 'Datalayer', + ): """Execute the event. :param db: Datalayer instance @@ -57,10 +60,14 @@ class Signal(Event): msg: str context: str - def execute(self, db: 'Datalayer'): + def execute( + self, + db: 'Datalayer', + ): """Execute the signal. :param db: Datalayer instance + """ if self.msg.lower() == 'done': db.cluster.compute.release_futures(self.context) @@ -91,7 +98,10 @@ def create(cls, kwargs): kwargs.pop('genus') return cls(**kwargs) - def execute(self, db: 'Datalayer'): + def execute( + self, + db: 'Datalayer', + ): """Execute the change event. :param db: Datalayer instance @@ -119,9 +129,12 @@ class Create(Event): def execute(self, db: 'Datalayer'): """Execute the create event.""" # TODO decide where to assign version + artifact_ids, _ = db._find_artifacts(self.component) + db.metadata.create_artifact_relation(self.component['uuid'], artifact_ids) + db.metadata.create_component(self.component) - # TODO - do we really need to load the whole component? component = db.load(uuid=self.component['uuid']) + if self.parent: db.metadata.create_parent_child(self.parent, component.uuid) @@ -129,7 +142,10 @@ def execute(self, db: 'Datalayer'): for dep in component.dependencies: if isinstance(dep, (tuple, list)): dep = dep[-1] - db.metadata.create_parent_child(component.uuid, dep) + db.metadata.create_parent_child( + component.uuid, + dep, + ) component.on_create(db=db) @property @@ -159,9 +175,14 @@ class Update(Event): component: t.Dict parent: str | None = None - def execute(self, db: 'Datalayer'): + def execute( + self, + db: 'Datalayer', + ): """Execute the create event.""" # TODO decide where to assign version + artifact_ids, _ = db._find_artifacts(self.component) + db.metadata.create_artifact_relation(self.component['uuid'], artifact_ids) db.metadata.replace_object(self.component, uuid=self.component['uuid']) db.expire(self.component['uuid']) @@ -238,14 +259,16 @@ def get_args_kwargs(self, futures): kwargs['dependencies'] = dependencies return args, kwargs - def execute(self, db: 'Datalayer'): + def execute( + self, + db: 'Datalayer', + ): """Execute the job event. :param db: Datalayer instance """ - db.metadata.create_job( - {k: v for k, v in self.dict().items() if k not in {'genus', 'queue'}} - ) + meta = {k: v for k, v in self.dict().items() if k not in {'genus', 'queue'}} + db.metadata.create_job(meta) return db.cluster.compute.submit(self) diff --git a/superduper/base/superduper.py b/superduper/base/superduper.py index e2a3d774d..4c0152653 100644 --- a/superduper/base/superduper.py +++ b/superduper/base/superduper.py @@ -15,7 +15,7 @@ def superduper(item: str | None = None, **kwargs) -> t.Any: from superduper.base.build import build_datalayer if item is None: - return build_datalayer() + return build_datalayer(**kwargs) if item.startswith('mongomock://'): kwargs['data_backend'] = item diff --git a/test/rest/test_rest.py b/test/rest/test_rest.py index fc1d0c1a8..7ef7be3e9 100644 --- a/test/rest/test_rest.py +++ b/test/rest/test_rest.py @@ -59,6 +59,8 @@ def test_apply(setup): }, }, '_base': '?my_function', + 'build_template': {}, + 'identifier': 'my_function', } _ = setup.post( diff --git a/test/unittest/base/test_apply.py b/test/unittest/base/test_apply.py index 30d3d8bfb..9ec0a0e56 100644 --- a/test/unittest/base/test_apply.py +++ b/test/unittest/base/test_apply.py @@ -191,7 +191,6 @@ def test_job_on_update(db: Datalayer): assert db.show('my', 'test') == [0] c = MyComponent('test', a='value', b=2, sub=MyValidator('valid', target=2)) - db.apply(c) reload = db.load('my', 'test') diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index 9d934402e..bdb5a8b95 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -272,10 +272,12 @@ def test_remove_component_with_clean_up(db): def test_remove_component_from_data_layer_dict(db): # Test component is deleted from datalayer test_datatype = DataType(identifier='test_datatype') + db.metadata._batched = False db.apply(test_datatype) db._remove_component_version('datatype', 'test_datatype', 0, force=True) with pytest.raises(FileNotFoundError): db.load('datatype', 'test_datatype') + db.metadata._batched = True def test_remove_component_with_artifact(db):