Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[patch] Break the recursive ownership between HasStorage and StorageInterface #423

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 63 additions & 82 deletions pyiron_workflow/mixin/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,59 +26,52 @@ class TypeNotFoundError(ImportError):


class StorageInterface:
def __init__(self, owner: HasStorage):
self._owner = owner

@property
def owner(self) -> HasStorage:
# Property access is just to allow children to override the type hint
return self._owner

def save(self):
root = self.owner.storage_root
def save(self, obj: HasStorage):
root = obj.storage_root
if not root.import_ready:
raise TypeNotFoundError(
f"{self.owner.label} cannot be saved with the "
f"{self.owner.storage_backend} because it (or one of its children) has "
f"{obj.label} cannot be saved with the "
f"{obj.storage_backend} because it (or one of its children) has "
f"a type that cannot be imported. Did you dynamically define this "
f"object? \n"
f"Import readiness report: \n"
f"{self.owner.report_import_readiness()}"
f"{obj.report_import_readiness()}"
)
root_storage = self if root is self.owner else root.storage
root_storage._save()
if root is self:
self._save(obj)
else:
root.storage._save(root)

@abstractmethod
def _save(self):
def _save(self, obj: HasStorage):
pass

def load(self):
def load(self, obj: HasStorage):
# Misdirection is strictly for symmetry with _save, so child classes define the
# private method in both cases
return self._load()
return self._load(obj)

@abstractmethod
def _load(self):
def _load(self, obj: HasStorage):
pass

@property
def has_contents(self) -> bool:
has_contents = self._has_contents
self.owner.tidy_storage_directory()
def has_contents(self, obj: HasStorage) -> bool:
has_contents = self._has_contents(obj)
obj.tidy_storage_directory()
return has_contents

@property
@abstractmethod
def _has_contents(self) -> bool:
def _has_contents(self, obj: HasStorage) -> bool:
"""Whether a save file exists for this backend"""

def delete(self):
def delete(self, obj: HasStorage):
if self.has_contents:
self._delete()
self.owner.tidy_storage_directory()
self._delete(obj)
obj.tidy_storage_directory()

@abstractmethod
def _delete(self):
def _delete(self, obj: HasStorage):
"""Remove an existing save-file for this backend"""


Expand All @@ -87,69 +80,57 @@ class PickleStorage(StorageInterface):
_PICKLE_STORAGE_FILE_NAME = "pickle.pckl"
_CLOUDPICKLE_STORAGE_FILE_NAME = "cloudpickle.cpckl"

def __init__(self, owner: HasPickleStorage):
super().__init__(owner=owner)

@property
def owner(self) -> HasPickleStorage:
return self._owner

def _save(self):
def _save(self, obj: HasStorage):
try:
with open(self._pickle_storage_file_path, "wb") as file:
pickle.dump(self.owner, file)
with open(self._pickle_storage_file_path(obj), "wb") as file:
pickle.dump(obj, file)
except Exception:
self._delete()
with open(self._cloudpickle_storage_file_path, "wb") as file:
cloudpickle.dump(self.owner, file)
self._delete(obj)
with open(self._cloudpickle_storage_file_path(obj), "wb") as file:
cloudpickle.dump(obj, file)

def _load(self):
if self._has_pickle_contents:
with open(self._pickle_storage_file_path, "rb") as file:
def _load(self, obj: HasStorage):
if self._has_pickle_contents(obj):
with open(self._pickle_storage_file_path(obj), "rb") as file:
inst = pickle.load(file)
elif self._has_cloudpickle_contents:
with open(self._cloudpickle_storage_file_path, "rb") as file:
elif self._has_cloudpickle_contents(obj):
with open(self._cloudpickle_storage_file_path(obj), "rb") as file:
inst = cloudpickle.load(file)

if inst.__class__ != self.owner.__class__:
if inst.__class__ != obj.__class__:
raise TypeError(
f"{self.owner.label} cannot load, as it has type "
f"{self.owner.__class__.__name__}, but the saved node has type "
f"{obj.label} cannot load, as it has type "
f"{obj.__class__.__name__}, but the saved node has type "
f"{inst.__class__.__name__}"
)
self.owner.__setstate__(inst.__getstate__())
obj.__setstate__(inst.__getstate__())

