Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Macro.graph_creator a normal method #568

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pyiron_workflow/nodes/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from pyiron_snippets.factory import classfactory

from pyiron_workflow.compatibility import Self
from pyiron_workflow.io import Inputs
from pyiron_workflow.mixin.has_interface_mixins import HasChannel
from pyiron_workflow.mixin.injection import OutputsWithInjection
Expand Down Expand Up @@ -196,7 +197,6 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC):
>>> class AddThreeMacro(Macro):
... _output_labels = ["three"]
...
... @staticmethod
... def graph_creator(self, x):
... add_three_macro(self, one__x=x)
... return self.three
Expand Down Expand Up @@ -252,7 +252,7 @@ def _setup_node(self) -> None:
super()._setup_node()

ui_nodes = self._prepopulate_ui_nodes_from_graph_creator_signature()
returned_has_channel_objects = self.graph_creator(self, *ui_nodes)
returned_has_channel_objects = self.graph_creator(*ui_nodes)
if returned_has_channel_objects is None:
returned_has_channel_objects = ()
elif isinstance(returned_has_channel_objects, HasChannel):
Expand All @@ -271,10 +271,9 @@ def _setup_node(self) -> None:
remaining_ui_nodes = self._purge_single_use_ui_nodes(ui_nodes)
self._configure_graph_execution(remaining_ui_nodes)

@staticmethod
@abstractmethod
def graph_creator(
self: Macro, *args, **kwargs # noqa: PLW0211
self: Self, *args, **kwargs
) -> HasChannel | tuple[HasChannel, ...] | None:
"""Build the graph the node will run."""

Expand Down Expand Up @@ -480,7 +479,8 @@ def macro_node_factory(
Create a new :class:`Macro` subclass using the given graph creator function.

Args:
graph_creator (callable): Function to create the graph for the :class:`Macro`.
graph_creator (callable): Function to create the graph for this subclass of
:class:`Macro`.
validate_output_labels (bool): Whether to validate the output labels against
the return values of the wrapped function.
use_cache (bool): Whether nodes of this type should default to caching their
Expand All @@ -495,7 +495,7 @@ def macro_node_factory(
graph_creator.__name__,
(Macro,), # Define parentage
{
"graph_creator": staticmethod(graph_creator),
"graph_creator": graph_creator,
"__module__": graph_creator.__module__,
"__qualname__": graph_creator.__qualname__,
"_output_labels": None if len(output_labels) == 0 else output_labels,
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/nodes/test_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def test_creation_from_subclass(self):
class MyMacro(Macro):
_output_labels = ("three__result",)

@staticmethod
def graph_creator(self, one__x): # noqa: PLW0211
def graph_creator(self, one__x):
add_three_macro(self, one__x)
return self.three

Expand Down
Loading