diff --git a/src/abcdmicro/io.py b/src/abcdmicro/io.py index 5a7d2a2..e7898e4 100644 --- a/src/abcdmicro/io.py +++ b/src/abcdmicro/io.py @@ -1,12 +1,13 @@ from __future__ import annotations -from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, TypeVar, get_type_hints +from typing import Any import itk +import numpy as np from dipy.io.gradients import read_bvals_bvecs +from numpy.typing import NDArray from abcdmicro.resource import ( BvalResource, @@ -17,54 +18,9 @@ VolumeResource, ) -T = TypeVar("T", bound="LoadableResource") - -class LoadableResource(ABC): - """Base class for on-disk resources that have a load method that converts them into in-memory resources""" - - @abstractmethod - def load(self) -> Any: - """Load this resource to get an in-memory version of it.""" - - -def implement_via_loading(method_names: list[str]) -> Callable[[type[T]], type[T]]: - """Decorator that implements the listed abstract methods of a LoadableResource class by calling the - load() method and then using the loaded object's method of the same name.""" - - def implement_via_loading_decorator(cls: type[T]) -> type[T]: - for method_name in method_names: - - def method(self, method_name=method_name): # type: ignore[no-untyped-def] - return getattr(self.load(), method_name)() - - method.__name__ = method_name - method.__doc__ = f"Automatically implemented method that returns `self.load().{method_name}()`." - - for parent_class in cls.__bases__: - if hasattr(parent_class, method_name): - parent_method = getattr(parent_class, method_name) - return_type = get_type_hints(parent_method).get("return", Any) - method.__annotations__ = {"return": return_type} - break - - setattr(cls, method_name, method) - - # If the automatically implemented methods were abstract methods, then remove them from the set - # to indicate that they have been implemented. - if hasattr(cls, "__abstractmethods__"): - cls.__abstractmethods__ = frozenset( - name for name in cls.__abstractmethods__ if name not in method_names - ) - - return cls - - return implement_via_loading_decorator - - -@implement_via_loading(["get_array", "get_metadata"]) @dataclass -class NiftiVolumeResrouce(VolumeResource): # type: ignore[type-var] +class NiftiVolumeResrouce(VolumeResource): """A volume or volume stack that is saved to disk in the nifti file format.""" path: Path @@ -73,10 +29,15 @@ class NiftiVolumeResrouce(VolumeResource): # type: ignore[type-var] def load(self) -> InMemoryVolumeResource: return InMemoryVolumeResource(itk.imread(self.path)) + def get_array(self) -> NDArray[Any]: + return self.load().get_array() + + def get_metadata(self) -> dict[Any, Any]: + return self.load().get_metadata() + -@implement_via_loading(["get"]) @dataclass -class FslBvalResource(BvalResource, LoadableResource): +class FslBvalResource(BvalResource): """A b-value list that is saved to disk in the FSL text file format.""" path: Path @@ -86,10 +47,12 @@ def load(self) -> InMemoryBvalResource: bvals_array, _ = read_bvals_bvecs(self.path, None) return InMemoryBvalResource(bvals_array) + def get(self) -> NDArray[np.floating]: + return self.load().get() + -@implement_via_loading(["get"]) @dataclass -class FslBvecResource(BvecResource, LoadableResource): +class FslBvecResource(BvecResource): """A b-vector list that is saved to disk in the FSL text file format.""" path: Path @@ -98,3 +61,6 @@ class FslBvecResource(BvecResource, LoadableResource): def load(self) -> InMemoryBvecResource: _, bvecs_array = read_bvals_bvecs(None, self.path) return InMemoryBvecResource(bvecs_array) + + def get(self) -> NDArray[np.floating]: + return self.load().get() diff --git a/tests/test_io.py b/tests/test_io.py index b74fb48..1b8da0b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -51,5 +51,5 @@ def test_nifti_volume_resource(volume_array): ] ), ) - volume_resource = NiftiVolumeResrouce(path=volume_file) # type: ignore[abstract] + volume_resource = NiftiVolumeResrouce(path=volume_file) assert np.allclose(volume_resource.get_array(), volume_array)