Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_argnames option to qjit #1158

Merged
merged 28 commits into from
Oct 22, 2024
Merged

Add static_argnames option to qjit #1158

merged 28 commits into from
Oct 22, 2024

Conversation

paul0403
Copy link
Contributor

@paul0403 paul0403 commented Sep 26, 2024

Context:
Adding a static_argnames option to qjit for users to configure static arguments by name.

Description of the Change:
Under the hood, this just maps the static_argnames to their argument indices and add to static_argnums.

Benefits:
Users can specify static arguments to jitted functions by name.

Possible Drawbacks:
Even more keyword arguments to qjit...

[sc-41335]

@paul0403 paul0403 requested review from josh146 and a team September 26, 2024 18:21
@paul0403 paul0403 marked this pull request as ready for review September 26, 2024 18:21
Copy link

codecov bot commented Sep 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.96%. Comparing base (d7c7e39) to head (cf5a031).
Report is 201 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1158   +/-   ##
=======================================
  Coverage   97.96%   97.96%           
=======================================
  Files          77       77           
  Lines       11244    11263   +19     
  Branches      967      972    +5     
=======================================
+ Hits        11015    11034   +19     
  Misses        180      180           
  Partials       49       49           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@erick-xanadu
Copy link
Contributor

I would replace the tests with lit tests that don't execute the function. They are quicker.

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just one question!

frontend/catalyst/jit.py Outdated Show resolved Hide resolved
if y < 10:
return x + y

assert r(1, 2) == 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paul0403 have you tested that it works with different combinations of defining and calling functions? E.g.,

  • defining functions that use * and / in the signature
  • defining functions that use default values in the signature
  • calling with a mixture of keyword and positional values

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Through testing this I realized currently static_argnums do not support calls with default arguments (this is on main):

@qjit(static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

res = f(20)
print(res)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).

Note that jax works:

@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000


res = f(20)
print(res)
29

I don't think this is an issue with static_argname. I propose we open a separate issue to fix it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks @paul0403! My intuition was that this is worth testing as edge cases/bugs like this might appear :)

If this affects static_argnum, and would be resolved here once static_argnum is fixed, I am okay opening a separate issue to fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think none of the more complicated patterns work, but not because of static_argname, but because they never worked with static_argnum to begin with.

One particular case was handed out as an assessment: #1163 . In the issue I documented the failure. The apparent failure is in a verification, but I am not sure whether this means (a) the underlying mechanism is buggy and the verification failed as a by-product, or (b) the underlying mechanism is ok, but the verification itself is overly strict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josh146 I suggest we finish the simple base case of static_argnames for release, since the other failed cases were already there with static_argnums as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works for me!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paul0403 if you haven't already, would you be able to open an issue detailing the more advanced cases that don't work with static_argnums? #1158 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paul0403 if you haven't already, would you be able to open an issue detailing the more advanced cases that don't work with static_argnums? #1158 (comment)

Already did, and actually we sent this out as an assessment! #1163

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(As far as I can tell, the default value issue seems like the root cause of all other patterns' failures.)

@paul0403
Copy link
Contributor Author

I would replace the tests with lit tests that don't execute the function. They are quicker.

34a418c

@dime10
Copy link
Contributor

dime10 commented Oct 1, 2024

Another integration test we might want to consider is with decorators (like vmap, qnode, grad, etc), since this functionality inspects the function signature, there is a chance decorators might not propagate this info properly.

frontend/catalyst/jit.py Outdated Show resolved Hide resolved
@paul0403
Copy link
Contributor Author

Another integration test we might want to consider is with decorators (like vmap, qnode, grad, etc), since this functionality inspects the function signature, there is a chance decorators might not propagate this info properly.

Really conveniently, inspect module has a inspect.signature(), which can take in any callable, not just plain functions.
7e80645

@paul0403 paul0403 requested a review from joeycarter October 22, 2024 14:40
Copy link
Contributor

@joeycarter joeycarter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @paul0403! Just a few minor comments.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
frontend/catalyst/tracing/type_signatures.py Outdated Show resolved Hide resolved
frontend/catalyst/tracing/type_signatures.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_static_arguments.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_static_arguments.py Outdated Show resolved Hide resolved
note that set(sorted(set(...))) does not work
@paul0403 paul0403 merged commit 4182a20 into main Oct 22, 2024
42 checks passed
@paul0403 paul0403 deleted the qjit_static_argname branch October 22, 2024 18:33
@paul0403 paul0403 added this to the v0.9.0 milestone Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants