Skip to content

Commit

Permalink
feat(treespec): add method PyTreeSpec.transform
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 6, 2024
1 parent 997af1d commit cb3b46a
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 20 deletions.
7 changes: 7 additions & 0 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ class PyTreeSpec {
[[nodiscard]] std::unique_ptr<PyTreeSpec> BroadcastToCommonSuffix(
const PyTreeSpec &other) const;

// Transform a PyTreeSpec by applying a function to each node.
// The function input can be an one-level PyTreeSpec for internal nodes or a leaf PyTreeSpec for
// leaf nodes. The function output must be a PyTreeSpec instance. If the input is an one-level
// PyTreeSpec, the output must be an one-level PyTreeSpec with the same arity. If the input is a
// leaf PyTreeSpec, no restriction is imposed on the output PyTreeSpec.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Transform(const py::function &func) const;

// Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Compose(const PyTreeSpec &inner_treespec) const;

Expand Down
13 changes: 7 additions & 6 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ class PyTreeSpec:
kind: PyTreeKind
def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ...
def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ...
def broadcast_to_common_suffix(self, other: PyTreeSpec) -> PyTreeSpec: ...
def compose(self, inner_treespec: PyTreeSpec) -> PyTreeSpec: ...
def broadcast_to_common_suffix(self, other: Self) -> Self: ...
def transform(self, func: Callable[[Self], Self]) -> Self: ...
def compose(self, inner_treespec: Self) -> Self: ...
def walk(
self,
f_node: Callable[[tuple[U, ...], MetaData], U],
Expand All @@ -122,11 +123,11 @@ class PyTreeSpec:
def accessors(self) -> list[PyTreeAccessor]: ...
def entries(self) -> list[Any]: ...
def entry(self, index: int) -> Any: ...
def children(self) -> list[PyTreeSpec]: ...
def child(self, index: int) -> PyTreeSpec: ...
def children(self) -> list[Self]: ...
def child(self, index: int) -> Self: ...
def is_leaf(self, strict: bool = True) -> bool: ...
def is_prefix(self, other: PyTreeSpec, strict: bool = False) -> bool: ...
def is_suffix(self, other: PyTreeSpec, strict: bool = False) -> bool: ...
def is_prefix(self, other: Self, strict: bool = False) -> bool: ...
def is_suffix(self, other: Self, strict: bool = False) -> bool: ...
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __lt__(self, other: object) -> bool: ...
Expand Down
4 changes: 4 additions & 0 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
&PyTreeSpec::BroadcastToCommonSuffix,
"Broadcast to the common suffix of this treespec and other treespec.",
py::arg("other"))
.def("transform",
&PyTreeSpec::Transform,
"Transform the pytree structure by applying a function to each node.",
py::arg("func"))
.def("compose",
&PyTreeSpec::Compose,
"Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.",
Expand Down
28 changes: 14 additions & 14 deletions src/treespec/constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ template <bool NoneIsLeaf>
Node node;
node.kind = PyTreeTypeRegistry::GetKind<NoneIsLeaf>(handle, node.custom, registry_namespace);

const auto verify_children = [&handle, &node](const std::vector<py::object>& children,
std::vector<PyTreeSpec>& treespecs,
std::string& register_namespace) -> void {
const auto verify_children = [&handle, &node, &registry_namespace](
const std::vector<py::object>& children,
std::vector<PyTreeSpec>& treespecs) -> void {
for (const py::object& child : children) {
if (!py::isinstance<PyTreeSpec>(child)) [[unlikely]] {
std::ostringstream oss{};
Expand Down Expand Up @@ -106,16 +106,16 @@ template <bool NoneIsLeaf>
}
}
if (!common_registry_namespace.empty()) [[likely]] {
if (register_namespace.empty()) [[likely]] {
register_namespace = common_registry_namespace;
} else if (register_namespace != common_registry_namespace) [[unlikely]] {
if (registry_namespace.empty()) [[likely]] {
registry_namespace = common_registry_namespace;
} else if (registry_namespace != common_registry_namespace) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected treespec(s) with namespace " << PyRepr(register_namespace)
oss << "Expected treespec(s) with namespace " << PyRepr(registry_namespace)
<< ", got " << PyRepr(common_registry_namespace) << ".";
throw py::value_error(oss.str());
}
} else if (node.kind != PyTreeKind::Custom) [[likely]] {
register_namespace = "";
registry_namespace = "";
}
};

Expand Down Expand Up @@ -143,7 +143,7 @@ template <bool NoneIsLeaf>
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(TupleGetItem(handle, i));
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -155,7 +155,7 @@ template <bool NoneIsLeaf>
children.emplace_back(ListGetItem(handle, i));
}
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -178,7 +178,7 @@ template <bool NoneIsLeaf>
children.emplace_back(DictGetItem(dict, key));
}
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
const scoped_critical_section cs{handle};
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
Expand All @@ -197,7 +197,7 @@ template <bool NoneIsLeaf>
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(TupleGetItem(tuple, i));
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -209,7 +209,7 @@ template <bool NoneIsLeaf>
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(ListGetItem(list, i));
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -235,7 +235,7 @@ template <bool NoneIsLeaf>
children.emplace_back(py::reinterpret_borrow<py::object>(child));
}
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
if (num_out == 3) [[likely]] {
const py::object node_entries = TupleGetItem(out, 2);
if (!node_entries.is_none()) [[likely]] {
Expand Down
140 changes: 140 additions & 0 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <memory> // std::unique_ptr, std::make_unique
#include <optional> // std::optional
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <tuple> // std::tuple
#include <utility> // std::move
#include <vector> // std::vector
Expand Down Expand Up @@ -448,6 +449,145 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::BroadcastToCommonSuffix(const PyTreeSpec
return treespec;
}

std::unique_ptr<PyTreeSpec> PyTreeSpec::Transform(const py::function& func) const {
PYTREESPEC_SANITY_CHECK(*this);

const auto create_nodespec =
[](const Node& node,
const bool& none_is_leaf,
const std::string& registry_namespace) -> std::unique_ptr<PyTreeSpec> {
auto nodespec = std::make_unique<PyTreeSpec>();
for (ssize_t i = 0; i < node.arity; ++i) {
nodespec->m_traversal.emplace_back(Node{
.kind = PyTreeKind::Leaf,
.arity = 0,
.num_leaves = 1,
.num_nodes = 1,
});
}
auto& root = nodespec->m_traversal.emplace_back(node);
root.num_leaves = (node.kind == PyTreeKind::Leaf ? 1 : node.arity);
root.num_nodes = node.arity + 1;
nodespec->m_none_is_leaf = none_is_leaf;
nodespec->m_namespace = registry_namespace;
nodespec->m_traversal.shrink_to_fit();
PYTREESPEC_SANITY_CHECK(*nodespec);
return nodespec;
};

const auto one_level_structure_string =
[&create_nodespec](const Node& node,
const bool& none_is_leaf,
const std::string& registry_namespace) -> std::string {
return create_nodespec(node, none_is_leaf, registry_namespace)->ToString();
};

auto treespec = std::make_unique<PyTreeSpec>();
std::string common_registry_namespace = m_namespace;
ssize_t num_extra_leaves = 0;
ssize_t num_extra_nodes = 0;
auto pending_num_leaves_nodes = reserved_vector<std::pair<ssize_t, ssize_t>>(4);
for (const Node& node : m_traversal) {
auto nodespec = create_nodespec(node, m_none_is_leaf, m_namespace);
const py::object& out = EVALUATE_WITH_LOCK_HELD(func(std::move(nodespec)), func);

if (!py::isinstance<PyTreeSpec>(out)) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns a PyTreeSpec, got "
<< PyRepr(out)
<< " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace)
<< ").";
}
auto& transformed = py::cast<PyTreeSpec&>(out);

if (transformed.m_none_is_leaf != m_none_is_leaf) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns "
"a PyTreeSpec with the same value of "
<< (m_none_is_leaf ? "`node_is_leaf=True`" : "`node_is_leaf=False`")
<< " as the input, got " << transformed.ToString()
<< " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace)
<< ").";
throw py::value_error(oss.str());
}
if (!transformed.m_namespace.empty()) [[unlikely]] {
if (common_registry_namespace.empty()) [[likely]] {
common_registry_namespace = transformed.m_namespace;
} else if (transformed.m_namespace != common_registry_namespace) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns "
"a PyTreeSpec with namespace "
<< PyRepr(common_registry_namespace) << ", got "
<< PyRepr(transformed.m_namespace) << ".";
throw py::value_error(oss.str());
}
}
if (node.kind != PyTreeKind::Leaf) [[likely]] {
if (transformed.GetNumLeaves() != node.arity) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns "
"a PyTreeSpec with the same number of arity as the input ("
<< node.arity << "), got " << transformed.ToString()
<< " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace)
<< ").";
throw py::value_error(oss.str());
}
if (transformed.GetNumNodes() != node.arity + 1) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected the PyTreeSpec transform function returns an one-level PyTreeSpec "
"as the input, got "
<< transformed.ToString()
<< " (input: " << one_level_structure_string(node, m_none_is_leaf, m_namespace)
<< ").";
throw py::value_error(oss.str());
}
auto& subroot = treespec->m_traversal.emplace_back(transformed.m_traversal.back());
EXPECT_GE(py::ssize_t_cast(pending_num_leaves_nodes.size()),
node.arity,
"PyTreeSpec::Transform() walked off start of array.");
subroot.num_leaves = 0;
subroot.num_nodes = 1;
for (ssize_t i = 0; i < node.arity; ++i) {
const auto& [num_leaves, num_nodes] = pending_num_leaves_nodes.back();
pending_num_leaves_nodes.pop_back();
subroot.num_leaves += num_leaves;
subroot.num_nodes += num_nodes;
}
pending_num_leaves_nodes.emplace_back(subroot.num_leaves, subroot.num_nodes);
} else [[unlikely]] {
std::copy(transformed.m_traversal.cbegin(),
transformed.m_traversal.cend(),
std::back_inserter(treespec->m_traversal));
num_extra_leaves += transformed.GetNumLeaves() - 1;
num_extra_nodes += transformed.GetNumNodes() - 1;
pending_num_leaves_nodes.emplace_back(transformed.GetNumLeaves(),
transformed.GetNumNodes());
}
}
EXPECT_EQ(pending_num_leaves_nodes.size(),
1,
"PyTreeSpec::Transform() did not yield a singleton.");

const auto& root = treespec->m_traversal.back();
EXPECT_EQ(root.num_leaves,
GetNumLeaves() + num_extra_leaves,
"Number of transformed tree leaves mismatch.");
EXPECT_EQ(root.num_nodes,
GetNumNodes() + num_extra_nodes,
"Number of transformed tree nodes mismatch.");
EXPECT_EQ(root.num_leaves,
treespec->GetNumLeaves(),
"Number of transformed tree leaves mismatch.");
EXPECT_EQ(root.num_nodes,
treespec->GetNumNodes(),
"Number of transformed tree nodes mismatch.");
treespec->m_none_is_leaf = m_none_is_leaf;
treespec->m_namespace = m_namespace;
treespec->m_traversal.shrink_to_fit();
PYTREESPEC_SANITY_CHECK(*treespec);
return treespec;
}

std::unique_ptr<PyTreeSpec> PyTreeSpec::Compose(const PyTreeSpec& inner_treespec) const {
PYTREESPEC_SANITY_CHECK(*this);
PYTREESPEC_SANITY_CHECK(inner_treespec);
Expand Down

0 comments on commit cb3b46a

Please sign in to comment.