Skip to content

Commit

Permalink
test: add more tests for PyTreeSpec.transform
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 7, 2024
1 parent c8c437b commit f7b67a0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ std::unique_ptr<PyTreeSpec> PyTreeSpec::Transform(const std::optional<py::functi
treespec->GetNumNodes(),
"Number of transformed tree nodes mismatch.");
treespec->m_none_is_leaf = m_none_is_leaf;
treespec->m_namespace = m_namespace;
treespec->m_namespace = common_registry_namespace;
treespec->m_traversal.shrink_to_fit();
PYTREESPEC_SANITY_CHECK(*treespec);
return treespec;
Expand Down
46 changes: 41 additions & 5 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,24 @@ def test_treespec_transform():
) == optree.tree_structure(
{'a': {'a': [0, None, 1], 'b': [2, None, 3], 'c': [4, None, 5]}, 'b': {'a': [6, None, 7]}},
)
namespaced_treespec = optree.tree_structure(
MyAnotherDict({1: MyAnotherDict({2: 1, 1: 2, 0: 3}), 0: MyAnotherDict({0: 4})}),
namespace='namespace',
)
assert (
optree.treespec_transform(
treespec,
lambda spec: optree.tree_structure(
MyAnotherDict(zip(spec.entries(), spec.children())),
namespace='namespace',
),
)
== namespaced_treespec
)
assert optree.treespec_transform(
namespaced_treespec,
lambda spec: optree.treespec_list(spec.children()),
) == optree.tree_structure([[1, 2, 3], [4]])

with pytest.raises(
TypeError,
Expand All @@ -961,8 +979,8 @@ def test_treespec_transform():
with pytest.raises(
ValueError,
match=(
r'Expected the PyTreeSpec transform function returns a PyTreeSpec '
r'with the same value of `none_is_leaf=\w+` as the input'
r'Expected the PyTreeSpec transform function returns '
r'a PyTreeSpec with the same value of `none_is_leaf=\w+` as the input'
),
):
optree.treespec_transform(
Expand All @@ -972,14 +990,32 @@ def test_treespec_transform():
none_is_leaf=True,
),
)
with pytest.raises(ValueError, match=r'Expected treespec\(s\) with namespace .*, got .*\.'):

def fn(spec):
with optree.dict_insertion_ordered(True, namespace='undefined'):
return optree.treespec_dict(zip('abcd', spec.children()), namespace='undefined')

optree.treespec_transform(namespaced_treespec, fn)
with pytest.raises(
ValueError,
match=(
r'Expected the PyTreeSpec transform function returns a PyTreeSpec '
r'with the same number of arity as the input'
match=re.escape(
'Expected the PyTreeSpec transform function returns '
'a PyTreeSpec with the same number of arity as the input',
),
):
optree.treespec_transform(treespec, lambda _: optree.tree_structure([0, 1]))
with pytest.raises(
ValueError,
match=re.escape(
'Expected the PyTreeSpec transform function returns '
'an one-level PyTreeSpec as the input',
),
):
optree.treespec_transform(
treespec,
lambda spec: optree.tree_structure([None] + [0] * spec.num_children),
)


@parametrize(
Expand Down

0 comments on commit f7b67a0

Please sign in to comment.