Skip to content

Commit

Permalink
feat(treespec): add method PyTreeSpec.one_level
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 14, 2024
1 parent 111ffa1 commit 5f626f9
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 47 deletions.
9 changes: 9 additions & 0 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ class PyTreeSpec {
// Return the child at the given index of the PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Child(ssize_t index) const;

// Return the one-level structure of the PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> GetOneLevel(
const std::optional<Node> &node = std::nullopt) const;

[[nodiscard]] inline Py_ALWAYS_INLINE ssize_t GetNumLeaves() const {
PYTREESPEC_SANITY_CHECK(*this);
return m_traversal.back().num_leaves;
Expand Down Expand Up @@ -195,6 +199,11 @@ class PyTreeSpec {
return GetNumNodes() == 1;
}

// Test whether this PyTreeSpec represents a one-level tree.
[[nodiscard]] inline Py_ALWAYS_INLINE bool IsOneLevel() const {
return GetNumNodes() == GetNumChildren() + 1 && GetNumLeaves() == GetNumChildren();
}

// Return true if this PyTreeSpec is a prefix of `other`.
[[nodiscard]] bool IsPrefix(const PyTreeSpec &other, const bool &strict = false) const;

Expand Down
2 changes: 2 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ class PyTreeSpec:
def entry(self, index: int) -> Any: ...
def children(self) -> list[Self]: ...
def child(self, index: int) -> Self: ...
def one_level(self) -> Self | None: ...
def is_leaf(self, *, strict: bool = True) -> bool: ...
def is_one_level(self) -> bool: ...
def is_prefix(self, other: Self, *, strict: bool = False) -> bool: ...
def is_suffix(self, other: Self, *, strict: bool = False) -> bool: ...
def __eq__(self, other: object) -> bool: ...
Expand Down
46 changes: 29 additions & 17 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,31 +903,31 @@ def tree_map_with_accessor(
{'x': "tree['x'] = 7", 'y': ("tree['y'][0] = 42", "tree['y'][1] = 64")}
>>> tree_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None})
{'x': 8, 'y': (44, 66), 'z': None}
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS
... lambda a, x: a,
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}},
... )
{
'x': PyTreeAccessor(*['x'], (MappingEntry(key='x', type=<class 'dict'>),)),
'x': PyTreeAccessor(*['x'], ...),
'y': (
PyTreeAccessor(*['y'][0], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['y'][1], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
PyTreeAccessor(*['y'][0], ...),
PyTreeAccessor(*['y'][1], ...)
),
'z': {1.5: None}
}
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS
... lambda a, x: a,
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}},
... none_is_leaf=True,
... )
{
'x': PyTreeAccessor(*['x'], (MappingEntry(key='x', type=<class 'dict'>),)),
'x': PyTreeAccessor(*['x'], ...),
'y': (
PyTreeAccessor(*['y'][0], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['y'][1], (MappingEntry(key='y', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
PyTreeAccessor(*['y'][0], ...),
PyTreeAccessor(*['y'][1], ...)
),
'z': {
1.5: PyTreeAccessor(*['z'][1.5], (MappingEntry(key='z', type=<class 'dict'>), MappingEntry(key=1.5, type=<class 'dict'>)))
1.5: PyTreeAccessor(*['z'][1.5], ...)
}
}
Expand All @@ -954,7 +954,7 @@ def tree_map_with_accessor(
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
``func(a, x, *xs)`` where ``(a, x)`` are the accessor and value at the corresponding leaf in
``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``.
""" # pylint: disable=line-too-long
"""
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, treespec.accessors(), *flat_args))
Expand Down Expand Up @@ -1322,7 +1322,7 @@ def tree_transpose_map_with_accessor(
'c': (5, 6)
}
}
>>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE
>>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS
... lambda a, x: {'path': a.path, 'accessor': a, 'value': x},
... tree,
... inner_treespec=tree_structure({'path': 0, 'accessor': 0, 'value': 0}),
Expand All @@ -1335,16 +1335,16 @@ def tree_transpose_map_with_accessor(
},
'accessor': {
'b': (
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['b'][0], ...),
[
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))),
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>)))
PyTreeAccessor(*['b'][1][0], ...),
PyTreeAccessor(*['b'][1][1], ...)
]
),
'a': PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)),
'a': PyTreeAccessor(*['a'], ...),
'c': (
PyTreeAccessor(*['c'][0], (MappingEntry(key='c', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['c'][1], (MappingEntry(key='c', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
PyTreeAccessor(*['c'][0], ...),
PyTreeAccessor(*['c'][1], ...)
)
},
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}
Expand Down Expand Up @@ -2549,6 +2549,18 @@ def treespec_accessors(treespec: PyTreeSpec) -> list[PyTreeAccessor]:
See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors` and
:meth:`PyTreeSpec.accessors`.
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_accessors(treespec) # doctest: +IGNORE_WHITESPACE,ELLIPSIS
[
PyTreeAccessor(*['a'][0], ...),
PyTreeAccessor(*['a'][1][0], ...),
PyTreeAccessor(*['a'][1][1], ...),
PyTreeAccessor(*['b'], ...),
PyTreeAccessor(*['c'][0], ...)
]
"""
return treespec.accessors()

Expand Down
21 changes: 17 additions & 4 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
&PyTreeSpec::Child,
"Return the treespec for the child at the given index.",
py::arg("index"))
.def(
"one_level",
[](const PyTreeSpec& t) -> std::optional<std::unique_ptr<PyTreeSpec>> {
if (t.IsLeaf()) [[unlikely]] {
return std::nullopt;
}
return t.GetOneLevel();
},
"Return the one-level structure of the root node. Return None if the root node "
"represents a leaf.")
.def_property_readonly("num_leaves",
&PyTreeSpec::GetNumLeaves,
"Number of leaves in the tree.")
Expand All @@ -237,7 +247,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
"Note that a leaf is also a node but has no children.")
.def_property_readonly("num_children",
&PyTreeSpec::GetNumChildren,
"Number of children in the current node. "
"Number of children of the root node. "
"Note that a leaf is also a node but has no children.")
.def_property_readonly(
"none_is_leaf",
Expand All @@ -252,13 +262,16 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
.def_property_readonly(
"type",
[](const PyTreeSpec& t) -> py::object { return t.GetType(); },
"The type of the current node. Return None if the current node is a leaf.")
.def_property_readonly("kind", &PyTreeSpec::GetPyTreeKind, "The kind of the current node.")
"The type of the root node. Return None if the root node is a leaf.")
.def_property_readonly("kind", &PyTreeSpec::GetPyTreeKind, "The kind of the root node.")
.def("is_leaf",
&PyTreeSpec::IsLeaf,
"Test whether the current node is a leaf.",
"Test whether the treespec represents a leaf.",
py::kw_only(),
py::arg("strict") = true)
.def("is_one_level",
&PyTreeSpec::IsOneLevel,
"Test whether the treespec represents a one-level tree.")
.def("is_prefix",
&PyTreeSpec::IsPrefix,
"Test whether this treespec is a prefix of the given treespec.",
Expand Down
57 changes: 31 additions & 26 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,29 +458,9 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Transform(const std::optional<py::functi
return std::make_unique<PyTreeSpec>(*this);
}

const auto create_nodespec = [this](const Node& node) -> std::unique_ptr<PyTreeSpec> {
auto nodespec = std::make_unique<PyTreeSpec>();
for (ssize_t i = 0; i < node.arity; ++i) {
nodespec->m_traversal.emplace_back(Node{
.kind = PyTreeKind::Leaf,
.arity = 0,
.num_leaves = 1,
.num_nodes = 1,
});
}
auto& root = nodespec->m_traversal.emplace_back(node);
root.num_leaves = (node.kind == PyTreeKind::Leaf ? 1 : node.arity);
root.num_nodes = node.arity + 1;
nodespec->m_none_is_leaf = m_none_is_leaf;
nodespec->m_namespace = m_namespace;
nodespec->m_traversal.shrink_to_fit();
PYTREESPEC_SANITY_CHECK(*nodespec);
return nodespec;
};

const auto transform =
[&create_nodespec, &f_node, &f_leaf](const Node& node) -> std::unique_ptr<PyTreeSpec> {
auto nodespec = create_nodespec(node);
[this, &f_node, &f_leaf](const Node& node) -> std::unique_ptr<PyTreeSpec> {
auto nodespec = GetOneLevel(node);

const auto& func = (node.kind == PyTreeKind::Leaf ? f_leaf : f_node);
if (!func) [[likely]] {
Expand All @@ -491,7 +471,7 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Transform(const std::optional<py::functi
if (!py::isinstance<PyTreeSpec>(out)) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns a PyTreeSpec, got "
<< PyRepr(out) << " (input: " << create_nodespec(node)->ToString() << ").";
<< PyRepr(out) << " (input: " << GetOneLevel(node)->ToString() << ").";
throw py::type_error(oss.str());
}
return std::make_unique<PyTreeSpec>(thread_safe_cast<PyTreeSpec&>(out));
Expand All @@ -510,7 +490,7 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Transform(const std::optional<py::functi
"a PyTreeSpec with the same value of "
<< (m_none_is_leaf ? "`none_is_leaf=True`" : "`none_is_leaf=False`")
<< " as the input, got " << transformed->ToString()
<< " (input: " << create_nodespec(node)->ToString() << ").";
<< " (input: " << GetOneLevel(node)->ToString() << ").";
throw py::value_error(oss.str());
}
if (!transformed->m_namespace.empty()) [[unlikely]] {
Expand All @@ -531,14 +511,14 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Transform(const std::optional<py::functi
oss << "Expected the PyTreeSpec transform function returns "
"a PyTreeSpec with the same number of arity as the input ("
<< node.arity << "), got " << transformed->ToString()
<< " (input: " << create_nodespec(node)->ToString() << ").";
<< " (input: " << GetOneLevel(node)->ToString() << ").";
throw py::value_error(oss.str());
}
if (transformed->GetNumNodes() != node.arity + 1) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns an one-level PyTreeSpec "
"as the input, got "
<< transformed->ToString() << " (input: " << create_nodespec(node)->ToString()
<< transformed->ToString() << " (input: " << GetOneLevel(node)->ToString()
<< ").";
throw py::value_error(oss.str());
}
Expand Down Expand Up @@ -1012,4 +992,29 @@ py::object PyTreeSpec::GetType(const std::optional<Node>& node) const {
}
}

std::unique_ptr<PyTreeSpec> PyTreeSpec::GetOneLevel(const std::optional<Node>& node) const {
if (!node) [[likely]] {
PYTREESPEC_SANITY_CHECK(*this);
}

const Node& n = node.value_or(m_traversal.back());
auto out = std::make_unique<PyTreeSpec>();
for (ssize_t i = 0; i < n.arity; ++i) {
out->m_traversal.emplace_back(Node{
.kind = PyTreeKind::Leaf,
.arity = 0,
.num_leaves = 1,
.num_nodes = 1,
});
}
auto& root = out->m_traversal.emplace_back(n);
root.num_leaves = (n.kind == PyTreeKind::Leaf ? 1 : n.arity);
root.num_nodes = n.arity + 1;
out->m_none_is_leaf = m_none_is_leaf;
out->m_namespace = m_namespace;
out->m_traversal.shrink_to_fit();
PYTREESPEC_SANITY_CHECK(*out);
return out;
}

} // namespace optree

0 comments on commit 5f626f9

Please sign in to comment.