def _delete_file(self, file: str):
FileObject(file, self.owner.storage_directory).delete()
def _delete_file(self, file: str, obj: HasStorage):
FileObject(file, obj.storage_directory).delete()

def _delete(self):
def _delete(self, obj: HasStorage):
if self._has_pickle_contents:
self._delete_file(self._PICKLE_STORAGE_FILE_NAME)
self._delete_file(self._PICKLE_STORAGE_FILE_NAME, obj)
elif self._has_cloudpickle_contents:
self._delete_file(self._CLOUDPICKLE_STORAGE_FILE_NAME)
self._delete_file(self._CLOUDPICKLE_STORAGE_FILE_NAME, obj)

def _storage_path(self, file: str):
return str((self.owner.storage_directory.path / file).resolve())
def _storage_path(self, file: str, obj: HasStorage):
return str((obj.storage_directory.path / file).resolve())

@property
def _pickle_storage_file_path(self) -> str:
return self._storage_path(self._PICKLE_STORAGE_FILE_NAME)
def _pickle_storage_file_path(self, obj: HasStorage) -> str:
return self._storage_path(self._PICKLE_STORAGE_FILE_NAME, obj)

@property
def _cloudpickle_storage_file_path(self) -> str:
return self._storage_path(self._CLOUDPICKLE_STORAGE_FILE_NAME)
def _cloudpickle_storage_file_path(self, obj: HasStorage) -> str:
return self._storage_path(self._CLOUDPICKLE_STORAGE_FILE_NAME, obj)

@property
def _has_contents(self) -> bool:
return self._has_pickle_contents or self._has_cloudpickle_contents
def _has_contents(self, obj: HasStorage) -> bool:
return self._has_pickle_contents(obj) or self._has_cloudpickle_contents(obj)

@property
def _has_pickle_contents(self) -> bool:
return os.path.isfile(self._pickle_storage_file_path)
def _has_pickle_contents(self, obj: HasStorage) -> bool:
return os.path.isfile(self._pickle_storage_file_path(obj))

@property
def _has_cloudpickle_contents(self) -> bool:
return os.path.isfile(self._cloudpickle_storage_file_path)
def _has_cloudpickle_contents(self, obj: HasStorage) -> bool:
return os.path.isfile(self._cloudpickle_storage_file_path(obj))


class HasStorage(HasLabel, HasParent, ABC):
Expand Down Expand Up @@ -218,7 +199,7 @@ def save(self):
type can :meth:`load()` the data to return to the same state as the save point,
i.e. the same data IO channel values, the same flags, etc.
"""
self.storage.save()
self.storage.save(self)

save.__doc__ += _save_load_warnings

Expand All @@ -230,24 +211,24 @@ def load(self):
Raises:
TypeError: when the saved node has a different class name.
"""
if self.storage.has_contents:
self.storage.load()
if self.storage.has_contents(self):
self.storage.load(self)
else:
# Check for saved content using any other backend
for backend in self.allowed_backends():
interface = self._storage_interfaces()[backend](self)
if interface.has_contents:
interface.load()
interface = self._storage_interfaces()[backend]()
if interface.has_contents(self):
interface.load(self)
break

save.__doc__ += _save_load_warnings

def delete_storage(self):
"""Remove save files for _all_ available backends."""
for backend in self.allowed_backends():
interface = self._storage_interfaces()[backend](self)
interface = self._storage_interfaces()[backend]()
try:
interface.delete()
interface.delete(self)
except FileNotFoundError:
pass

Expand Down Expand Up @@ -291,12 +272,12 @@ def storage_backend(self, new_backend):
def storage(self) -> StorageInterface:
if self.storage_backend is None:
raise ValueError(f"{self.label} does not have a storage backend set")
return self._storage_interfaces()[self.storage_backend](self)
return self._storage_interfaces()[self.storage_backend]()

@property
def any_storage_has_contents(self):
return any(
self._storage_interfaces()[backend](self).has_contents
self._storage_interfaces()[backend]().has_contents(self)
for backend in self.allowed_backends()
)

Expand Down
Loading