Skip to content

Commit

Permalink
Change return on Composite.remove_child
Browse files Browse the repository at this point in the history
To match return in parent class. Disconnections were only ever used in the test case, and users are always free to disconnect and _then_ remove if they want to capture the broken connections explicitly.

Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber committed Jan 17, 2025
1 parent 2240e98 commit 286e688
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
8 changes: 4 additions & 4 deletions pyiron_workflow/nodes/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def add_child(
self._cached_inputs = None # Reset cache after graph change
return super().add_child(child, label=label, strict_naming=strict_naming)

def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]:
def remove_child(self, child: Node | str) -> Node:
"""
Remove a child from the :attr:`children` collection, disconnecting it and
setting its :attr:`parent` to None.
Expand All @@ -316,14 +316,14 @@ def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]:
child (Node|str): The child (or its label) to remove.
Returns:
(list[tuple[Channel, Channel]]): Any connections that node had.
(Node): The (now disconnected and de-parented) (former) child node.
"""
child = super().remove_child(child)
disconnected = child.disconnect()
child.disconnect()
if child in self.starting_nodes:
self.starting_nodes.remove(child)
self._cached_inputs = None # Reset cache after graph change
return disconnected
return child

def replace_child(
self, owned_node: Node | str, replacement: Node | type[Node]
Expand Down
19 changes: 4 additions & 15 deletions tests/unit/nodes/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,18 @@ def test_node_removal(self):
# Connect it inside the composite
self.comp.foo.inputs.x = self.comp.owned.outputs.y

disconnected = self.comp.remove_child(node)
self.comp.remove_child(node)
self.assertIsNone(node.parent, msg="Removal should de-parent")
self.assertFalse(node.connected, msg="Removal should disconnect")
self.assertListEqual(
[(node.inputs.x, self.comp.owned.outputs.y)],
disconnected,
msg="Removal should return destroyed connections",
)
self.assertListEqual(
self.comp.starting_nodes,
[],
msg="Removal should also remove from starting nodes",
)

node_owned = self.comp.owned
disconnections = self.comp.remove_child(node_owned.label)
self.assertEqual(
node_owned.parent,
None,
msg="Should be able to remove nodes by label as well as by object",
)
self.assertListEqual(
[], disconnections, msg="node1 should have no connections left"
[],
self.comp.owned.connections,
msg="Remaining node should have no connections left",
)

def test_label_uniqueness(self):
Expand Down

0 comments on commit 286e688

Please sign in to comment.