Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_argnames option to qjit #1158

Merged
merged 28 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d83ec15
add `static_argnames` option to qjit
paul0403 Sep 26, 2024
9e24d71
add static_argnames
paul0403 Sep 26, 2024
127df24
tests
paul0403 Sep 26, 2024
8d8919a
pylint
paul0403 Sep 26, 2024
0ce2c9a
changelog
paul0403 Sep 26, 2024
54c3972
changelog
paul0403 Sep 26, 2024
002a08b
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Sep 27, 2024
f3ea8a8
raise error when static_argname contains strings that are not functio…
paul0403 Sep 30, 2024
0aef4ab
remove unnecessary else
paul0403 Sep 30, 2024
34a418c
change to tests that don't execute the function
paul0403 Sep 30, 2024
a662cdb
codefactor
paul0403 Sep 30, 2024
8467aa3
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Sep 30, 2024
226f1a2
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 21, 2024
bf05ee6
move the mapping of static_argnames to static_argnums into QJIT objec…
paul0403 Oct 21, 2024
7e80645
change to inspect.signature to deal with any callable, not just plain…
paul0403 Oct 21, 2024
2c82559
add tests for decorators
paul0403 Oct 21, 2024
9506021
factor out merge_static_argname_into_argnum into a function, instead …
paul0403 Oct 21, 2024
574cc25
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 21, 2024
a3d6321
fix CI failure
paul0403 Oct 22, 2024
a34cca3
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 22, 2024
7e5faf4
Merge remote-tracking branch 'origin/main' into qjit_static_argname
paul0403 Oct 22, 2024
522be57
changelog fix
paul0403 Oct 22, 2024
daa12ce
better sorting
paul0403 Oct 22, 2024
bd4125d
collect all non existent args
paul0403 Oct 22, 2024
a2e793f
fix non existent args order in error msg
paul0403 Oct 22, 2024
7657985
one line error msg
paul0403 Oct 22, 2024
771f2fb
use repr to one-liner the build of non_existent_args_str
paul0403 Oct 22, 2024
cf5a031
format
paul0403 Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,30 @@
Array([2, 4, 6], dtype=int64)
```

* Static arguments of a qjit-compiled function can now be indicated by a `static_argnames`
argument to `qjit`.
[(#1158)](https://github.com/PennyLaneAI/catalyst/pull/1158)

```python
@qjit(static_argnames="y")
def f(x, y):
if y < 10: # y needs to be marked as static since its concrete boolean value is needed
return x + y

@qjit(static_argnames=["x","y"])
def g(x, y):
if x < 10 and y < 10:
return x + y

