Skip to content

Commit

Permalink
Fixed tree_check raising false positives
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 7, 2023
1 parent 617e043 commit 339611d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 31 deletions.
95 changes: 64 additions & 31 deletions equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,42 @@ def is_leaf(node):


def tree_check(pytree: Any) -> None:
"""Checks if the PyTree is well-formed: does it have no repeated nodes, and does it
have no self-references.
"""Checks if the PyTree is well-formed: does it have no self-references, and does
it have no duplicate layers.
Precisely, a "duplicate layer" is any PyTree node with at least one child node.
!!! info
This is automatically called when creating an `eqx.Module` instance, to help
avoid bugs from duplicating layers.
!!! Example
```python
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is nontrivial!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is trivial
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is trivial
eqx.tree_check([None, None]) # passes, duplicate is trivial
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is nontrivial!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
```
**Arguments:**
Expand All @@ -358,38 +392,37 @@ def tree_check(pytree: Any) -> None:
_tree_check(pytree, all_nodes)


_trivial_treedef = jtu.tree_structure(0)
_leaf_treedef = jtu.tree_structure(0)


def _tree_check(node, all_nodes):
try:
self_referential, type_string = all_nodes[id(node)]
except KeyError:
pass
else:
if self_referential:
raise ValueError(
f"PyTree node of type `{type_string}` is self-referential; that is to "
"say it appears somewhere within its own PyTree structure. This is "
"not allowed."
)
else:
raise ValueError(
f"PyTree node of type `{type_string}` appears in the PyTree multiple "
"times. This is almost always an error, as these nodes will turn into "
"two duplicate copies after flattening/unflattening, e.g. when "
"crossing a JIT boundary."
)
try:
type_string = type(node).__name__
except AttributeError:
# AttributeError: in case we cannot get __name__ for some weird reason.
type_string = "<unknown type>"
all_nodes[id(node)] = (True, type_string)
subnodes, treedef = tree_flatten_one_level(node)
if treedef != _trivial_treedef:
# This does mean that leaves can appear multiple times. This is valid, e.g.
# [4, 4].
# We allow duplicate leaves and empty containers, so don't raise an error with those
if treedef != _leaf_treedef and treedef.num_leaves > 0:
try:
self_referential, type_string = all_nodes[id(node)]
except KeyError:
pass
else:
if self_referential:
raise ValueError(
f"PyTree node of type `{type_string}` is self-referential; that is "
"to say it appears somewhere within its own PyTree structure. This "
"is not allowed."
)
else:
raise ValueError(
f"PyTree node of type `{type_string}` appears in the PyTree "
"multiple times. This is almost always an error, as these nodes "
"will turn into two duplicate copies after "
"flattening/unflattening, e.g. when crossing a JIT boundary."
)
try:
type_string = type(node).__name__
except AttributeError:
# AttributeError: in case we cannot get __name__ for some weird reason.
type_string = "<unknown type>"
all_nodes[id(node)] = (True, type_string)
for subnode in subnodes:
_tree_check(subnode, all_nodes)
all_nodes[id(node)] = (False, type_string)
all_nodes[id(node)] = (False, type_string)
21 changes: 21 additions & 0 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def test_tree_flatten_one_level():
eqx.tree_flatten_one_level(x)


# This matches the behaviour of `jax._src.tree_util.flatten_one_level`
def test_tree_flatten_one_level_special():
x = [None, None, eqx.Module(), 1, 2]
leaves, treedef = eqx.tree_flatten_one_level(x)
assert leaves == [None, None, eqx.Module(), 1, 2]
assert treedef == jtu.tree_structure([0, 0, 0, 0, 0])


def test_tree_check():
x = []
y = []
Expand Down Expand Up @@ -226,3 +234,16 @@ def _transform(self, x):
a = SubComponent()
with pytest.raises(ValueError):
eqx.tree_check(a)


def test_tree_check_none():
eqx.tree_check([None, None])


def test_tree_check_integer():
eqx.tree_check([0, 0])


def test_tree_check_module():
a = eqx.Module() # same `id(...)` for both entries passed to `tree_check`.
eqx.tree_check([a, a])

0 comments on commit 339611d

Please sign in to comment.