From 25c12e8af6147ddc7e15f30c8914495085fcd3e1 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 24 Jan 2025 14:46:27 -0500 Subject: [PATCH] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 666941f52e6613c4ea3b53b177698bf006954c56 Author: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Fri Jan 24 11:22:28 2025 -0500 `qml.execute` no longer accepts `mcm_config` argument (#6807) **Context:** Further `qml.workflow` clean-up. This enables `qml.execute` to mimic the signature of `QNode`. This ensures that we don't get incompatible configurations because we have two different entry points. **Description of the Change:** - [x] Catalyst: https://github.com/PennyLaneAI/catalyst/pull/1452 - [x] Lightning: No instances of deprecated code found. - [x] QML Demos: No instances of deprecated code found. - [x] Pennylane-AQT: No instances of deprecated code found. - [x] Pennylane-Qiskit: No instances of deprecated code found. - [x] Pennylane-IonQ: No instances of deprecated code found. - [x] Pennylane-Qrack: No instances of deprecated code found. - [x] Pennylane-Cirq: No instances of deprecated code found. - [x] Pennylane-Qulacs: No instances of deprecated code found. Introduce kwargs `postselect_mode` and `mcm_method` to `qml.execute` signature. Raise deprecation warning to user if they try to use `mcm_config` kwarg. Side-effect: Had to `xfail` any tests that had to do with Catalyst since they assume certain keys from the `QNode.execute_kwargs`. Will be reverted in https://github.com/PennyLaneAI/pennylane/pull/6873. **Benefits:** Keyword parity with `qml.QNode`. **Possible Drawbacks:** None identified. [sc-80541] --------- Co-authored-by: Christina Lee Co-authored-by: Yushao Chen (Jerry) commit dbb33159294537cb28fd5ab01db2776ef9935256 Author: ringo-but-quantum Date: Fri Jan 24 09:51:39 2025 +0000 [no ci] bump nightly version commit f367e01791f26923bf9222ca863d6910c1e08287 Author: Rashid N H M <95639609+rashidnhm@users.noreply.github.com> Date: Thu Jan 23 10:47:17 2025 -0500 Add merge queue trigger for required workflows (#6860) **Context:** This pull request adds a new trigger to the existing workflows that run `on.pull_request`. This trigger indicates to GitHub which workflows needs to be run when a merge queue is building. **Description of the Change:** The change is adding `on.merge_group` to our required workflows. **Benefits:** This change itself will not enable merge queues, that needs to be enabled from the admin settings of branch protection rules. The changes in this PR mainly tell merge queues which workflows to run. **Possible Drawbacks:** Usage of merge queue is a new thing for pennylane, if issues arise we can rollback. **Related GitHub Issues:** None. [sc-82039] commit 8a12fa59dcf42c83dfd57935007ae860269d7410 Author: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu Jan 23 10:06:38 2025 -0500 Promote `gradient_kwargs` to a positional keyword argument in `QNode` (#6828) **Context:** `gradient_kwargs` is now a positional keyword argument for the `QNode`. This means you can not simply express, ```python qml.QNode(func, dev, h=1) ``` instead, you must deliberately, ```python qml.QNode(func, dev, gradient_kwargs={"h":1}) ``` This allows easier and cleaner input validation. This PR could have wide-spread impact as it is very common to just specify `gradient_kwargs` casually as additional kwargs. - [x] Catalyst: https://github.com/PennyLaneAI/catalyst/pull/1480 - [x] Lightning: https://github.com/PennyLaneAI/pennylane-lightning/pull/1045 - [x] QML Demos: No instances of deprecated code found. - [x] Pennylane-AQT: No instances of deprecated code found. - [x] Pennylane-Qiskit: No instances of deprecated code found. - [x] Pennylane-IonQ: No instances of deprecated code found. - [x] Pennylane-Qrack: No instances of deprecated code found. - [x] Pennylane-Cirq: No instances of deprecated code found. - [x] Pennylane-Qulacs: No instances of deprecated code found. **Description of the Change:** Allow additional kwargs for now to ensure same functionality, but raise a deprecation warning. Append those additional kwargs to the internal gradient_kwargs dictionary. **Benefits:** Improved input validation for users. **Possible Drawbacks:** Might have missed some eco-system changes. Especially with CI **sometimes** not raising `PennyLaneDeprecationWarning`s as errors 😒 . [sc-81531] --------- Co-authored-by: Christina Lee commit 875ae112ab9b2da436c328c11cf5505bb39b13f4 Author: Christina Lee Date: Thu Jan 23 09:31:47 2025 -0500 Revert end-to-end jitting with default qubit (#6869) **Context:** In #6788 , we started allowing executions on default qubit to be jitted from end-to-end. Unfortunately, we found that the compilation overheads on these executions can get very, very expensive. So until we find a way to reduce the compilation overheads, we are using pure callbacks and conversion to numpy. **Description of the Change:** Default to `convert_to_numpy=False`, and xfail relevant tests. This change can be undone once we figure out how to resolve the compilation issue. **Benefits:** Reduced compilation overheads, because the execution itself does not get compiled. **Possible Drawbacks:** Slow down on post-compiled workflows. No way to jit an entire execution on default qubit. **Related GitHub Issues:** --------- Co-authored-by: Pietropaolo Frisoni commit 63cca88c2c44b034eb85d164778de0c0355c6f04 Author: ringo-but-quantum Date: Thu Jan 23 09:51:30 2025 +0000 [no ci] bump nightly version commit 61dbc7145cb9b0883e3dff817399c983db178279 Author: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed Jan 22 16:15:54 2025 -0500 Deprecate `qml.gradients.hamiltonian_grad` (#6849) **Context:** **Description of the Change:** _Source-Code_ Standard deprecation of `hamiltonian_grad` function. _Test suite_ I noticed that the tests I've removed from `test_parameter_shift.py` have improved duplicates in `tests/workflow/interfaces` under the `TestHamiltonianWorkflows` test class. Therefore, they were all removed except `test_jax`. The reason is that this test follows the outdated workflow that still hits the branch in `parameter_shift.py::expval_param_shift` that raises the deprecation warning. So, I've added a warning and left that test. This should be removed with the `hamiltonian_grad` function next release. **Impact:** No deprecated code found elsewhere. Impact to the eco-system should be minimal. [sc-81526] commit 3e1521bdef05235d915ddb8273ea177fcf62d755 Author: Mudit Pandey Date: Wed Jan 22 11:09:07 2025 -0500 [Capture] Add a `QmlPrimitive` class to differentiate between different types of primitives (#6847) This PR adds a `QmlPrimitive` subclass of `jax.core.Primitive`. This class contains a `prim_type` property set using a new `PrimitiveType` enum. `PrimitiveType`s currently available are "default", "operator", "measurement", "transform", and "higher_order". This can be made more or less fine grained as needed, but should be enough to differentiate between different types of primitives for now. Additionally, this PR: * updates `NonInterpPrimitive` to be a subclass of `QmlPrimitive` * updates all existing PennyLane primitives to be either `QmlPrimitive` or `NonInterpPrimitive`. See [this comment](https://github.com/PennyLaneAI/pennylane/pull/6851#discussion_r1922462699) to see the logic used to determine which `Primitive` subclass is used for each primitive. * updates `PlxprInterpreter.eval` and `CancelInversesInterpreter.eval` to use this `prim_type` property. [sc-82420] --------- Co-authored-by: Pietropaolo Frisoni commit fdf34ec6f287bf953a73df75d7fdcb531ec4dc87 Author: ringo-but-quantum Date: Wed Jan 22 09:51:47 2025 +0000 [no ci] bump nightly version commit 90dc57cf50fc49349a0c2bce2b3f2abac2abd453 Author: Christina Lee Date: Tue Jan 21 16:21:11 2025 -0500 [Capture] Add backprop validation (#6852) **Context:** We currently use un-validated backprop for differentiation with program capture. This leads to some unintuitive errors if you try and take a gradient on lightning with capture enabled. **Description of the Change:** Adds some validation to make sure the device supports backprop. Adds the backprop logic to a `_backprop` jvp function, and dispatches to that method based on the diff method. **Benefits:** Improved error messages when backprop or the requested diff method isn't supported. **Possible Drawbacks:** The code currently is a little clunky, but it is private so we should be able to move things around once we have more information. **Related GitHub Issues:** [sc-82166] commit 98bb29b3002175052ea5fd10371e3d0d1afdcbba Author: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com> Date: Tue Jan 21 14:14:39 2025 -0500 `lie_closure_dense` typo in docstring (#6858) **Context:** **Description of the Change:** Docstring code example **Benefits:** Docstring code example works **Possible Drawbacks:** 0️⃣ **Related GitHub Issues:** commit fe9c9a18856acc3b9fabbafc2e35614de998ec6e Author: Yushao Chen (Jerry) Date: Tue Jan 21 13:48:22 2025 -0500 Fix the deprecated usage of `MultiControlledX` in labs (#6862) **Context:** In the tests `pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py` there are several test suites that were using the deprecated interfaces of `MultiControlledX`, which will fail after the removal of corresponding deprecated `control_wires` arg. We fix the usages in this PR. **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** commit 872607dfae49441a9dad5170dd393111feb3a817 Author: Mudit Pandey Date: Tue Jan 21 11:48:25 2025 -0500 Bump `torch` version in CI to 2.5.0 (#6868) As name says. We should be testing against the latest version of `torch`. This PR updates the torch version used in CI to `2.5.0`. Also updated torch installation to use `~=` instead of `==` for choosing the version, so that bug fix releases (highest possible `2.5.X`) are automatically used instead of sticking with `2.5.0`. This should be a safe change to make, as torch only adds bug fixes to patch releases and no breaking changes ([ref](https://lightning.ai/docs/pytorch/stable/versioning.html)). commit 633b5bd454c27925914eff82b8d99bd33ae34741 Author: Mudit Pandey Date: Tue Jan 21 10:14:00 2025 -0500 Update docs workflow schedule (#6867) [sc-82706] Docs workflow currently opens a PR to update stable deps on Wednesday, while the tests workflow opens the stable deps update PR on Monday. Having 2 PRs per week for the same thing is annoying, so the schedule time for the docs workflow should be the same as the tests workflow. commit 37de9755ab5ebdb2f865231c699475891cf16aff Author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue Jan 21 09:31:51 2025 -0500 Update stable dependency files (#6856) commit a97fca5bb707895adfbdb9abc2042f0afbb24288 Author: ringo-but-quantum Date: Tue Jan 21 09:51:51 2025 +0000 [no ci] bump nightly version commit 5eaaccbd7dd1292b368616a8450ecb1174942ab4 Author: David Wierichs Date: Tue Jan 21 10:05:20 2025 +0100 Add some explanations on `NonInterpPrimitive` class (#6851) **Context:** The capture module uses a variant of `jax.core.Primitive` called `NonInterpPrimitive`. There were questions about why we need this and what it does. **Description of the Change:** This PR only adds some explanations to the respective `md` file to motivate our usage of this primitive variant. **Benefits:** Explain code **Possible Drawbacks:** N/A **Related GitHub Issues:** --------- Co-authored-by: Mudit Pandey Co-authored-by: Pietropaolo Frisoni --- .github/stable/all_interfaces.txt | 15 +- .github/stable/core.txt | 11 +- .github/stable/external.txt | 30 +-- .github/stable/jax.txt | 13 +- .github/stable/tf.txt | 15 +- .github/stable/torch.txt | 11 +- .github/workflows/docs.yml | 7 +- .github/workflows/format.yml | 3 + .../interface-dependency-versions.yml | 4 +- .github/workflows/module-validation.yml | 3 + .github/workflows/tests-gpu.yml | 3 + .github/workflows/tests.yml | 3 + doc/development/deprecations.rst | 19 ++ doc/releases/changelog-dev.md | 39 +++- pennylane/_version.py | 2 +- pennylane/capture/base_interpreter.py | 13 +- pennylane/capture/capture_diff.py | 41 +--- pennylane/capture/capture_measurements.py | 15 +- pennylane/capture/capture_operators.py | 7 +- pennylane/capture/custom_primitives.py | 64 +++++ pennylane/capture/explanations.md | 106 ++++++++- pennylane/compiler/qjit_api.py | 17 +- pennylane/devices/default_qubit.py | 17 +- pennylane/gradients/hamiltonian_grad.py | 11 + pennylane/gradients/parameter_shift.py | 7 + pennylane/labs/dla/lie_closure_dense.py | 4 +- .../ops/op_math/test_controlled_ops.py | 29 +-- pennylane/measurements/mid_measure.py | 7 +- pennylane/measurements/sample.py | 2 + pennylane/ops/op_math/adjoint.py | 9 +- pennylane/ops/op_math/condition.py | 9 +- pennylane/ops/op_math/controlled.py | 10 +- .../transforms/core/transform_dispatcher.py | 5 +- .../optimization/cancel_inverses.py | 8 +- pennylane/workflow/_capture_qnode.py | 39 +++- pennylane/workflow/execution.py | 28 ++- pennylane/workflow/qnode.py | 44 ++-- tests/capture/test_capture_qnode.py | 65 +++++- tests/capture/test_custom_primitives.py | 48 ++++ tests/capture/test_switches.py | 19 +- .../test_default_qubit_native_mcm.py | 5 +- .../test_default_qubit_preprocessing.py | 8 +- .../core/test_hamiltonian_gradient.py | 27 ++- tests/gradients/core/test_pulse_gradient.py | 56 +++-- .../finite_diff/test_spsa_gradient.py | 2 +- .../parameter_shift/test_cv_gradients.py | 6 +- .../parameter_shift/test_parameter_shift.py | 159 +------------ .../test_parameter_shift_shot_vec.py | 219 +----------------- tests/measurements/test_probs.py | 6 +- tests/ops/functions/test_matrix.py | 1 + tests/ops/op_math/test_exp.py | 1 + tests/resource/test_specs.py | 13 +- .../test_mottonen_state_prep.py | 2 +- tests/test_compiler.py | 28 +++ tests/test_qnode.py | 49 ++-- tests/test_qnode_legacy.py | 4 +- .../core/test_transform_dispatcher.py | 1 + tests/transforms/test_add_noise.py | 4 +- .../interfaces/execute/test_execute.py | 22 ++ .../interfaces/qnode/test_autograd_qnode.py | 63 ++--- .../qnode/test_autograd_qnode_shot_vector.py | 42 ++-- .../interfaces/qnode/test_jax_jit_qnode.py | 45 ++-- .../interfaces/qnode/test_jax_qnode.py | 65 +++--- .../qnode/test_jax_qnode_shot_vector.py | 100 ++++++-- ..._tensorflow_autograph_qnode_shot_vector.py | 28 +-- .../interfaces/qnode/test_tensorflow_qnode.py | 51 ++-- .../test_tensorflow_qnode_shot_vector.py | 64 +++-- .../interfaces/qnode/test_torch_qnode.py | 56 +++-- tests/workflow/test_construct_batch.py | 2 +- 69 files changed, 1111 insertions(+), 820 deletions(-) create mode 100644 pennylane/capture/custom_primitives.py create mode 100644 tests/capture/test_custom_primitives.py diff --git a/.github/stable/all_interfaces.txt b/.github/stable/all_interfaces.txt index b53f5569d44..84489804731 100644 --- a/.github/stable/all_interfaces.txt +++ b/.github/stable/all_interfaces.txt @@ -38,7 +38,7 @@ isort==5.13.2 jax==0.4.28 jaxlib==0.4.28 Jinja2==3.1.5 -keras==3.7.0 +keras==3.8.0 kiwisolver==1.4.8 lazy-object-proxy==1.10.0 libclang==18.1.1 @@ -68,12 +68,12 @@ nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.6.85 nvidia-nvtx-cu12==12.1.105 opt_einsum==3.4.0 -optree==0.13.1 +optree==0.14.0 osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -83,7 +83,7 @@ protobuf==4.25.5 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -101,7 +101,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -115,14 +116,14 @@ termcolor==2.5.0 tf_keras==2.16.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 torch==2.3.0 triton==2.3.0 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 Werkzeug==3.1.3 wrapt==1.12.1 diff --git a/.github/stable/core.txt b/.github/stable/core.txt index fc24ed2dff0..925553c4c2f 100644 --- a/.github/stable/core.txt +++ b/.github/stable/core.txt @@ -43,7 +43,7 @@ osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -52,7 +52,7 @@ prompt_toolkit==3.0.48 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -70,7 +70,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -78,11 +79,11 @@ tach==0.13.1 termcolor==2.5.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 wrapt==1.12.1 diff --git a/.github/stable/external.txt b/.github/stable/external.txt index 5cb949a78d7..80052907bc0 100644 --- a/.github/stable/external.txt +++ b/.github/stable/external.txt @@ -27,14 +27,14 @@ clarabel==0.9.0 click==8.1.8 comm==0.2.2 contourpy==1.3.1 -cotengra==0.6.2 +cotengra==0.7.0 coverage==7.6.10 cryptography==44.0.0 cvxopt==1.3.2 cvxpy==1.6.0 cycler==0.12.1 cytoolz==1.0.1 -debugpy==1.8.11 +debugpy==1.8.12 decorator==5.1.1 defusedxml==0.7.1 diastatic-malt==2.15.2 @@ -60,8 +60,8 @@ h11==0.14.0 h5py==3.12.1 httpcore==1.0.7 httpx==0.28.1 -ibm-cloud-sdk-core==3.22.0 -ibm-platform-services==0.59.0 +ibm-cloud-sdk-core==3.22.1 +ibm-platform-services==0.59.1 identify==2.6.5 idna==3.10 iniconfig==2.0.0 @@ -89,7 +89,7 @@ jupyterlab==4.3.4 jupyterlab_pygments==0.3.0 jupyterlab_server==2.27.3 jupyterlab_widgets==1.1.11 -keras==3.7.0 +keras==3.8.0 kiwisolver==1.4.8 lark==1.1.9 lazy-object-proxy==1.10.0 @@ -120,7 +120,7 @@ numba==0.60.0 numpy==1.26.4 opt_einsum==3.4.0 optax==0.2.4 -optree==0.13.1 +optree==0.14.0 osqp==0.6.7.post3 overrides==7.7.0 packaging==24.2 @@ -129,8 +129,8 @@ pandocfilters==1.5.1 parso==0.8.4 pathspec==0.12.1 pbr==6.1.0 -PennyLane-Catalyst==0.10.0.dev39 -PennyLane-qiskit @ git+https://github.com/PennyLaneAI/pennylane-qiskit.git@40a4d24f126e51e0e3e28a4cd737f883a6fd5ebc +PennyLane-Catalyst==0.11.0.dev1 +PennyLane-qiskit @ git+https://github.com/PennyLaneAI/pennylane-qiskit.git@b46fbca3372979534bc33af701194df548cf4b16 PennyLane_Lightning==0.40.0 PennyLane_Lightning_Kokkos==0.40.0.dev41 pexpect==4.9.0 @@ -148,10 +148,10 @@ pure_eval==0.2.3 py==1.11.0 py-cpuinfo==9.0.0 pycparser==2.22 -pydantic==2.10.4 +pydantic==2.10.5 pydantic_core==2.27.2 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 PyJWT==2.10.1 pylint==2.7.4 pyparsing==3.2.1 @@ -173,11 +173,11 @@ pyzmq==26.2.0 pyzx==0.8.0 qdldl==0.1.7.post5 qiskit==1.2.4 -qiskit-aer==0.15.1 +qiskit-aer==0.16.0 qiskit-ibm-provider==0.11.0 qiskit-ibm-runtime==0.29.0 quimb==1.10.0 -referencing==0.35.1 +referencing==0.36.1 requests==2.32.3 requests_ntlm==1.3.0 rfc3339-validator==0.1.4 @@ -211,7 +211,7 @@ tf_keras==2.16.0 tinycss2==1.4.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 toolz==1.0.0 tornado==6.4.2 @@ -222,12 +222,12 @@ typing_extensions==4.12.2 tzdata==2024.2 uri-template==1.3.0 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 webcolors==24.11.1 webencodings==0.5.1 websocket-client==1.8.0 -websockets==14.1 +websockets==14.2 Werkzeug==3.1.3 widgetsnbextension==3.6.10 wrapt==1.12.1 diff --git a/.github/stable/jax.txt b/.github/stable/jax.txt index ed2c0b32037..504bcff8e47 100644 --- a/.github/stable/jax.txt +++ b/.github/stable/jax.txt @@ -37,7 +37,7 @@ markdown-it-py==3.0.0 matplotlib==3.10.0 mccabe==0.6.1 mdurl==0.1.2 -ml_dtypes==0.5.0 +ml_dtypes==0.5.1 mypy-extensions==1.0.0 networkx==3.4.2 nodeenv==1.9.1 @@ -47,7 +47,7 @@ osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -56,7 +56,7 @@ prompt_toolkit==3.0.48 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -74,7 +74,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -82,11 +83,11 @@ tach==0.13.1 termcolor==2.5.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 wrapt==1.12.1 diff --git a/.github/stable/tf.txt b/.github/stable/tf.txt index 634fbd09526..02d559768eb 100644 --- a/.github/stable/tf.txt +++ b/.github/stable/tf.txt @@ -34,7 +34,7 @@ identify==2.6.5 idna==3.10 iniconfig==2.0.0 isort==5.13.2 -keras==3.7.0 +keras==3.8.0 kiwisolver==1.4.8 lazy-object-proxy==1.10.0 libclang==18.1.1 @@ -51,12 +51,12 @@ networkx==3.4.2 nodeenv==1.9.1 numpy==1.26.4 opt_einsum==3.4.0 -optree==0.13.1 +optree==0.14.0 osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -66,7 +66,7 @@ protobuf==4.25.5 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -84,7 +84,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -97,12 +98,12 @@ termcolor==2.5.0 tf_keras==2.16.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 Werkzeug==3.1.3 wrapt==1.12.1 diff --git a/.github/stable/torch.txt b/.github/stable/torch.txt index 3ab4c0e93f2..46b5f061df6 100644 --- a/.github/stable/torch.txt +++ b/.github/stable/torch.txt @@ -59,7 +59,7 @@ osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -68,7 +68,7 @@ prompt_toolkit==3.0.48 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -86,7 +86,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -95,13 +96,13 @@ tach==0.13.1 termcolor==2.5.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 torch==2.3.0 triton==2.3.0 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 wrapt==1.12.1 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 02b95edc395..5683d05880c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -16,15 +16,18 @@ name: "Documentation check" on: + merge_group: + types: + - checks_requested pull_request: types: - opened - reopened - synchronize - ready_for_review - # Scheduled trigger on Wednesdays at 3:00am UTC + # Scheduled trigger on Monday at 2:47am UTC schedule: - - cron: "0 3 * * 3" + - cron: "47 2 * * 1" permissions: write-all diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 0a29e842bf1..ba5d6ce7776 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -1,5 +1,8 @@ name: Formatting check on: + merge_group: + types: + - checks_requested pull_request: types: - opened diff --git a/.github/workflows/interface-dependency-versions.yml b/.github/workflows/interface-dependency-versions.yml index f567321b27c..d8520064dc1 100644 --- a/.github/workflows/interface-dependency-versions.yml +++ b/.github/workflows/interface-dependency-versions.yml @@ -26,7 +26,7 @@ on: description: The version of PyTorch to use for testing required: false type: string - default: '2.3.0' + default: '2.5.0' outputs: jax-version: description: The version of JAX to use @@ -72,6 +72,6 @@ jobs: outputs: jax-version: jax==${{ steps.jax.outputs.version }} jaxlib==${{ steps.jax.outputs.version }} tensorflow-version: tensorflow~=${{ steps.tensorflow.outputs.version }} tf-keras~=${{ steps.tensorflow.outputs.version }} - pytorch-version: torch==${{ steps.pytorch.outputs.version }} + pytorch-version: torch~=${{ steps.pytorch.outputs.version }} catalyst-nightly: ${{ steps.catalyst.outputs.nightly }} pennylane-lightning-latest: ${{ steps.pennylane-lightning.outputs.latest }} diff --git a/.github/workflows/module-validation.yml b/.github/workflows/module-validation.yml index 0db4f16faeb..66af8bc218e 100644 --- a/.github/workflows/module-validation.yml +++ b/.github/workflows/module-validation.yml @@ -1,6 +1,9 @@ name: Validate module imports on: + merge_group: + types: + - checks_requested pull_request: types: - opened diff --git a/.github/workflows/tests-gpu.yml b/.github/workflows/tests-gpu.yml index bd8cd5dae67..92a7bae8ac9 100644 --- a/.github/workflows/tests-gpu.yml +++ b/.github/workflows/tests-gpu.yml @@ -3,6 +3,9 @@ on: push: branches: - master + merge_group: + types: + - checks_requested pull_request: types: - opened diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 63c8074e524..620dc39ac57 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,6 +14,9 @@ on: # Scheduled trigger on Monday at 2:47am UTC schedule: - cron: "47 2 * * 1" + merge_group: + types: + - checks_requested concurrency: group: unit-tests-${{ github.ref }} diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst index ff6048e18be..a61b8338222 100644 --- a/doc/development/deprecations.rst +++ b/doc/development/deprecations.rst @@ -9,6 +9,25 @@ deprecations are listed below. Pending deprecations -------------------- +* The ``mcm_config`` argument to ``qml.execute`` has been deprecated. + Instead, use the ``mcm_method`` and ``postselect_mode`` arguments. + + - Deprecated in v0.41 + - Will be removed in v0.42 + +* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated + and will be removed in v0.42. The gradient keyword arguments should be passed to the new + keyword argument ``gradient_kwargs`` via an explicit dictionary, like ``gradient_kwargs={"h": 1e-4}``. + + - Deprecated in v0.41 + - Will be removed in v0.42 + +* The `qml.gradients.hamiltonian_grad` function has been deprecated. + This gradient recipe is not required with the new operator arithmetic system. + + - Deprecated in v0.41 + - Will be removed in v0.42 + * The ``inner_transform_program`` and ``config`` keyword arguments in ``qml.execute`` have been deprecated. If more detailed control over the execution is required, use ``qml.workflow.run`` with these arguments instead. diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 89e21560387..88280392227 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -29,11 +29,11 @@ 'parameter-shift' ``` -* Finite shot and parameter-shift executions on `default.qubit` can now - be natively jitted end-to-end, leading to performance improvements. - Devices can now configure whether or not ML framework data is sent to them - via an `ExecutionConfig.convert_to_numpy` parameter. +* Devices can now configure whether or not ML framework data is sent to them + via an `ExecutionConfig.convert_to_numpy` parameter. This is not used on + `default.qubit` due to compilation overheads when jitting. [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) + [(#6869)](https://github.com/PennyLaneAI/pennylane/pull/6869) * The coefficients of observables now have improved differentiability. [(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598) @@ -44,6 +44,9 @@ * An informative error is raised when a `QNode` with `diff_method=None` is differentiated. [(#6770)](https://github.com/PennyLaneAI/pennylane/pull/6770) +* The requested `diff_method` is now validated when program capture is enabled. + [(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852) +

Breaking changes πŸ’”

* `MultiControlledX` no longer accepts strings as control values. @@ -51,6 +54,7 @@ * The input argument `control_wires` of `MultiControlledX` has been removed. [(#6832)](https://github.com/PennyLaneAI/pennylane/pull/6832) + [(#6862)](https://github.com/PennyLaneAI/pennylane/pull/6862) * `qml.execute` now has a collection of keyword-only arguments. [(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598) @@ -80,17 +84,40 @@

Deprecations πŸ‘‹

+* The `mcm_method` keyword in `qml.execute` is deprecated. Instead, use the ``mcm_method`` and ``postselect_mode`` arguments. + [(#6807)](https://github.com/PennyLaneAI/pennylane/pull/6807) + +* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated + and will be removed in v0.42. The gradient keyword arguments should be passed to the new + keyword argument `gradient_kwargs` via an explicit dictionary. This change will improve qnode argument + validation. + [(#6828)](https://github.com/PennyLaneAI/pennylane/pull/6828) + +* The `qml.gradients.hamiltonian_grad` function has been deprecated. + This gradient recipe is not required with the new operator arithmetic system. + [(#6849)](https://github.com/PennyLaneAI/pennylane/pull/6849) + * The ``inner_transform_program`` and ``config`` keyword arguments in ``qml.execute`` have been deprecated. If more detailed control over the execution is required, use ``qml.workflow.run`` with these arguments instead. [(#6822)](https://github.com/PennyLaneAI/pennylane/pull/6822)

Internal changes βš™οΈ

+* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module. + This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives. + Consequently, `QmlPrimitive` is now used to define all PennyLane primitives. + [(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847) +

Documentation πŸ“

-* Updated documentation for vibrational Hamiltonians +* The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`, + `qml.qchem.localize_normal_modes`, and `qml.qchem.VibrationalPES` have been updated to include better + code examples. [(#6717)](https://github.com/PennyLaneAI/pennylane/pull/6717) +* Fixed a typo in the code example for `qml.labs.dla.lie_closure_dense`. + [(#6858)](https://github.com/PennyLaneAI/pennylane/pull/6858) +

Bug fixes πŸ›

* `BasisState` now casts its input to integers. @@ -101,8 +128,10 @@ This release contains contributions from (in alphabetical order): Yushao Chen, +Isaac De Vlugt, Diksha Dhawan, Pietropaolo Frisoni, Marcus GisslΓ©n, Christina Lee, +Mudit Pandey, Andrija Paurevic diff --git a/pennylane/_version.py b/pennylane/_version.py index 704c372802e..9295240495c 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.41.0-dev10" +__version__ = "0.41.0-dev14" diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 4af11cb6198..2b7314f7c2e 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -25,8 +25,6 @@ from .flatfn import FlatFn from .primitives import ( - AbstractMeasurement, - AbstractOperator, adjoint_transform_prim, cond_prim, ctrl_transform_prim, @@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: self._env[constvar] = const for eqn in jaxpr.eqns: + primitive = eqn.primitive + custom_handler = self._primitive_registrations.get(primitive, None) - custom_handler = self._primitive_registrations.get(eqn.primitive, None) if custom_handler: invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif isinstance(eqn.outvars[0].aval, AbstractOperator): + elif getattr(primitive, "prim_type", "") == "operator": outvals = self.interpret_operation_eqn(eqn) - elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + elif getattr(primitive, "prim_type", "") == "measurement": outvals = self.interpret_measurement_eqn(eqn) else: invals = [self.read(invar) for invar in eqn.invars] - outvals = eqn.primitive.bind(*invals, **eqn.params) + outvals = primitive.bind(*invals, **eqn.params) - if not eqn.primitive.multiple_results: + if not primitive.multiple_results: outvals = [outvals] for outvar, outval in zip(eqn.outvars, outvals, strict=True): self._env[outvar] = outval diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index 482f692df69..ba7c3846693 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -24,34 +24,6 @@ has_jax = False -@lru_cache -def create_non_interpreted_prim(): - """Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace - and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive. - """ - - if not has_jax: # pragma: no cover - return None - - # pylint: disable=too-few-public-methods - class NonInterpPrimitive(jax.core.Primitive): - """A subclass to JAX's Primitive that works like a Python function - when evaluating JVPTracers and BatchTracers.""" - - def bind_with_trace(self, trace, args, params): - """Bind the ``NonInterpPrimitive`` with a trace. - - If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. - Otherwise, the bind call of JAX's standard Primitive is used.""" - if isinstance( - trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace) - ): - return self.impl(*args, **params) - return super().bind_with_trace(trace, args, params) - - return NonInterpPrimitive - - @lru_cache def _get_grad_prim(): """Create a primitive for gradient computations. @@ -60,8 +32,11 @@ def _get_grad_prim(): if not has_jax: # pragma: no cover return None - grad_prim = create_non_interpreted_prim()("grad") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + grad_prim = NonInterpPrimitive("grad") grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + grad_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @grad_prim.def_impl @@ -91,8 +66,14 @@ def _get_jacobian_prim(): """Create a primitive for Jacobian computations. This primitive is used when capturing ``qml.jacobian``. """ - jacobian_prim = create_non_interpreted_prim()("jacobian") + if not has_jax: # pragma: no cover + return None + + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + jacobian_prim = NonInterpPrimitive("jacobian") jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + jacobian_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @jacobian_prim.def_impl diff --git a/pennylane/capture/capture_measurements.py b/pennylane/capture/capture_measurements.py index 59bf5490679..e23457bc7b7 100644 --- a/pennylane/capture/capture_measurements.py +++ b/pennylane/capture/capture_measurements.py @@ -128,7 +128,10 @@ def create_measurement_obs_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_obs") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_obs") + primitive.prim_type = "measurement" @primitive.def_impl def _(obs, **kwargs): @@ -165,7 +168,10 @@ def create_measurement_mcm_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_mcm") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_mcm") + primitive.prim_type = "measurement" @primitive.def_impl def _(*mcms, single_mcm=True, **kwargs): @@ -200,7 +206,10 @@ def create_measurement_wires_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_wires") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_wires") + primitive.prim_type = "measurement" @primitive.def_impl def _(*args, has_eigvals=False, **kwargs): diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py index 23c98f38944..2124b5b9fe4 100644 --- a/pennylane/capture/capture_operators.py +++ b/pennylane/capture/capture_operators.py @@ -20,8 +20,6 @@ import pennylane as qml -from .capture_diff import create_non_interpreted_prim - has_jax = True try: import jax @@ -103,7 +101,10 @@ def create_operator_primitive( if not has_jax: return None - primitive = create_non_interpreted_prim()(operator_type.__name__) + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(operator_type.__name__) + primitive.prim_type = "operator" @primitive.def_impl def _(*args, **kwargs): diff --git a/pennylane/capture/custom_primitives.py b/pennylane/capture/custom_primitives.py new file mode 100644 index 00000000000..183ae05771b --- /dev/null +++ b/pennylane/capture/custom_primitives.py @@ -0,0 +1,64 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule offers custom primitives for the PennyLane capture module. +""" +from enum import Enum +from typing import Union + +import jax + + +class PrimitiveType(Enum): + """Enum to define valid set of primitive classes""" + + DEFAULT = "default" + OPERATOR = "operator" + MEASUREMENT = "measurement" + HIGHER_ORDER = "higher_order" + TRANSFORM = "transform" + + +# pylint: disable=too-few-public-methods,abstract-method +class QmlPrimitive(jax.core.Primitive): + """A subclass for JAX's Primitive that differentiates between different + classes of primitives.""" + + _prim_type: PrimitiveType = PrimitiveType.DEFAULT + + @property + def prim_type(self): + """Value of Enum representing the primitive type to differentiate between various + sets of PennyLane primitives.""" + return self._prim_type.value + + @prim_type.setter + def prim_type(self, value: Union[str, PrimitiveType]): + """Setter for QmlPrimitive.prim_type.""" + self._prim_type = PrimitiveType(value) + + +# pylint: disable=too-few-public-methods,abstract-method +class NonInterpPrimitive(QmlPrimitive): + """A subclass to JAX's Primitive that works like a Python function + when evaluating JVPTracers and BatchTracers.""" + + def bind_with_trace(self, trace, args, params): + """Bind the ``NonInterpPrimitive`` with a trace. + + If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)): + return self.impl(*args, **params) + return super().bind_with_trace(trace, args, params) diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 84feef9786f..71033aac3ee 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -255,7 +255,7 @@ class MyClass(metaclass=MyMetaClass): self.kwargs = kwargs ``` - Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}. + Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}. And that we have set a class property `a` @@ -272,7 +272,7 @@ But can we actually create instances of these classes? ```python >> obj = MyClass(0.1, a=2) >>> obj -creating an instance of type with (0.1,), {'a': 2}. +creating an instance of type with (0.1,), {'a': 2}. now creating an instance in __init__ <__main__.MyClass at 0x11c5a2810> ``` @@ -294,7 +294,7 @@ class MyClass2(metaclass=MetaClass2): self.args = args ``` -You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`. +You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`. Using a metaclass, we can hijack what happens when a type is called. @@ -425,3 +425,103 @@ Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `A >>> jax.make_jaxpr(PrimitiveClass2)(0.1) { lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) } ``` + +# Non-interpreted primitives + +Some of the primitives in the capture module have a somewhat non-standard requirement for the +behaviour under differentiation or batching: they should ignore that an input is a differentiation +or batching tracer and just execute the standard implementation on them. + +We will look at an example to make the necessity for such a non-interpreted primitive clear. + +Consider a finite-difference differentiation routine together with some test function `fun`. + +```python +def finite_diff_impl(x, fun, delta): + """Finite difference differentiation routine. Only supports differentiating + a function `fun` with a single scalar argument, for simplicity.""" + + out_plus = fun(x + delta) + out_minus = fun(x - delta) + return tuple((out_p - out_m) / (2 * delta) for out_p, out_m in zip(out_plus, out_minus)) + +def fun(x): + return (x**2, 4 * x - 3, x**23) +``` + +Now suppose we want to turn this into a primitive. We could just promote it to a standard +`jax.core.Primitive` as + +```python +import jax + +fd_prim = jax.core.Primitive("finite_diff") +fd_prim.multiple_results = True +fd_prim.def_impl(finite_diff_impl) + +def finite_diff(x, fun, delta=1e-5): + return fd_prim.bind(x, fun, delta) +``` + +This allows us to use the forward pass as usual (to compute the first-order derivative): + +```pycon +>>> finite_diff(1., fun, delta=1e-6) +(2.000000000002, 3.999999999892978, 23.000000001216492) +``` + +Now if we want to make this primitive differentiable (with automatic +differentiation/backprop, not by using a higher-order finite difference scheme), +we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example +that we could use to implement a finite difference scheme that is readily differentiable. This is +somewhat beside the point because we did not identify the possibility of using any of those +alternatives in the PennyLane code). + +However, the finite difference rule is just a standard +algebraic function making use of calls to `fun` and some elementary operations, so ideally +we would like to just use the chain rule as it is known to the automatic differentiation framework. A JVP rule would +then just manually re-implement this chain rule, which we'd rather not do. + +Instead, we define a non-interpreted type of primitive and create such a primitive +for our finite difference method. We also create the usual method that binds the +primitive to inputs. + +```python +class NonInterpPrimitive(jax.core.Primitive): + """A subclass to JAX's Primitive that works like a Python function + when evaluating JVPTracers.""" + + def bind_with_trace(self, trace, args, params): + """Bind the ``NonInterpPrimitive`` with a trace. + If the trace is a ``JVPTrace``, it falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance(trace, jax.interpreters.ad.JVPTrace): + return self.impl(*args, **params) + return super().bind_with_trace(trace, args, params) + +fd_prim_2 = NonInterpPrimitive("finite_diff_2") +fd_prim_2.multiple_results = True +fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer + +def finite_diff_2(x, fun, delta=1e-5): + return fd_prim_2.bind(x, fun, delta) +``` + +Now we can use the primitive in a differentiable workflow, without defining a JVP rule +that just repeats the chain rule: + +```pycon +>>> # Forward execution of finite_diff_2 (-> first-order derivative) +>>> finite_diff_2(1., fun, delta=1e-6) +(2.000000000002, 3.999999999892978, 23.000000001216492) +>>> # Differentiation of finite_diff_2 (-> second-order derivative) +>>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6) +(Array(1.9375, dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True), Array(498., dtype=float32, weak_type=True)) +``` + +In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, quantum operators +have non-interpreted primitives as well. This is because their differentiation is performed +by the surrounding QNode primitive rather than through the standard chain rule that acts +"locally" (in the circuit). In short, we only want gates to store their tracers (which will help +determine the differentiability of gate arguments, for example), but not to do anything with them. + diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 08d88988b79..797a7437abb 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -17,7 +17,6 @@ from collections.abc import Callable import pennylane as qml -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from .compiler import ( @@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable: def _get_while_loop_qfunc_prim(): """Get the while_loop primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - while_loop_prim = create_non_interpreted_prim()("while_loop") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + while_loop_prim = NonInterpPrimitive("while_loop") while_loop_prim.multiple_results = True + while_loop_prim.prim_type = "higher_order" @while_loop_prim.def_impl def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice): @@ -626,10 +629,14 @@ def _decorator(body_fn): def _get_for_loop_qfunc_prim(): """Get the loop_for primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive - for_loop_prim = create_non_interpreted_prim()("for_loop") + for_loop_prim = NonInterpPrimitive("for_loop") for_loop_prim.multiple_results = True + for_loop_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @for_loop_prim.def_impl diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 48f8ffbc686..38d5e1b34a1 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -591,13 +591,15 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio """ updated_values = {} - jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT} - updated_values["convert_to_numpy"] = ( - execution_config.interface not in jax_interfaces - or execution_config.gradient_method == "adjoint" - # need numpy to use caching, and need caching higher order derivatives - or execution_config.derivative_order > 1 - ) + # uncomment once compilation overhead with jitting improved + # TODO: [sc-82874] + # jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT} + # updated_values["convert_to_numpy"] = ( + # execution_config.interface not in jax_interfaces + # or execution_config.gradient_method == "adjoint" + # # need numpy to use caching, and need caching higher order derivatives + # or execution_config.derivative_order > 1 + # ) for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") @@ -643,6 +645,7 @@ def execute( prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))] if max_workers is None: + return tuple( _simulate_wrapper( c, diff --git a/pennylane/gradients/hamiltonian_grad.py b/pennylane/gradients/hamiltonian_grad.py index e83d942dcc9..95769bdc6d1 100644 --- a/pennylane/gradients/hamiltonian_grad.py +++ b/pennylane/gradients/hamiltonian_grad.py @@ -13,6 +13,8 @@ # limitations under the License. """Contains a gradient recipe for the coefficients of Hamiltonians.""" # pylint: disable=protected-access,unnecessary-lambda +import warnings + import pennylane as qml @@ -20,10 +22,19 @@ def hamiltonian_grad(tape, idx): """Computes the tapes necessary to get the gradient of a tape with respect to a Hamiltonian observable's coefficients. + .. warning:: + This function is deprecated and will be removed in v0.42. This gradient recipe is not + required for the new operator arithmetic of PennyLane. + Args: tape (qml.tape.QuantumTape): tape with a single Hamiltonian expectation as measurement idx (int): index of parameter that we differentiate with respect to """ + warnings.warn( + "The 'hamiltonian_grad' function is deprecated and will be removed in v0.42. " + "This gradient recipe is not required for the new operator arithmetic system.", + qml.PennyLaneDeprecationWarning, + ) op, m_pos, p_idx = tape.get_operation(idx) diff --git a/pennylane/gradients/parameter_shift.py b/pennylane/gradients/parameter_shift.py index 6fcdd17df19..26a2b0e4a8b 100644 --- a/pennylane/gradients/parameter_shift.py +++ b/pennylane/gradients/parameter_shift.py @@ -15,6 +15,7 @@ This module contains functions for computing the parameter-shift gradient of a qubit-based quantum tape. """ +import warnings from functools import partial import numpy as np @@ -372,6 +373,12 @@ def expval_param_shift( op, op_idx, _ = tape.get_operation(idx) if op.name == "LinearCombination": + warnings.warn( + "Please use qml.gradients.split_to_single_terms so that the ML framework " + "can compute the gradients of the coefficients.", + UserWarning, + ) + # operation is a Hamiltonian if tape[op_idx].return_type is not qml.measurements.Expectation: raise ValueError( diff --git a/pennylane/labs/dla/lie_closure_dense.py b/pennylane/labs/dla/lie_closure_dense.py index 71131cd1064..a2fb76d02e9 100644 --- a/pennylane/labs/dla/lie_closure_dense.py +++ b/pennylane/labs/dla/lie_closure_dense.py @@ -97,9 +97,11 @@ def lie_closure_dense( Compute the Lie closure of the isotropic Heisenberg model with generators :math:`\{X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}\}_{i=0}^{n-1}`. + >>> from pennylane import X, Y, Z + >>> from pennylane.labs.dla import lie_closure_dense >>> n = 5 >>> gens = [X(i) @ X(i+1) + Y(i) @ Y(i+1) + Z(i) @ Z(i+1) for i in range(n-1)] - >>> g = lie_closure_mat(gens, n) + >>> g = lie_closure_dense(gens, n) The result is a ``numpy`` array. We can turn the matrices back into PennyLane operators by employing :func:`~batched_pauli_decompose`. diff --git a/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py b/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py index 6d59fea1b96..440d053dee0 100644 --- a/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py +++ b/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py @@ -643,22 +643,17 @@ class TestResourceMultiControlledX: """Test the ResourceMultiControlledX operation""" res_ops = ( - re.ResourceMultiControlledX(control_wires=[0], wires=["t"], control_values=[1]), - re.ResourceMultiControlledX(control_wires=[0, 1], wires=["t"], control_values=[1, 1]), - re.ResourceMultiControlledX(control_wires=[0, 1, 2], wires=["t"], control_values=[1, 1, 1]), + re.ResourceMultiControlledX(wires=[0, "t"], control_values=[1]), + re.ResourceMultiControlledX(wires=[0, 1, "t"], control_values=[1, 1]), + re.ResourceMultiControlledX(wires=[0, 1, 2, "t"], control_values=[1, 1, 1]), + re.ResourceMultiControlledX(wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 1, 1, 1, 1]), + re.ResourceMultiControlledX(wires=[0, "t"], control_values=[0], work_wires=["w1"]), re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], wires=["t"], control_values=[1, 1, 1, 1, 1] + wires=[0, 1, "t"], control_values=[1, 0], work_wires=["w1", "w2"] ), + re.ResourceMultiControlledX(wires=[0, 1, 2, "t"], control_values=[0, 0, 1]), re.ResourceMultiControlledX( - control_wires=[0], wires=["t"], control_values=[0], work_wires=["w1"] - ), - re.ResourceMultiControlledX( - control_wires=[0, 1], wires=["t"], control_values=[1, 0], work_wires=["w1", "w2"] - ), - re.ResourceMultiControlledX(control_wires=[0, 1, 2], wires=["t"], control_values=[0, 0, 1]), - re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], - wires=["t"], + wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 0, 0, 1, 0], work_wires=["w1"], ), @@ -732,8 +727,7 @@ def test_resource_params(self, op, params): def test_resource_adjoint(self): """Test that the adjoint resources are as expected""" op = re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], - wires=["t"], + wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 0, 0, 1, 0], work_wires=["w1"], ) @@ -777,7 +771,7 @@ def test_resource_adjoint(self): ) def test_resource_controlled(self, ctrl_wires, ctrl_values, work_wires, expected_res): """Test that the controlled resources are as expected""" - op = re.ResourceMultiControlledX(control_wires=[0], wires=["t"], control_values=[1]) + op = re.ResourceMultiControlledX(wires=[0, "t"], control_values=[1]) num_ctrl_wires = len(ctrl_wires) num_ctrl_values = len([v for v in ctrl_values if not v]) @@ -806,8 +800,7 @@ def test_resource_controlled(self, ctrl_wires, ctrl_values, work_wires, expected def test_resource_pow(self, z, expected_res): """Test that the pow resources are as expected""" op = re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], - wires=["t"], + wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 0, 0, 1, 0], work_wires=["w1"], ) diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 5cdcd8cd708..3c9bdc8f1a8 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -243,9 +243,12 @@ def _create_mid_measure_primitive(): measurement. """ - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - mid_measure_p = jax.core.Primitive("measure") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + mid_measure_p = NonInterpPrimitive("measure") @mid_measure_p.def_impl def _(wires, reset=False, postselect=None): diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index cbc4d0a0bdb..18da1b9d284 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -228,6 +228,8 @@ def shape(self, shots: Optional[int] = None, num_device_wires: int = 0) -> tuple ) if self.obs: num_values_per_shot = 1 # one single eigenvalue + elif self.mv is not None: + num_values_per_shot = 1 if isinstance(self.mv, MeasurementValue) else len(self.mv) else: # one value per wire num_values_per_shot = len(self.wires) if len(self.wires) > 0 else num_device_wires diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index 400f2fc83c0..5bda04440d4 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -18,7 +18,6 @@ from typing import Callable, overload import pennylane as qml -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.math import conj, moveaxis, transpose from pennylane.operation import Observable, Operation, Operator @@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy): def _get_adjoint_qfunc_prim(): """See capture/explanations.md : Higher Order primitives for more information on this code.""" # if capture is enabled, jax should be installed - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive - adjoint_prim = create_non_interpreted_prim()("adjoint_transform") + adjoint_prim = NonInterpPrimitive("adjoint_transform") adjoint_prim.multiple_results = True + adjoint_prim.prim_type = "higher_order" @adjoint_prim.def_impl def _(*args, jaxpr, lazy, n_consts): diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index deace92e73c..a15fdafff1d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -20,7 +20,6 @@ import pennylane as qml from pennylane import QueuingManager -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from pennylane.compiler import compiler from pennylane.measurements import MeasurementValue @@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement def _get_cond_qfunc_prim(): """Get the cond primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - cond_prim = create_non_interpreted_prim()("cond") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + cond_prim = NonInterpPrimitive("cond") cond_prim.multiple_results = True + cond_prim.prim_type = "higher_order" @cond_prim.def_impl def _(*all_args, jaxpr_branches, consts_slices, args_slice): diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index d49209660c6..17e62323223 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -28,7 +28,6 @@ import pennylane as qml from pennylane import math as qmlmath from pennylane import operation -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.operation import Operator from pennylane.wires import Wires, WiresLike @@ -233,10 +232,15 @@ def wrapper(*args, **kwargs): def _get_ctrl_qfunc_prim(): """See capture/explanations.md : Higher Order primitives for more information on this code.""" # if capture is enabled, jax should be installed - import jax # pylint: disable=import-outside-toplevel - ctrl_prim = create_non_interpreted_prim()("ctrl_transform") + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive + + ctrl_prim = NonInterpPrimitive("ctrl_transform") ctrl_prim.multiple_results = True + ctrl_prim.prim_type = "higher_order" @ctrl_prim.def_impl def _(*args, n_control, jaxpr, control_values, work_wires, n_consts): diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index e00bda09c8d..1cefb724cd2 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -540,12 +540,13 @@ def final_transform(self): def _create_transform_primitive(name): try: # pylint: disable=import-outside-toplevel - import jax + from pennylane.capture.custom_primitives import NonInterpPrimitive except ImportError: return None - transform_prim = jax.core.Primitive(name + "_transform") + transform_prim = NonInterpPrimitive(name + "_transform") transform_prim.multiple_results = True + transform_prim.prim_type = "transform" @transform_prim.def_impl def _( diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 418e68941a6..85dc0320fb7 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -70,7 +70,7 @@ def _get_plxpr_cancel_inverses(): # pylint: disable=missing-function-docstring, # pylint: disable=import-outside-toplevel from jax import make_jaxpr - from pennylane.capture import AbstractMeasurement, AbstractOperator, PlxprInterpreter + from pennylane.capture import PlxprInterpreter from pennylane.operation import Operator except ImportError: # pragma: no cover return None, None @@ -204,15 +204,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractOperator): + elif getattr(eqn.primitive, "prim_type", "") == "operator": outvals = self.interpret_operation_eqn(eqn) - elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractMeasurement): + elif getattr(eqn.primitive, "prim_type", "") == "measurement": self.interpret_all_previous_ops() outvals = self.interpret_measurement_eqn(eqn) else: # Transform primitives don't have custom handlers, so we check for them here # to purge the stored ops in self.previous_ops - if eqn.primitive.name.endswith("_transform"): + if getattr(eqn.primitive, "prim_type", "") == "transform": self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 05f6b196440..2552770a159 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -107,7 +107,6 @@ """ from copy import copy -from dataclasses import asdict from functools import partial from numbers import Number from warnings import warn @@ -117,6 +116,7 @@ import pennylane as qml from pennylane.capture import FlatFn +from pennylane.capture.custom_primitives import QmlPrimitive from pennylane.typing import TensorLike @@ -177,8 +177,9 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=( return shapes -qnode_prim = jax.core.Primitive("qnode") +qnode_prim = QmlPrimitive("qnode") qnode_prim.multiple_results = True +qnode_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments, unused-argument @@ -249,7 +250,6 @@ def _qnode_batching_rule( "using parameter broadcasting to a quantum operation that supports batching.", UserWarning, ) - # To resolve this ambiguity, we might add more properties to the AbstractOperator # class to indicate which operators support batching and check them here. # As above, at this stage we raise a warning and give the user full flexibility. @@ -277,15 +277,43 @@ def _qnode_batching_rule( return result, (0,) * len(result) +### JVP CALCULATION ######################################################### +# This structure will change as we add more diff methods + + def _make_zero(tan, arg): return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan -def _qnode_jvp(args, tangents, **impl_kwargs): +def _backprop(args, tangents, **impl_kwargs): tangents = tuple(map(_make_zero, tangents, args)) return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents) +diff_method_map = {"backprop": _backprop} + + +def _resolve_diff_method(diff_method: str, device) -> str: + # check if best is backprop + if diff_method == "best": + config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax") + diff_method = device.setup_execution_config(config).gradient_method + + if diff_method not in diff_method_map: + raise NotImplementedError(f"diff_method {diff_method} not yet implemented.") + + return diff_method + + +def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs): + diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device) + return diff_method_map[diff_method]( + args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs + ) + + +### END JVP CALCULATION ####################################################### + ad.primitive_jvps[qnode_prim] = _qnode_jvp batching.primitive_batchers[qnode_prim] = _qnode_batching_rule @@ -375,8 +403,7 @@ def f(x): qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args) execute_kwargs = copy(qnode.execute_kwargs) - mcm_config = asdict(execute_kwargs.pop("mcm_config")) - qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} + qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs} flat_args = jax.tree_util.tree_leaves(args) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index cfe93f322d2..fb9ac5ef708 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -47,14 +47,16 @@ def execute( diff_method: Optional[Union[Callable, SupportedDiffMethods, TransformDispatcher]] = None, interface: Optional[InterfaceLike] = Interface.AUTO, *, + transform_program: TransformProgram = None, grad_on_execution: Literal[bool, "best"] = "best", cache: Union[None, bool, dict, Cache] = True, cachesize: int = 10000, max_diff: int = 1, device_vjp: Union[bool, None] = False, + postselect_mode=None, + mcm_method=None, gradient_kwargs: dict = None, - transform_program: TransformProgram = None, - mcm_config: "qml.devices.MCMConfig" = None, + mcm_config: "qml.devices.MCMConfig" = "unset", config="unset", inner_transform="unset", ) -> ResultBatch: @@ -86,10 +88,18 @@ def execute( (classical) computational overhead during the backward pass. device_vjp=False (Optional[bool]): whether or not to use the device-provided Jacobian product if it is available. - mcm_config (dict): Dictionary containing configuration options for handling - mid-circuit measurements. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Default is ``None``. + mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. + ``"deferred"`` is ignored. If mid-circuit measurements are found in the circuit, + the device will use ``"tree-traversal"`` if specified and the ``"one-shot"`` method + otherwise. For usage details, please refer to the + :doc:`dynamic quantum circuits page `. gradient_kwargs (dict): dictionary of keyword arguments to pass when determining the gradients of tapes. + mcm_config="unset": **DEPRECATED**. This keyword argument has been replaced by ``postselect_mode`` + and ``mcm_method`` and will be removed in v0.42. config="unset": **DEPRECATED**. This keyword argument has been deprecated and will be removed in v0.42. inner_transform="unset": **DEPRECATED**. This keyword argument has been deprecated @@ -173,6 +183,14 @@ def cost_fn(params, x): qml.PennyLaneDeprecationWarning, ) + if mcm_config != "unset": + warn( + "The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.", + qml.PennyLaneDeprecationWarning, + ) + mcm_method = mcm_config.mcm_method + postselect_mode = mcm_config.postselect_mode + if logger.isEnabledFor(logging.DEBUG): logger.debug( ( @@ -209,7 +227,7 @@ def cost_fn(params, x): gradient_method=diff_method, grad_on_execution=None if grad_on_execution == "best" else grad_on_execution, use_device_jacobian_product=device_vjp, - mcm_config=mcm_config or {}, + mcm_config=qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method), gradient_keyword_arguments=gradient_kwargs or {}, derivative_order=max_diff, ) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 29430c9a645..ea1a9e33900 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -107,6 +107,10 @@ def _to_qfunc_output_type( return qml.pytrees.unflatten(results, qfunc_output_structure) +def _validate_mcm_config(postselect_mode: str, mcm_method: str) -> None: + qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method) + + def _validate_gradient_kwargs(gradient_kwargs: dict) -> None: for kwarg in gradient_kwargs: if kwarg == "expansion_strategy": @@ -133,7 +137,8 @@ def _validate_gradient_kwargs(gradient_kwargs: dict) -> None: elif kwarg not in qml.gradients.SUPPORTED_GRADIENT_KWARGS: warnings.warn( f"Received gradient_kwarg {kwarg}, which is not included in the list of " - "standard qnode gradient kwargs." + "standard qnode gradient kwargs. Please specify all gradient kwargs through " + "the gradient_kwargs argument as a dictionary." ) @@ -284,9 +289,7 @@ class QNode: as the name suggests. If not provided, the device will determine the best choice automatically. For usage details, please refer to the :doc:`dynamic quantum circuits page `. - - Keyword Args: - **kwargs: Any additional keyword arguments provided are passed to the differentiation + gradient_kwargs (dict): A dictionary of keyword arguments that are passed to the differentiation method. Please refer to the :mod:`qml.gradients <.gradients>` module for details on supported options for your chosen gradient transform. @@ -505,10 +508,12 @@ def __init__( device_vjp: Union[None, bool] = False, postselect_mode: Literal[None, "hw-like", "fill-shots"] = None, mcm_method: Literal[None, "deferred", "one-shot", "tree-traversal"] = None, - **gradient_kwargs, + gradient_kwargs: Optional[dict] = None, + **kwargs, ): self._init_args = locals() del self._init_args["self"] + del self._init_args["kwargs"] if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -536,7 +541,16 @@ def __init__( if not isinstance(device, qml.devices.Device): device = qml.devices.LegacyDeviceFacade(device) + gradient_kwargs = gradient_kwargs or {} + if kwargs: + if any(k in qml.gradients.SUPPORTED_GRADIENT_KWARGS for k in list(kwargs.keys())): + warnings.warn( + f"Specifying gradient keyword arguments {list(kwargs.keys())} is deprecated and will be removed in v0.42. Instead, please specify all arguments in the gradient_kwargs argument.", + qml.PennyLaneDeprecationWarning, + ) + gradient_kwargs |= kwargs _validate_gradient_kwargs(gradient_kwargs) + if "shots" in inspect.signature(func).parameters: warnings.warn( "Detected 'shots' as an argument to the given quantum function. " @@ -553,17 +567,18 @@ def __init__( self.device = device self._interface = get_canonical_interface_name(interface) self.diff_method = diff_method - mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) cache = (max_diff > 1) if cache == "auto" else cache # execution keyword arguments + _validate_mcm_config(postselect_mode, mcm_method) self.execute_kwargs = { "grad_on_execution": grad_on_execution, "cache": cache, "cachesize": cachesize, "max_diff": max_diff, "device_vjp": device_vjp, - "mcm_config": mcm_config, + "postselect_mode": postselect_mode, + "mcm_method": mcm_method, } # internal data attributes @@ -676,16 +691,19 @@ def circuit(x): tensor(0.5403, dtype=torch.float64) """ if not kwargs: - valid_params = ( - set(self._init_args.copy().pop("gradient_kwargs")) - | qml.gradients.SUPPORTED_GRADIENT_KWARGS - ) + valid_params = set(self._init_args.copy()) | qml.gradients.SUPPORTED_GRADIENT_KWARGS raise ValueError( f"Must specify at least one configuration property to update. Valid properties are: {valid_params}." ) original_init_args = self._init_args.copy() - gradient_kwargs = original_init_args.pop("gradient_kwargs") - original_init_args.update(gradient_kwargs) + # gradient_kwargs defaults to None + original_init_args["gradient_kwargs"] = original_init_args["gradient_kwargs"] or {} + # nested dictionary update + new_gradient_kwargs = kwargs.pop("gradient_kwargs", {}) + old_gradient_kwargs = original_init_args.get("gradient_kwargs").copy() + old_gradient_kwargs.update(new_gradient_kwargs) + kwargs["gradient_kwargs"] = old_gradient_kwargs + original_init_args.update(kwargs) updated_qn = QNode(**original_init_args) # pylint: disable=protected-access diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 0ac083a768c..680d11f083e 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -14,7 +14,6 @@ """ Tests for capturing a qnode into jaxpr. """ -from dataclasses import asdict from functools import partial # pylint: disable=protected-access @@ -130,7 +129,6 @@ def circuit(x): assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) - expected_kwargs.update(asdict(expected_kwargs.pop("mcm_config"))) assert eqn0.params["qnode_kwargs"] == expected_kwargs qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] @@ -339,18 +337,61 @@ def circuit(x): assert list(out.keys()) == ["a", "b"] -def test_qnode_jvp(): - """Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule.""" +class TestDifferentiation: - @qml.qnode(qml.device("default.qubit", wires=1)) - def circuit(x): - qml.RX(x, 0) - return qml.expval(qml.Z(0)) + def test_error_backprop_unsupported(self): + """Test an error is raised with backprop if the device does not support it.""" + + # pylint: disable=too-few-public-methods + class DummyDev(qml.devices.Device): + + def execute(self, *_, **__): + return 0 + + with pytest.raises(qml.QuantumFunctionError, match="does not support backprop"): + + @qml.qnode(DummyDev(wires=2), diff_method="backprop") + def _(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + def test_error_unsupported_diff_method(self): + """Test an error is raised for a non-backprop diff method.""" + + @qml.qnode(qml.device("default.qubit", wires=2), diff_method="parameter-shift") + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + with pytest.raises( + NotImplementedError, match="diff_method parameter-shift not yet implemented." + ): + jax.grad(circuit)(0.5) + + @pytest.mark.parametrize("diff_method", ("best", "backprop")) + def test_default_qubit_backprop(self, diff_method): + """Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule.""" + + @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + x = 0.9 + xt = -0.6 + jvp = jax.jvp(circuit, (x,), (xt,)) + assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt)) + + def test_no_gradients_with_lightning(self): + """Test that we get an error if we try and differentiate a lightning execution.""" + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) - x = 0.9 - xt = -0.6 - jvp = jax.jvp(circuit, (x,), (xt,)) - assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt)) + with pytest.raises(NotImplementedError, match=r"diff_method adjoint not yet implemented"): + jax.grad(circuit)(0.5) def test_qnode_jit(): diff --git a/tests/capture/test_custom_primitives.py b/tests/capture/test_custom_primitives.py new file mode 100644 index 00000000000..3d35e3e57e4 --- /dev/null +++ b/tests/capture/test_custom_primitives.py @@ -0,0 +1,48 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for PennyLane custom primitives. +""" +# pylint: disable=wrong-import-position +import pytest + +jax = pytest.importorskip("jax") + +from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive + +pytestmark = pytest.mark.jax + + +def test_qml_primitive_prim_type_default(): + """Test that the default prim_type of a QmlPrimitive is set correctly.""" + prim = QmlPrimitive("primitive") + assert prim._prim_type == PrimitiveType("default") # pylint: disable=protected-access + assert prim.prim_type == "default" + + +@pytest.mark.parametrize("cast_in_enum", [True, False]) +@pytest.mark.parametrize("prim_type", ["operator", "measurement", "transform", "higher_order"]) +def test_qml_primitive_prim_type_setter(prim_type, cast_in_enum): + """Test that the QmlPrimitive.prim_type setter works correctly""" + prim = QmlPrimitive("primitive") + prim.prim_type = PrimitiveType(prim_type) if cast_in_enum else prim_type + assert prim._prim_type == PrimitiveType(prim_type) # pylint: disable=protected-access + assert prim.prim_type == prim_type + + +def test_qml_primitive_prim_type_setter_invalid(): + """Test that setting an invalid prim_type raises an error""" + prim = QmlPrimitive("primitive") + with pytest.raises(ValueError, match="not a valid PrimitiveType"): + prim.prim_type = "blah" diff --git a/tests/capture/test_switches.py b/tests/capture/test_switches.py index 52f50321740..72a170205e7 100644 --- a/tests/capture/test_switches.py +++ b/tests/capture/test_switches.py @@ -32,10 +32,15 @@ def test_switches_with_jax(): def test_switches_without_jax(): """Test switches and status reporting function.""" - - assert qml.capture.enabled() is False - with pytest.raises(ImportError, match="plxpr requires JAX to be installed."): - qml.capture.enable() - assert qml.capture.enabled() is False - assert qml.capture.disable() is None - assert qml.capture.enabled() is False + # We want to skip the test if jax is installed + try: + # pylint: disable=import-outside-toplevel, unused-import + import jax + except ImportError: + + assert qml.capture.enabled() is False + with pytest.raises(ImportError, match="plxpr requires JAX to be installed."): + qml.capture.enable() + assert qml.capture.enabled() is False + assert qml.capture.disable() is None + assert qml.capture.enabled() is False diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 42faaa10809..32367c455f4 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -389,7 +389,10 @@ def func(x, y, z): results1 = func1(*params) jaxpr = str(jax.make_jaxpr(func)(*params)) - assert "pure_callback" not in jaxpr + # will change once we solve the compilation overhead issue + # assert "pure_callback" not in jaxpr + # TODO: [sc-82874] + assert "pure_callback" in jaxpr func2 = jax.jit(func) results2 = func2(*params) diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py index 59a9098f7ce..31a2c978745 100644 --- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py +++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py @@ -142,15 +142,17 @@ def circuit(x): assert dev.tracker.totals["execute_and_derivative_batches"] == 1 @pytest.mark.parametrize("interface", ("jax", "jax-jit")) - def test_not_convert_to_numpy_with_jax(self, interface): + def test_convert_to_numpy_with_jax(self, interface): """Test that we will not convert to numpy when working with jax.""" - + # separate test so we can easily update it once we solve the + # compilation overhead issue + # TODO: [sc-82874] dev = qml.device("default.qubit") config = qml.devices.ExecutionConfig( gradient_method=qml.gradients.param_shift, interface=interface ) processed = dev.setup_execution_config(config) - assert not processed.convert_to_numpy + assert processed.convert_to_numpy def test_convert_to_numpy_with_adjoint(self): """Test that we will convert to numpy with adjoint.""" diff --git a/tests/gradients/core/test_hamiltonian_gradient.py b/tests/gradients/core/test_hamiltonian_gradient.py index 1bcb4bfc4fe..9474faf39a0 100644 --- a/tests/gradients/core/test_hamiltonian_gradient.py +++ b/tests/gradients/core/test_hamiltonian_gradient.py @@ -12,10 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for the gradients.hamiltonian module.""" +import pytest + import pennylane as qml from pennylane.gradients.hamiltonian_grad import hamiltonian_grad +def test_hamiltonian_grad_deprecation(): + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + with qml.queuing.AnnotatedQueue() as q: + qml.RY(0.3, wires=0) + qml.RX(0.5, wires=1) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.Hamiltonian([-1.5, 2.0], [qml.PauliZ(0), qml.PauliZ(1)])) + + tape = qml.tape.QuantumScript.from_queue(q) + tape.trainable_params = {2, 3} + hamiltonian_grad(tape, idx=0) + + def test_behaviour(): """Test that the function behaves as expected.""" @@ -29,10 +46,16 @@ def test_behaviour(): tape = qml.tape.QuantumScript.from_queue(q) tape.trainable_params = {2, 3} - tapes, processing_fn = hamiltonian_grad(tape, idx=0) + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + tapes, processing_fn = hamiltonian_grad(tape, idx=0) res1 = processing_fn(dev.execute(tapes)) - tapes, processing_fn = hamiltonian_grad(tape, idx=1) + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + tapes, processing_fn = hamiltonian_grad(tape, idx=1) res2 = processing_fn(dev.execute(tapes)) with qml.queuing.AnnotatedQueue() as q1: diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index c684d53b0c6..b06b7c711d0 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -1385,7 +1385,10 @@ def test_simple_qnode_expval(self, num_split_times, shots, tol, seed): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1415,7 +1418,10 @@ def test_simple_qnode_expval_two_evolves(self, num_split_times, shots, tol, seed ham_y = qml.pulse.constant * qml.PauliX(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_x)(params[0], T_x) @@ -1444,7 +1450,10 @@ def test_simple_qnode_probs(self, num_split_times, shots, tol, seed): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1471,7 +1480,10 @@ def test_simple_qnode_probs_expval(self, num_split_times, shots, tol, seed): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1490,6 +1502,7 @@ def circuit(params): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) jax.clear_caches() + @pytest.mark.xfail # TODO: [sc-82874] @pytest.mark.parametrize("num_split_times", [1, 2]) @pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"]) def test_simple_qnode_jit(self, num_split_times, time_interface): @@ -1503,7 +1516,10 @@ def test_simple_qnode_jit(self, num_split_times, time_interface): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params, T=None): qml.evolve(ham_single_q_const)(params, T) @@ -1542,8 +1558,7 @@ def ansatz(params): dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - sampler_seed=seed, + gradient_kwargs={"num_split_times": num_split_times, "sampler_seed": seed}, ) qnode_backprop = qml.QNode(ansatz, dev, interface="jax") @@ -1574,8 +1589,7 @@ def test_qnode_probs_expval_broadcasting(self, num_split_times, shots, tol, seed dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - use_broadcasting=True, + gradient_kwargs={"num_split_times": num_split_times, "use_broadcasting": True}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1619,18 +1633,22 @@ def ansatz(params): dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - use_broadcasting=True, - sampler_seed=seed, + gradient_kwargs={ + "num_split_times": num_split_times, + "use_broadcasting": True, + "sampler_seed": seed, + }, ) circuit_no_bc = qml.QNode( ansatz, dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - use_broadcasting=False, - sampler_seed=seed, + gradient_kwargs={ + "num_split_times": num_split_times, + "use_broadcasting": False, + "sampler_seed": seed, + }, ) params = [jnp.array(0.4)] jac_bc = jax.jacobian(circuit_bc)(params) @@ -1684,9 +1702,7 @@ def ansatz(params): dev, interface="jax", diff_method=qml.gradients.stoch_pulse_grad, - num_split_times=7, - use_broadcasting=True, - sampler_seed=seed, + gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True}, ) cost_jax = qml.QNode(ansatz, dev, interface="jax") params = (0.42,) @@ -1729,9 +1745,7 @@ def ansatz(params): dev, interface="jax", diff_method=qml.gradients.stoch_pulse_grad, - num_split_times=7, - use_broadcasting=True, - sampler_seed=seed, + gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True}, ) cost_jax = qml.QNode(ansatz, dev, interface="jax") diff --git a/tests/gradients/finite_diff/test_spsa_gradient.py b/tests/gradients/finite_diff/test_spsa_gradient.py index 1bd2a198bca..413bc8ff6ec 100644 --- a/tests/gradients/finite_diff/test_spsa_gradient.py +++ b/tests/gradients/finite_diff/test_spsa_gradient.py @@ -161,7 +161,7 @@ def test_invalid_sampler_rng(self): """Tests that if sampler_rng has an unexpected type, an error is raised.""" dev = qml.device("default.qubit", wires=1) - @qml.qnode(dev, diff_method="spsa", sampler_rng="foo") + @qml.qnode(dev, diff_method="spsa", gradient_kwargs={"sampler_rng": "foo"}) def circuit(param): qml.RX(param, wires=0) return qml.expval(qml.PauliZ(0)) diff --git a/tests/gradients/parameter_shift/test_cv_gradients.py b/tests/gradients/parameter_shift/test_cv_gradients.py index c9709753283..55955396ab7 100644 --- a/tests/gradients/parameter_shift/test_cv_gradients.py +++ b/tests/gradients/parameter_shift/test_cv_gradients.py @@ -268,7 +268,11 @@ def qf(x, y): grad_F = jax.grad(qf)(*par) - @qml.qnode(device=gaussian_dev, diff_method="parameter-shift", force_order2=True) + @qml.qnode( + device=gaussian_dev, + diff_method="parameter-shift", + gradient_kwargs={"force_order2": True}, + ) def qf2(x, y): qml.Displacement(0.5, 0, wires=[0]) qml.Squeezing(x, 0, wires=[0]) diff --git a/tests/gradients/parameter_shift/test_parameter_shift.py b/tests/gradients/parameter_shift/test_parameter_shift.py index f35a4b4fc95..4b741c5a66b 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift.py +++ b/tests/gradients/parameter_shift/test_parameter_shift.py @@ -3473,58 +3473,6 @@ def test_trainable_coeffs(self, tol, broadcast): assert np.allclose(res[0], expected[0], atol=tol, rtol=0) assert np.allclose(res[1], expected[1], atol=tol, rtol=0) - def test_multiple_hamiltonians(self, tol, broadcast): - """Test multiple trainable Hamiltonian coefficients""" - dev = qml.device("default.qubit", wires=2) - - obs = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)] - coeffs = np.array([0.1, 0.2, 0.3]) - a, b, c = coeffs - H1 = qml.Hamiltonian(coeffs, obs) - - obs = [qml.PauliZ(0)] - coeffs = np.array([0.7]) - d = coeffs[0] - H2 = qml.Hamiltonian(coeffs, obs) - - weights = np.array([0.4, 0.5]) - x, y = weights - - with qml.queuing.AnnotatedQueue() as q: - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - qml.expval(H1) - qml.expval(H2) - - tape = qml.tape.QuantumScript.from_queue(q) - tape.trainable_params = {0, 1, 2, 4, 5} - - res = dev.execute([tape]) - expected = [-c * np.sin(x) * np.sin(y) + np.cos(x) * (a + b * np.sin(y)), d * np.cos(x)] - assert np.allclose(res, expected, atol=tol, rtol=0) - - tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast) - # two shifts per rotation gate (in one batched tape if broadcasting), - # one circuit per trainable H term - assert len(tapes) == 2 * (1 if broadcast else 2) - - res = fn(dev.execute(tapes)) - assert isinstance(res, tuple) - assert len(res) == 2 - assert len(res[0]) == 2 - assert len(res[1]) == 2 - - expected = [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - ], - [-d * np.sin(x), 0], - ] - - assert np.allclose(np.stack(res), expected, atol=tol, rtol=0) - @staticmethod def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False): """Cost function for gradient tests""" @@ -3547,95 +3495,8 @@ def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False): jac = fn(dev.execute(tapes)) return jac - @staticmethod - def cost_fn_expected(weights, coeffs1, coeffs2): - """Analytic jacobian of cost_fn above""" - a, b, c = coeffs1 - d = coeffs2[0] - x, y = weights - return [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - ], - [-d * np.sin(x), 0], - ] - - @pytest.mark.autograd - def test_autograd(self, tol, broadcast): - """Test gradient of multiple trainable Hamiltonian coefficients - using autograd""" - coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True) - coeffs2 = np.array([0.7], requires_grad=True) - weights = np.array([0.4, 0.5], requires_grad=True) - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(res, np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # res = qml.jacobian(self.cost_fn)(weights, coeffs1, coeffs2, dev=dev) - # assert np.allclose(res[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(res[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.tf - def test_tf(self, tol, broadcast): - """Test gradient of multiple trainable Hamiltonian coefficients - using tf""" - import tensorflow as tf - - coeffs1 = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64) - coeffs2 = tf.Variable([0.7], dtype=tf.float64) - weights = tf.Variable([0.4, 0.5], dtype=tf.float64) - - dev = qml.device("default.qubit", wires=2) - - with tf.GradientTape() as _: - jac = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - - expected = self.cost_fn_expected(weights.numpy(), coeffs1.numpy(), coeffs2.numpy()) - assert np.allclose(jac[0], np.array(expected)[0], atol=tol, rtol=0) - assert np.allclose(jac[1], np.array(expected)[1], atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero. - # When activating the following, rename the GradientTape above from _ to t - # --- - # hess = t.jacobian(jac, [coeffs1, coeffs2]) - # assert np.allclose(hess[0][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[1][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.torch - def test_torch(self, tol, broadcast): - """Test gradient of multiple trainable Hamiltonian coefficients - using torch""" - import torch - - coeffs1 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64, requires_grad=True) - coeffs2 = torch.tensor([0.7], dtype=torch.float64, requires_grad=True) - weights = torch.tensor([0.4, 0.5], dtype=torch.float64, requires_grad=True) - - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected( - weights.detach().numpy(), coeffs1.detach().numpy(), coeffs2.detach().numpy() - ) - res = tuple(tuple(_r.detach() for _r in r) for r in res) - assert np.allclose(res, expected, atol=tol, rtol=0) - - # second derivative wrt to Hamiltonian coefficients should be zero - # hess = torch.autograd.functional.jacobian( - # lambda *args: self.cost_fn(*args, dev, broadcast), (weights, coeffs1, coeffs2) - # ) - # assert np.allclose(hess[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - @pytest.mark.jax - def test_jax(self, tol, broadcast): + def test_jax(self, broadcast): """Test gradient of multiple trainable Hamiltonian coefficients using JAX""" import jax @@ -3647,19 +3508,11 @@ def test_jax(self, tol, broadcast): weights = jnp.array([0.4, 0.5]) dev = qml.device("default.qubit", wires=2) - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(np.array(res)[:, :2], np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # second derivative wrt to Hamiltonian coefficients should be zero - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + with pytest.warns(UserWarning, match="Please use qml.gradients.split_to_single_terms"): + self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) @pytest.mark.autograd diff --git a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py index 34478ccecd0..e3151735e98 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py +++ b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py @@ -1970,12 +1970,12 @@ def expval(self, observable, **kwargs): dev = DeviceSupporingSpecialObservable(wires=1, shots=None) - @qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast}) def qnode(x): qml.RY(x, wires=0) return qml.expval(SpecialObservable(wires=0)) - @qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast}) def reference_qnode(x): qml.RY(x, wires=0) return qml.expval(qml.PauliZ(wires=0)) @@ -2128,221 +2128,6 @@ def test_trainable_coeffs(self, broadcast, tol): for r in res: assert qml.math.allclose(r, expected, atol=shot_vec_tol) - @pytest.mark.xfail(reason="TODO") - def test_multiple_hamiltonians(self, mocker, broadcast, tol): - """Test multiple trainable Hamiltonian coefficients""" - shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=2, shots=shot_vec) - spy = mocker.spy(qml.gradients, "hamiltonian_grad") - - obs = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)] - coeffs = np.array([0.1, 0.2, 0.3]) - a, b, c = coeffs - H1 = qml.Hamiltonian(coeffs, obs) - - obs = [qml.PauliZ(0)] - coeffs = np.array([0.7]) - d = coeffs[0] - H2 = qml.Hamiltonian(coeffs, obs) - - weights = np.array([0.4, 0.5]) - x, y = weights - - with qml.queuing.AnnotatedQueue() as q: - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - qml.expval(H1) - qml.expval(H2) - - tape = qml.tape.QuantumScript.from_queue(q, shots=shot_vec) - tape.trainable_params = {0, 1, 2, 4, 5} - - res = dev.execute([tape]) - expected = [-c * np.sin(x) * np.sin(y) + np.cos(x) * (a + b * np.sin(y)), d * np.cos(x)] - assert np.allclose(res, expected, atol=tol, rtol=0) - - if broadcast: - with pytest.raises( - NotImplementedError, match="Broadcasting with multiple measurements" - ): - qml.gradients.param_shift(tape, broadcast=broadcast) - return - - tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast) - # two shifts per rotation gate, one circuit per trainable H term - assert len(tapes) == 2 * 2 + 3 - spy.assert_called() - - res = fn(dev.execute(tapes)) - assert isinstance(res, tuple) - assert len(res) == 2 - assert len(res[0]) == 5 - assert len(res[1]) == 5 - - expected = [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - np.cos(x), - -(np.sin(x) * np.sin(y)), - 0, - ], - [-d * np.sin(x), 0, 0, 0, np.cos(x)], - ] - - assert np.allclose(np.stack(res), expected, atol=tol, rtol=0) - - @staticmethod - def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False): - """Cost function for gradient tests""" - obs1 = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)] - H1 = qml.Hamiltonian(coeffs1, obs1) - - obs2 = [qml.PauliZ(0)] - H2 = qml.Hamiltonian(coeffs2, obs2) - - with qml.queuing.AnnotatedQueue() as q: - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - qml.expval(H1) - qml.expval(H2) - - tape = qml.tape.QuantumScript.from_queue(q, shots=dev.shots) - tape.trainable_params = {0, 1, 2, 3, 4, 5} - tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast) - return fn(dev.execute(tapes)) - - @staticmethod - def cost_fn_expected(weights, coeffs1, coeffs2): - """Analytic jacobian of cost_fn above""" - a, b, c = coeffs1 - d = coeffs2[0] - x, y = weights - return [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - np.cos(x), - np.cos(x) * np.sin(y), - -(np.sin(x) * np.sin(y)), - 0, - ], - [-d * np.sin(x), 0, 0, 0, 0, np.cos(x)], - ] - - @pytest.mark.xfail(reason="TODO") - @pytest.mark.autograd - def test_autograd(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients - using autograd""" - coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True) - coeffs2 = np.array([0.7], requires_grad=True) - weights = np.array([0.4, 0.5], requires_grad=True) - shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=2, shots=shot_vec) - - if broadcast: - with pytest.raises( - NotImplementedError, match="Broadcasting with multiple measurements" - ): - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - return - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(res, np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # res = qml.jacobian(self.cost_fn)(weights, coeffs1, coeffs2, dev=dev) - # assert np.allclose(res[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(res[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.xfail(reason="TODO") - @pytest.mark.tf - def test_tf(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients using tf""" - import tensorflow as tf - - coeffs1 = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64) - coeffs2 = tf.Variable([0.7], dtype=tf.float64) - weights = tf.Variable([0.4, 0.5], dtype=tf.float64) - - shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=2, shots=shot_vec) - - with tf.GradientTape() as _: - jac = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - - expected = self.cost_fn_expected(weights.numpy(), coeffs1.numpy(), coeffs2.numpy()) - assert np.allclose(jac[0], np.array(expected)[0], atol=tol, rtol=0) - assert np.allclose(jac[1], np.array(expected)[1], atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # When activating the following, rename the GradientTape above from _ to t - # --- - # hess = t.jacobian(jac, [coeffs1, coeffs2]) - # assert np.allclose(hess[0][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[1][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.torch - def test_torch(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients - using torch""" - import torch - - coeffs1 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64, requires_grad=True) - coeffs2 = torch.tensor([0.7], dtype=torch.float64, requires_grad=True) - weights = torch.tensor([0.4, 0.5], dtype=torch.float64, requires_grad=True) - - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected( - weights.detach().numpy(), coeffs1.detach().numpy(), coeffs2.detach().numpy() - ) - for actual, _expected in zip(res, expected): - for val, exp_val in zip(actual, _expected): - assert qml.math.allclose(val.detach(), exp_val, atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # hess = torch.autograd.functional.jacobian( - # lambda *args: self.cost_fn(*args, dev, broadcast), (weights, coeffs1, coeffs2) - # ) - # assert np.allclose(hess[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.jax - def test_jax(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients - using JAX""" - import jax - - jnp = jax.numpy - - coeffs1 = jnp.array([0.1, 0.2, 0.3]) - coeffs2 = jnp.array([0.7]) - weights = jnp.array([0.4, 0.5]) - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(res, np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # second derivative wrt to Hamiltonian coefficients should be zero - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - pauliz = qml.PauliZ(wires=0) proj = qml.Projector([1], wires=0) diff --git a/tests/measurements/test_probs.py b/tests/measurements/test_probs.py index 07bd8fcd463..b5af8118b07 100644 --- a/tests/measurements/test_probs.py +++ b/tests/measurements/test_probs.py @@ -338,7 +338,7 @@ def circuit(): @pytest.mark.jax @pytest.mark.parametrize("shots", (None, 500)) @pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1))) - @pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2])) + @pytest.mark.parametrize("params", (np.pi / 2, [np.pi / 2, np.pi / 2, np.pi / 2])) def test_integration_jax(self, tol_stochastic, shots, obs, params, seed): """Test the probability is correct for a known state preparation when jitted with JAX.""" jax = pytest.importorskip("jax") @@ -359,7 +359,9 @@ def circuit(x): # expected probability, using [00, 01, 10, 11] # ordering, is [0.5, 0.5, 0, 0] - assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params)) + # TODO: [sc-82874] + # revert once we are able to jit end to end without extreme compilation overheads + assert "pure_callback" in str(jax.make_jaxpr(circuit)(params)) res = jax.jit(circuit)(params) expected = np.array([0.5, 0.5, 0, 0]) diff --git a/tests/ops/functions/test_matrix.py b/tests/ops/functions/test_matrix.py index 5adc8912cd2..a1f82e6077a 100644 --- a/tests/ops/functions/test_matrix.py +++ b/tests/ops/functions/test_matrix.py @@ -683,6 +683,7 @@ def circuit(theta): assert np.allclose(matrix, expected_matrix) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.catalyst @pytest.mark.external def test_catalyst(self): diff --git a/tests/ops/op_math/test_exp.py b/tests/ops/op_math/test_exp.py index 5abead84595..6ca4c68091b 100644 --- a/tests/ops/op_math/test_exp.py +++ b/tests/ops/op_math/test_exp.py @@ -768,6 +768,7 @@ def circ(phi): grad = jax.grad(circ)(phi) assert qml.math.allclose(grad, -jnp.sin(phi)) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.catalyst @pytest.mark.external def test_catalyst_qnode(self): diff --git a/tests/resource/test_specs.py b/tests/resource/test_specs.py index a02b35ef97b..5a5764e7153 100644 --- a/tests/resource/test_specs.py +++ b/tests/resource/test_specs.py @@ -33,10 +33,15 @@ class TestSpecsTransform: """Tests for the transform specs using the QNode""" def sample_circuit(self): + @qml.transforms.merge_rotations @qml.transforms.undo_swaps @qml.transforms.cancel_inverses - @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4) + @qml.qnode( + qml.device("default.qubit"), + diff_method="parameter-shift", + gradient_kwargs={"shifts": pnp.pi / 4}, + ) def circuit(x): qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1)) qml.RX(x, wires=0) @@ -222,7 +227,11 @@ def test_splitting_transforms(self): @qml.transforms.split_non_commuting @qml.transforms.merge_rotations - @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4) + @qml.qnode( + qml.device("default.qubit"), + diff_method="parameter-shift", + gradient_kwargs={"shifts": pnp.pi / 4}, + ) def circuit(x): qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1)) qml.RX(x, wires=0) diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py index 885def86e60..a63a2d20129 100644 --- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py +++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py @@ -431,7 +431,7 @@ def circuit(coeffs): qml.MottonenStatePreparation(coeffs, wires=[0, 1]) return qml.probs(wires=[0, 1]) - circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05) + circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", gradient_kwargs={"h": 0.05}) circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift") circuit_exact = qml.QNode(circuit, dev_no_shots) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 38d61cd86cf..6eb3fa851a4 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -76,6 +76,7 @@ def test_compiler(self): assert qml.compiler.available("catalyst") assert qml.compiler.available_compilers() == ["catalyst", "cuda_quantum"] + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_active_compiler(self): """Test `qml.compiler.active_compiler` inside a simple circuit""" dev = qml.device("lightning.qubit", wires=2) @@ -91,6 +92,7 @@ def circuit(phi, theta): assert jnp.allclose(circuit(jnp.pi, jnp.pi / 2), 1.0) assert jnp.allclose(qml.qjit(circuit)(jnp.pi, jnp.pi / 2), -1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_active(self): """Test `qml.compiler.active` inside a simple circuit""" dev = qml.device("lightning.qubit", wires=2) @@ -114,6 +116,7 @@ def test_jax_enable_x64(self, jax_enable_x64): qml.compiler.active() assert jax.config.jax_enable_x64 is jax_enable_x64 + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_circuit(self): """Test JIT compilation of a circuit with 2-qubit""" dev = qml.device("lightning.qubit", wires=2) @@ -128,6 +131,7 @@ def circuit(theta): assert jnp.allclose(circuit(0.5), 0.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_aot(self): """Test AOT compilation of a circuit with 2-qubit""" @@ -152,6 +156,7 @@ def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)): ) assert jnp.allclose(result, expected) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.parametrize( "_in,_out", [ @@ -196,6 +201,7 @@ def workflow1(params1, params2): result = workflow1(params1, params2) assert jnp.allclose(result, expected) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_return_value_dict(self): """Test pytree return values.""" dev = qml.device("lightning.qubit", wires=2) @@ -218,6 +224,7 @@ def circuit1(params): assert jnp.allclose(result["w0"], expected["w0"]) assert jnp.allclose(result["w1"], expected["w1"]) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_python_if(self): """Test JIT compilation with the autograph support""" dev = qml.device("lightning.qubit", wires=2) @@ -235,6 +242,7 @@ def circuit(x: int): assert jnp.allclose(circuit(3), 0.0) assert jnp.allclose(circuit(5), 1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_compilation_opt(self): """Test user-configurable compilation options""" dev = qml.device("lightning.qubit", wires=2) @@ -250,6 +258,7 @@ def circuit(x: float): result_header = "func.func public @circuit(%arg0: tensor) -> tensor" assert result_header in mlir_str + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_adjoint(self): """Test JIT compilation with adjoint support""" dev = qml.device("lightning.qubit", wires=2) @@ -273,6 +282,7 @@ def func(): assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1])) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_adjoint_lazy(self): """Test that the lazy kwarg is supported.""" dev = qml.device("lightning.qubit", wires=2) @@ -287,6 +297,7 @@ def workflow_pl(theta, wires): assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1])) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_control(self): """Test that control works with qjit.""" dev = qml.device("lightning.qubit", wires=2) @@ -317,6 +328,7 @@ def cond_fn(): class TestCatalystControlFlow: """Test ``qml.qjit`` with Catalyst's control-flow operations""" + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_alternating_while_loop(self): """Test simple while loop.""" dev = qml.device("lightning.qubit", wires=1) @@ -334,6 +346,7 @@ def loop(v): assert jnp.allclose(circuit(1), -1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_nested_while_loops(self): """Test nested while loops.""" dev = qml.device("lightning.qubit", wires=1) @@ -393,6 +406,7 @@ def loop(v): expected = [qml.PauliX(0) for i in range(4)] _ = [qml.assert_equal(i, j) for i, j in zip(tape.operations, expected)] + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_dynamic_wires_for_loops(self): """Test for loops with iteration index-dependant wires.""" dev = qml.device("lightning.qubit", wires=6) @@ -414,6 +428,7 @@ def loop_fn(i): assert jnp.allclose(circuit(6), expected) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_nested_for_loops(self): """Test nested for loops.""" dev = qml.device("lightning.qubit", wires=4) @@ -445,6 +460,7 @@ def inner(j): assert jnp.allclose(circuit(4), jnp.eye(2**4)[0]) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_for_loop_python_fallback(self): """Test that qml.for_loop fallsback to Python interpretation if Catalyst is not available""" @@ -496,6 +512,7 @@ def inner(j): _ = [qml.assert_equal(i, j) for i, j in zip(res, expected)] + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond(self): """Test condition with simple true_fn""" dev = qml.device("lightning.qubit", wires=1) @@ -514,6 +531,7 @@ def ansatz_true(): assert jnp.allclose(circuit(1.4), 1.0) assert jnp.allclose(circuit(1.6), 0.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond_with_else(self): """Test condition with simple true_fn and false_fn""" dev = qml.device("lightning.qubit", wires=1) @@ -535,6 +553,7 @@ def ansatz_false(): assert jnp.allclose(circuit(1.4), 0.16996714) assert jnp.allclose(circuit(1.6), 0.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond_with_elif(self): """Test condition with a simple elif branch""" dev = qml.device("lightning.qubit", wires=1) @@ -558,6 +577,7 @@ def false_fn(): assert jnp.allclose(circuit(1.2), 0.13042371) assert jnp.allclose(circuit(jnp.pi), -1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond_with_elifs(self): """Test condition with multiple elif branches""" dev = qml.device("lightning.qubit", wires=1) @@ -630,6 +650,7 @@ def conditional_false_fn(): # pylint: disable=unused-variable class TestCatalystGrad: """Test ``qml.qjit`` with Catalyst's grad operations""" + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_grad_classical_preprocessing(self): """Test the grad transformation with classical preprocessing.""" @@ -647,6 +668,7 @@ def circuit(x): assert jnp.allclose(workflow(2.0), -jnp.pi) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_grad_with_postprocessing(self): """Test the grad transformation with classical preprocessing and postprocessing.""" dev = qml.device("lightning.qubit", wires=1) @@ -665,6 +687,7 @@ def loss(theta): assert jnp.allclose(workflow(1.0), 5.04324559) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_grad_with_multiple_qnodes(self): """Test the grad transformation with multiple QNodes with their own differentiation methods.""" dev = qml.device("lightning.qubit", wires=1) @@ -703,6 +726,7 @@ def dsquare(x: float): assert jnp.allclose(dsquare(2.3), 4.6) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_jacobian_diff_method(self): """Test the Jacobian transformation with the device diff_method.""" dev = qml.device("lightning.qubit", wires=1) @@ -721,6 +745,7 @@ def workflow(p: float): assert jnp.allclose(result, reference) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_jacobian_auto(self): """Test the Jacobian transformation with 'auto'.""" dev = qml.device("lightning.qubit", wires=1) @@ -740,6 +765,7 @@ def circuit(x): assert jnp.allclose(result, reference) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_jacobian_fd(self): """Test the Jacobian transformation with 'fd'.""" dev = qml.device("lightning.qubit", wires=1) @@ -838,6 +864,7 @@ def f(x): vjp(x, dy) +@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") class TestCatalystSample: """Test qml.sample with Catalyst.""" @@ -858,6 +885,7 @@ def circuit(x): assert circuit(jnp.pi) == 1 +@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") class TestCatalystMCMs: """Test dynamic_one_shot with Catalyst.""" diff --git a/tests/test_qnode.py b/tests/test_qnode.py index c60ae352e4b..5d270d1f9b8 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -36,6 +36,17 @@ def dummyfunc(): return None +def test_additional_kwargs_is_deprecated(): + """Test that passing gradient_kwargs as additional kwargs raises a deprecation warning.""" + dev = qml.device("default.qubit", wires=1) + + with pytest.warns( + qml.PennyLaneDeprecationWarning, + match=r"Specifying gradient keyword arguments \[\'atol\'\] is deprecated", + ): + QNode(dummyfunc, dev, atol=1) + + # pylint: disable=unused-argument class CustomDevice(qml.devices.Device): """A null device that just returns 0.""" @@ -145,24 +156,21 @@ def test_update_gradient_kwargs(self): """Test that gradient kwargs are updated correctly""" dev = qml.device("default.qubit") - @qml.qnode(dev, atol=1) + @qml.qnode(dev, gradient_kwargs={"atol": 1}) def circuit(x): qml.RZ(x, wires=0) qml.CNOT(wires=[0, 1]) qml.RY(x, wires=1) return qml.expval(qml.PauliZ(1)) - assert len(circuit.gradient_kwargs) == 1 - assert list(circuit.gradient_kwargs.keys()) == ["atol"] + assert set(circuit.gradient_kwargs.keys()) == {"atol"} - new_atol_circuit = circuit.update(atol=2) - assert len(new_atol_circuit.gradient_kwargs) == 1 - assert list(new_atol_circuit.gradient_kwargs.keys()) == ["atol"] + new_atol_circuit = circuit.update(gradient_kwargs={"atol": 2}) + assert set(new_atol_circuit.gradient_kwargs.keys()) == {"atol"} assert new_atol_circuit.gradient_kwargs["atol"] == 2 - new_kwarg_circuit = circuit.update(h=1) - assert len(new_kwarg_circuit.gradient_kwargs) == 2 - assert list(new_kwarg_circuit.gradient_kwargs.keys()) == ["atol", "h"] + new_kwarg_circuit = circuit.update(gradient_kwargs={"h": 1}) + assert set(new_kwarg_circuit.gradient_kwargs.keys()) == {"atol", "h"} assert new_kwarg_circuit.gradient_kwargs["atol"] == 1 assert new_kwarg_circuit.gradient_kwargs["h"] == 1 @@ -170,7 +178,7 @@ def circuit(x): UserWarning, match="Received gradient_kwarg blah, which is not included in the list of standard qnode gradient kwargs.", ): - circuit.update(blah=1) + circuit.update(gradient_kwargs={"blah": 1}) def test_update_multiple_arguments(self): """Test that multiple parameters can be updated at once.""" @@ -194,7 +202,7 @@ def test_update_transform_program(self): dev = qml.device("default.qubit", wires=2) @qml.transforms.combine_global_phases - @qml.qnode(dev, atol=1) + @qml.qnode(dev) def circuit(x): qml.RZ(x, wires=0) qml.GlobalPhase(phi=1) @@ -248,7 +256,7 @@ def circuit(return_type): def test_expansion_strategy_error(self): """Test that an error is raised if expansion_strategy is passed to the qnode.""" - with pytest.raises(ValueError, match=r"'expansion_strategy' is no longer"): + with pytest.raises(ValueError, match="'expansion_strategy' is no longer"): @qml.qnode(qml.device("default.qubit"), expansion_strategy="device") def _(): @@ -453,7 +461,7 @@ def test_unrecognized_kwargs_raise_warning(self): with warnings.catch_warnings(record=True) as w: - @qml.qnode(dev, random_kwarg=qml.gradients.finite_diff) + @qml.qnode(dev, gradient_kwargs={"random_kwarg": qml.gradients.finite_diff}) def circuit(params): qml.RX(params[0], wires=0) return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0)) @@ -846,7 +854,7 @@ def test_single_expectation_value_with_argnum_one(self, diff_method, tol): y = pnp.array(-0.654, requires_grad=True) @qnode( - dev, diff_method=diff_method, argnum=[1] + dev, diff_method=diff_method, gradient_kwargs={"argnum": [1]} ) # <--- we only choose one trainable parameter def circuit(x, y): qml.RX(x, wires=[0]) @@ -1320,11 +1328,11 @@ def ansatz0(): return qml.expval(qml.X(0)) with pytest.raises(ValueError, match="'shots' is not a valid gradient_kwarg."): - qml.QNode(ansatz0, dev, shots=100) + qml.QNode(ansatz0, dev, gradient_kwargs={"shots": 100}) with pytest.raises(ValueError, match="'shots' is not a valid gradient_kwarg."): - @qml.qnode(dev, shots=100) + @qml.qnode(dev, gradient_kwargs={"shots": 100}) def _(): return qml.expval(qml.X(0)) @@ -1918,14 +1926,17 @@ def circuit(x, mp): return mp(qml.PauliZ(0)) _ = circuit(1.8, qml.expval, shots=10) - assert circuit.execute_kwargs["mcm_config"] == original_config + assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode + assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method if mcm_method != "one-shot": _ = circuit(1.8, qml.expval) - assert circuit.execute_kwargs["mcm_config"] == original_config + assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode + assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method _ = circuit(1.8, qml.expval, shots=10) - assert circuit.execute_kwargs["mcm_config"] == original_config + assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode + assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method class TestTapeExpansion: diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py index 26c46687934..5a8c63c7d22 100644 --- a/tests/test_qnode_legacy.py +++ b/tests/test_qnode_legacy.py @@ -201,7 +201,7 @@ def test_unrecognized_kwargs_raise_warning(self): with warnings.catch_warnings(record=True) as w: - @qml.qnode(dev, random_kwarg=qml.gradients.finite_diff) + @qml.qnode(dev, gradient_kwargs={"random_kwarg": qml.gradients.finite_diff}) def circuit(params): qml.RX(params[0], wires=0) return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0)) @@ -627,7 +627,7 @@ def test_single_expectation_value_with_argnum_one(self, diff_method, tol): y = pnp.array(-0.654, requires_grad=True) @qnode( - dev, diff_method=diff_method, argnum=[1] + dev, diff_method=diff_method, gradient_kwargs={"argnum": [1]} ) # <--- we only choose one trainable parameter def circuit(x, y): qml.RX(x, wires=[0]) diff --git a/tests/transforms/core/test_transform_dispatcher.py b/tests/transforms/core/test_transform_dispatcher.py index 4a24627c739..6784c3b0a6e 100644 --- a/tests/transforms/core/test_transform_dispatcher.py +++ b/tests/transforms/core/test_transform_dispatcher.py @@ -216,6 +216,7 @@ def test_the_transform_container_attributes(self): class TestTransformDispatcher: # pylint: disable=too-many-public-methods """Test the transform function (validate and dispatch).""" + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.catalyst @pytest.mark.external def test_error_on_qjit(self): diff --git a/tests/transforms/test_add_noise.py b/tests/transforms/test_add_noise.py index 779dcf91e36..3beab611105 100644 --- a/tests/transforms/test_add_noise.py +++ b/tests/transforms/test_add_noise.py @@ -414,7 +414,7 @@ def test_add_noise_level(self, level1, level2): @qml.transforms.undo_swaps @qml.transforms.merge_rotations @qml.transforms.cancel_inverses - @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4}) def f(w, x, y, z): qml.RX(w, wires=0) qml.RY(x, wires=1) @@ -447,7 +447,7 @@ def test_add_noise_level_with_final(self): @qml.transforms.undo_swaps @qml.transforms.merge_rotations @qml.transforms.cancel_inverses - @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4}) def f(w, x, y, z): qml.RX(w, wires=0) qml.RY(x, wires=1) diff --git a/tests/workflow/interfaces/execute/test_execute.py b/tests/workflow/interfaces/execute/test_execute.py index 8bea335ff7e..4a253727261 100644 --- a/tests/workflow/interfaces/execute/test_execute.py +++ b/tests/workflow/interfaces/execute/test_execute.py @@ -57,6 +57,28 @@ def test_execute_legacy_device(): assert qml.math.allclose(res[0], np.cos(0.1)) +def test_mcm_config_deprecation(mocker): + """Test that mcm_config argument has been deprecated.""" + + tape = qml.tape.QuantumScript( + [qml.RX(qml.numpy.array(1.0), 0)], [qml.expval(qml.Z(0))], shots=10 + ) + dev = qml.device("default.qubit") + + with dev.tracker: + with pytest.warns( + qml.PennyLaneDeprecationWarning, + match="The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.", + ): + spy = mocker.spy(qml.dynamic_one_shot, "_transform") + qml.execute( + (tape,), + dev, + mcm_config=qml.devices.MCMConfig(mcm_method="one-shot", postselect_mode=None), + ) + spy.assert_called_once() + + def test_config_deprecation(): """Test that the config argument has been deprecated.""" diff --git a/tests/workflow/interfaces/qnode/test_autograd_qnode.py b/tests/workflow/interfaces/qnode/test_autograd_qnode.py index ac30592926b..18e5f70cb83 100644 --- a/tests/workflow/interfaces/qnode/test_autograd_qnode.py +++ b/tests/workflow/interfaces/qnode/test_autograd_qnode.py @@ -137,14 +137,15 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, tol, dev device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA a = np.array(0.1, requires_grad=True) b = np.array(0.2, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -183,14 +184,15 @@ def test_jacobian_no_evaluate( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA a = np.array(0.1, requires_grad=True) b = np.array(0.2, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -222,13 +224,14 @@ def test_jacobian_options(self, interface, dev, diff_method, grad_on_execution, a = np.array([0.1, 0.2], requires_grad=True) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( dev, interface=interface, - h=1e-8, - order=2, grad_on_execution=grad_on_execution, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -408,10 +411,10 @@ def test_differentiable_expand( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 10 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 10 tol = TOL_FOR_SPSA # pylint: disable=too-few-public-methods @@ -429,7 +432,7 @@ def decomposition(self): a = np.array(0.1, requires_grad=False) p = np.array([0.1, 0.2, 0.3], requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -555,15 +558,15 @@ def test_probability_differentiation( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -593,14 +596,15 @@ def test_multiple_probability_differentiation( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -660,14 +664,15 @@ def test_ragged_differentiation( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -711,8 +716,9 @@ def test_ragged_differentiation_variance( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard gradient does not support variances.") @@ -720,7 +726,7 @@ def test_ragged_differentiation_variance( x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -826,12 +832,13 @@ def test_chained_gradient_value( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA dev1 = qml.device("default.qubit") - @qnode(dev1, **kwargs) + @qnode(dev1, **kwargs, gradient_kwargs=gradient_kwargs) def circuit1(a, b, c): qml.RX(a, wires=0) qml.RX(b, wires=1) @@ -1350,8 +1357,9 @@ def test_projector( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard gradient does not support variances.") @@ -1359,7 +1367,7 @@ def test_projector( P = np.array(state, requires_grad=False) x, y = np.array([0.765, -0.654], requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=0) qml.RY(y, wires=1) @@ -1498,16 +1506,17 @@ def test_hamiltonian_expansion_analytic( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method in ["adjoint", "hadamard"]: pytest.skip("The diff method requested does not yet support Hamiltonians") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 10 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 10 tol = TOL_FOR_SPSA obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) qml.templates.AngleEmbedding(data, wires=[0, 1]) @@ -1584,7 +1593,7 @@ def test_hamiltonian_finite_shots( grad_on_execution=grad_on_execution, max_diff=max_diff, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) diff --git a/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py index 87e534fdb52..54072ae442e 100644 --- a/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py @@ -52,7 +52,7 @@ def test_jac_single_measurement_param( """For one measurement and one param, the gradient is a float.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -74,7 +74,7 @@ def test_jac_single_measurement_multiple_param( """For one measurement and multiple param, the gradient is a tuple of arrays.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -100,7 +100,7 @@ def test_jacobian_single_measurement_multiple_param_array( """For one measurement and multiple param as a single array params, the gradient is an array.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -123,7 +123,7 @@ def test_jacobian_single_measurement_param_probs( dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -146,7 +146,7 @@ def test_jacobian_single_measurement_probs_multiple_param( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -173,7 +173,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -198,7 +198,7 @@ def test_jacobian_expval_expval_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=1, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=1, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -223,7 +223,7 @@ def test_jacobian_expval_expval_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -250,7 +250,7 @@ def test_jacobian_var_var_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -275,7 +275,7 @@ def test_jacobian_var_var_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -299,7 +299,7 @@ def test_jacobian_multiple_measurement_single_param( """The jacobian of multiple measurements with a single params return an array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -322,7 +322,7 @@ def test_jacobian_multiple_measurement_multiple_param( """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -349,7 +349,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -382,7 +382,7 @@ def test_hessian_expval_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -412,7 +412,7 @@ def test_hessian_expval_multiple_param_array( params = np.array([0.1, 0.2]) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -440,7 +440,7 @@ def test_hessian_var_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -470,7 +470,7 @@ def test_hessian_var_multiple_param_array( params = np.array([0.1, 0.2]) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -500,7 +500,7 @@ def test_hessian_probs_expval_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -533,7 +533,7 @@ def test_hessian_expval_probs_multiple_param_array( params = np.array([0.1, 0.2]) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -571,7 +571,7 @@ def test_single_expectation_value( x = np.array(0.543) y = np.array(-0.654) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -604,7 +604,7 @@ def test_prob_expectation_values( x = np.array(0.543) y = np.array(-0.654) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) diff --git a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py index 862806999c6..02a586be118 100644 --- a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py +++ b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py @@ -230,7 +230,7 @@ def decomposition(self): interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, p): qml.RX(a, wires=0) @@ -266,14 +266,15 @@ def test_jacobian_options( a = np.array([0.1, 0.2], requires_grad=True) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( get_device(dev_name, wires=1, seed=seed), interface=interface, diff_method="finite-diff", - h=1e-8, - approx_order=2, grad_on_execution=grad_on_execution, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -323,7 +324,7 @@ def test_diff_expval_expval( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -389,7 +390,7 @@ def test_jacobian_no_evaluate( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -456,7 +457,7 @@ def test_diff_single_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -512,7 +513,7 @@ def test_diff_multi_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -601,7 +602,7 @@ def test_diff_expval_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -679,7 +680,7 @@ def test_diff_expval_probs_sub_argnums( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -740,7 +741,7 @@ def test_diff_var_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -1130,7 +1131,7 @@ def test_second_derivative( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1183,7 +1184,7 @@ def test_hessian( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1237,7 +1238,7 @@ def test_hessian_vector_valued( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1300,7 +1301,7 @@ def test_hessian_vector_valued_postprocessing( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RX(x[0], wires=0) @@ -1366,7 +1367,7 @@ def test_hessian_vector_valued_separate_args( grad_on_execution=grad_on_execution, max_diff=2, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -1478,7 +1479,7 @@ def test_projector( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=0) @@ -1593,7 +1594,7 @@ def test_hamiltonian_expansion_analytic( grad_on_execution=grad_on_execution, max_diff=max_diff, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) @@ -1663,7 +1664,7 @@ def test_hamiltonian_finite_shots( grad_on_execution=grad_on_execution, max_diff=max_diff, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) @@ -1862,7 +1863,7 @@ def test_gradient( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -2013,7 +2014,7 @@ def test_gradient_scalar_cost_vector_valued_qnode( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -3062,7 +3063,7 @@ def test_single_measurement( grad_on_execution=grad_on_execution, cache=False, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -3123,7 +3124,7 @@ def test_multi_measurements( diff_method=diff_method, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(a, b): qml.RY(a, wires=0) diff --git a/tests/workflow/interfaces/qnode/test_jax_qnode.py b/tests/workflow/interfaces/qnode/test_jax_qnode.py index ae1cb53a398..ae7ea130193 100644 --- a/tests/workflow/interfaces/qnode/test_jax_qnode.py +++ b/tests/workflow/interfaces/qnode/test_jax_qnode.py @@ -188,10 +188,10 @@ def test_differentiable_expand( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 10 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 10 tol = TOL_FOR_SPSA class U3(qml.U3): # pylint:disable=too-few-public-methods @@ -206,7 +206,7 @@ def decomposition(self): a = jax.numpy.array(0.1) p = jax.numpy.array([0.1, 0.2, 0.3]) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -240,12 +240,13 @@ def test_jacobian_options( a = jax.numpy.array([0.1, 0.2]) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( get_device(dev_name, wires=1, seed=seed), interface=interface, diff_method="finite-diff", - h=1e-8, - approx_order=2, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -273,9 +274,9 @@ def test_diff_expval_expval( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -283,7 +284,7 @@ def test_diff_expval_expval( a = jax.numpy.array(0.1) b = jax.numpy.array(0.2) - @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -334,14 +335,15 @@ def test_jacobian_no_evaluate( if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA a = jax.numpy.array(0.1) b = jax.numpy.array(0.2) - @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -387,8 +389,9 @@ def test_diff_single_probs( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -396,7 +399,7 @@ def test_diff_single_probs( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -436,8 +439,9 @@ def test_diff_multi_probs( "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -445,7 +449,7 @@ def test_diff_multi_probs( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -517,8 +521,9 @@ def test_diff_expval_probs( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -526,7 +531,7 @@ def test_diff_expval_probs( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -599,7 +604,7 @@ def test_diff_expval_probs_sub_argnums( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -646,18 +651,19 @@ def test_diff_var_probs( "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "hadamard": pytest.skip("Hadamard does not support var") if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -968,13 +974,14 @@ def test_second_derivative( "max_diff": 2, } + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("Adjoint does not second derivative.") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA - @qnode(get_device(dev_name, wires=0, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=0, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RY(x[0], wires=0) qml.RX(x[1], wires=0) @@ -1024,7 +1031,7 @@ def test_hessian( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1079,7 +1086,7 @@ def test_hessian_vector_valued( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1142,7 +1149,7 @@ def test_hessian_vector_valued_postprocessing( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RX(x[0], wires=0) @@ -1208,7 +1215,7 @@ def test_hessian_vector_valued_separate_args( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -1316,7 +1323,7 @@ def test_projector( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=0) @@ -1425,7 +1432,7 @@ def test_split_non_commuting_analytic( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=max_diff, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) @@ -1513,7 +1520,7 @@ def test_hamiltonian_finite_shots( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=max_diff, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) diff --git a/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py index 697a1e90223..210ba601586 100644 --- a/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py @@ -60,7 +60,7 @@ def test_jac_single_measurement_param( """For one measurement and one param, the gradient is a float.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -86,7 +86,7 @@ def test_jac_single_measurement_multiple_param( """For one measurement and multiple param, the gradient is a tuple of arrays.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -115,7 +115,7 @@ def test_jacobian_single_measurement_multiple_param_array( """For one measurement and multiple param as a single array params, the gradient is an array.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -141,7 +141,7 @@ def test_jacobian_single_measurement_param_probs( dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -167,7 +167,7 @@ def test_jacobian_single_measurement_probs_multiple_param( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -199,7 +199,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -227,7 +227,13 @@ def test_jacobian_expval_expval_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=1, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=1, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -264,7 +270,7 @@ def test_jacobian_expval_expval_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -298,7 +304,7 @@ def test_jacobian_var_var_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -336,7 +342,7 @@ def test_jacobian_var_var_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -367,7 +373,7 @@ def test_jacobian_multiple_measurement_single_param( """The jacobian of multiple measurements with a single params return an array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -398,7 +404,7 @@ def test_jacobian_multiple_measurement_multiple_param( """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -438,7 +444,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -471,7 +477,13 @@ def test_hessian_expval_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -508,7 +520,13 @@ def test_hessian_expval_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -534,7 +552,13 @@ def test_hessian_var_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -571,7 +595,13 @@ def test_hessian_var_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -597,7 +627,13 @@ def test_hessian_probs_expval_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -655,7 +691,13 @@ def test_hessian_expval_probs_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -687,7 +729,13 @@ def test_hessian_probs_var_multiple_params( par_0 = qml.numpy.array(0.1) par_1 = qml.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -742,7 +790,13 @@ def test_hessian_var_probs_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -792,7 +846,7 @@ def test_single_expectation_value( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -827,7 +881,7 @@ def test_prob_expectation_values( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py index f24eda4b382..6a20f264ec6 100644 --- a/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py @@ -73,7 +73,7 @@ def test_jac_single_measurement_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a, wires=0) @@ -101,7 +101,7 @@ def test_jac_single_measurement_multiple_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b, **_): qml.RY(a, wires=0) @@ -133,7 +133,7 @@ def test_jacobian_single_measurement_multiple_param_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -162,7 +162,7 @@ def test_jacobian_single_measurement_param_probs( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a, wires=0) @@ -191,7 +191,7 @@ def test_jacobian_single_measurement_probs_multiple_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b, **_): qml.RY(a, wires=0) @@ -224,7 +224,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -255,7 +255,7 @@ def test_jacobian_expval_expval_multiple_params( diff_method=diff_method, interface=interface, max_diff=1, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) @@ -285,7 +285,7 @@ def test_jacobian_expval_expval_multiple_params_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -314,7 +314,7 @@ def test_jacobian_multiple_measurement_single_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a, wires=0) @@ -342,7 +342,7 @@ def test_jacobian_multiple_measurement_multiple_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b, **_): qml.RY(a, wires=0) @@ -374,7 +374,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -421,7 +421,7 @@ def test_hessian_expval_multiple_params( diff_method=diff_method, interface=interface, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) @@ -471,7 +471,7 @@ def test_single_expectation_value( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) @@ -509,7 +509,7 @@ def test_prob_expectation_values( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py b/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py index 85a39cc588f..79f19bd2f4d 100644 --- a/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py +++ b/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py @@ -158,6 +158,7 @@ def circuit(p1, p2=y, **kwargs): def test_jacobian(self, dev, diff_method, grad_on_execution, device_vjp, tol, interface, seed): """Test jacobian calculation""" + gradient_kwargs = {} kwargs = { "diff_method": diff_method, "grad_on_execution": grad_on_execution, @@ -165,14 +166,14 @@ def test_jacobian(self, dev, diff_method, grad_on_execution, device_vjp, tol, in "device_vjp": device_vjp, } if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA a = tf.Variable(0.1, dtype=tf.float64) b = tf.Variable(0.2, dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -200,14 +201,15 @@ def test_jacobian_options(self, dev, diff_method, grad_on_execution, device_vjp, a = tf.Variable([0.1, 0.2]) + gradient_kwargs = {"approx_order": 2, "h": 1e-8} + @qnode( dev, interface=interface, - h=1e-8, - approx_order=2, diff_method=diff_method, grad_on_execution=grad_on_execution, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -240,7 +242,7 @@ def test_changing_trainability( diff_method=diff_method, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **diff_kwargs, + gradient_kwargs=diff_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -370,6 +372,7 @@ def test_differentiable_expand( ): """Test that operation and nested tapes expansion is differentiable""" + gradient_kwargs = {} kwargs = { "diff_method": diff_method, "grad_on_execution": grad_on_execution, @@ -377,8 +380,8 @@ def test_differentiable_expand( "device_vjp": device_vjp, } if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA class U3(qml.U3): @@ -393,7 +396,7 @@ def decomposition(self): a = np.array(0.1) p = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -548,15 +551,16 @@ def test_probability_differentiation( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -605,15 +609,16 @@ def test_ragged_differentiation( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -943,13 +948,14 @@ def test_projector( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("adjoint supports either all expvals or all diagonal measurements.") if diff_method == "hadamard": pytest.skip("Variance not implemented yet.") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA if dev.name == "reference.qubit": pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") @@ -959,7 +965,7 @@ def test_projector( x, y = 0.765, -0.654 weights = tf.Variable([x, y], dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(weights): qml.RX(weights[0], wires=0) qml.RY(weights[1], wires=1) @@ -1129,16 +1135,17 @@ def test_hamiltonian_expansion_analytic( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method in ["adjoint", "hadamard"]: pytest.skip("The adjoint/hadamard method does not yet support Hamiltonians") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(data, weights, coeffs): weights = tf.reshape(weights, [1, -1]) qml.templates.AngleEmbedding(data, wires=[0, 1]) @@ -1211,7 +1218,7 @@ def test_hamiltonian_finite_shots( max_diff=max_diff, interface=interface, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = tf.reshape(weights, [1, -1]) diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py index 037469657bc..fafd85e7e1e 100644 --- a/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py @@ -70,7 +70,7 @@ def test_jac_single_measurement_param( ): """For one measurement and one param, the gradient is a float.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -92,7 +92,7 @@ def test_jac_single_measurement_multiple_param( ): """For one measurement and multiple param, the gradient is a tuple of arrays.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -118,7 +118,7 @@ def test_jacobian_single_measurement_multiple_param_array( ): """For one measurement and multiple param as a single array params, the gradient is an array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -141,7 +141,7 @@ def test_jacobian_single_measurement_param_probs( """For a multi dimensional measurement (probs), check that a single array is returned with the correct dimension""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -164,7 +164,7 @@ def test_jacobian_single_measurement_probs_multiple_param( """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with the correct dimension""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -191,7 +191,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with the correct dimension""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -216,7 +216,13 @@ def test_jacobian_expval_expval_multiple_params( par_0 = tf.Variable(0.1) par_1 = tf.Variable(0.2) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=1, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=1, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -240,7 +246,7 @@ def test_jacobian_expval_expval_multiple_params_array( ): """The jacobian of multiple measurements with a multiple params array return a single array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -263,7 +269,7 @@ def test_jacobian_multiple_measurement_single_param( ): """The jacobian of multiple measurements with a single params return an array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -285,7 +291,7 @@ def test_jacobian_multiple_measurement_multiple_param( ): """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -311,7 +317,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( ): """The jacobian of multiple measurements with a multiple params array return a single array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -343,7 +349,13 @@ def test_hessian_expval_multiple_params( par_0 = tf.Variable(0.1, dtype=tf.float64) par_1 = tf.Variable(0.2, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -373,7 +385,13 @@ def test_hessian_expval_multiple_param_array( params = tf.Variable([0.1, 0.2], dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -400,7 +418,13 @@ def test_hessian_probs_expval_multiple_params( par_0 = tf.Variable(0.1, dtype=tf.float64) par_1 = tf.Variable(0.2, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -430,7 +454,13 @@ def test_hessian_expval_probs_multiple_param_array( params = tf.Variable([0.1, 0.2], dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -467,7 +497,7 @@ def test_single_expectation_value( x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -500,7 +530,7 @@ def test_prob_expectation_values( x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) diff --git a/tests/workflow/interfaces/qnode/test_torch_qnode.py b/tests/workflow/interfaces/qnode/test_torch_qnode.py index e9581898522..a2922a7ce62 100644 --- a/tests/workflow/interfaces/qnode/test_torch_qnode.py +++ b/tests/workflow/interfaces/qnode/test_torch_qnode.py @@ -173,9 +173,10 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, device_v interface=interface, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA a_val = 0.1 @@ -184,7 +185,7 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, device_v a = torch.tensor(a_val, dtype=torch.float64, requires_grad=True) b = torch.tensor(b_val, dtype=torch.float64, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -276,14 +277,15 @@ def test_jacobian_options( a = torch.tensor([0.1, 0.2], requires_grad=True) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( dev, diff_method=diff_method, grad_on_execution=grad_on_execution, interface=interface, - h=1e-8, - approx_order=2, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -483,9 +485,10 @@ def test_differentiable_expand( interface=interface, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA class U3(qml.U3): # pylint:disable=too-few-public-methods @@ -501,7 +504,7 @@ def decomposition(self): p_val = [0.1, 0.2, 0.3] p = torch.tensor(p_val, dtype=torch.float64, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -644,9 +647,9 @@ def test_probability_differentiation( with prob and expval outputs""" if "lightning" in getattr(dev, "name", "").lower(): pytest.xfail("lightning does not support measureing probabilities with adjoint.") - kwargs = {} + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x_val = 0.543 @@ -660,7 +663,7 @@ def test_probability_differentiation( grad_on_execution=grad_on_execution, interface=interface, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -708,9 +711,10 @@ def test_ragged_differentiation( interface=interface, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA x_val = 0.543 @@ -718,7 +722,7 @@ def test_ragged_differentiation( x = torch.tensor(x_val, requires_grad=True, dtype=torch.float64) y = torch.tensor(y_val, requires_grad=True, dtype=torch.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -813,7 +817,7 @@ def test_hessian(self, interface, dev, diff_method, grad_on_execution, device_vj max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RY(x[0], wires=0) @@ -866,7 +870,7 @@ def test_hessian_vector_valued( max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RY(x[0], wires=0) @@ -928,7 +932,7 @@ def test_hessian_ragged(self, interface, dev, diff_method, grad_on_execution, de max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1003,7 +1007,7 @@ def test_hessian_vector_valued_postprocessing( max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RX(x[0], wires=0) @@ -1100,11 +1104,12 @@ def test_projector( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("adjoint supports either all expvals or all diagonal measurements") if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard does not support variances.") @@ -1116,7 +1121,7 @@ def test_projector( x, y = 0.765, -0.654 weights = torch.tensor([x, y], requires_grad=True, dtype=torch.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=0) qml.RY(y, wires=1) @@ -1285,18 +1290,19 @@ def test_hamiltonian_expansion_analytic( interface="torch", device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("The adjoint method does not yet support Hamiltonians") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("The hadamard method does not yet support Hamiltonians") obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(data, weights, coeffs): weights = torch.reshape(weights, [1, -1]) qml.templates.AngleEmbedding(data, wires=[0, 1]) @@ -1389,7 +1395,7 @@ def test_hamiltonian_finite_shots( max_diff=max_diff, interface="torch", device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = torch.reshape(weights, [1, -1]) diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py index f11eaee459c..d4288da65d9 100644 --- a/tests/workflow/test_construct_batch.py +++ b/tests/workflow/test_construct_batch.py @@ -72,7 +72,7 @@ def test_get_transform_program_diff_method_transform(self): @partial(qml.transforms.compile, num_passes=2) @partial(qml.transforms.merge_rotations, atol=1e-5) @qml.transforms.cancel_inverses - @qml.qnode(dev, diff_method="parameter-shift", shifts=2) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": 2}) def circuit(): return qml.expval(qml.PauliZ(0))