res_f = f(1, 2)
res_g = g(3, 4)
print(res_f, res_g)
```

```pycon
3 7
```

<h3>Improvements</h3>

* Implement a Catalyst runtime plugin that mocks out all functions in the QuantumDevice interface.
Expand Down
3 changes: 3 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class CompileOptions:
the main compilation pipeline is complete. Default is ``True``.
static_argnums (Optional[Union[int, Iterable[int]]]): indices of static arguments.
Default is ``None``.
static_argnames (Optional[Union[str, Iterable[str]]]): names of static arguments.
Default is ``None``.
abstracted_axes (Optional[Any]): store the abstracted_axes value. Defaults to ``None``.
disable_assertions (Optional[bool]): disables all assertions. Default is ``False``.
seed (Optional[int]) : the seed for random operations in a qjit call.
Expand All @@ -92,6 +94,7 @@ class CompileOptions:
autograph_include: Optional[Iterable[str]] = ()
async_qnodes: Optional[bool] = False
static_argnums: Optional[Union[int, Iterable[int]]] = None
static_argnames: Optional[Union[str, Iterable[str]]] = None
abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None
lower_to_llvm: Optional[bool] = True
checkpoint_stage: Optional[str] = ""
Expand Down
10 changes: 10 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
filter_static_args,
get_abstract_signature,
get_type_annotations,
merge_static_argname_into_argnum,
merge_static_args,
promote_arguments,
verify_static_argnums,
Expand Down Expand Up @@ -82,6 +83,7 @@ def qjit(
logfile=None,
pipelines=None,
static_argnums=None,
static_argnames=None,
abstracted_axes=None,
disable_assertions=False,
seed=None,
Expand Down Expand Up @@ -124,6 +126,8 @@ def qjit(
considered to be used by advanced users for low-level debugging purposes.
static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the
positions of static arguments.
static_argnames(str or Seqence[str]): a string or a sequence of strings that specifies the
names of static arguments.
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]):
An experimental option to specify dynamic tensor shapes.
This option affects the compilation of the annotated function.
Expand Down Expand Up @@ -482,6 +486,12 @@ def __init__(self, fn, compile_options):
self.user_sig = get_type_annotations(fn)
self._validate_configuration()

# If static_argnames are present, convert them to static_argnums
if compile_options.static_argnames is not None:
compile_options.static_argnums = merge_static_argname_into_argnum(
fn, compile_options.static_argnames, compile_options.static_argnums
)

# Patch the conversion rules by adding the included modules before the block list
include_convertlist = tuple(
ag_config.Convert(rule) for rule in self.compile_options.autograph_include
Expand Down
35 changes: 35 additions & 0 deletions frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,41 @@ def split_static_args(args, static_argnums):
return tuple(dynamic_args), tuple(static_args)


def merge_static_argname_into_argnum(fn: Callable, static_argnames, static_argnums):
"""Map static_argnames of the callable to the corresponding argument indices,
and add them to static_argnums"""
new_static_argnums = [] if (static_argnums is None) else list(static_argnums)
fn_argnames = list(inspect.signature(fn).parameters.keys())

# static_argnames can be a single str, or a list/tuple of strs
# convert all of them to list
if isinstance(static_argnames, str):
static_argnames = [static_argnames]

non_existent_args = []
for static_argname in static_argnames:
if static_argname in fn_argnames:
new_static_argnums.append(fn_argnames.index(static_argname))
continue
non_existent_args.append(static_argname)

if non_existent_args:
non_existent_args_str = "{"
for arg in non_existent_args:
non_existent_args_str += "'" + arg + "', "
non_existent_args_str = non_existent_args_str[:-2] + "}"
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

raise ValueError(
f"qjitted function has invalid argname {non_existent_args_str} in static_argnames. "
"Function does not take these args."
)

# Remove potential duplicates from static_argnums and static_argnames
new_static_argnums = tuple(sorted(set(new_static_argnums)))

return new_static_argnums


def merge_static_args(signature, args, static_argnums):
"""Merge static arguments back into an abstract signature, retaining the original ordering.

Expand Down
77 changes: 76 additions & 1 deletion frontend/test/pytest/test_static_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pennylane as qml
import pytest

from catalyst import qjit
from catalyst import grad, qjit
from catalyst.utils.exceptions import CompileError


Expand Down Expand Up @@ -221,6 +221,81 @@ def wrapper(x, c):
captured = capsys.readouterr()
assert captured.out.strip() == "Inside QNode: 0.5"

def test_static_argnames(self):
# pylint: disable=unused-argument, function-redefined
"""Test static arguments specified by names"""

@qjit(static_argnames="y")
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {1}

with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"):

@qjit(static_argnames="yy")
def f_badname(x, y):
return

with pytest.raises(ValueError, match="qjitted function has invalid argname {'yy'}"):

@qjit(static_argnames=["y", "yy"])
def f_badname_list(x, y):
return

with pytest.raises(ValueError, match="qjitted function has invalid argname {'xx', 'yy'}"):

@qjit(static_argnames=["xx", "yy"])
def f_badname_list(x, y):
return

@qjit(static_argnames=("x", "y"))
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {0, 1}

@qjit(static_argnames=("x"), static_argnums=[1])
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {0, 1}

@qjit(static_argnames=("y"), static_argnums=[0])
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {0, 1}

@qjit(static_argnames=("y"), static_argnums=[1])
def f(x, y):
return

assert set(f.compile_options.static_argnums) == {1}

def test_static_argnames_with_decorator(self):
# pylint: disable=unused-argument, function-redefined
"""Test static arguments specified by names
on functions with decorators"""

dev = qml.device("lightning.qubit", wires=3)

@qjit(static_argnames="theta")
@qml.qnode(dev)
def f(theta, phi):
qml.RX(theta, wires=0)
qml.RY(phi, wires=1)
return qml.probs()

assert set(f.compile_options.static_argnums) == {0}

@qjit(static_argnames=("x", "y"))
@grad
def f(x, y):
return x * y

assert set(f.compile_options.static_argnums) == {0, 1}


if __name__ == "__main__":
pytest.main(["-x", __file__])
Loading