From cb3b46add1b95067c191ad5033afd1f6c50eeb81 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 6 Dec 2024 15:18:36 +0800 Subject: [PATCH] feat(treespec): add method `PyTreeSpec.transform` --- include/optree/treespec.h | 7 ++ optree/_C.pyi | 13 ++-- src/optree.cpp | 4 + src/treespec/constructor.cpp | 28 +++---- src/treespec/treespec.cpp | 140 +++++++++++++++++++++++++++++++++++ 5 files changed, 172 insertions(+), 20 deletions(-) diff --git a/include/optree/treespec.h b/include/optree/treespec.h index ef10adbc..19c6aae6 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -128,6 +128,13 @@ class PyTreeSpec { [[nodiscard]] std::unique_ptr BroadcastToCommonSuffix( const PyTreeSpec &other) const; + // Transform a PyTreeSpec by applying a function to each node. + // The function input can be an one-level PyTreeSpec for internal nodes or a leaf PyTreeSpec for + // leaf nodes. The function output must be a PyTreeSpec instance. If the input is an one-level + // PyTreeSpec, the output must be an one-level PyTreeSpec with the same arity. If the input is a + // leaf PyTreeSpec, no restriction is imposed on the output PyTreeSpec. + [[nodiscard]] std::unique_ptr Transform(const py::function &func) const; + // Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`. [[nodiscard]] std::unique_ptr Compose(const PyTreeSpec &inner_treespec) const; diff --git a/optree/_C.pyi b/optree/_C.pyi index cf8ce63a..845d20a7 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -110,8 +110,9 @@ class PyTreeSpec: kind: PyTreeKind def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ... def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ... - def broadcast_to_common_suffix(self, other: PyTreeSpec) -> PyTreeSpec: ... - def compose(self, inner_treespec: PyTreeSpec) -> PyTreeSpec: ... + def broadcast_to_common_suffix(self, other: Self) -> Self: ... + def transform(self, func: Callable[[Self], Self]) -> Self: ... + def compose(self, inner_treespec: Self) -> Self: ... def walk( self, f_node: Callable[[tuple[U, ...], MetaData], U], @@ -122,11 +123,11 @@ class PyTreeSpec: def accessors(self) -> list[PyTreeAccessor]: ... def entries(self) -> list[Any]: ... def entry(self, index: int) -> Any: ... - def children(self) -> list[PyTreeSpec]: ... - def child(self, index: int) -> PyTreeSpec: ... + def children(self) -> list[Self]: ... + def child(self, index: int) -> Self: ... def is_leaf(self, strict: bool = True) -> bool: ... - def is_prefix(self, other: PyTreeSpec, strict: bool = False) -> bool: ... - def is_suffix(self, other: PyTreeSpec, strict: bool = False) -> bool: ... + def is_prefix(self, other: Self, strict: bool = False) -> bool: ... + def is_suffix(self, other: Self, strict: bool = False) -> bool: ... def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... def __lt__(self, other: object) -> bool: ... diff --git a/src/optree.cpp b/src/optree.cpp index 27d3c78d..4ce60b2a 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -200,6 +200,10 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeSpec::BroadcastToCommonSuffix, "Broadcast to the common suffix of this treespec and other treespec.", py::arg("other")) + .def("transform", + &PyTreeSpec::Transform, + "Transform the pytree structure by applying a function to each node.", + py::arg("func")) .def("compose", &PyTreeSpec::Compose, "Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.", diff --git a/src/treespec/constructor.cpp b/src/treespec/constructor.cpp index 903c417a..a023e257 100644 --- a/src/treespec/constructor.cpp +++ b/src/treespec/constructor.cpp @@ -72,9 +72,9 @@ template Node node; node.kind = PyTreeTypeRegistry::GetKind(handle, node.custom, registry_namespace); - const auto verify_children = [&handle, &node](const std::vector& children, - std::vector& treespecs, - std::string& register_namespace) -> void { + const auto verify_children = [&handle, &node, ®istry_namespace]( + const std::vector& children, + std::vector& treespecs) -> void { for (const py::object& child : children) { if (!py::isinstance(child)) [[unlikely]] { std::ostringstream oss{}; @@ -106,16 +106,16 @@ template } } if (!common_registry_namespace.empty()) [[likely]] { - if (register_namespace.empty()) [[likely]] { - register_namespace = common_registry_namespace; - } else if (register_namespace != common_registry_namespace) [[unlikely]] { + if (registry_namespace.empty()) [[likely]] { + registry_namespace = common_registry_namespace; + } else if (registry_namespace != common_registry_namespace) [[unlikely]] { std::ostringstream oss{}; - oss << "Expected treespec(s) with namespace " << PyRepr(register_namespace) + oss << "Expected treespec(s) with namespace " << PyRepr(registry_namespace) << ", got " << PyRepr(common_registry_namespace) << "."; throw py::value_error(oss.str()); } } else if (node.kind != PyTreeKind::Custom) [[likely]] { - register_namespace = ""; + registry_namespace = ""; } }; @@ -143,7 +143,7 @@ template for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(TupleGetItem(handle, i)); } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -155,7 +155,7 @@ template children.emplace_back(ListGetItem(handle, i)); } } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -178,7 +178,7 @@ template children.emplace_back(DictGetItem(dict, key)); } } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { const scoped_critical_section cs{handle}; node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), @@ -197,7 +197,7 @@ template for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(TupleGetItem(tuple, i)); } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -209,7 +209,7 @@ template for (ssize_t i = 0; i < node.arity; ++i) { children.emplace_back(ListGetItem(list, i)); } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); break; } @@ -235,7 +235,7 @@ template children.emplace_back(py::reinterpret_borrow(child)); } } - verify_children(children, treespecs, registry_namespace); + verify_children(children, treespecs); if (num_out == 3) [[likely]] { const py::object node_entries = TupleGetItem(out, 2); if (!node_entries.is_none()) [[likely]] { diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index d3313d85..89b726e6 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -20,6 +20,7 @@ limitations under the License. #include // std::unique_ptr, std::make_unique #include // std::optional #include // std::ostringstream +#include // std::string #include // std::tuple #include // std::move #include // std::vector @@ -448,6 +449,145 @@ std::unique_ptr PyTreeSpec::BroadcastToCommonSuffix(const PyTreeSpec return treespec; } +std::unique_ptr PyTreeSpec::Transform(const py::function& func) const { + PYTREESPEC_SANITY_CHECK(*this); + + const auto create_nodespec = + [](const Node& node, + const bool& none_is_leaf, + const std::string& registry_namespace) -> std::unique_ptr { + auto nodespec = std::make_unique(); + for (ssize_t i = 0; i < node.arity; ++i) { + nodespec->m_traversal.emplace_back(Node{ + .kind = PyTreeKind::Leaf, + .arity = 0, + .num_leaves = 1, + .num_nodes = 1, + }); + } + auto& root = nodespec->m_traversal.emplace_back(node); + root.num_leaves = (node.kind == PyTreeKind::Leaf ? 1 : node.arity); + root.num_nodes = node.arity + 1; + nodespec->m_none_is_leaf = none_is_leaf; + nodespec->m_namespace = registry_namespace; + nodespec->m_traversal.shrink_to_fit(); + PYTREESPEC_SANITY_CHECK(*nodespec); + return nodespec; + }; + + const auto one_level_structure_string = + [&create_nodespec](const Node& node, + const bool& none_is_leaf, + const std::string& registry_namespace) -> std::string { + return create_nodespec(node, none_is_leaf, registry_namespace)->ToString(); + }; + + auto treespec = std::make_unique(); + std::string common_registry_namespace = m_namespace; + ssize_t num_extra_leaves = 0; + ssize_t num_extra_nodes = 0; + auto pending_num_leaves_nodes = reserved_vector>(4); + for (const Node& node : m_traversal) { + auto nodespec = create_nodespec(node, m_none_is_leaf, m_namespace); + const py::object& out = EVALUATE_WITH_LOCK_HELD(func(std::move(nodespec)), func); + + if (!py::isinstance(out)) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns a PyTreeSpec, got " + << PyRepr(out) + << " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace) + << ")."; + } + auto& transformed = py::cast(out); + + if (transformed.m_none_is_leaf != m_none_is_leaf) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns " + "a PyTreeSpec with the same value of " + << (m_none_is_leaf ? "`node_is_leaf=True`" : "`node_is_leaf=False`") + << " as the input, got " << transformed.ToString() + << " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace) + << ")."; + throw py::value_error(oss.str()); + } + if (!transformed.m_namespace.empty()) [[unlikely]] { + if (common_registry_namespace.empty()) [[likely]] { + common_registry_namespace = transformed.m_namespace; + } else if (transformed.m_namespace != common_registry_namespace) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns " + "a PyTreeSpec with namespace " + << PyRepr(common_registry_namespace) << ", got " + << PyRepr(transformed.m_namespace) << "."; + throw py::value_error(oss.str()); + } + } + if (node.kind != PyTreeKind::Leaf) [[likely]] { + if (transformed.GetNumLeaves() != node.arity) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns " + "a PyTreeSpec with the same number of arity as the input (" + << node.arity << "), got " << transformed.ToString() + << " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace) + << ")."; + throw py::value_error(oss.str()); + } + if (transformed.GetNumNodes() != node.arity + 1) [[unlikely]] { + std::ostringstream oss{}; + oss << "Expected the PyTreeSpec transform function returns an one-level PyTreeSpec " + "as the input, got " + << transformed.ToString() + << " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace) + << ")."; + throw py::value_error(oss.str()); + } + auto& subroot = treespec->m_traversal.emplace_back(transformed.m_traversal.back()); + EXPECT_GE(py::ssize_t_cast(pending_num_leaves_nodes.size()), + node.arity, + "PyTreeSpec::Transform() walked off start of array."); + subroot.num_leaves = 0; + subroot.num_nodes = 1; + for (ssize_t i = 0; i < node.arity; ++i) { + const auto& [num_leaves, num_nodes] = pending_num_leaves_nodes.back(); + pending_num_leaves_nodes.pop_back(); + subroot.num_leaves += num_leaves; + subroot.num_nodes += num_nodes; + } + pending_num_leaves_nodes.emplace_back(subroot.num_leaves, subroot.num_nodes); + } else [[unlikely]] { + std::copy(transformed.m_traversal.cbegin(), + transformed.m_traversal.cend(), + std::back_inserter(treespec->m_traversal)); + num_extra_leaves += transformed.GetNumLeaves() - 1; + num_extra_nodes += transformed.GetNumNodes() - 1; + pending_num_leaves_nodes.emplace_back(transformed.GetNumLeaves(), + transformed.GetNumNodes()); + } + } + EXPECT_EQ(pending_num_leaves_nodes.size(), + 1, + "PyTreeSpec::Transform() did not yield a singleton."); + + const auto& root = treespec->m_traversal.back(); + EXPECT_EQ(root.num_leaves, + GetNumLeaves() + num_extra_leaves, + "Number of transformed tree leaves mismatch."); + EXPECT_EQ(root.num_nodes, + GetNumNodes() + num_extra_nodes, + "Number of transformed tree nodes mismatch."); + EXPECT_EQ(root.num_leaves, + treespec->GetNumLeaves(), + "Number of transformed tree leaves mismatch."); + EXPECT_EQ(root.num_nodes, + treespec->GetNumNodes(), + "Number of transformed tree nodes mismatch."); + treespec->m_none_is_leaf = m_none_is_leaf; + treespec->m_namespace = m_namespace; + treespec->m_traversal.shrink_to_fit(); + PYTREESPEC_SANITY_CHECK(*treespec); + return treespec; +} + std::unique_ptr PyTreeSpec::Compose(const PyTreeSpec& inner_treespec) const { PYTREESPEC_SANITY_CHECK(*this); PYTREESPEC_SANITY_CHECK(inner_treespec);