Skip to content

Commit

Permalink
More return hints (#552)
Browse files Browse the repository at this point in the history
* Fix returned type of __dir__

Conventionally it returns a list, not a set, of strings

Signed-off-by: liamhuber <[email protected]>

* Add hints to io

Signed-off-by: liamhuber <[email protected]>

* Adjust run_finally signature

Signed-off-by: liamhuber <[email protected]>

* Hint user data

Signed-off-by: liamhuber <[email protected]>

* Hint Workflow.automate_execution

Signed-off-by: liamhuber <[email protected]>

* Provide a type-compliant default

It never actually matters with the current logic, because of all the checks if parent is None and the fact that it is otherwise hinted to be at least a `Composite`, but it shuts mypy up and it does zero harm.

Signed-off-by: liamhuber <[email protected]>

* black

Signed-off-by: liamhuber <[email protected]>

* `mypy` storage (#553)

* Add return hints

Signed-off-by: liamhuber <[email protected]>

* End clause with else

Signed-off-by: liamhuber <[email protected]>

* Explicitly raise an error

After narrowing our search to files, actually throw an error right away if you never found one to load.

Signed-off-by: liamhuber <[email protected]>

* Resolve method extension complaints

Signed-off-by: liamhuber <[email protected]>

* `mypy` signature compliance (#554)

* Extend runnable signatures

Signed-off-by: liamhuber <[email protected]>

* Align Workflow.run with superclass signature

Signed-off-by: liamhuber <[email protected]>

* Relax FromManyInputs._on_run constraint

It was too strict for the DataFrame subclass, so just keep the superclass reference instead of narrowing the constraints.

Signed-off-by: liamhuber <[email protected]>

* black

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber authored Jan 17, 2025
1 parent 51257c0 commit 7422cad
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 60 deletions.
75 changes: 40 additions & 35 deletions pyiron_workflow/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import contextlib
from abc import ABC, abstractmethod
from collections.abc import ItemsView, Iterator
from typing import Any, Generic, TypeVar

from pyiron_snippets.dotdict import DotDict
Expand Down Expand Up @@ -59,7 +60,7 @@ class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC):

channel_dict: DotDict[str, OwnedType]

def __init__(self, *channels: OwnedType):
def __init__(self, *channels: OwnedType) -> None:
self.__dict__["channel_dict"] = DotDict(
{
channel.label: channel
Expand All @@ -74,11 +75,11 @@ def _channel_class(self) -> type[OwnedType]:
pass

@abstractmethod
def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None:
def _assign_a_non_channel_value(self, channel: OwnedType, value: Any) -> None:
"""What to do when some non-channel value gets assigned to a channel"""
pass

def __getattr__(self, item) -> OwnedType:
def __getattr__(self, item: str) -> OwnedType:
try:
return self.channel_dict[item]
except KeyError as key_error:
Expand All @@ -88,7 +89,7 @@ def __getattr__(self, item) -> OwnedType:
f"nor in its channels ({self.labels})"
) from key_error

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: Any) -> None:
if key in self.channel_dict:
self._assign_value_to_existing_channel(self.channel_dict[key], value)
elif isinstance(value, self._channel_class):
Expand All @@ -104,16 +105,16 @@ def __setattr__(self, key, value):
f"attribute {key} got assigned {value} of type {type(value)}"
)

def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None:
def _assign_value_to_existing_channel(self, channel: OwnedType, value: Any) -> None:
if isinstance(value, HasChannel):
channel.connect(value.channel)
else:
self._assign_a_non_channel_value(channel, value)

def __getitem__(self, item) -> OwnedType:
def __getitem__(self, item: str) -> OwnedType:
return self.__getattr__(item)

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
self.__setattr__(key, value)

@property
Expand All @@ -124,11 +125,11 @@ def connections(self) -> list[OwnedConjugate]:
)

@property
def connected(self):
def connected(self) -> bool:
return any(c.connected for c in self)

@property
def fully_connected(self):
def fully_connected(self) -> bool:
return all(c.connected for c in self)

def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]:
Expand All @@ -145,34 +146,36 @@ def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]:
return destroyed_connections

@property
def labels(self):
def labels(self) -> list[str]:
return list(self.channel_dict.keys())

def items(self):
def items(self) -> ItemsView[str, OwnedType]:
return self.channel_dict.items()

def __iter__(self):
def __iter__(self) -> Iterator[OwnedType]:
return self.channel_dict.values().__iter__()

def __len__(self):
def __len__(self) -> int:
return len(self.channel_dict)

def __dir__(self):
return set(super().__dir__() + self.labels)
return list(set(super().__dir__() + self.labels))

def __str__(self):
def __str__(self) -> str:
return f"{self.__class__.__name__} {self.labels}"

def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
# Compatibility with python <3.11
return dict(self.__dict__)

def __setstate__(self, state):
def __setstate__(self, state: dict[str, Any]) -> None:
# Because we override getattr, we need to use __dict__ assignment directly in
# __setstate__ the same way we need it in __init__
self.__dict__["channel_dict"] = state["channel_dict"]

def display_state(self, state=None, ignore_private=True):
def display_state(
self, state: dict[str, Any] | None = None, ignore_private: bool = True
) -> dict[str, Any]:
state = dict(self.__getstate__()) if state is None else state
for k, v in state["channel_dict"].items():
state[k] = v
Expand All @@ -192,15 +195,15 @@ class DataIO(IO[DataChannel, DataChannel], ABC):
def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None:
channel.value = value

def to_value_dict(self):
def to_value_dict(self) -> dict[str, Any]:
return {label: channel.value for label, channel in self.channel_dict.items()}

def to_list(self):
def to_list(self) -> list[Any]:
"""A list of channel values (order not guaranteed)"""
return [channel.value for channel in self.channel_dict.values()]

@property
def ready(self):
def ready(self) -> bool:
return all(c.ready for c in self)

def activate_strict_hints(self):
Expand All @@ -215,7 +218,7 @@ class Inputs(InputsIO, DataIO):
def _channel_class(self) -> type[InputData]:
return InputData

def fetch(self):
def fetch(self) -> None:
for c in self:
c.fetch()

Expand All @@ -237,7 +240,7 @@ def _channel_class(self) -> type[OutputData]:


class SignalIO(IO[SignalChannel, SignalChannel], ABC):
def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None:
def _assign_a_non_channel_value(self, channel: SignalChannel, value: Any) -> None:
raise TypeError(
f"Tried to assign {value} ({type(value)} to the {channel.full_label}, "
f"which is already a {type(channel)}. Only other signal channels may be "
Expand Down Expand Up @@ -275,9 +278,9 @@ class Signals(HasStateDisplay):
output (OutputSignals): An empty input signals IO container.
"""

def __init__(self):
self.input = InputSignals()
self.output = OutputSignals()
def __init__(self) -> None:
self.input: InputSignals = InputSignals()
self.output: OutputSignals = OutputSignals()

def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]:
"""
Expand All @@ -293,14 +296,14 @@ def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]:
return self.input.disconnect_run()

@property
def connected(self):
def connected(self) -> bool:
return self.input.connected or self.output.connected

@property
def fully_connected(self):
def fully_connected(self) -> bool:
return self.input.fully_connected and self.output.fully_connected

def __str__(self):
def __str__(self) -> str:
return f"{str(self.input)}\n{str(self.output)}"


Expand All @@ -316,7 +319,7 @@ class HasIO(HasStateDisplay, HasLabel, HasRun, Generic[OutputsType], ABC):
interface.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._signals = Signals()
self._signals.input.run = InputSignal("run", self, self.run)
Expand Down Expand Up @@ -375,17 +378,17 @@ def disconnect(self) -> list[tuple[Channel, Channel]]:
destroyed_connections.extend(self.signals.disconnect())
return destroyed_connections

def activate_strict_hints(self):
def activate_strict_hints(self) -> None:
"""Enable type hint checks for all data IO"""
self.inputs.activate_strict_hints()
self.outputs.activate_strict_hints()

def deactivate_strict_hints(self):
def deactivate_strict_hints(self) -> None:
"""Disable type hint checks for all data IO"""
self.inputs.deactivate_strict_hints()
self.outputs.deactivate_strict_hints()

def _connect_output_signal(self, signal: OutputSignal):
def _connect_output_signal(self, signal: OutputSignal) -> None:
self.signals.input.run.connect(signal)

