Skip to content

Commit

Permalink
Fixed model cleanup on cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Jan 2, 2025
1 parent a18c1aa commit 504f2ad
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add create events waiting on db apply.
- Refactor secrets loading method.
- Add db.load in db wait
- Add model component cleanup

#### New Features & Functionality

Expand Down
10 changes: 10 additions & 0 deletions superduper/backends/base/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def initialize(self, with_compute: bool = False):
self.vector_search.initialize()
self.crontab.initialize()
self.cdc.initialize()

def drop_component(self, uuid: str):
"""Drop component and its services rom the cluster.
:param uuid: Component uuid.
"""
try:
del self.cache[uuid]
except KeyError:
pass
6 changes: 6 additions & 0 deletions superduper/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,9 @@ def db(self, value: 'Datalayer'):
:param value: ``Datalayer`` instance.
"""
self._db = value

def drop_component(self, uuid: str):
"""Drop the component from compute.
:param uuid: Component uuid.
"""
1 change: 1 addition & 0 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def remove(

for v in sorted(versions_in_use):
self.metadata.hide_component_version(type_id, identifier, v)

else:
logging.warn('aborting.')

Expand Down
7 changes: 7 additions & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ def __post_init__(self, db, example):
if not self.identifier:
raise Exception('_Predictor identifier must be non-empty')

def cleanup(self, db: "Datalayer") -> None:
"""Clean up when the model is deleted.
:param db: Data layer instance to process.
"""
db.cluster.compute.drop_component(self.uuid)

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down
4 changes: 2 additions & 2 deletions superduper/ext/llm/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class RetrievalPrompt(QueryModel):
prompt_introduction: str = PROMPT_INTRODUCTION
join: str = "\n---\n"

def __post_init__(self, db):
def __post_init__(self, db, example):
assert 'prompt' in self.select.variables
return super().__post_init__(db)
return super().__post_init__(db, example)

@property
def inputs(self):
Expand Down
5 changes: 2 additions & 3 deletions superduper/rest/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import threading
import time
import os
import typing as t
from functools import cached_property
from traceback import format_exc
Expand Down Expand Up @@ -198,13 +199,11 @@ def _add_templates(self, db):

existing = db.show('template')
for t in self.templates:
if t in existing:
if t in existing or t is None:
logging.info(f'Found existing template: {t}')
continue
logging.info(f'Applying template: {t}')

import os

if t is None:
continue

Expand Down
2 changes: 2 additions & 0 deletions superduper/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def ls():
def __getattr__(name: str):
import re

breakpoint()

if not re.match('.*[0-9]+\.[0-9]+\.[0-9]+.*', name):
assert name in TEMPLATES, f'{name} not in supported templates {TEMPLATES}'
file = TEMPLATES[name].split('/')[-1]
Expand Down

0 comments on commit 504f2ad

Please sign in to comment.