Skip to content

Commit

Permalink
Remove __post_init__ from developer contract
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 17, 2025
1 parent 4132691 commit 548c814
Show file tree
Hide file tree
Showing 34 changed files with 374 additions and 125 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
#### Changed defaults / behaviours

- No need to add `.signature` to `Model` implementations
- No need to write `Component.__post_init__` to modify attributes (use `Component.postinit`).

#### New Features & Functionality

Expand Down
5 changes: 3 additions & 2 deletions plugins/anthropic/superduper_anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ class Anthropic(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, example):
def postinit(self):
"""Post-initialization method."""
self.model = self.model or self.identifier
super().__post_init__(db, example=example)
super().postinit()

def init(self, db=None):
"""Initialize the model.
Expand Down
10 changes: 6 additions & 4 deletions plugins/cohere/superduper_cohere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class Cohere(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, example):
super().__post_init__(db, example=example)
def postinit(self):
"""Post-initialization method."""
self.identifier = self.identifier or self.model
return super().postinit()


class CohereEmbed(Cohere):
Expand All @@ -46,10 +47,11 @@ class CohereEmbed(Cohere):
batch_size: int = 100
signature: str = 'singleton'

def __post_init__(self, db, example):
super().__post_init__(db, example=example)
def postinit(self):
"""Post-initialization method."""
if self.shape is None:
self.shape = self.shapes[self.identifier]
return super().postinit()

@retry
def predict(self, X: str):
Expand Down
10 changes: 6 additions & 4 deletions plugins/jina/superduper_jina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class Jina(APIBaseModel):

api_key: t.Optional[str] = None

def __post_init__(self, db, example):
super().__post_init__(db, example=example)
def postinit(self):
"""Post-initialization method."""
self.identifier = self.identifier or self.model
self.client = JinaAPIClient(model_name=self.identifier, api_key=self.api_key)
return super().postinit()


class JinaEmbedding(Jina):
Expand All @@ -40,8 +41,9 @@ class JinaEmbedding(Jina):
shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'

def __post_init__(self, db, example):
super().__post_init__(db, example)
def postinit(self):
"""Post-initialization method."""
super().postinit()
if self.shape is None:
self.shape = (len(self.client.encode_batch(['shape'])[0]),)

Expand Down
15 changes: 8 additions & 7 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,23 @@ class _OpenAI(APIBaseModel):
openai_api_base: t.Optional[str] = None
client_kwargs: t.Optional[dict] = dc.field(default_factory=dict)

def __post_init__(self, db, example):
super().__post_init__(db, example)

def postinit(self):
assert isinstance(self.client_kwargs, dict)
if self.openai_api_key is not None:
self.client_kwargs['api_key'] = self.openai_api_key
if self.openai_api_base is not None:
self.client_kwargs['base_url'] = self.openai_api_base
self.client_kwargs['default_headers'] = self.openai_api_base

super().postinit()

@safe_retry(exceptions.MissingSecretsException, verbose=0)
def init(self, db=None):
def init(self):
"""Initialize the model.
:param db: Database instance.
"""
super().init(db=db)
super().init()

# dall-e is not currently included in list returned by OpenAI model endpoint
if 'OPENAI_API_KEY' not in os.environ or (
Expand Down Expand Up @@ -167,9 +167,10 @@ class OpenAIChatCompletion(_OpenAI):
batch_size: int = 1
prompt: str = ''

def __post_init__(self, db, example):
super().__post_init__(db, example)
def postinit(self):
"""Post-initialization method."""
self.takes_context = True
return super().postinit()

def _format_prompt(self, context, X):
prompt = self.prompt.format(context='\n'.join(context))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ class SentenceTransformer(Model, _DeviceManaged):
postprocess: t.Union[None, t.Callable] = None
signature: Signature = 'singleton'

def __post_init__(self, db, example):
super().__post_init__(db, example=example)

def postinit(self):
"""Post-initialization method."""
if self.model is None:
self.model = self.identifier

self._default_model = False
if self.object is None:
self.object = _SentenceTransformer(self.model, device=self.device)
self._default_model = True
return super().postinit()

def dict(self, metadata: bool = True, defaults: bool = True, refs: bool = False):
"""Serialize as a dictionary."""
Expand Down
5 changes: 3 additions & 2 deletions plugins/torch/superduper_torch/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ class Tensor(BaseDataType):
shape: int | t.Tuple[int]
identifier: str = ''

def __post_init__(self, db):
def postinit(self):
"""Post-initialization method."""
self.encodable = 'encodable'
if not self.identifier:
dtype = str(self.dtype)
self.identifier = f'torch-{dtype}[{str_shape(self.shape)}]'
return super().__post_init__(db)
return super().postinit()

def encode_data(self, item):
"""Encode data.
Expand Down
7 changes: 3 additions & 4 deletions plugins/torch/superduper_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,11 @@ class TorchModel(Model, _DeviceManaged):
optimizer_state: t.Optional[t.Any] = None
loader_kwargs: t.Dict = dc.field(default_factory=lambda: {})

def __post_init__(self, db, example):
super().__post_init__(db, example=example)

def init(self):
"""Initialize the model data."""
super().init()
if self.optimizer_state is not None:
self.optimizer.load_state_dict(self.optimizer_state)

self._validation_set_cache = {}

@property
Expand Down
7 changes: 4 additions & 3 deletions plugins/transformers/superduper_transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_DeviceManaged,
)
from superduper.components.training import Checkpoint
from superduper.ext.llm.model import BaseLLM
from superduper.components.llm.model import BaseLLM
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
Expand Down Expand Up @@ -209,10 +209,11 @@ def _build_pipeline(self):
model=self.model_cls.from_pretrained(self.model_name),
)

