Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_argnames option to qjit #1158

Merged
merged 28 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d83ec15
add `static_argnames` option to qjit
paul0403 Sep 26, 2024
9e24d71
add static_argnames
paul0403 Sep 26, 2024
127df24
tests
paul0403 Sep 26, 2024
8d8919a
pylint
paul0403 Sep 26, 2024
0ce2c9a
changelog
paul0403 Sep 26, 2024
54c3972
changelog
paul0403 Sep 26, 2024
002a08b
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Sep 27, 2024
f3ea8a8
raise error when static_argname contains strings that are not functio…
paul0403 Sep 30, 2024
0aef4ab
remove unnecessary else
paul0403 Sep 30, 2024
34a418c
change to tests that don't execute the function
paul0403 Sep 30, 2024
a662cdb
codefactor
paul0403 Sep 30, 2024
8467aa3
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Sep 30, 2024
226f1a2
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 21, 2024
bf05ee6
move the mapping of static_argnames to static_argnums into QJIT objec…
paul0403 Oct 21, 2024
7e80645
change to inspect.signature to deal with any callable, not just plain…
paul0403 Oct 21, 2024
2c82559
add tests for decorators
paul0403 Oct 21, 2024
9506021
factor out merge_static_argname_into_argnum into a function, instead …
paul0403 Oct 21, 2024
574cc25
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 21, 2024
a3d6321
fix CI failure
paul0403 Oct 22, 2024
a34cca3
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 22, 2024
7e5faf4
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 22, 2024
522be57
changelog fix
paul0403 Oct 22, 2024
daa12ce
better sorting
paul0403 Oct 22, 2024
bd4125d
collect all non existent args
paul0403 Oct 22, 2024
a2e793f
fix non existent args order in error msg
paul0403 Oct 22, 2024
7657985
one line error msg
paul0403 Oct 22, 2024
771f2fb
use repr to one-liner the build of non_existent_args_str
paul0403 Oct 22, 2024
cf5a031
format
paul0403 Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,30 @@
Array([2, 4, 6], dtype=int64)
```

* Static arguments of a jit-ted function can now be indicated by a `static_argnames`
argument to `qjit`.
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
[(#1158)](https://github.com/PennyLaneAI/catalyst/pull/1158)

```python
@qjit(static_argnames="y")
def f(x, y):
if y < 10: # y needs to be marked as static since its concrete boolean value is needed
return x + y

@qjit(static_argnames=["x","y"])
def g(x, y):
if x < 10 and y < 10:
return x + y

res_f = f(1, 2)
res_g = g(3, 4)
print(res_f, res_g)
```

```pycon
>>> 3 7
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
```

<h3>Improvements</h3>

* Scalar tensors are eliminated from control flow operations in the program, and are replaced with
Expand Down
18 changes: 18 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def qjit(
logfile=None,
pipelines=None,
static_argnums=None,
static_argnames=None,
abstracted_axes=None,
disable_assertions=False,
seed=None,
Expand Down Expand Up @@ -123,6 +124,8 @@ def qjit(
considered to be used by advanced users for low-level debugging purposes.
static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the
positions of static arguments.
static_argnames(str or Seqence[str]): a string or a sequence of strings that specifies the
names of static arguments.
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]):
An experimental option to specify dynamic tensor shapes.
This option affects the compilation of the annotated function.
Expand Down Expand Up @@ -430,6 +433,21 @@ def sum_abstracted(arr):
if fn is None:
return functools.partial(qjit, **kwargs)

# Map static_argnames to static_argnums
# no need to propagate static_argnames further
if static_argnames is not None:
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
static_argnums = [] if (static_argnums is None) else list(kwargs["static_argnums"])
fn_argnames = inspect.getfullargspec(fn).args
for argname in fn_argnames:
if argname in static_argnames:
static_argnums.append(fn_argnames.index(argname))

# Remove potential duplicates from static_argnums and static_argnames
static_argnums = list(dict.fromkeys(static_argnums))
static_argnums.sort()
kwargs["static_argnums"] = static_argnums
kwargs.pop("static_argnames")

return QJIT(fn, CompileOptions(**kwargs))


Expand Down
50 changes: 50 additions & 0 deletions frontend/test/pytest/test_static_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pennylane as qml
import pytest
from jax.errors import TracerBoolConversionError

from catalyst import qjit
from catalyst.utils.exceptions import CompileError
Expand Down Expand Up @@ -221,6 +222,55 @@ def wrapper(x, c):
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"

def test_static_argnames(self):
# pylint: disable=inconsistent-return-statements
"""Test static arguments specified by names"""

@qjit(static_argnames="y")
def f(x, y):
if y < 10:
return x + y

assert f(1, 2) == 3

@qjit(static_argnames="x")
def g(x, y):
if y < 10:
return x + y

with pytest.raises(
TracerBoolConversionError, match="Attempted boolean conversion of traced"
):
g(1, 2)

@qjit(static_argnames=("x", "y"))
def h(x, y):
if x < 10 and y < 10:
return x + y

assert h(1, 2) == 3

@qjit(static_argnames=("x"), static_argnums=[1])
def p(x, y):
if x < 10 and y < 10:
return x + y

assert p(1, 2) == 3

@qjit(static_argnames=("y"), static_argnums=[0])
def q(x, y):
if x < 10 and y < 10:
return x + y

assert q(1, 2) == 3

@qjit(static_argnames=("y"), static_argnums=[1])
def r(x, y):
if y < 10:
return x + y

assert r(1, 2) == 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paul0403 have you tested that it works with different combinations of defining and calling functions? E.g.,

  • defining functions that use * and / in the signature
  • defining functions that use default values in the signature
  • calling with a mixture of keyword and positional values

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Through testing this I realized currently static_argnums do not support calls with default arguments (this is on main):

@qjit(static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

res = f(20)
print(res)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).

Note that jax works:

@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000


res = f(20)
print(res)
29

I don't think this is an issue with static_argname. I propose we open a separate issue to fix it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks @paul0403! My intuition was that this is worth testing as edge cases/bugs like this might appear :)

If this affects static_argnum, and would be resolved here once static_argnum is fixed, I am okay opening a separate issue to fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think none of the more complicated patterns work, but not because of static_argname, but because they never worked with static_argnum to begin with.

One particular case was handed out as an assessment: #1163 . In the issue I documented the failure. The apparent failure is in a verification, but I am not sure whether this means (a) the underlying mechanism is buggy and the verification failed as a by-product, or (b) the underlying mechanism is ok, but the verification itself is overly strict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josh146 I suggest we finish the simple base case of static_argnames for release, since the other failed cases were already there with static_argnums as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works for me!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paul0403 if you haven't already, would you be able to open an issue detailing the more advanced cases that don't work with static_argnums? #1158 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paul0403 if you haven't already, would you be able to open an issue detailing the more advanced cases that don't work with static_argnums? #1158 (comment)

Already did, and actually we sent this out as an assessment! #1163

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(As far as I can tell, the default value issue seems like the root cause of all other patterns' failures.)



if __name__ == "__main__":
pytest.main(["-x", __file__])
Loading