Skip to content

Commit

Permalink
Return both nodes on replacement
Browse files Browse the repository at this point in the history
Instead of only returning the replaced node.

Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber committed Jan 17, 2025
1 parent e61a9db commit a520df5
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
29 changes: 19 additions & 10 deletions pyiron_workflow/nodes/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def remove_child(self, child: Node | str) -> Node:

def replace_child(
self, owned_node: Node | str, replacement: Node | type[Node]
) -> Node:
) -> tuple[Node, Node]:
"""
Replaces a node currently owned with a new node instance.
The replacement must not belong to any other parent or have any connections.
Expand All @@ -348,7 +348,7 @@ def replace_child(
and simply gets instantiated.)
Returns:
(Node): The node that got removed
(Node, Node): The node that got removed and the new one that replaced it.
"""
if isinstance(owned_node, str):
owned_node = self.children[owned_node]
Expand All @@ -367,15 +367,18 @@ def replace_child(
)
if replacement.connected:
raise ValueError("Replacement node must not have any connections")
replacement_node = replacement
elif issubclass(replacement, Node):
replacement = replacement(label=owned_node.label)
replacement_node = replacement(label=owned_node.label)
else:
raise TypeError(
f"Expected replacement node to be a node instance or node subclass, but "
f"got {replacement}"
)

replacement.copy_io(owned_node) # If the replacement is incompatible, we'll
replacement_node.copy_io(
owned_node
) # If the replacement is incompatible, we'll
# fail here before we've changed the parent at all. Since the replacement was
# first guaranteed to be an unconnected orphan, there is not yet any permanent
# damage
Expand All @@ -388,23 +391,29 @@ def replace_child(
if sending_channel.value_receiver in owned_node.inputs
]
outbound_links = [
(replacement.outputs[sending_channel.label], sending_channel.value_receiver)
(
replacement_node.outputs[sending_channel.label],
sending_channel.value_receiver,
)
for sending_channel in owned_node.outputs
if sending_channel.value_receiver in self.outputs
]
self.remove_child(owned_node)
replacement.label, owned_node.label = owned_node.label, replacement.label
self.add_child(replacement)
replacement_node.label, owned_node.label = (
owned_node.label,
replacement_node.label,
)
self.add_child(replacement_node)
if is_starting_node:
self.starting_nodes.append(replacement)
self.starting_nodes.append(replacement_node)
for sending_channel, receiving_channel in inbound_links + outbound_links:
sending_channel.value_receiver = receiving_channel

# Clear caches
self._cached_inputs = None
replacement._cached_inputs = None
replacement_node._cached_inputs = None

return owned_node
return owned_node, replacement_node

def executor_shutdown(self, wait=True, *, cancel_futures=False):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyiron_workflow/nodes/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC):
>>> # With the replace method
>>> # (replacement target can be specified by label or instance,
>>> # the replacing node can be specified by instance or class)
>>> replaced = adds_six_macro.replace_child(adds_six_macro.one, add_two())
>>> replaced, _ = adds_six_macro.replace_child(adds_six_macro.one, add_two())
>>> # With the replace_with method
>>> adds_six_macro.two.replace_with(add_two())
>>> # And by assignment of a compatible class to an occupied node label
Expand Down
10 changes: 6 additions & 4 deletions pyiron_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,10 @@ def _owned_io_panels(self) -> list[IO]:

def replace_child(
self, owned_node: Node | str, replacement: Node | type[Node]
) -> Node:
super().replace_child(owned_node=owned_node, replacement=replacement)
) -> tuple[Node, Node]:
replaced, replacement_node = super().replace_child(
owned_node=owned_node, replacement=replacement
)

# Finally, make sure the IO is constructible with this new node, which will
# catch things like incompatible IO maps
Expand All @@ -509,11 +511,11 @@ def replace_child(
except Exception as e:
# If IO can't be successfully rebuilt using this node, revert changes and
# raise the exception
self.replace_child(replacement, owned_node) # Guaranteed to work since
self.replace_child(replacement_node, replaced) # Guaranteed to work since
# replacement in the other direction was already a success
raise e

return owned_node
return replaced, replacement_node

@property
def parent(self) -> None:
Expand Down

0 comments on commit a520df5

Please sign in to comment.