Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(treespec): add method PyTreeSpec.transform #177

Merged
merged 7 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/tests-with-pydebug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ jobs:
git clone https://github.com/pyenv/pyenv.git "${PYENV_ROOT}"
echo "PYENV_ROOT=${PYENV_ROOT}" >> "${GITHUB_ENV}"
echo "PATH=${PATH}" >> "${GITHUB_ENV}"
if [[ "${{ runner.os }}" == 'Linux' ]]; then
sudo apt-get update -qq && sudo apt-get install -yqq --no-install-recommends \
make \
build-essential \
libssl-dev \
zlib1g-dev \
libbz2-dev \
libsqlite3-dev \
libncurses-dev \
libreadline-dev \
libgdbm-dev \
liblzma-dev
elif [[ "${{ runner.os }}" == 'macOS' ]]; then
brew install --only-dependencies python@3
fi

- name: Set up pyenv
id: setup-pyenv-windows
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Add method `PyTreeSpec.transform` by [@XuehaiPan](https://github.com/XuehaiPan) in [#177](https://github.com/metaopt/optree/pull/177).

### Changed

Expand Down
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ PyTreeSpec Functions
treespec_entry
treespec_children
treespec_child
treespec_transform
treespec_is_leaf
treespec_is_strict_leaf
treespec_is_prefix
Expand All @@ -145,6 +146,7 @@ PyTreeSpec Functions
.. autofunction:: treespec_entry
.. autofunction:: treespec_children
.. autofunction:: treespec_child
.. autofunction:: treespec_transform
.. autofunction:: treespec_is_leaf
.. autofunction:: treespec_is_strict_leaf
.. autofunction:: treespec_is_prefix
Expand Down
2 changes: 1 addition & 1 deletion include/optree/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class InternalError : public std::logic_error {
: InternalError([&message, &file, &lineno, &function]() -> std::string {
std::ostringstream oss{};
oss << message << " (";
if (function.has_value()) [[likely]] {
if (function) [[likely]] {
oss << "function `" << *function << "` ";
}
oss << "at file " << file << ":" << lineno << ")\n\n"
Expand Down
14 changes: 10 additions & 4 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,20 @@ class PyTreeSpec {
[[nodiscard]] std::unique_ptr<PyTreeSpec> BroadcastToCommonSuffix(
const PyTreeSpec &other) const;

// Transform a PyTreeSpec by applying `f_node(nodespec)` to nodes and `f_leaf(leafspec)` to
// leaves.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Transform(
const std::optional<py::function> &f_node = std::nullopt,
const std::optional<py::function> &f_leaf = std::nullopt) const;

// Compose two PyTreeSpecs, replacing the leaves of this tree with copies of `inner`.
[[nodiscard]] std::unique_ptr<PyTreeSpec> Compose(const PyTreeSpec &inner_treespec) const;

// Map a function over a PyTree structure, applying f_leaf to each leaf, and
// f_node(children, node_data) to each container node.
[[nodiscard]] py::object Walk(const py::function &f_node,
const std::optional<py::function> &f_leaf,
const py::iterable &leaves) const;
// f_node(node_type, node_data, children) to each container node.
[[nodiscard]] py::object Walk(const py::iterable &leaves,
const std::optional<py::function> &f_node = std::nullopt,
const std::optional<py::function> &f_leaf = std::nullopt) const;

// Return paths to all leaves in the PyTreeSpec.
[[nodiscard]] std::vector<py::tuple> Paths() const;
Expand Down
9 changes: 7 additions & 2 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,17 @@ class PyTreeSpec:
def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ...
def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ...
def broadcast_to_common_suffix(self, other: Self) -> Self: ...
def transform(
self,
f_node: Callable[[Self], Self] | None = None,
f_leaf: Callable[[Self], Self] | None = None,
) -> Self: ...
def compose(self, inner_treespec: Self) -> Self: ...
def walk(
self,
f_node: Callable[[tuple[U, ...], MetaData], U],
f_leaf: Callable[[T], U] | None,
leaves: Iterable[T],
f_node: Callable[[builtins.type, MetaData, tuple[U, ...]], U] | None = None,
f_leaf: Callable[[T], U] | None = None,
) -> U: ...
def paths(self) -> list[tuple[Any, ...]]: ...
def accessors(self) -> list[PyTreeAccessor]: ...
Expand Down
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
treespec_ordereddict,
treespec_paths,
treespec_structseq,
treespec_transform,
treespec_tuple,
)
from optree.registry import (
Expand Down Expand Up @@ -170,6 +171,7 @@
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
'treespec_is_prefix',
Expand Down
50 changes: 50 additions & 0 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'treespec_entry',
'treespec_children',
'treespec_child',
'treespec_transform',
'treespec_is_leaf',
'treespec_is_strict_leaf',
'treespec_is_prefix',
Expand Down Expand Up @@ -2586,6 +2587,55 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec:
return treespec.child(index)


def treespec_transform(
treespec: PyTreeSpec,
f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None,
f_leaf: Callable[[PyTreeSpec], PyTreeSpec] | None = None,
) -> PyTreeSpec:
"""Transform a treespec by applying functions to its nodes and leaves.

See also :func:`treespec_children`, :func:`treespec_is_leaf`, and :meth:`PyTreeSpec.transform`.

>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)})
>>> treespec
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)})
>>> treespec_transform(treespec, lambda spec: treespec_dict(zip(spec.entries(), spec.children())))
PyTreeSpec({'a': {0: *, 1: {0: *, 1: *}}, 'b': *, 'c': {0: *, 1: {}}})
>>> treespec_transform(
... treespec,
... lambda spec: (
... treespec_ordereddict(zip(spec.entries(), spec.children()))
... if spec.type is dict
... else spec
... ),
... )
PyTreeSpec(OrderedDict({'a': (*, [*, *]), 'b': *, 'c': (*, None)}))
>>> treespec_transform(
... treespec,
... lambda spec: (
... treespec_ordereddict(tree_unflatten(spec, spec.children()))
... if spec.type is dict
... else spec
... ),
... )
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': (*, None)}))
>>> treespec_transform(treespec, lambda spec: treespec_tuple(spec.children()))
PyTreeSpec(((*, (*, *)), *, (*, ())))
>>> treespec_transform(
... treespec,
... lambda spec: (
... treespec_list(spec.children())
... if spec.type is tuple
... else spec
... ),
... )
PyTreeSpec({'a': [*, [*, *]], 'b': *, 'c': [*, None]})
>>> treespec_transform(treespec, None, lambda spec: tree_structure((1, [2])))
PyTreeSpec({'a': ((*, [*]), [(*, [*]), (*, [*])]), 'b': (*, [*]), 'c': ((*, [*]), None)})
"""
return treespec.transform(f_node, f_leaf)


def treespec_is_leaf(treespec: PyTreeSpec, *, strict: bool = True) -> bool:
"""Return whether the treespec is a leaf that has no children.

Expand Down
18 changes: 12 additions & 6 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ py::module_ GetCxxModule(const std::optional<py::module_>& module) {
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::module_> storage;
return storage
.call_once_and_store_result([&module]() -> py::module_ {
EXPECT_TRUE(module.has_value(), "The module must be provided.");
EXPECT_TRUE(module, "The module must be provided.");
return *module;
})
.get_stored();
Expand Down Expand Up @@ -200,17 +200,23 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
&PyTreeSpec::BroadcastToCommonSuffix,
"Broadcast to the common suffix of this treespec and other treespec.",
py::arg("other"))
.def("transform",
&PyTreeSpec::Transform,
"Transform the pytree structure by applying ``f_node(nodespec)`` at nodes and "
"``f_leaf(leafspec)`` at leaves.",
py::arg("f_node") = std::nullopt,
py::arg("f_leaf") = std::nullopt)
.def("compose",
&PyTreeSpec::Compose,
"Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.",
py::arg("inner_treespec"))
.def("walk",
&PyTreeSpec::Walk,
"Walk over the pytree structure, calling ``f_node(children, node_data)`` at nodes, "
"and ``f_leaf(leaf)`` at leaves.",
py::arg("f_node"),
py::arg("f_leaf"),
py::arg("leaves"))
"Walk over the pytree structure, calling ``f_node(node_type, node_data, children)`` "
"at nodes, and ``f_leaf(leaf)`` at leaves.",
py::arg("leaves"),
py::arg("f_node") = std::nullopt,
py::arg("f_leaf") = std::nullopt)
.def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.")
.def("accessors",
&PyTreeSpec::Accessors,
Expand Down
28 changes: 14 additions & 14 deletions src/treespec/constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ template <bool NoneIsLeaf>
Node node;
node.kind = PyTreeTypeRegistry::GetKind<NoneIsLeaf>(handle, node.custom, registry_namespace);

const auto verify_children = [&handle, &node](const std::vector<py::object>& children,
std::vector<PyTreeSpec>& treespecs,
std::string& register_namespace) -> void {
const auto verify_children = [&handle, &node, &registry_namespace](
const std::vector<py::object>& children,
std::vector<PyTreeSpec>& treespecs) -> void {
for (const py::object& child : children) {
if (!py::isinstance<PyTreeSpec>(child)) [[unlikely]] {
std::ostringstream oss{};
Expand Down Expand Up @@ -106,16 +106,16 @@ template <bool NoneIsLeaf>
}
}
if (!common_registry_namespace.empty()) [[likely]] {
if (register_namespace.empty()) [[likely]] {
register_namespace = common_registry_namespace;
} else if (register_namespace != common_registry_namespace) [[unlikely]] {
if (registry_namespace.empty()) [[likely]] {
registry_namespace = common_registry_namespace;
} else if (registry_namespace != common_registry_namespace) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected treespec(s) with namespace " << PyRepr(register_namespace)
oss << "Expected treespec(s) with namespace " << PyRepr(registry_namespace)
<< ", got " << PyRepr(common_registry_namespace) << ".";
throw py::value_error(oss.str());
}
} else if (node.kind != PyTreeKind::Custom) [[likely]] {
register_namespace = "";
registry_namespace = "";
}
};

Expand Down Expand Up @@ -143,7 +143,7 @@ template <bool NoneIsLeaf>
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(TupleGetItem(handle, i));
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -155,7 +155,7 @@ template <bool NoneIsLeaf>
children.emplace_back(ListGetItem(handle, i));
}
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -178,7 +178,7 @@ template <bool NoneIsLeaf>
children.emplace_back(DictGetItem(dict, key));
}
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
const scoped_critical_section cs{handle};
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
Expand All @@ -197,7 +197,7 @@ template <bool NoneIsLeaf>
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(TupleGetItem(tuple, i));
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -209,7 +209,7 @@ template <bool NoneIsLeaf>
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(ListGetItem(list, i));
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
break;
}

Expand All @@ -235,7 +235,7 @@ template <bool NoneIsLeaf>
children.emplace_back(py::reinterpret_borrow<py::object>(child));
}
}
verify_children(children, treespecs, registry_namespace);
verify_children(children, treespecs);
if (num_out == 3) [[likely]] {
const py::object node_entries = TupleGetItem(out, 2);
if (!node_entries.is_none()) [[likely]] {
Expand Down
46 changes: 32 additions & 14 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <optional> // std::optional
#include <sstream> // std::ostringstream
#include <stdexcept> // std::runtime_error
#include <utility> // std::move

#include "optree/optree.h"

Expand Down Expand Up @@ -171,9 +172,10 @@ py::object PyTreeIter::Next() {
}
}

py::object PyTreeSpec::Walk(const py::function& f_node,
const std::optional<py::function>& f_leaf,
const py::iterable& leaves) const {
// NOLINTNEXTLINE[readability-function-cognitive-complexity]
py::object PyTreeSpec::Walk(const py::iterable& leaves,
const std::optional<py::function>& f_node,
const std::optional<py::function>& f_leaf) const {
PYTREESPEC_SANITY_CHECK(*this);

const scoped_critical_section cs{leaves};
Expand Down Expand Up @@ -203,18 +205,34 @@ py::object PyTreeSpec::Walk(const py::function& f_node,
case PyTreeKind::Deque:
case PyTreeKind::StructSequence:
case PyTreeKind::Custom: {
EXPECT_GE(py::ssize_t_cast(agenda.size()),
node.arity,
"Too few elements for custom type.");
const py::tuple tuple{node.arity};
for (ssize_t i = node.arity - 1; i >= 0; --i) {
TupleSetItem(tuple, i, agenda.back());
agenda.pop_back();
const ssize_t size = py::ssize_t_cast(agenda.size());
EXPECT_GE(size, node.arity, "Too few elements for custom type.");

if (f_node) [[likely]] {
const py::tuple children{node.arity};
for (ssize_t i = node.arity - 1; i >= 0; --i) {
TupleSetItem(children, i, agenda.back());
agenda.pop_back();
}

const py::object& node_type = GetType(node);
{
const scoped_critical_section cs2{node_type};
agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2(
(*f_node)(node_type,
(node.node_data ? node.node_data : py::none()),
children),
node.node_data,
*f_node));
}
} else [[unlikely]] {
py::object out =
MakeNode(node,
(node.arity > 0 ? &agenda[size - node.arity] : nullptr),
node.arity);
agenda.resize(size - node.arity);
agenda.emplace_back(std::move(out));
}
agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2(
f_node(tuple, (node.node_data ? node.node_data : py::none())),
node.node_data,
f_node));
break;
}

Expand Down
Loading
Loading