-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add static_argnames
option to qjit
#1158
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1158 +/- ##
=======================================
Coverage 97.96% 97.96%
=======================================
Files 77 77
Lines 11244 11263 +19
Branches 967 972 +5
=======================================
+ Hits 11015 11034 +19
Misses 180 180
Partials 49 49 ☔ View full report in Codecov by Sentry. |
I would replace the tests with lit tests that don't execute the function. They are quicker. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one question!
if y < 10: | ||
return x + y | ||
|
||
assert r(1, 2) == 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@paul0403 have you tested that it works with different combinations of defining and calling functions? E.g.,
- defining functions that use
*
and/
in the signature - defining functions that use default values in the signature
- calling with a mixture of keyword and positional values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Through testing this I realized currently static_argnums
do not support calls with default arguments (this is on main):
@qjit(static_argnums=[1])
def f(y, x=9):
if x < 10:
return x + y
return 42000
res = f(20)
print(res)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).
Note that jax works:
@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
if x < 10:
return x + y
return 42000
res = f(20)
print(res)
29
I don't think this is an issue with static_argname
. I propose we open a separate issue to fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks @paul0403! My intuition was that this is worth testing as edge cases/bugs like this might appear :)
If this affects static_argnum
, and would be resolved here once static_argnum
is fixed, I am okay opening a separate issue to fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think none of the more complicated patterns work, but not because of static_argname
, but because they never worked with static_argnum
to begin with.
One particular case was handed out as an assessment: #1163 . In the issue I documented the failure. The apparent failure is in a verification, but I am not sure whether this means (a) the underlying mechanism is buggy and the verification failed as a by-product, or (b) the underlying mechanism is ok, but the verification itself is overly strict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josh146 I suggest we finish the simple base case of static_argnames
for release, since the other failed cases were already there with static_argnums
as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that works for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@paul0403 if you haven't already, would you be able to open an issue detailing the more advanced cases that don't work with static_argnums
? #1158 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@paul0403 if you haven't already, would you be able to open an issue detailing the more advanced cases that don't work with
static_argnums
? #1158 (comment)
Already did, and actually we sent this out as an assessment! #1163
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(As far as I can tell, the default value issue seems like the root cause of all other patterns' failures.)
|
Another integration test we might want to consider is with decorators (like vmap, qnode, grad, etc), since this functionality inspects the function signature, there is a chance decorators might not propagate this info properly. |
…t's init; if revert, revert this commit
… python functions
…of having it free floating in QJIT __init__
Really conveniently, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @paul0403! Just a few minor comments.
note that set(sorted(set(...))) does not work
Context:
Adding a
static_argnames
option to qjit for users to configure static arguments by name.Description of the Change:
Under the hood, this just maps the
static_argnames
to their argument indices and add tostatic_argnums
.Benefits:
Users can specify static arguments to jitted functions by name.
Possible Drawbacks:
Even more keyword arguments to qjit...
[sc-41335]