From fae88a5e5f544a0e4578ccd99de7356371363515 Mon Sep 17 00:00:00 2001 From: Duncan Blythe Date: Tue, 31 Dec 2024 15:35:10 +0100 Subject: [PATCH] Remove components recursively --- CHANGELOG.md | 1 + superduper/backends/local/artifacts.py | 1 + superduper/base/datalayer.py | 74 ++++++++++++++--------- superduper/components/component.py | 21 +++++++ superduper/rest/build.py | 2 +- test/unittest/component/test_component.py | 18 ++++++ 6 files changed, 87 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d634fefd..b7b041d9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add Data Component for storing data directly in the template - Add a standalone flag in Streamlit to mark the page as independent. - Add secrets directory mount for loading secret env vars. +- Remove components recursively #### Bug Fixes diff --git a/superduper/backends/local/artifacts.py b/superduper/backends/local/artifacts.py index 6288ec31b..6363ca98c 100644 --- a/superduper/backends/local/artifacts.py +++ b/superduper/backends/local/artifacts.py @@ -110,6 +110,7 @@ def put_file(self, file_path: str, file_id: str): path = Path(file_path) name = path.name file_id_folder = os.path.join(self.conn, file_id) + os.makedirs(file_id_folder, exist_ok=True) os.chmod(file_id_folder, 0o777) save_path = os.path.join(file_id_folder, name) diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index b6a1eca08..c2aba72dc 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -1,6 +1,5 @@ import random import typing as t -import warnings from collections import namedtuple import click @@ -460,6 +459,7 @@ def remove( type_id: str, identifier: str, version: t.Optional[int] = None, + recursive: bool = False, force: bool = False, ): """ @@ -476,7 +476,7 @@ def remove( # TODO: versions = [version] if version is not None else ... if version is not None: return self._remove_component_version( - type_id, identifier, version=version, force=force + type_id, identifier, version=version, force=force, recursive=recursive ) versions = self.metadata.show_component_versions(type_id, identifier) @@ -497,20 +497,15 @@ def remove( raise exceptions.ComponentInUseError( f'Component versions: {component_versions_in_use} are in use' ) - else: - warnings.warn( - exceptions.ComponentInUseWarning( - f'Component versions: {component_versions_in_use}' - ', marking as hidden' - ) - ) if force or click.confirm( f'You are about to delete {type_id}/{identifier}, are you sure?', default=False, ): for v in sorted(list(set(versions) - set(versions_in_use))): - self._remove_component_version(type_id, identifier, v, force=True) + self._remove_component_version( + type_id, identifier, v, recursive=recursive, force=True + ) for v in sorted(versions_in_use): self.metadata.hide_component_version(type_id, identifier, v) @@ -617,33 +612,54 @@ def _remove_component_version( identifier: str, version: int, force: bool = False, + recursive: bool = False, ): r = self.metadata.get_component(type_id, identifier, version=version) if self.metadata.component_version_has_parents(type_id, identifier, version): parents = self.metadata.get_component_version_parents(r['uuid']) raise Exception(f'{r["uuid"]} is involved in other components: {parents}') - if force or click.confirm( - f'You are about to delete {type_id}/{identifier}{version}, are you sure?', - default=False, - ): - # TODO - make this less I/O intensive - component = self.load( - type_id, - identifier, - version=version, + if not ( + force + or click.confirm( + f'You are about to delete {type_id}/{identifier}{version}, ' + 'are you sure?', + default=False, ) - info = self.metadata.get_component( - type_id, identifier, version=version, allow_hidden=force - ) - component.cleanup(self) - try: - del self.cluster.cache[component.uuid] - except KeyError: - pass + ): + return + + component = self.load( + type_id, + identifier, + version=version, + ) + info = self.metadata.get_component( + type_id, identifier, version=version, allow_hidden=force + ) + component.cleanup(self) + try: + del self.cluster.cache[component.uuid] + except KeyError: + pass + + self._delete_artifacts(r['uuid'], info) + self.metadata.delete_component_version(type_id, identifier, version=version) - self._delete_artifacts(r['uuid'], info) - self.metadata.delete_component_version(type_id, identifier, version=version) + if not recursive: + return + + children = component.get_children(deep=True) + children = component.sort_components(children)[::-1] + for c in children: + assert isinstance(c.version, int) + self._remove_component_version( + c.type_id, + c.identifier, + version=c.version, + recursive=False, + force=force, + ) def replace(self, object: t.Any): """ diff --git a/superduper/components/component.py b/superduper/components/component.py index dd5259ec4..0fb57a20c 100644 --- a/superduper/components/component.py +++ b/superduper/components/component.py @@ -12,6 +12,7 @@ from enum import Enum from functools import wraps +import networkx import yaml from superduper import logging @@ -187,9 +188,29 @@ class Component(Leaf, metaclass=ComponentMeta): build_variables: t.Dict | None = None build_template: str | None = None + # TODO what's this? def refresh(self): pass + @staticmethod + def sort_components(components): + """Sort components based on topological order. + + :param components: List of components. + """ + logging.info('Resorting components based on topological order.') + G = networkx.DiGraph() + lookup = {c.huuid: c for c in components} + for k in lookup: + G.add_node(k) + for d in lookup[k].get_children_refs(): # dependencies: + if d in lookup: + G.add_edge(d, k) + + nodes = list(networkx.topological_sort(G)) + logging.info(f'New order of components: {nodes}') + return [lookup[n] for n in nodes] + @property def huuid(self): """Return a human-readable uuid.""" diff --git a/superduper/rest/build.py b/superduper/rest/build.py index 18f498991..dbc731945 100644 --- a/superduper/rest/build.py +++ b/superduper/rest/build.py @@ -237,7 +237,7 @@ def db_show( def db_remove( type_id: str, identifier: str, db: 'Datalayer' = DatalayerDependency() ): - db.remove(type_id=type_id, identifier=identifier, force=True) + db.remove(type_id=type_id, identifier=identifier, recursive=True, force=True) return {'status': 'ok'} @app.add('/db/show_template', method='get') diff --git a/test/unittest/component/test_component.py b/test/unittest/component/test_component.py index 1a363d29e..27abe0468 100644 --- a/test/unittest/component/test_component.py +++ b/test/unittest/component/test_component.py @@ -196,3 +196,21 @@ def test_set_db_deep(db): assert m.upstream[0].db is not None assert m.model.db is not None + + +class NewComponent(Component): + ... + + +def test_remove_recursive(db): + c1 = NewComponent(identifier='c1') + c2 = NewComponent(identifier='c2', upstream=[c1]) + c3 = NewComponent(identifier='c3', upstream=[c2, c1]) + + db.apply(c3) + + assert sorted([r['identifier'] for r in db.show()]) == ['c1', 'c2', 'c3'] + + db.remove('component', c3.identifier, recursive=True, force=True) + + assert not db.show()