From 339611d0be16ec843c1d3c02497c1af3bd1556d9 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 7 Sep 2023 07:45:00 -0700 Subject: [PATCH] Fixed tree_check raising false positives --- equinox/_tree.py | 95 +++++++++++++++++++++++++++++++--------------- tests/test_tree.py | 21 ++++++++++ 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/equinox/_tree.py b/equinox/_tree.py index c0e1d39e..de1b1dc3 100644 --- a/equinox/_tree.py +++ b/equinox/_tree.py @@ -339,8 +339,42 @@ def is_leaf(node): def tree_check(pytree: Any) -> None: - """Checks if the PyTree is well-formed: does it have no repeated nodes, and does it - have no self-references. + """Checks if the PyTree is well-formed: does it have no self-references, and does + it have no duplicate layers. + + Precisely, a "duplicate layer" is any PyTree node with at least one child node. + + !!! info + + This is automatically called when creating an `eqx.Module` instance, to help + avoid bugs from duplicating layers. + + !!! Example + + ```python + a = 1 + eqx.tree_check([a, a]) # passes, duplicate is a leaf + + b = eqx.nn.Linear(...) + eqx.tree_check([b, b]) # fails, duplicate is nontrivial! + + c = [] # empty list + eqx.tree_check([c, c]) # passes, duplicate is trivial + + d = eqx.Module() + eqx.tree_check([d, d]) # passes, duplicate is trivial + + eqx.tree_check([None, None]) # passes, duplicate is trivial + + e = [1] + eqx.tree_check([e, e]) # fails, duplicate is nontrivial! + + eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]` + # has the same structure, but they're different. + + # passes, not actually a duplicate: each Linear layer is a separate layer. + eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)]) + ``` **Arguments:** @@ -358,38 +392,37 @@ def tree_check(pytree: Any) -> None: _tree_check(pytree, all_nodes) -_trivial_treedef = jtu.tree_structure(0) +_leaf_treedef = jtu.tree_structure(0) def _tree_check(node, all_nodes): - try: - self_referential, type_string = all_nodes[id(node)] - except KeyError: - pass - else: - if self_referential: - raise ValueError( - f"PyTree node of type `{type_string}` is self-referential; that is to " - "say it appears somewhere within its own PyTree structure. This is " - "not allowed." - ) - else: - raise ValueError( - f"PyTree node of type `{type_string}` appears in the PyTree multiple " - "times. This is almost always an error, as these nodes will turn into " - "two duplicate copies after flattening/unflattening, e.g. when " - "crossing a JIT boundary." - ) - try: - type_string = type(node).__name__ - except AttributeError: - # AttributeError: in case we cannot get __name__ for some weird reason. - type_string = "" - all_nodes[id(node)] = (True, type_string) subnodes, treedef = tree_flatten_one_level(node) - if treedef != _trivial_treedef: - # This does mean that leaves can appear multiple times. This is valid, e.g. - # [4, 4]. + # We allow duplicate leaves and empty containers, so don't raise an error with those + if treedef != _leaf_treedef and treedef.num_leaves > 0: + try: + self_referential, type_string = all_nodes[id(node)] + except KeyError: + pass + else: + if self_referential: + raise ValueError( + f"PyTree node of type `{type_string}` is self-referential; that is " + "to say it appears somewhere within its own PyTree structure. This " + "is not allowed." + ) + else: + raise ValueError( + f"PyTree node of type `{type_string}` appears in the PyTree " + "multiple times. This is almost always an error, as these nodes " + "will turn into two duplicate copies after " + "flattening/unflattening, e.g. when crossing a JIT boundary." + ) + try: + type_string = type(node).__name__ + except AttributeError: + # AttributeError: in case we cannot get __name__ for some weird reason. + type_string = "" + all_nodes[id(node)] = (True, type_string) for subnode in subnodes: _tree_check(subnode, all_nodes) - all_nodes[id(node)] = (False, type_string) + all_nodes[id(node)] = (False, type_string) diff --git a/tests/test_tree.py b/tests/test_tree.py index 8e4e38e4..9f883704 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -188,6 +188,14 @@ def test_tree_flatten_one_level(): eqx.tree_flatten_one_level(x) +# This matches the behaviour of `jax._src.tree_util.flatten_one_level` +def test_tree_flatten_one_level_special(): + x = [None, None, eqx.Module(), 1, 2] + leaves, treedef = eqx.tree_flatten_one_level(x) + assert leaves == [None, None, eqx.Module(), 1, 2] + assert treedef == jtu.tree_structure([0, 0, 0, 0, 0]) + + def test_tree_check(): x = [] y = [] @@ -226,3 +234,16 @@ def _transform(self, x): a = SubComponent() with pytest.raises(ValueError): eqx.tree_check(a) + + +def test_tree_check_none(): + eqx.tree_check([None, None]) + + +def test_tree_check_integer(): + eqx.tree_check([0, 0]) + + +def test_tree_check_module(): + a = eqx.Module() # same `id(...)` for both entries passed to `tree_check`. + eqx.tree_check([a, a])