def __post_init__(self, db, example):
def postinit(self):
"""Post-initialization method."""
if self.pipeline is None:
self._build_pipeline()
super().__post_init__(db, example)
super().postinit()

def predict(self, text: str):
"""Predict the class of a single text.
Expand Down
4 changes: 2 additions & 2 deletions plugins/vllm/superduper_vllm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class _VLLMCore(Model):

vllm_params: dict = dc.field(default_factory=dict)

def __post_init__(self, db, example):
super().__post_init__(db, example)
def postinit(self):
assert "model" in self.vllm_params, "model is required in vllm_params"
self._async_llm = None
self._sync_llm = None
Expand All @@ -29,6 +28,7 @@ def __post_init__(self, db, example):
parallel_size = max(tensor_parallel_size, pipeline_parallel_size)
self.compute_kwargs["num_gpus"] = parallel_size
logging.info(f"Setting num_gpus to {parallel_size}")
super().postinit()

def _init_sync_llm(self):
if self._sync_llm is not None:
Expand Down
1 change: 0 additions & 1 deletion superduper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
'superduper',
'BaseDataType',
'Document',
'code',
'ObjectModel',
'QueryModel',
'Validation',
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def apply(
:param db: Datalayer instance
:param object: Object to be stored.
:param force: List of jobs which should execute before component
initialization begins.
initialization begins.
:param wait: Blocks execution till create events finish.
"""
if force is None:
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def _deep_flat_encode(

# inline components do not need to be kept
# they are simply parametrized by their inputs
if isinstance(r, leaves_to_keep):
if isinstance(r, leaves_to_keep):
builds[key] = r
return '?' + key

Expand Down
5 changes: 5 additions & 0 deletions superduper/base/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class Leaf(metaclass=LeafMeta):
db: dc.InitVar[t.Optional['Datalayer']] = None
uuid: str = dc.field(default_factory=build_uuid)

def postinit(self):
"""Post-initialization method."""
pass

def _get_metadata(self):
return {}

Expand All @@ -148,6 +152,7 @@ def metadata(self):

def __post_init__(self, db: t.Optional['Datalayer'] = None):
self.db = db
self.postinit()

@property
def leaves(self):
Expand Down
9 changes: 4 additions & 5 deletions superduper/components/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ class Application(Component):
namespace: t.Optional[t.Sequence[t.Tuple[str, str]]] = None
link: t.Optional[str] = None

def __post_init__(self, db):
super().__post_init__(db)
self._sort_components_and_set_upstream()

def _sort_components_and_set_upstream(self):
def postinit(self):
"""Post initialization method."""
logging.info('Resorting components based on topological order.')
G = networkx.DiGraph()
lookup = {c.huuid: c for c in self.components}
Expand All @@ -57,6 +54,8 @@ def _sort_components_and_set_upstream(self):
components = [lookup[n] for n in nodes]
self.components = components

super().postinit()

def pre_create(self, db: "Datalayer"):
"""Pre-create hook.
Expand Down
16 changes: 8 additions & 8 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ def leaves(self):

def __post_init__(self, db):
super().__post_init__(db)
self.postinit()

def postinit(self):
"""Post initialization method."""
self.version: t.Optional[int] = None
if not self.identifier:
raise ValueError('identifier cannot be empty or None')
Expand All @@ -555,13 +559,9 @@ def dependencies(self):
"""Get dependencies on the component."""
return ()

def init(self, db=None):
"""Method to help initiate component field dependencies.
:param db: The `Datalayer` to use for the operation.
"""
self.db = self.db or db
self.unpack(db=db)
def init(self):
"""Method to help initiate component field dependencies."""
self.unpack(db=self.db)

# TODO Why both methods?
def unpack(self, db=None):
Expand All @@ -575,7 +575,7 @@ def unpack(self, db=None):
def _init(item):
nonlocal db
if isinstance(item, Component):
item.init(db=db)
item.init()
return item

if isinstance(item, dict):
Expand Down
31 changes: 10 additions & 21 deletions superduper/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import cached_property

import numpy
from overrides import override

from superduper.backends.base.query import Query
from superduper.base.datalayer import Datalayer
Expand Down Expand Up @@ -35,39 +34,29 @@ class Dataset(Component):
raw_data: t.Optional[t.Sequence[t.Any]] = None
pin: bool = False

def __post_init__(self, db):
"""Post-initialization method.
:param artifacts: Optional additional artifacts for initialization.
"""
super().__post_init__(db=db)
def postinit(self):
"""Post initialization method."""
self._data = None
super().postinit()

@property
@ensure_initialized
def data(self):
"""Property representing the dataset's data."""
return self._data

def init(self, db=None):
"""Initialization method.
:param db: The database to use for the operation.
"""
db = db or self.db
super().init(db=db)
def init(self):
"""Initialization method."""
super().init()
if self.pin:
assert self.raw_data is not None
self._data = [Document.decode(r, db=db).unpack() for r in self.raw_data]
self._data = [
Document.decode(r, db=self.db).unpack() for r in self.raw_data
]
else:
self._data = self._load_data(db)
self._data = self._load_data(self.db)

@override
def _pre_create(self, db: 'Datalayer', startup_cache: t.Dict = {}) -> None:
"""Pre-create hook for database operations.
:param db: The database to use for the operation.
"""
if self.raw_data is None and self.pin:
data = self._load_data(db)
self.raw_data = [r.encode() for r in data]
Expand Down
4 changes: 2 additions & 2 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ class FileItem(Saveable):

path: str = ''

def __post_init__(self, db=None):
def postinit(self):
"""Post init."""
if not self.identifier:
self.identifier = get_hash(self.path)
return super().__post_init__(db)

def init(self):
"""Initialize the file to local disk."""
Expand Down
Loading

0 comments on commit 548c814

Please sign in to comment.