Skip to content

Commit

Permalink
Remove components recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Dec 31, 2024
1 parent 177a25a commit fae88a5
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add Data Component for storing data directly in the template
- Add a standalone flag in Streamlit to mark the page as independent.
- Add secrets directory mount for loading secret env vars.
- Remove components recursively

#### Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions superduper/backends/local/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def put_file(self, file_path: str, file_id: str):
path = Path(file_path)
name = path.name
file_id_folder = os.path.join(self.conn, file_id)

os.makedirs(file_id_folder, exist_ok=True)
os.chmod(file_id_folder, 0o777)
save_path = os.path.join(file_id_folder, name)
Expand Down
74 changes: 45 additions & 29 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
import typing as t
import warnings
from collections import namedtuple

import click
Expand Down Expand Up @@ -460,6 +459,7 @@ def remove(
type_id: str,
identifier: str,
version: t.Optional[int] = None,
recursive: bool = False,
force: bool = False,
):
"""
Expand All @@ -476,7 +476,7 @@ def remove(
# TODO: versions = [version] if version is not None else ...
if version is not None:
return self._remove_component_version(
type_id, identifier, version=version, force=force
type_id, identifier, version=version, force=force, recursive=recursive
)

versions = self.metadata.show_component_versions(type_id, identifier)
Expand All @@ -497,20 +497,15 @@ def remove(
raise exceptions.ComponentInUseError(
f'Component versions: {component_versions_in_use} are in use'
)
else:
warnings.warn(
exceptions.ComponentInUseWarning(
f'Component versions: {component_versions_in_use}'
', marking as hidden'
)
)

if force or click.confirm(
f'You are about to delete {type_id}/{identifier}, are you sure?',
default=False,
):
for v in sorted(list(set(versions) - set(versions_in_use))):
self._remove_component_version(type_id, identifier, v, force=True)
self._remove_component_version(
type_id, identifier, v, recursive=recursive, force=True
)

for v in sorted(versions_in_use):
self.metadata.hide_component_version(type_id, identifier, v)
Expand Down Expand Up @@ -617,33 +612,54 @@ def _remove_component_version(
identifier: str,
version: int,
force: bool = False,
recursive: bool = False,
):
r = self.metadata.get_component(type_id, identifier, version=version)
if self.metadata.component_version_has_parents(type_id, identifier, version):
parents = self.metadata.get_component_version_parents(r['uuid'])
raise Exception(f'{r["uuid"]} is involved in other components: {parents}')

if force or click.confirm(
f'You are about to delete {type_id}/{identifier}{version}, are you sure?',
default=False,
):
# TODO - make this less I/O intensive
component = self.load(
type_id,
identifier,
version=version,
if not (
force
or click.confirm(
f'You are about to delete {type_id}/{identifier}{version}, '
'are you sure?',
default=False,
)
info = self.metadata.get_component(
type_id, identifier, version=version, allow_hidden=force
)
component.cleanup(self)
try:
del self.cluster.cache[component.uuid]
except KeyError:
pass
):
return

component = self.load(
type_id,
identifier,
version=version,
)
info = self.metadata.get_component(
type_id, identifier, version=version, allow_hidden=force
)
component.cleanup(self)
try:
del self.cluster.cache[component.uuid]
except KeyError:
pass

self._delete_artifacts(r['uuid'], info)
self.metadata.delete_component_version(type_id, identifier, version=version)

self._delete_artifacts(r['uuid'], info)
self.metadata.delete_component_version(type_id, identifier, version=version)
if not recursive:
return

children = component.get_children(deep=True)
children = component.sort_components(children)[::-1]
for c in children:
assert isinstance(c.version, int)
self._remove_component_version(
c.type_id,
c.identifier,
version=c.version,
recursive=False,
force=force,
)

def replace(self, object: t.Any):
"""
Expand Down
21 changes: 21 additions & 0 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from enum import Enum
from functools import wraps

import networkx
import yaml

from superduper import logging
Expand Down Expand Up @@ -187,9 +188,29 @@ class Component(Leaf, metaclass=ComponentMeta):
build_variables: t.Dict | None = None
build_template: str | None = None

# TODO what's this?
def refresh(self):
pass

@staticmethod
def sort_components(components):
"""Sort components based on topological order.
:param components: List of components.
"""
logging.info('Resorting components based on topological order.')
G = networkx.DiGraph()
lookup = {c.huuid: c for c in components}
for k in lookup:
G.add_node(k)
for d in lookup[k].get_children_refs(): # dependencies:
if d in lookup:
G.add_edge(d, k)

nodes = list(networkx.topological_sort(G))
logging.info(f'New order of components: {nodes}')
return [lookup[n] for n in nodes]

@property
def huuid(self):
"""Return a human-readable uuid."""
Expand Down
2 changes: 1 addition & 1 deletion superduper/rest/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def db_show(
def db_remove(
type_id: str, identifier: str, db: 'Datalayer' = DatalayerDependency()
):
db.remove(type_id=type_id, identifier=identifier, force=True)
db.remove(type_id=type_id, identifier=identifier, recursive=True, force=True)
return {'status': 'ok'}

@app.add('/db/show_template', method='get')
Expand Down
18 changes: 18 additions & 0 deletions test/unittest/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,21 @@ def test_set_db_deep(db):

assert m.upstream[0].db is not None
assert m.model.db is not None


class NewComponent(Component):
...


def test_remove_recursive(db):
c1 = NewComponent(identifier='c1')
c2 = NewComponent(identifier='c2', upstream=[c1])
c3 = NewComponent(identifier='c3', upstream=[c2, c1])

db.apply(c3)

assert sorted([r['identifier'] for r in db.show()]) == ['c1', 'c2', 'c3']

db.remove('component', c3.identifier, recursive=True, force=True)

assert not db.show()

0 comments on commit fae88a5

Please sign in to comment.