Skip to content

Commit

Permalink
feat(treespec): add method PyTreeSpec.one_level (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 15, 2024
1 parent 111ffa1 commit 92428dd
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add method `PyTreeSpec.one_level` and `PyTreeSpec.is_one_level` by [@XuehaiPan](https://github.com/XuehaiPan) in [#179](https://github.com/metaopt/optree/pull/179).
- Add method `PyTreeSpec.transform` by [@XuehaiPan](https://github.com/XuehaiPan) in [#177](https://github.com/metaopt/optree/pull/177).

### Changed
Expand Down
2 changes: 1 addition & 1 deletion conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies:
- pybind11 >= 2.13.1

# Benchmark
- pytorch::pytorch >= 2.0, < 2.4.0a0
- pytorch::pytorch >= 2.0, < 2.6.0a0
- pytorch::torchvision
- pytorch::pytorch-mutex = *=*cpu*
- conda-forge::jax >= 0.4.6, < 0.5.0a0
Expand Down
4 changes: 4 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ PyTreeSpec Functions
treespec_entry
treespec_children
treespec_child
treespec_one_level
treespec_transform
treespec_is_leaf
treespec_is_strict_leaf
treespec_is_one_level
treespec_is_prefix
treespec_is_suffix
treespec_leaf
Expand All @@ -146,9 +148,11 @@ PyTreeSpec Functions
.. autofunction:: treespec_entry
.. autofunction:: treespec_children
.. autofunction:: treespec_child
.. autofunction:: treespec_one_level
.. autofunction:: treespec_transform
.. autofunction:: treespec_is_leaf
.. autofunction:: treespec_is_strict_leaf
.. autofunction:: treespec_is_one_level
.. autofunction:: treespec_is_prefix
.. autofunction:: treespec_is_suffix
.. autofunction:: treespec_leaf
Expand Down
9 changes: 9 additions & 0 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ class PyTreeSpec {
// Return the child at the given index of the PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Child(ssize_t index) const;

// Return the one-level structure of the PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> GetOneLevel(
const std::optional<Node> &node = std::nullopt) const;

[[nodiscard]] inline Py_ALWAYS_INLINE ssize_t GetNumLeaves() const {
PYTREESPEC_SANITY_CHECK(*this);
return m_traversal.back().num_leaves;
Expand Down Expand Up @@ -195,6 +199,11 @@ class PyTreeSpec {
return GetNumNodes() == 1;
}

// Test whether this PyTreeSpec represents a one-level tree.
[[nodiscard]] inline Py_ALWAYS_INLINE bool IsOneLevel() const {
return GetNumNodes() == GetNumChildren() + 1 && GetNumLeaves() == GetNumChildren();
}

// Return true if this PyTreeSpec is a prefix of `other`.
[[nodiscard]] bool IsPrefix(const PyTreeSpec &other, const bool &strict = false) const;

Expand Down
2 changes: 2 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ class PyTreeSpec:
def entry(self, index: int) -> Any: ...
def children(self) -> list[Self]: ...
def child(self, index: int) -> Self: ...
def one_level(self) -> Self | None: ...
def is_leaf(self, *, strict: bool = True) -> bool: ...
def is_one_level(self) -> bool: ...
def is_prefix(self, other: Self, *, strict: bool = False) -> bool: ...
def is_suffix(self, other: Self, *, strict: bool = False) -> bool: ...
def __eq__(self, other: object) -> bool: ...
Expand Down
4 changes: 4 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@
treespec_entry,
treespec_from_collection,
treespec_is_leaf,
treespec_is_one_level,
treespec_is_prefix,
treespec_is_strict_leaf,
treespec_is_suffix,
treespec_leaf,
treespec_list,
treespec_namedtuple,
treespec_none,
treespec_one_level,
treespec_ordereddict,
treespec_paths,
treespec_structseq,
Expand Down Expand Up @@ -171,9 +173,11 @@
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_one_level',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
'treespec_is_one_level',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_leaf',
Expand Down
112 changes: 92 additions & 20 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_one_level',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
'treespec_is_one_level',
'treespec_is_prefix',
'treespec_is_suffix',
'treespec_leaf',
Expand Down Expand Up @@ -903,31 +905,31 @@ def tree_map_with_accessor(
{'x': "tree['x'] = 7", 'y': ("tree['y'][0] = 42", "tree['y'][1] = 64")}
>>> tree_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None})
{'x': 8, 'y': (44, 66), 'z': None}
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS
... lambda a, x: a,
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}},
... )
{
'x': PyTreeAccessor(*['x'], (MappingEntry(key='x', type=<class 'dict'>),)),
'x': PyTreeAccessor(*['x'], ...),
'y': (
PyTreeAccessor(*['y'][0], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['y'][1], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
PyTreeAccessor(*['y'][0], ...),
PyTreeAccessor(*['y'][1], ...)
),
'z': {1.5: None}
}
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS
... lambda a, x: a,
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}},
... none_is_leaf=True,
... )
{
'x': PyTreeAccessor(*['x'], (MappingEntry(key='x', type=<class 'dict'>),)),
'x': PyTreeAccessor(*['x'], ...),
'y': (
PyTreeAccessor(*['y'][0], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['y'][1], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
PyTreeAccessor(*['y'][0], ...),
PyTreeAccessor(*['y'][1], ...)
),
'z': {
1.5: PyTreeAccessor(*['z'][1.5], (MappingEntry(key='z', type=<class 'dict'>), MappingEntry(key=1.5, type=<class 'dict'>)))
1.5: PyTreeAccessor(*['z'][1.5], ...)
}
}
Expand All @@ -954,7 +956,7 @@ def tree_map_with_accessor(
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
``func(a, x, *xs)`` where ``(a, x)`` are the accessor and value at the corresponding leaf in
``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``.
""" # pylint: disable=line-too-long
"""
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, treespec.accessors(), *flat_args))
Expand Down Expand Up @@ -1322,7 +1324,7 @@ def tree_transpose_map_with_accessor(
'c': (5, 6)
}
}
>>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE
>>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS
... lambda a, x: {'path': a.path, 'accessor': a, 'value': x},
... tree,
... inner_treespec=tree_structure({'path': 0, 'accessor': 0, 'value': 0}),
Expand All @@ -1335,16 +1337,16 @@ def tree_transpose_map_with_accessor(
},
'accessor': {
'b': (
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['b'][0], ...),
[
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))),
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>)))
PyTreeAccessor(*['b'][1][0], ...),
PyTreeAccessor(*['b'][1][1], ...)
]
),
'a': PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)),
'a': PyTreeAccessor(*['a'], ...),
'c': (
PyTreeAccessor(*['c'][0], (MappingEntry(key='c', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['c'][1], (MappingEntry(key='c', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
PyTreeAccessor(*['c'][0], ...),
PyTreeAccessor(*['c'][1], ...)
)
},
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
Expand Down Expand Up @@ -2540,15 +2542,37 @@ def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]:
"""Return a list of paths to the leaves of a treespec.
See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_paths(treespec)
[('a', 0), ('a', 1, 0), ('a', 1, 1), ('b',), ('c', 0)]
"""
return treespec.paths()


def treespec_accessors(treespec: PyTreeSpec) -> list[PyTreeAccessor]:
"""Return a list of accessors to the leaves of a treespec.
See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors` and
:meth:`PyTreeSpec.accessors`.
See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors`,
and :meth:`PyTreeSpec.accessors`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_accessors(treespec) # doctest: +IGNORE_WHITESPACE,ELLIPSIS
[
PyTreeAccessor(*['a'][0], ...),
PyTreeAccessor(*['a'][1][0], ...),
PyTreeAccessor(*['a'][1][1], ...),
PyTreeAccessor(*['b'], ...),
PyTreeAccessor(*['c'][0], ...)
]
>>> treespec_accessors(treespec_leaf())
[PyTreeAccessor(*, ())]
>>> treespec_accessors(treespec_none())
[]
"""
return treespec.accessors()

Expand All @@ -2558,6 +2582,12 @@ def treespec_entries(treespec: PyTreeSpec) -> list[Any]:
See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`,
and :meth:`PyTreeSpec.entries`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_entries(treespec)
['a', 'b', 'c']
"""
return treespec.entries()

Expand All @@ -2574,7 +2604,13 @@ def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]:
"""Return a list of treespecs for the children of a treespec.
See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`,
and :meth:`PyTreeSpec.children`.
:func:`treespec_one_level`, and :meth:`PyTreeSpec.children`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_children(treespec)
[PyTreeSpec((*, [*, *])), PyTreeSpec(*), PyTreeSpec((*, None))]
"""
return treespec.children()

Expand All @@ -2587,6 +2623,20 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec:
return treespec.child(index)


def treespec_one_level(treespec: PyTreeSpec) -> PyTreeSpec | None:
"""Return the one-level tree structure of the treespec or :data:`None` if the treespec is a leaf.
See also :func:`treespec_children`, :func:`treespec_is_one_level`, and :meth:`PyTreeSpec.one_level`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_one_level(treespec)
PyTreeSpec({'a': *, 'b': *, 'c': *})
"""
return treespec.one_level()


def treespec_transform(
treespec: PyTreeSpec,
f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None,
Expand Down Expand Up @@ -2713,6 +2763,28 @@ def treespec_is_strict_leaf(treespec: PyTreeSpec) -> bool:
return treespec.num_nodes == 1 and treespec.num_leaves == 1


def treespec_is_one_level(treespec: PyTreeSpec) -> bool:
"""Return whether the treespec is a one-level tree structure.
See also :func:`treespec_is_leaf`, :func:`treespec_one_level`, and :meth:`PyTreeSpec.is_one_level`.
>>> treespec_is_one_level(tree_structure(1))
False
>>> treespec_is_one_level(tree_structure((1, 2)))
True
>>> treespec_is_one_level(tree_structure({'a': 1, 'b': 2, 'c': 3}))
True
>>> treespec_is_one_level(tree_structure({'a': 1, 'b': (2, 3), 'c': 4}))
False
>>> treespec_is_one_level(tree_structure(None))
True
"""
return (
treespec.num_nodes == treespec.num_children + 1
and treespec.num_leaves == treespec.num_children
)


def treespec_is_prefix(
treespec: PyTreeSpec,
other_treespec: PyTreeSpec,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ docs = [
]
benchmark = [
"jax[cpu] >= 0.4.6, < 0.5.0a0",
"torch >= 2.0, < 2.4.0a0",
"torch >= 2.0, < 2.6.0a0",
"torchvision",
"dm-tree >= 0.1, < 0.2.0a0",
"pandas",
Expand Down
21 changes: 17 additions & 4 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
&PyTreeSpec::Child,
"Return the treespec for the child at the given index.",
py::arg("index"))
.def(
"one_level",
[](const PyTreeSpec& t) -> std::optional<std::unique_ptr<PyTreeSpec>> {
if (t.IsLeaf()) [[unlikely]] {
return std::nullopt;
}
return t.GetOneLevel();
},
"Return the one-level structure of the root node. Return None if the root node "
"represents a leaf.")
.def_property_readonly("num_leaves",
&PyTreeSpec::GetNumLeaves,
"Number of leaves in the tree.")
Expand All @@ -237,7 +247,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
"Note that a leaf is also a node but has no children.")
.def_property_readonly("num_children",
&PyTreeSpec::GetNumChildren,
"Number of children in the current node. "
"Number of children of the root node. "
"Note that a leaf is also a node but has no children.")
.def_property_readonly(
"none_is_leaf",
Expand All @@ -252,13 +262,16 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
.def_property_readonly(
"type",
[](const PyTreeSpec& t) -> py::object { return t.GetType(); },
"The type of the current node. Return None if the current node is a leaf.")
.def_property_readonly("kind", &PyTreeSpec::GetPyTreeKind, "The kind of the current node.")
"The type of the root node. Return None if the root node is a leaf.")
.def_property_readonly("kind", &PyTreeSpec::GetPyTreeKind, "The kind of the root node.")
.def("is_leaf",
&PyTreeSpec::IsLeaf,
"Test whether the current node is a leaf.",
"Test whether the treespec represents a leaf.",
py::kw_only(),
py::arg("strict") = true)
.def("is_one_level",
&PyTreeSpec::IsOneLevel,
"Test whether the treespec represents a one-level tree.")
.def("is_prefix",
&PyTreeSpec::IsPrefix,
"Test whether this treespec is a prefix of the given treespec.",
Expand Down
Loading

0 comments on commit 92428dd

Please sign in to comment.