def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO:
Expand All @@ -395,10 +398,12 @@ def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO:
other._connect_output_signal(self.signals.output.ran)
return other

def _connect_accumulating_input_signal(self, signal: AccumulatingInputSignal):
def _connect_accumulating_input_signal(
self, signal: AccumulatingInputSignal
) -> None:
self.signals.output.ran.connect(signal)

def __lshift__(self, others):
def __lshift__(self, others: tuple[OutputSignal | HasIO, ...]):
"""
Connect one or more `ran` signals to `accumulate_and_run` signals like:
`this << some_object, another_object, or_by_channel.signals.output.ran`
Expand Down
3 changes: 2 additions & 1 deletion pyiron_workflow/mixin/display_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import ABC
from json import dumps
from typing import Any

from pyiron_workflow.mixin.has_interface_mixins import UsesState

Expand All @@ -24,7 +25,7 @@ class HasStateDisplay(UsesState, ABC):

def display_state(
self, state: dict | None = None, ignore_private: bool = True
) -> dict:
) -> dict[str, Any]:
"""
A dictionary of JSON-compatible objects based on the object state (plus
whatever modifications to the state the class designer has chosen to make).
Expand Down
11 changes: 7 additions & 4 deletions pyiron_workflow/mixin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_args(self) -> tuple[tuple, dict]:
Any data needed for :meth:`on_run`, will be passed as (*args, **kwargs).
"""

def process_run_result(self, run_output):
def process_run_result(self, run_output: Any) -> Any:
"""
What to _do_ with the results of :meth:`on_run` once you have them.
Expand Down Expand Up @@ -165,7 +165,9 @@ def _none_to_dict(inp: dict | None) -> dict:
**run_kwargs,
)

def _before_run(self, /, check_readiness, **kwargs) -> tuple[bool, Any]:
def _before_run(
self, /, check_readiness: bool, *args, **kwargs
) -> tuple[bool, Any]:
"""
Things to do _before_ running.
Expand Down Expand Up @@ -194,6 +196,7 @@ def _run(
run_exception_kwargs: dict,
run_finally_kwargs: dict,
finish_run_kwargs: dict,
*args,
**kwargs,
) -> Any | tuple | Future:
"""
Expand Down Expand Up @@ -254,15 +257,15 @@ def _run(
)
return self.future

def _run_exception(self, /, **kwargs):
def _run_exception(self, /, *args, **kwargs):
"""
What to do if an exception is encountered inside :meth:`_run` or
:meth:`_finish_run.
"""
self.running = False
self.failed = True

def _run_finally(self, /, **kwargs):
def _run_finally(self, /, *args, **kwargs):
"""
What to do after :meth:`_finish_run` (whether an exception is encountered or
not), or in :meth:`_run` after an exception is encountered.
Expand Down
6 changes: 4 additions & 2 deletions pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def __init__(
self._do_clean: bool = False # Power-user override for cleaning up temporary
# serialized results and empty directories (or not).
self._cached_inputs = None
self._user_data = {} # A place for power-users to bypass node-injection

self._user_data: dict[str, Any] = {}
# A place for power-users to bypass node-injection

self._setup_node()
self._after_node_setup(
Expand Down Expand Up @@ -629,7 +631,7 @@ def run_data_tree(self, run_parent_trees_too=False) -> None:

try:
parent_starting_nodes = (
self.parent.starting_nodes if self.parent is not None else None
self.parent.starting_nodes if self.parent is not None else []
) # We need these for state recovery later, even if we crash

if len(data_tree_starters) == 1 and data_tree_starters[0] is self:
Expand Down
4 changes: 0 additions & 4 deletions pyiron_workflow/nodes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class FromManyInputs(Transformer, ABC):
# Inputs convert to `run_args` as a value dictionary
# This must be commensurate with the internal expectations of _on_run

@abstractmethod
def _on_run(self, **inputs_to_value_dict) -> Any:
"""Must take inputs kwargs"""

@property
def _run_args(self) -> tuple[tuple, dict]:
return (), self.inputs.to_value_dict()
Expand Down
Loading

0 comments on commit 7422cad

Please sign in to comment.