Skip to content

Commit

Permalink
Add an option to simplify keystr output and use a custom separator.
Browse files Browse the repository at this point in the history
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
  • Loading branch information
tomhennigan authored and Google-ML-Automation committed Jan 21, 2025
1 parent 96a3ed3 commit 7f43316
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
37 changes: 32 additions & 5 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,22 +722,49 @@ def _equality_errors(path, t1, t2, is_leaf):


@export
def keystr(keys: KeyPath):
def keystr(keys: KeyPath, *, simple: bool = False, separator: str = '') -> str:
"""Helper to pretty-print a tuple of keys.
Args:
keys: A tuple of ``KeyEntry`` or any class that can be converted to string.
simple: If True, use a simplified string representation for keys. The
simple representation of keys will be more compact than the default, but
is ambiguous in some cases (for example "0" might refer to the first item
in a list or a dictionary key for the integer 0 or string "0").
separator: The separator to use to join string representations of the keys.
Returns:
A string that joins all string representations of the keys.
Examples:
>>> import jax
>>> keys = (0, 1, 'a', 'b')
>>> jax.tree_util.keystr(keys)
'01ab'
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz
"""
return ''.join(map(str, keys))
str_fn = _simple_entrystr if simple else str
return separator.join(map(str_fn, keys))


def _simple_entrystr(key: KeyEntry) -> str:
match key:
case (
SequenceKey(idx=key)
| DictKey(key=key)
| GetAttrKey(name=key)
| FlattenedIndexKey(key=key)
):
return str(key)
case _:
return str(key)


# TODO(ivyzheng): remove this after another jaxlib release.
Expand Down
13 changes: 13 additions & 0 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,19 @@ def testKeyStr(self):
],
)

strs = [f"{tree_util.keystr(kp, simple=True, separator='/')}: {x}"
for kp, x in flattened]
self.assertEqual(
strs,
[
"0/foo: 12",
"0/bar/cin/0: 1",
"0/bar/cin/1: 4",
"0/bar/cin/2: 10",
"1: [0 1 2 3 4]",
],
)

def testTreeMapWithPathWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))
Expand Down

0 comments on commit 7f43316

Please sign in to comment.