From 5d0b4a4e24618a02e650087f1874a9a588ab06dd Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 7 Dec 2024 16:41:14 +0800 Subject: [PATCH] test: add more tests for `PyTreeSpec.transform` --- src/treespec/treespec.cpp | 2 +- tests/test_treespec.py | 46 ++++++++++++++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index bc19de59..7aa5d5bd 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -584,7 +584,7 @@ std::unique_ptr PyTreeSpec::Transform(const std::optionalGetNumNodes(), "Number of transformed tree nodes mismatch."); treespec->m_none_is_leaf = m_none_is_leaf; - treespec->m_namespace = m_namespace; + treespec->m_namespace = common_registry_namespace; treespec->m_traversal.shrink_to_fit(); PYTREESPEC_SANITY_CHECK(*treespec); return treespec; diff --git a/tests/test_treespec.py b/tests/test_treespec.py index d9eee7e2..10018a87 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -947,6 +947,24 @@ def test_treespec_transform(): ) == optree.tree_structure( {'a': {'a': [0, None, 1], 'b': [2, None, 3], 'c': [4, None, 5]}, 'b': {'a': [6, None, 7]}}, ) + namespaced_treespec = optree.tree_structure( + MyAnotherDict({1: MyAnotherDict({2: 1, 1: 2, 0: 3}), 0: MyAnotherDict({0: 4})}), + namespace='namespace', + ) + assert ( + optree.treespec_transform( + treespec, + lambda spec: optree.tree_structure( + MyAnotherDict(zip(spec.entries(), spec.children())), + namespace='namespace', + ), + ) + == namespaced_treespec + ) + assert optree.treespec_transform( + namespaced_treespec, + lambda spec: optree.treespec_list(spec.children()), + ) == optree.tree_structure([[1, 2, 3], [4]]) with pytest.raises( TypeError, @@ -961,8 +979,8 @@ def test_treespec_transform(): with pytest.raises( ValueError, match=( - r'Expected the PyTreeSpec transform function returns a PyTreeSpec ' - r'with the same value of `none_is_leaf=\w+` as the input' + r'Expected the PyTreeSpec transform function returns ' + r'a PyTreeSpec with the same value of `none_is_leaf=\w+` as the input' ), ): optree.treespec_transform( @@ -972,14 +990,32 @@ def test_treespec_transform(): none_is_leaf=True, ), ) + with pytest.raises(ValueError, match=r'Expected treespec\(s\) with namespace .*, got .*\.'): + + def fn(spec): + with optree.dict_insertion_ordered(True, namespace='undefined'): + return optree.treespec_dict(zip('abcd', spec.children()), namespace='undefined') + + optree.treespec_transform(namespaced_treespec, fn) with pytest.raises( ValueError, - match=( - r'Expected the PyTreeSpec transform function returns a PyTreeSpec ' - r'with the same number of arity as the input' + match=re.escape( + 'Expected the PyTreeSpec transform function returns ' + 'a PyTreeSpec with the same number of arity as the input', ), ): optree.treespec_transform(treespec, lambda _: optree.tree_structure([0, 1])) + with pytest.raises( + ValueError, + match=re.escape( + 'Expected the PyTreeSpec transform function returns ' + 'an one-level PyTreeSpec as the input', + ), + ): + optree.treespec_transform( + treespec, + lambda spec: optree.tree_structure([None] + [0] * spec.num_children), + ) @parametrize(