diff --git a/.github/stable/all_interfaces.txt b/.github/stable/all_interfaces.txt index b53f5569d44..84489804731 100644 --- a/.github/stable/all_interfaces.txt +++ b/.github/stable/all_interfaces.txt @@ -38,7 +38,7 @@ isort==5.13.2 jax==0.4.28 jaxlib==0.4.28 Jinja2==3.1.5 -keras==3.7.0 +keras==3.8.0 kiwisolver==1.4.8 lazy-object-proxy==1.10.0 libclang==18.1.1 @@ -68,12 +68,12 @@ nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.6.85 nvidia-nvtx-cu12==12.1.105 opt_einsum==3.4.0 -optree==0.13.1 +optree==0.14.0 osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -83,7 +83,7 @@ protobuf==4.25.5 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -101,7 +101,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -115,14 +116,14 @@ termcolor==2.5.0 tf_keras==2.16.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 torch==2.3.0 triton==2.3.0 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 Werkzeug==3.1.3 wrapt==1.12.1 diff --git a/.github/stable/core.txt b/.github/stable/core.txt index fc24ed2dff0..925553c4c2f 100644 --- a/.github/stable/core.txt +++ b/.github/stable/core.txt @@ -43,7 +43,7 @@ osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -52,7 +52,7 @@ prompt_toolkit==3.0.48 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -70,7 +70,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -78,11 +79,11 @@ tach==0.13.1 termcolor==2.5.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 wrapt==1.12.1 diff --git a/.github/stable/external.txt b/.github/stable/external.txt index 5cb949a78d7..80052907bc0 100644 --- a/.github/stable/external.txt +++ b/.github/stable/external.txt @@ -27,14 +27,14 @@ clarabel==0.9.0 click==8.1.8 comm==0.2.2 contourpy==1.3.1 -cotengra==0.6.2 +cotengra==0.7.0 coverage==7.6.10 cryptography==44.0.0 cvxopt==1.3.2 cvxpy==1.6.0 cycler==0.12.1 cytoolz==1.0.1 -debugpy==1.8.11 +debugpy==1.8.12 decorator==5.1.1 defusedxml==0.7.1 diastatic-malt==2.15.2 @@ -60,8 +60,8 @@ h11==0.14.0 h5py==3.12.1 httpcore==1.0.7 httpx==0.28.1 -ibm-cloud-sdk-core==3.22.0 -ibm-platform-services==0.59.0 +ibm-cloud-sdk-core==3.22.1 +ibm-platform-services==0.59.1 identify==2.6.5 idna==3.10 iniconfig==2.0.0 @@ -89,7 +89,7 @@ jupyterlab==4.3.4 jupyterlab_pygments==0.3.0 jupyterlab_server==2.27.3 jupyterlab_widgets==1.1.11 -keras==3.7.0 +keras==3.8.0 kiwisolver==1.4.8 lark==1.1.9 lazy-object-proxy==1.10.0 @@ -120,7 +120,7 @@ numba==0.60.0 numpy==1.26.4 opt_einsum==3.4.0 optax==0.2.4 -optree==0.13.1 +optree==0.14.0 osqp==0.6.7.post3 overrides==7.7.0 packaging==24.2 @@ -129,8 +129,8 @@ pandocfilters==1.5.1 parso==0.8.4 pathspec==0.12.1 pbr==6.1.0 -PennyLane-Catalyst==0.10.0.dev39 -PennyLane-qiskit @ git+https://github.com/PennyLaneAI/pennylane-qiskit.git@40a4d24f126e51e0e3e28a4cd737f883a6fd5ebc +PennyLane-Catalyst==0.11.0.dev1 +PennyLane-qiskit @ git+https://github.com/PennyLaneAI/pennylane-qiskit.git@b46fbca3372979534bc33af701194df548cf4b16 PennyLane_Lightning==0.40.0 PennyLane_Lightning_Kokkos==0.40.0.dev41 pexpect==4.9.0 @@ -148,10 +148,10 @@ pure_eval==0.2.3 py==1.11.0 py-cpuinfo==9.0.0 pycparser==2.22 -pydantic==2.10.4 +pydantic==2.10.5 pydantic_core==2.27.2 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 PyJWT==2.10.1 pylint==2.7.4 pyparsing==3.2.1 @@ -173,11 +173,11 @@ pyzmq==26.2.0 pyzx==0.8.0 qdldl==0.1.7.post5 qiskit==1.2.4 -qiskit-aer==0.15.1 +qiskit-aer==0.16.0 qiskit-ibm-provider==0.11.0 qiskit-ibm-runtime==0.29.0 quimb==1.10.0 -referencing==0.35.1 +referencing==0.36.1 requests==2.32.3 requests_ntlm==1.3.0 rfc3339-validator==0.1.4 @@ -211,7 +211,7 @@ tf_keras==2.16.0 tinycss2==1.4.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 toolz==1.0.0 tornado==6.4.2 @@ -222,12 +222,12 @@ typing_extensions==4.12.2 tzdata==2024.2 uri-template==1.3.0 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 webcolors==24.11.1 webencodings==0.5.1 websocket-client==1.8.0 -websockets==14.1 +websockets==14.2 Werkzeug==3.1.3 widgetsnbextension==3.6.10 wrapt==1.12.1 diff --git a/.github/stable/jax.txt b/.github/stable/jax.txt index ed2c0b32037..504bcff8e47 100644 --- a/.github/stable/jax.txt +++ b/.github/stable/jax.txt @@ -37,7 +37,7 @@ markdown-it-py==3.0.0 matplotlib==3.10.0 mccabe==0.6.1 mdurl==0.1.2 -ml_dtypes==0.5.0 +ml_dtypes==0.5.1 mypy-extensions==1.0.0 networkx==3.4.2 nodeenv==1.9.1 @@ -47,7 +47,7 @@ osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -56,7 +56,7 @@ prompt_toolkit==3.0.48 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -74,7 +74,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -82,11 +83,11 @@ tach==0.13.1 termcolor==2.5.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 wrapt==1.12.1 diff --git a/.github/stable/tf.txt b/.github/stable/tf.txt index 634fbd09526..02d559768eb 100644 --- a/.github/stable/tf.txt +++ b/.github/stable/tf.txt @@ -34,7 +34,7 @@ identify==2.6.5 idna==3.10 iniconfig==2.0.0 isort==5.13.2 -keras==3.7.0 +keras==3.8.0 kiwisolver==1.4.8 lazy-object-proxy==1.10.0 libclang==18.1.1 @@ -51,12 +51,12 @@ networkx==3.4.2 nodeenv==1.9.1 numpy==1.26.4 opt_einsum==3.4.0 -optree==0.13.1 +optree==0.14.0 osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -66,7 +66,7 @@ protobuf==4.25.5 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -84,7 +84,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -97,12 +98,12 @@ termcolor==2.5.0 tf_keras==2.16.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 Werkzeug==3.1.3 wrapt==1.12.1 diff --git a/.github/stable/torch.txt b/.github/stable/torch.txt index 3ab4c0e93f2..46b5f061df6 100644 --- a/.github/stable/torch.txt +++ b/.github/stable/torch.txt @@ -59,7 +59,7 @@ osqp==0.6.7.post3 packaging==24.2 pandas==2.2.3 pathspec==0.12.1 -PennyLane_Lightning==0.40.0 +PennyLane_Lightning==0.41.0 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 @@ -68,7 +68,7 @@ prompt_toolkit==3.0.48 py==1.11.0 py-cpuinfo==9.0.0 pydot==3.0.4 -Pygments==2.19.0 +Pygments==2.19.1 pylint==2.7.4 pyparsing==3.2.1 pytest==8.3.4 @@ -86,7 +86,8 @@ qdldl==0.1.7.post5 requests==2.32.3 rich==13.9.4 rustworkx==0.15.1 -scipy==1.15.0 +scipy==1.15.1 +scipy-openblas32==0.3.28.0.2 scs==3.2.7.post2 six==1.17.0 smmap==5.0.2 @@ -95,13 +96,13 @@ tach==0.13.1 termcolor==2.5.0 toml==0.10.2 tomli==2.2.1 -tomli_w==1.1.0 +tomli_w==1.2.0 tomlkit==0.13.2 torch==2.3.0 triton==2.3.0 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.3.0 -virtualenv==20.28.1 +virtualenv==20.29.1 wcwidth==0.2.13 wrapt==1.12.1 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 02b95edc395..5683d05880c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -16,15 +16,18 @@ name: "Documentation check" on: + merge_group: + types: + - checks_requested pull_request: types: - opened - reopened - synchronize - ready_for_review - # Scheduled trigger on Wednesdays at 3:00am UTC + # Scheduled trigger on Monday at 2:47am UTC schedule: - - cron: "0 3 * * 3" + - cron: "47 2 * * 1" permissions: write-all diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 0a29e842bf1..ba5d6ce7776 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -1,5 +1,8 @@ name: Formatting check on: + merge_group: + types: + - checks_requested pull_request: types: - opened diff --git a/.github/workflows/interface-dependency-versions.yml b/.github/workflows/interface-dependency-versions.yml index f567321b27c..d8520064dc1 100644 --- a/.github/workflows/interface-dependency-versions.yml +++ b/.github/workflows/interface-dependency-versions.yml @@ -26,7 +26,7 @@ on: description: The version of PyTorch to use for testing required: false type: string - default: '2.3.0' + default: '2.5.0' outputs: jax-version: description: The version of JAX to use @@ -72,6 +72,6 @@ jobs: outputs: jax-version: jax==${{ steps.jax.outputs.version }} jaxlib==${{ steps.jax.outputs.version }} tensorflow-version: tensorflow~=${{ steps.tensorflow.outputs.version }} tf-keras~=${{ steps.tensorflow.outputs.version }} - pytorch-version: torch==${{ steps.pytorch.outputs.version }} + pytorch-version: torch~=${{ steps.pytorch.outputs.version }} catalyst-nightly: ${{ steps.catalyst.outputs.nightly }} pennylane-lightning-latest: ${{ steps.pennylane-lightning.outputs.latest }} diff --git a/.github/workflows/module-validation.yml b/.github/workflows/module-validation.yml index 0db4f16faeb..66af8bc218e 100644 --- a/.github/workflows/module-validation.yml +++ b/.github/workflows/module-validation.yml @@ -1,6 +1,9 @@ name: Validate module imports on: + merge_group: + types: + - checks_requested pull_request: types: - opened diff --git a/.github/workflows/tests-gpu.yml b/.github/workflows/tests-gpu.yml index bd8cd5dae67..92a7bae8ac9 100644 --- a/.github/workflows/tests-gpu.yml +++ b/.github/workflows/tests-gpu.yml @@ -3,6 +3,9 @@ on: push: branches: - master + merge_group: + types: + - checks_requested pull_request: types: - opened diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 63c8074e524..620dc39ac57 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,6 +14,9 @@ on: # Scheduled trigger on Monday at 2:47am UTC schedule: - cron: "47 2 * * 1" + merge_group: + types: + - checks_requested concurrency: group: unit-tests-${{ github.ref }} diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst index ff6048e18be..a61b8338222 100644 --- a/doc/development/deprecations.rst +++ b/doc/development/deprecations.rst @@ -9,6 +9,25 @@ deprecations are listed below. Pending deprecations -------------------- +* The ``mcm_config`` argument to ``qml.execute`` has been deprecated. + Instead, use the ``mcm_method`` and ``postselect_mode`` arguments. + + - Deprecated in v0.41 + - Will be removed in v0.42 + +* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated + and will be removed in v0.42. The gradient keyword arguments should be passed to the new + keyword argument ``gradient_kwargs`` via an explicit dictionary, like ``gradient_kwargs={"h": 1e-4}``. + + - Deprecated in v0.41 + - Will be removed in v0.42 + +* The `qml.gradients.hamiltonian_grad` function has been deprecated. + This gradient recipe is not required with the new operator arithmetic system. + + - Deprecated in v0.41 + - Will be removed in v0.42 + * The ``inner_transform_program`` and ``config`` keyword arguments in ``qml.execute`` have been deprecated. If more detailed control over the execution is required, use ``qml.workflow.run`` with these arguments instead. diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 89e21560387..88280392227 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -29,11 +29,11 @@ 'parameter-shift' ``` -* Finite shot and parameter-shift executions on `default.qubit` can now - be natively jitted end-to-end, leading to performance improvements. - Devices can now configure whether or not ML framework data is sent to them - via an `ExecutionConfig.convert_to_numpy` parameter. +* Devices can now configure whether or not ML framework data is sent to them + via an `ExecutionConfig.convert_to_numpy` parameter. This is not used on + `default.qubit` due to compilation overheads when jitting. [(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788) + [(#6869)](https://github.com/PennyLaneAI/pennylane/pull/6869) * The coefficients of observables now have improved differentiability. [(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598) @@ -44,6 +44,9 @@ * An informative error is raised when a `QNode` with `diff_method=None` is differentiated. [(#6770)](https://github.com/PennyLaneAI/pennylane/pull/6770) +* The requested `diff_method` is now validated when program capture is enabled. + [(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852) +

Breaking changes πŸ’”

* `MultiControlledX` no longer accepts strings as control values. @@ -51,6 +54,7 @@ * The input argument `control_wires` of `MultiControlledX` has been removed. [(#6832)](https://github.com/PennyLaneAI/pennylane/pull/6832) + [(#6862)](https://github.com/PennyLaneAI/pennylane/pull/6862) * `qml.execute` now has a collection of keyword-only arguments. [(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598) @@ -80,17 +84,40 @@

Deprecations πŸ‘‹

+* The `mcm_method` keyword in `qml.execute` is deprecated. Instead, use the ``mcm_method`` and ``postselect_mode`` arguments. + [(#6807)](https://github.com/PennyLaneAI/pennylane/pull/6807) + +* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated + and will be removed in v0.42. The gradient keyword arguments should be passed to the new + keyword argument `gradient_kwargs` via an explicit dictionary. This change will improve qnode argument + validation. + [(#6828)](https://github.com/PennyLaneAI/pennylane/pull/6828) + +* The `qml.gradients.hamiltonian_grad` function has been deprecated. + This gradient recipe is not required with the new operator arithmetic system. + [(#6849)](https://github.com/PennyLaneAI/pennylane/pull/6849) + * The ``inner_transform_program`` and ``config`` keyword arguments in ``qml.execute`` have been deprecated. If more detailed control over the execution is required, use ``qml.workflow.run`` with these arguments instead. [(#6822)](https://github.com/PennyLaneAI/pennylane/pull/6822)

Internal changes βš™οΈ

+* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module. + This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives. + Consequently, `QmlPrimitive` is now used to define all PennyLane primitives. + [(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847) +

Documentation πŸ“

-* Updated documentation for vibrational Hamiltonians +* The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`, + `qml.qchem.localize_normal_modes`, and `qml.qchem.VibrationalPES` have been updated to include better + code examples. [(#6717)](https://github.com/PennyLaneAI/pennylane/pull/6717) +* Fixed a typo in the code example for `qml.labs.dla.lie_closure_dense`. + [(#6858)](https://github.com/PennyLaneAI/pennylane/pull/6858) +

Bug fixes πŸ›

* `BasisState` now casts its input to integers. @@ -101,8 +128,10 @@ This release contains contributions from (in alphabetical order): Yushao Chen, +Isaac De Vlugt, Diksha Dhawan, Pietropaolo Frisoni, Marcus GisslΓ©n, Christina Lee, +Mudit Pandey, Andrija Paurevic diff --git a/pennylane/_version.py b/pennylane/_version.py index 704c372802e..9295240495c 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.41.0-dev10" +__version__ = "0.41.0-dev14" diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 4af11cb6198..2b7314f7c2e 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -25,8 +25,6 @@ from .flatfn import FlatFn from .primitives import ( - AbstractMeasurement, - AbstractOperator, adjoint_transform_prim, cond_prim, ctrl_transform_prim, @@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: self._env[constvar] = const for eqn in jaxpr.eqns: + primitive = eqn.primitive + custom_handler = self._primitive_registrations.get(primitive, None) - custom_handler = self._primitive_registrations.get(eqn.primitive, None) if custom_handler: invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif isinstance(eqn.outvars[0].aval, AbstractOperator): + elif getattr(primitive, "prim_type", "") == "operator": outvals = self.interpret_operation_eqn(eqn) - elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + elif getattr(primitive, "prim_type", "") == "measurement": outvals = self.interpret_measurement_eqn(eqn) else: invals = [self.read(invar) for invar in eqn.invars] - outvals = eqn.primitive.bind(*invals, **eqn.params) + outvals = primitive.bind(*invals, **eqn.params) - if not eqn.primitive.multiple_results: + if not primitive.multiple_results: outvals = [outvals] for outvar, outval in zip(eqn.outvars, outvals, strict=True): self._env[outvar] = outval diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index 482f692df69..ba7c3846693 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -24,34 +24,6 @@ has_jax = False -@lru_cache -def create_non_interpreted_prim(): - """Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace - and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive. - """ - - if not has_jax: # pragma: no cover - return None - - # pylint: disable=too-few-public-methods - class NonInterpPrimitive(jax.core.Primitive): - """A subclass to JAX's Primitive that works like a Python function - when evaluating JVPTracers and BatchTracers.""" - - def bind_with_trace(self, trace, args, params): - """Bind the ``NonInterpPrimitive`` with a trace. - - If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. - Otherwise, the bind call of JAX's standard Primitive is used.""" - if isinstance( - trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace) - ): - return self.impl(*args, **params) - return super().bind_with_trace(trace, args, params) - - return NonInterpPrimitive - - @lru_cache def _get_grad_prim(): """Create a primitive for gradient computations. @@ -60,8 +32,11 @@ def _get_grad_prim(): if not has_jax: # pragma: no cover return None - grad_prim = create_non_interpreted_prim()("grad") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + grad_prim = NonInterpPrimitive("grad") grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + grad_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @grad_prim.def_impl @@ -91,8 +66,14 @@ def _get_jacobian_prim(): """Create a primitive for Jacobian computations. This primitive is used when capturing ``qml.jacobian``. """ - jacobian_prim = create_non_interpreted_prim()("jacobian") + if not has_jax: # pragma: no cover + return None + + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + jacobian_prim = NonInterpPrimitive("jacobian") jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init + jacobian_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @jacobian_prim.def_impl diff --git a/pennylane/capture/capture_measurements.py b/pennylane/capture/capture_measurements.py index 59bf5490679..e23457bc7b7 100644 --- a/pennylane/capture/capture_measurements.py +++ b/pennylane/capture/capture_measurements.py @@ -128,7 +128,10 @@ def create_measurement_obs_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_obs") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_obs") + primitive.prim_type = "measurement" @primitive.def_impl def _(obs, **kwargs): @@ -165,7 +168,10 @@ def create_measurement_mcm_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_mcm") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_mcm") + primitive.prim_type = "measurement" @primitive.def_impl def _(*mcms, single_mcm=True, **kwargs): @@ -200,7 +206,10 @@ def create_measurement_wires_primitive( if not has_jax: return None - primitive = jax.core.Primitive(name + "_wires") + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(name + "_wires") + primitive.prim_type = "measurement" @primitive.def_impl def _(*args, has_eigvals=False, **kwargs): diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py index 23c98f38944..2124b5b9fe4 100644 --- a/pennylane/capture/capture_operators.py +++ b/pennylane/capture/capture_operators.py @@ -20,8 +20,6 @@ import pennylane as qml -from .capture_diff import create_non_interpreted_prim - has_jax = True try: import jax @@ -103,7 +101,10 @@ def create_operator_primitive( if not has_jax: return None - primitive = create_non_interpreted_prim()(operator_type.__name__) + from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel + + primitive = NonInterpPrimitive(operator_type.__name__) + primitive.prim_type = "operator" @primitive.def_impl def _(*args, **kwargs): diff --git a/pennylane/capture/custom_primitives.py b/pennylane/capture/custom_primitives.py new file mode 100644 index 00000000000..183ae05771b --- /dev/null +++ b/pennylane/capture/custom_primitives.py @@ -0,0 +1,64 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule offers custom primitives for the PennyLane capture module. +""" +from enum import Enum +from typing import Union + +import jax + + +class PrimitiveType(Enum): + """Enum to define valid set of primitive classes""" + + DEFAULT = "default" + OPERATOR = "operator" + MEASUREMENT = "measurement" + HIGHER_ORDER = "higher_order" + TRANSFORM = "transform" + + +# pylint: disable=too-few-public-methods,abstract-method +class QmlPrimitive(jax.core.Primitive): + """A subclass for JAX's Primitive that differentiates between different + classes of primitives.""" + + _prim_type: PrimitiveType = PrimitiveType.DEFAULT + + @property + def prim_type(self): + """Value of Enum representing the primitive type to differentiate between various + sets of PennyLane primitives.""" + return self._prim_type.value + + @prim_type.setter + def prim_type(self, value: Union[str, PrimitiveType]): + """Setter for QmlPrimitive.prim_type.""" + self._prim_type = PrimitiveType(value) + + +# pylint: disable=too-few-public-methods,abstract-method +class NonInterpPrimitive(QmlPrimitive): + """A subclass to JAX's Primitive that works like a Python function + when evaluating JVPTracers and BatchTracers.""" + + def bind_with_trace(self, trace, args, params): + """Bind the ``NonInterpPrimitive`` with a trace. + + If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)): + return self.impl(*args, **params) + return super().bind_with_trace(trace, args, params) diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 84feef9786f..71033aac3ee 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -255,7 +255,7 @@ class MyClass(metaclass=MyMetaClass): self.kwargs = kwargs ``` - Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}. + Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}. And that we have set a class property `a` @@ -272,7 +272,7 @@ But can we actually create instances of these classes? ```python >> obj = MyClass(0.1, a=2) >>> obj -creating an instance of type with (0.1,), {'a': 2}. +creating an instance of type with (0.1,), {'a': 2}. now creating an instance in __init__ <__main__.MyClass at 0x11c5a2810> ``` @@ -294,7 +294,7 @@ class MyClass2(metaclass=MetaClass2): self.args = args ``` -You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`. +You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`. Using a metaclass, we can hijack what happens when a type is called. @@ -425,3 +425,103 @@ Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `A >>> jax.make_jaxpr(PrimitiveClass2)(0.1) { lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) } ``` + +# Non-interpreted primitives + +Some of the primitives in the capture module have a somewhat non-standard requirement for the +behaviour under differentiation or batching: they should ignore that an input is a differentiation +or batching tracer and just execute the standard implementation on them. + +We will look at an example to make the necessity for such a non-interpreted primitive clear. + +Consider a finite-difference differentiation routine together with some test function `fun`. + +```python +def finite_diff_impl(x, fun, delta): + """Finite difference differentiation routine. Only supports differentiating + a function `fun` with a single scalar argument, for simplicity.""" + + out_plus = fun(x + delta) + out_minus = fun(x - delta) + return tuple((out_p - out_m) / (2 * delta) for out_p, out_m in zip(out_plus, out_minus)) + +def fun(x): + return (x**2, 4 * x - 3, x**23) +``` + +Now suppose we want to turn this into a primitive. We could just promote it to a standard +`jax.core.Primitive` as + +```python +import jax + +fd_prim = jax.core.Primitive("finite_diff") +fd_prim.multiple_results = True +fd_prim.def_impl(finite_diff_impl) + +def finite_diff(x, fun, delta=1e-5): + return fd_prim.bind(x, fun, delta) +``` + +This allows us to use the forward pass as usual (to compute the first-order derivative): + +```pycon +>>> finite_diff(1., fun, delta=1e-6) +(2.000000000002, 3.999999999892978, 23.000000001216492) +``` + +Now if we want to make this primitive differentiable (with automatic +differentiation/backprop, not by using a higher-order finite difference scheme), +we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example +that we could use to implement a finite difference scheme that is readily differentiable. This is +somewhat beside the point because we did not identify the possibility of using any of those +alternatives in the PennyLane code). + +However, the finite difference rule is just a standard +algebraic function making use of calls to `fun` and some elementary operations, so ideally +we would like to just use the chain rule as it is known to the automatic differentiation framework. A JVP rule would +then just manually re-implement this chain rule, which we'd rather not do. + +Instead, we define a non-interpreted type of primitive and create such a primitive +for our finite difference method. We also create the usual method that binds the +primitive to inputs. + +```python +class NonInterpPrimitive(jax.core.Primitive): + """A subclass to JAX's Primitive that works like a Python function + when evaluating JVPTracers.""" + + def bind_with_trace(self, trace, args, params): + """Bind the ``NonInterpPrimitive`` with a trace. + If the trace is a ``JVPTrace``, it falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance(trace, jax.interpreters.ad.JVPTrace): + return self.impl(*args, **params) + return super().bind_with_trace(trace, args, params) + +fd_prim_2 = NonInterpPrimitive("finite_diff_2") +fd_prim_2.multiple_results = True +fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer + +def finite_diff_2(x, fun, delta=1e-5): + return fd_prim_2.bind(x, fun, delta) +``` + +Now we can use the primitive in a differentiable workflow, without defining a JVP rule +that just repeats the chain rule: + +```pycon +>>> # Forward execution of finite_diff_2 (-> first-order derivative) +>>> finite_diff_2(1., fun, delta=1e-6) +(2.000000000002, 3.999999999892978, 23.000000001216492) +>>> # Differentiation of finite_diff_2 (-> second-order derivative) +>>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6) +(Array(1.9375, dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True), Array(498., dtype=float32, weak_type=True)) +``` + +In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, quantum operators +have non-interpreted primitives as well. This is because their differentiation is performed +by the surrounding QNode primitive rather than through the standard chain rule that acts +"locally" (in the circuit). In short, we only want gates to store their tracers (which will help +determine the differentiability of gate arguments, for example), but not to do anything with them. + diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 08d88988b79..797a7437abb 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -17,7 +17,6 @@ from collections.abc import Callable import pennylane as qml -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from .compiler import ( @@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable: def _get_while_loop_qfunc_prim(): """Get the while_loop primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - while_loop_prim = create_non_interpreted_prim()("while_loop") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + while_loop_prim = NonInterpPrimitive("while_loop") while_loop_prim.multiple_results = True + while_loop_prim.prim_type = "higher_order" @while_loop_prim.def_impl def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice): @@ -626,10 +629,14 @@ def _decorator(body_fn): def _get_for_loop_qfunc_prim(): """Get the loop_for primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive - for_loop_prim = create_non_interpreted_prim()("for_loop") + for_loop_prim = NonInterpPrimitive("for_loop") for_loop_prim.multiple_results = True + for_loop_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments @for_loop_prim.def_impl diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 48f8ffbc686..38d5e1b34a1 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -591,13 +591,15 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio """ updated_values = {} - jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT} - updated_values["convert_to_numpy"] = ( - execution_config.interface not in jax_interfaces - or execution_config.gradient_method == "adjoint" - # need numpy to use caching, and need caching higher order derivatives - or execution_config.derivative_order > 1 - ) + # uncomment once compilation overhead with jitting improved + # TODO: [sc-82874] + # jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT} + # updated_values["convert_to_numpy"] = ( + # execution_config.interface not in jax_interfaces + # or execution_config.gradient_method == "adjoint" + # # need numpy to use caching, and need caching higher order derivatives + # or execution_config.derivative_order > 1 + # ) for option in execution_config.device_options: if option not in self._device_options: raise qml.DeviceError(f"device option {option} not present on {self}") @@ -643,6 +645,7 @@ def execute( prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))] if max_workers is None: + return tuple( _simulate_wrapper( c, diff --git a/pennylane/gradients/hamiltonian_grad.py b/pennylane/gradients/hamiltonian_grad.py index e83d942dcc9..95769bdc6d1 100644 --- a/pennylane/gradients/hamiltonian_grad.py +++ b/pennylane/gradients/hamiltonian_grad.py @@ -13,6 +13,8 @@ # limitations under the License. """Contains a gradient recipe for the coefficients of Hamiltonians.""" # pylint: disable=protected-access,unnecessary-lambda +import warnings + import pennylane as qml @@ -20,10 +22,19 @@ def hamiltonian_grad(tape, idx): """Computes the tapes necessary to get the gradient of a tape with respect to a Hamiltonian observable's coefficients. + .. warning:: + This function is deprecated and will be removed in v0.42. This gradient recipe is not + required for the new operator arithmetic of PennyLane. + Args: tape (qml.tape.QuantumTape): tape with a single Hamiltonian expectation as measurement idx (int): index of parameter that we differentiate with respect to """ + warnings.warn( + "The 'hamiltonian_grad' function is deprecated and will be removed in v0.42. " + "This gradient recipe is not required for the new operator arithmetic system.", + qml.PennyLaneDeprecationWarning, + ) op, m_pos, p_idx = tape.get_operation(idx) diff --git a/pennylane/gradients/parameter_shift.py b/pennylane/gradients/parameter_shift.py index 6fcdd17df19..26a2b0e4a8b 100644 --- a/pennylane/gradients/parameter_shift.py +++ b/pennylane/gradients/parameter_shift.py @@ -15,6 +15,7 @@ This module contains functions for computing the parameter-shift gradient of a qubit-based quantum tape. """ +import warnings from functools import partial import numpy as np @@ -372,6 +373,12 @@ def expval_param_shift( op, op_idx, _ = tape.get_operation(idx) if op.name == "LinearCombination": + warnings.warn( + "Please use qml.gradients.split_to_single_terms so that the ML framework " + "can compute the gradients of the coefficients.", + UserWarning, + ) + # operation is a Hamiltonian if tape[op_idx].return_type is not qml.measurements.Expectation: raise ValueError( diff --git a/pennylane/labs/dla/lie_closure_dense.py b/pennylane/labs/dla/lie_closure_dense.py index 71131cd1064..a2fb76d02e9 100644 --- a/pennylane/labs/dla/lie_closure_dense.py +++ b/pennylane/labs/dla/lie_closure_dense.py @@ -97,9 +97,11 @@ def lie_closure_dense( Compute the Lie closure of the isotropic Heisenberg model with generators :math:`\{X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}\}_{i=0}^{n-1}`. + >>> from pennylane import X, Y, Z + >>> from pennylane.labs.dla import lie_closure_dense >>> n = 5 >>> gens = [X(i) @ X(i+1) + Y(i) @ Y(i+1) + Z(i) @ Z(i+1) for i in range(n-1)] - >>> g = lie_closure_mat(gens, n) + >>> g = lie_closure_dense(gens, n) The result is a ``numpy`` array. We can turn the matrices back into PennyLane operators by employing :func:`~batched_pauli_decompose`. diff --git a/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py b/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py index 6d59fea1b96..440d053dee0 100644 --- a/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py +++ b/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py @@ -643,22 +643,17 @@ class TestResourceMultiControlledX: """Test the ResourceMultiControlledX operation""" res_ops = ( - re.ResourceMultiControlledX(control_wires=[0], wires=["t"], control_values=[1]), - re.ResourceMultiControlledX(control_wires=[0, 1], wires=["t"], control_values=[1, 1]), - re.ResourceMultiControlledX(control_wires=[0, 1, 2], wires=["t"], control_values=[1, 1, 1]), + re.ResourceMultiControlledX(wires=[0, "t"], control_values=[1]), + re.ResourceMultiControlledX(wires=[0, 1, "t"], control_values=[1, 1]), + re.ResourceMultiControlledX(wires=[0, 1, 2, "t"], control_values=[1, 1, 1]), + re.ResourceMultiControlledX(wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 1, 1, 1, 1]), + re.ResourceMultiControlledX(wires=[0, "t"], control_values=[0], work_wires=["w1"]), re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], wires=["t"], control_values=[1, 1, 1, 1, 1] + wires=[0, 1, "t"], control_values=[1, 0], work_wires=["w1", "w2"] ), + re.ResourceMultiControlledX(wires=[0, 1, 2, "t"], control_values=[0, 0, 1]), re.ResourceMultiControlledX( - control_wires=[0], wires=["t"], control_values=[0], work_wires=["w1"] - ), - re.ResourceMultiControlledX( - control_wires=[0, 1], wires=["t"], control_values=[1, 0], work_wires=["w1", "w2"] - ), - re.ResourceMultiControlledX(control_wires=[0, 1, 2], wires=["t"], control_values=[0, 0, 1]), - re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], - wires=["t"], + wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 0, 0, 1, 0], work_wires=["w1"], ), @@ -732,8 +727,7 @@ def test_resource_params(self, op, params): def test_resource_adjoint(self): """Test that the adjoint resources are as expected""" op = re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], - wires=["t"], + wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 0, 0, 1, 0], work_wires=["w1"], ) @@ -777,7 +771,7 @@ def test_resource_adjoint(self): ) def test_resource_controlled(self, ctrl_wires, ctrl_values, work_wires, expected_res): """Test that the controlled resources are as expected""" - op = re.ResourceMultiControlledX(control_wires=[0], wires=["t"], control_values=[1]) + op = re.ResourceMultiControlledX(wires=[0, "t"], control_values=[1]) num_ctrl_wires = len(ctrl_wires) num_ctrl_values = len([v for v in ctrl_values if not v]) @@ -806,8 +800,7 @@ def test_resource_controlled(self, ctrl_wires, ctrl_values, work_wires, expected def test_resource_pow(self, z, expected_res): """Test that the pow resources are as expected""" op = re.ResourceMultiControlledX( - control_wires=[0, 1, 2, 3, 4], - wires=["t"], + wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 0, 0, 1, 0], work_wires=["w1"], ) diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index 5cdcd8cd708..3c9bdc8f1a8 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -243,9 +243,12 @@ def _create_mid_measure_primitive(): measurement. """ - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - mid_measure_p = jax.core.Primitive("measure") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + mid_measure_p = NonInterpPrimitive("measure") @mid_measure_p.def_impl def _(wires, reset=False, postselect=None): diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py index cbc4d0a0bdb..18da1b9d284 100644 --- a/pennylane/measurements/sample.py +++ b/pennylane/measurements/sample.py @@ -228,6 +228,8 @@ def shape(self, shots: Optional[int] = None, num_device_wires: int = 0) -> tuple ) if self.obs: num_values_per_shot = 1 # one single eigenvalue + elif self.mv is not None: + num_values_per_shot = 1 if isinstance(self.mv, MeasurementValue) else len(self.mv) else: # one value per wire num_values_per_shot = len(self.wires) if len(self.wires) > 0 else num_device_wires diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index 400f2fc83c0..5bda04440d4 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -18,7 +18,6 @@ from typing import Callable, overload import pennylane as qml -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.math import conj, moveaxis, transpose from pennylane.operation import Observable, Operation, Operator @@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy): def _get_adjoint_qfunc_prim(): """See capture/explanations.md : Higher Order primitives for more information on this code.""" # if capture is enabled, jax should be installed - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive - adjoint_prim = create_non_interpreted_prim()("adjoint_transform") + adjoint_prim = NonInterpPrimitive("adjoint_transform") adjoint_prim.multiple_results = True + adjoint_prim.prim_type = "higher_order" @adjoint_prim.def_impl def _(*args, jaxpr, lazy, n_consts): diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index deace92e73c..a15fdafff1d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -20,7 +20,6 @@ import pennylane as qml from pennylane import QueuingManager -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from pennylane.compiler import compiler from pennylane.measurements import MeasurementValue @@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement def _get_cond_qfunc_prim(): """Get the cond primitive for quantum functions.""" - import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + import jax - cond_prim = create_non_interpreted_prim()("cond") + from pennylane.capture.custom_primitives import NonInterpPrimitive + + cond_prim = NonInterpPrimitive("cond") cond_prim.multiple_results = True + cond_prim.prim_type = "higher_order" @cond_prim.def_impl def _(*all_args, jaxpr_branches, consts_slices, args_slice): diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index d49209660c6..17e62323223 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -28,7 +28,6 @@ import pennylane as qml from pennylane import math as qmlmath from pennylane import operation -from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.operation import Operator from pennylane.wires import Wires, WiresLike @@ -233,10 +232,15 @@ def wrapper(*args, **kwargs): def _get_ctrl_qfunc_prim(): """See capture/explanations.md : Higher Order primitives for more information on this code.""" # if capture is enabled, jax should be installed - import jax # pylint: disable=import-outside-toplevel - ctrl_prim = create_non_interpreted_prim()("ctrl_transform") + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture.custom_primitives import NonInterpPrimitive + + ctrl_prim = NonInterpPrimitive("ctrl_transform") ctrl_prim.multiple_results = True + ctrl_prim.prim_type = "higher_order" @ctrl_prim.def_impl def _(*args, n_control, jaxpr, control_values, work_wires, n_consts): diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index e00bda09c8d..1cefb724cd2 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -540,12 +540,13 @@ def final_transform(self): def _create_transform_primitive(name): try: # pylint: disable=import-outside-toplevel - import jax + from pennylane.capture.custom_primitives import NonInterpPrimitive except ImportError: return None - transform_prim = jax.core.Primitive(name + "_transform") + transform_prim = NonInterpPrimitive(name + "_transform") transform_prim.multiple_results = True + transform_prim.prim_type = "transform" @transform_prim.def_impl def _( diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 418e68941a6..85dc0320fb7 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -70,7 +70,7 @@ def _get_plxpr_cancel_inverses(): # pylint: disable=missing-function-docstring, # pylint: disable=import-outside-toplevel from jax import make_jaxpr - from pennylane.capture import AbstractMeasurement, AbstractOperator, PlxprInterpreter + from pennylane.capture import PlxprInterpreter from pennylane.operation import Operator except ImportError: # pragma: no cover return None, None @@ -204,15 +204,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractOperator): + elif getattr(eqn.primitive, "prim_type", "") == "operator": outvals = self.interpret_operation_eqn(eqn) - elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractMeasurement): + elif getattr(eqn.primitive, "prim_type", "") == "measurement": self.interpret_all_previous_ops() outvals = self.interpret_measurement_eqn(eqn) else: # Transform primitives don't have custom handlers, so we check for them here # to purge the stored ops in self.previous_ops - if eqn.primitive.name.endswith("_transform"): + if getattr(eqn.primitive, "prim_type", "") == "transform": self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 05f6b196440..2552770a159 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -107,7 +107,6 @@ """ from copy import copy -from dataclasses import asdict from functools import partial from numbers import Number from warnings import warn @@ -117,6 +116,7 @@ import pennylane as qml from pennylane.capture import FlatFn +from pennylane.capture.custom_primitives import QmlPrimitive from pennylane.typing import TensorLike @@ -177,8 +177,9 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=( return shapes -qnode_prim = jax.core.Primitive("qnode") +qnode_prim = QmlPrimitive("qnode") qnode_prim.multiple_results = True +qnode_prim.prim_type = "higher_order" # pylint: disable=too-many-arguments, unused-argument @@ -249,7 +250,6 @@ def _qnode_batching_rule( "using parameter broadcasting to a quantum operation that supports batching.", UserWarning, ) - # To resolve this ambiguity, we might add more properties to the AbstractOperator # class to indicate which operators support batching and check them here. # As above, at this stage we raise a warning and give the user full flexibility. @@ -277,15 +277,43 @@ def _qnode_batching_rule( return result, (0,) * len(result) +### JVP CALCULATION ######################################################### +# This structure will change as we add more diff methods + + def _make_zero(tan, arg): return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan -def _qnode_jvp(args, tangents, **impl_kwargs): +def _backprop(args, tangents, **impl_kwargs): tangents = tuple(map(_make_zero, tangents, args)) return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents) +diff_method_map = {"backprop": _backprop} + + +def _resolve_diff_method(diff_method: str, device) -> str: + # check if best is backprop + if diff_method == "best": + config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax") + diff_method = device.setup_execution_config(config).gradient_method + + if diff_method not in diff_method_map: + raise NotImplementedError(f"diff_method {diff_method} not yet implemented.") + + return diff_method + + +def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs): + diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device) + return diff_method_map[diff_method]( + args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs + ) + + +### END JVP CALCULATION ####################################################### + ad.primitive_jvps[qnode_prim] = _qnode_jvp batching.primitive_batchers[qnode_prim] = _qnode_batching_rule @@ -375,8 +403,7 @@ def f(x): qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args) execute_kwargs = copy(qnode.execute_kwargs) - mcm_config = asdict(execute_kwargs.pop("mcm_config")) - qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} + qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs} flat_args = jax.tree_util.tree_leaves(args) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index cfe93f322d2..fb9ac5ef708 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -47,14 +47,16 @@ def execute( diff_method: Optional[Union[Callable, SupportedDiffMethods, TransformDispatcher]] = None, interface: Optional[InterfaceLike] = Interface.AUTO, *, + transform_program: TransformProgram = None, grad_on_execution: Literal[bool, "best"] = "best", cache: Union[None, bool, dict, Cache] = True, cachesize: int = 10000, max_diff: int = 1, device_vjp: Union[bool, None] = False, + postselect_mode=None, + mcm_method=None, gradient_kwargs: dict = None, - transform_program: TransformProgram = None, - mcm_config: "qml.devices.MCMConfig" = None, + mcm_config: "qml.devices.MCMConfig" = "unset", config="unset", inner_transform="unset", ) -> ResultBatch: @@ -86,10 +88,18 @@ def execute( (classical) computational overhead during the backward pass. device_vjp=False (Optional[bool]): whether or not to use the device-provided Jacobian product if it is available. - mcm_config (dict): Dictionary containing configuration options for handling - mid-circuit measurements. + postselect_mode (str): Configuration for handling shots with mid-circuit measurement + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to + keep the same number of shots. Default is ``None``. + mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements. + ``"deferred"`` is ignored. If mid-circuit measurements are found in the circuit, + the device will use ``"tree-traversal"`` if specified and the ``"one-shot"`` method + otherwise. For usage details, please refer to the + :doc:`dynamic quantum circuits page `. gradient_kwargs (dict): dictionary of keyword arguments to pass when determining the gradients of tapes. + mcm_config="unset": **DEPRECATED**. This keyword argument has been replaced by ``postselect_mode`` + and ``mcm_method`` and will be removed in v0.42. config="unset": **DEPRECATED**. This keyword argument has been deprecated and will be removed in v0.42. inner_transform="unset": **DEPRECATED**. This keyword argument has been deprecated @@ -173,6 +183,14 @@ def cost_fn(params, x): qml.PennyLaneDeprecationWarning, ) + if mcm_config != "unset": + warn( + "The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.", + qml.PennyLaneDeprecationWarning, + ) + mcm_method = mcm_config.mcm_method + postselect_mode = mcm_config.postselect_mode + if logger.isEnabledFor(logging.DEBUG): logger.debug( ( @@ -209,7 +227,7 @@ def cost_fn(params, x): gradient_method=diff_method, grad_on_execution=None if grad_on_execution == "best" else grad_on_execution, use_device_jacobian_product=device_vjp, - mcm_config=mcm_config or {}, + mcm_config=qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method), gradient_keyword_arguments=gradient_kwargs or {}, derivative_order=max_diff, ) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 29430c9a645..ea1a9e33900 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -107,6 +107,10 @@ def _to_qfunc_output_type( return qml.pytrees.unflatten(results, qfunc_output_structure) +def _validate_mcm_config(postselect_mode: str, mcm_method: str) -> None: + qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method) + + def _validate_gradient_kwargs(gradient_kwargs: dict) -> None: for kwarg in gradient_kwargs: if kwarg == "expansion_strategy": @@ -133,7 +137,8 @@ def _validate_gradient_kwargs(gradient_kwargs: dict) -> None: elif kwarg not in qml.gradients.SUPPORTED_GRADIENT_KWARGS: warnings.warn( f"Received gradient_kwarg {kwarg}, which is not included in the list of " - "standard qnode gradient kwargs." + "standard qnode gradient kwargs. Please specify all gradient kwargs through " + "the gradient_kwargs argument as a dictionary." ) @@ -284,9 +289,7 @@ class QNode: as the name suggests. If not provided, the device will determine the best choice automatically. For usage details, please refer to the :doc:`dynamic quantum circuits page `. - - Keyword Args: - **kwargs: Any additional keyword arguments provided are passed to the differentiation + gradient_kwargs (dict): A dictionary of keyword arguments that are passed to the differentiation method. Please refer to the :mod:`qml.gradients <.gradients>` module for details on supported options for your chosen gradient transform. @@ -505,10 +508,12 @@ def __init__( device_vjp: Union[None, bool] = False, postselect_mode: Literal[None, "hw-like", "fill-shots"] = None, mcm_method: Literal[None, "deferred", "one-shot", "tree-traversal"] = None, - **gradient_kwargs, + gradient_kwargs: Optional[dict] = None, + **kwargs, ): self._init_args = locals() del self._init_args["self"] + del self._init_args["kwargs"] if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -536,7 +541,16 @@ def __init__( if not isinstance(device, qml.devices.Device): device = qml.devices.LegacyDeviceFacade(device) + gradient_kwargs = gradient_kwargs or {} + if kwargs: + if any(k in qml.gradients.SUPPORTED_GRADIENT_KWARGS for k in list(kwargs.keys())): + warnings.warn( + f"Specifying gradient keyword arguments {list(kwargs.keys())} is deprecated and will be removed in v0.42. Instead, please specify all arguments in the gradient_kwargs argument.", + qml.PennyLaneDeprecationWarning, + ) + gradient_kwargs |= kwargs _validate_gradient_kwargs(gradient_kwargs) + if "shots" in inspect.signature(func).parameters: warnings.warn( "Detected 'shots' as an argument to the given quantum function. " @@ -553,17 +567,18 @@ def __init__( self.device = device self._interface = get_canonical_interface_name(interface) self.diff_method = diff_method - mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) cache = (max_diff > 1) if cache == "auto" else cache # execution keyword arguments + _validate_mcm_config(postselect_mode, mcm_method) self.execute_kwargs = { "grad_on_execution": grad_on_execution, "cache": cache, "cachesize": cachesize, "max_diff": max_diff, "device_vjp": device_vjp, - "mcm_config": mcm_config, + "postselect_mode": postselect_mode, + "mcm_method": mcm_method, } # internal data attributes @@ -676,16 +691,19 @@ def circuit(x): tensor(0.5403, dtype=torch.float64) """ if not kwargs: - valid_params = ( - set(self._init_args.copy().pop("gradient_kwargs")) - | qml.gradients.SUPPORTED_GRADIENT_KWARGS - ) + valid_params = set(self._init_args.copy()) | qml.gradients.SUPPORTED_GRADIENT_KWARGS raise ValueError( f"Must specify at least one configuration property to update. Valid properties are: {valid_params}." ) original_init_args = self._init_args.copy() - gradient_kwargs = original_init_args.pop("gradient_kwargs") - original_init_args.update(gradient_kwargs) + # gradient_kwargs defaults to None + original_init_args["gradient_kwargs"] = original_init_args["gradient_kwargs"] or {} + # nested dictionary update + new_gradient_kwargs = kwargs.pop("gradient_kwargs", {}) + old_gradient_kwargs = original_init_args.get("gradient_kwargs").copy() + old_gradient_kwargs.update(new_gradient_kwargs) + kwargs["gradient_kwargs"] = old_gradient_kwargs + original_init_args.update(kwargs) updated_qn = QNode(**original_init_args) # pylint: disable=protected-access diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 0ac083a768c..680d11f083e 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -14,7 +14,6 @@ """ Tests for capturing a qnode into jaxpr. """ -from dataclasses import asdict from functools import partial # pylint: disable=protected-access @@ -130,7 +129,6 @@ def circuit(x): assert eqn0.params["shots"] == qml.measurements.Shots(None) expected_kwargs = {"diff_method": "best"} expected_kwargs.update(circuit.execute_kwargs) - expected_kwargs.update(asdict(expected_kwargs.pop("mcm_config"))) assert eqn0.params["qnode_kwargs"] == expected_kwargs qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] @@ -339,18 +337,61 @@ def circuit(x): assert list(out.keys()) == ["a", "b"] -def test_qnode_jvp(): - """Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule.""" +class TestDifferentiation: - @qml.qnode(qml.device("default.qubit", wires=1)) - def circuit(x): - qml.RX(x, 0) - return qml.expval(qml.Z(0)) + def test_error_backprop_unsupported(self): + """Test an error is raised with backprop if the device does not support it.""" + + # pylint: disable=too-few-public-methods + class DummyDev(qml.devices.Device): + + def execute(self, *_, **__): + return 0 + + with pytest.raises(qml.QuantumFunctionError, match="does not support backprop"): + + @qml.qnode(DummyDev(wires=2), diff_method="backprop") + def _(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + def test_error_unsupported_diff_method(self): + """Test an error is raised for a non-backprop diff method.""" + + @qml.qnode(qml.device("default.qubit", wires=2), diff_method="parameter-shift") + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + with pytest.raises( + NotImplementedError, match="diff_method parameter-shift not yet implemented." + ): + jax.grad(circuit)(0.5) + + @pytest.mark.parametrize("diff_method", ("best", "backprop")) + def test_default_qubit_backprop(self, diff_method): + """Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule.""" + + @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + x = 0.9 + xt = -0.6 + jvp = jax.jvp(circuit, (x,), (xt,)) + assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt)) + + def test_no_gradients_with_lightning(self): + """Test that we get an error if we try and differentiate a lightning execution.""" + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) - x = 0.9 - xt = -0.6 - jvp = jax.jvp(circuit, (x,), (xt,)) - assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt)) + with pytest.raises(NotImplementedError, match=r"diff_method adjoint not yet implemented"): + jax.grad(circuit)(0.5) def test_qnode_jit(): diff --git a/tests/capture/test_custom_primitives.py b/tests/capture/test_custom_primitives.py new file mode 100644 index 00000000000..3d35e3e57e4 --- /dev/null +++ b/tests/capture/test_custom_primitives.py @@ -0,0 +1,48 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for PennyLane custom primitives. +""" +# pylint: disable=wrong-import-position +import pytest + +jax = pytest.importorskip("jax") + +from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive + +pytestmark = pytest.mark.jax + + +def test_qml_primitive_prim_type_default(): + """Test that the default prim_type of a QmlPrimitive is set correctly.""" + prim = QmlPrimitive("primitive") + assert prim._prim_type == PrimitiveType("default") # pylint: disable=protected-access + assert prim.prim_type == "default" + + +@pytest.mark.parametrize("cast_in_enum", [True, False]) +@pytest.mark.parametrize("prim_type", ["operator", "measurement", "transform", "higher_order"]) +def test_qml_primitive_prim_type_setter(prim_type, cast_in_enum): + """Test that the QmlPrimitive.prim_type setter works correctly""" + prim = QmlPrimitive("primitive") + prim.prim_type = PrimitiveType(prim_type) if cast_in_enum else prim_type + assert prim._prim_type == PrimitiveType(prim_type) # pylint: disable=protected-access + assert prim.prim_type == prim_type + + +def test_qml_primitive_prim_type_setter_invalid(): + """Test that setting an invalid prim_type raises an error""" + prim = QmlPrimitive("primitive") + with pytest.raises(ValueError, match="not a valid PrimitiveType"): + prim.prim_type = "blah" diff --git a/tests/capture/test_switches.py b/tests/capture/test_switches.py index 52f50321740..72a170205e7 100644 --- a/tests/capture/test_switches.py +++ b/tests/capture/test_switches.py @@ -32,10 +32,15 @@ def test_switches_with_jax(): def test_switches_without_jax(): """Test switches and status reporting function.""" - - assert qml.capture.enabled() is False - with pytest.raises(ImportError, match="plxpr requires JAX to be installed."): - qml.capture.enable() - assert qml.capture.enabled() is False - assert qml.capture.disable() is None - assert qml.capture.enabled() is False + # We want to skip the test if jax is installed + try: + # pylint: disable=import-outside-toplevel, unused-import + import jax + except ImportError: + + assert qml.capture.enabled() is False + with pytest.raises(ImportError, match="plxpr requires JAX to be installed."): + qml.capture.enable() + assert qml.capture.enabled() is False + assert qml.capture.disable() is None + assert qml.capture.enabled() is False diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py index 42faaa10809..32367c455f4 100644 --- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py +++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py @@ -389,7 +389,10 @@ def func(x, y, z): results1 = func1(*params) jaxpr = str(jax.make_jaxpr(func)(*params)) - assert "pure_callback" not in jaxpr + # will change once we solve the compilation overhead issue + # assert "pure_callback" not in jaxpr + # TODO: [sc-82874] + assert "pure_callback" in jaxpr func2 = jax.jit(func) results2 = func2(*params) diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py index 59a9098f7ce..31a2c978745 100644 --- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py +++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py @@ -142,15 +142,17 @@ def circuit(x): assert dev.tracker.totals["execute_and_derivative_batches"] == 1 @pytest.mark.parametrize("interface", ("jax", "jax-jit")) - def test_not_convert_to_numpy_with_jax(self, interface): + def test_convert_to_numpy_with_jax(self, interface): """Test that we will not convert to numpy when working with jax.""" - + # separate test so we can easily update it once we solve the + # compilation overhead issue + # TODO: [sc-82874] dev = qml.device("default.qubit") config = qml.devices.ExecutionConfig( gradient_method=qml.gradients.param_shift, interface=interface ) processed = dev.setup_execution_config(config) - assert not processed.convert_to_numpy + assert processed.convert_to_numpy def test_convert_to_numpy_with_adjoint(self): """Test that we will convert to numpy with adjoint.""" diff --git a/tests/gradients/core/test_hamiltonian_gradient.py b/tests/gradients/core/test_hamiltonian_gradient.py index 1bcb4bfc4fe..9474faf39a0 100644 --- a/tests/gradients/core/test_hamiltonian_gradient.py +++ b/tests/gradients/core/test_hamiltonian_gradient.py @@ -12,10 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for the gradients.hamiltonian module.""" +import pytest + import pennylane as qml from pennylane.gradients.hamiltonian_grad import hamiltonian_grad +def test_hamiltonian_grad_deprecation(): + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + with qml.queuing.AnnotatedQueue() as q: + qml.RY(0.3, wires=0) + qml.RX(0.5, wires=1) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.Hamiltonian([-1.5, 2.0], [qml.PauliZ(0), qml.PauliZ(1)])) + + tape = qml.tape.QuantumScript.from_queue(q) + tape.trainable_params = {2, 3} + hamiltonian_grad(tape, idx=0) + + def test_behaviour(): """Test that the function behaves as expected.""" @@ -29,10 +46,16 @@ def test_behaviour(): tape = qml.tape.QuantumScript.from_queue(q) tape.trainable_params = {2, 3} - tapes, processing_fn = hamiltonian_grad(tape, idx=0) + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + tapes, processing_fn = hamiltonian_grad(tape, idx=0) res1 = processing_fn(dev.execute(tapes)) - tapes, processing_fn = hamiltonian_grad(tape, idx=1) + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + tapes, processing_fn = hamiltonian_grad(tape, idx=1) res2 = processing_fn(dev.execute(tapes)) with qml.queuing.AnnotatedQueue() as q1: diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index c684d53b0c6..b06b7c711d0 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -1385,7 +1385,10 @@ def test_simple_qnode_expval(self, num_split_times, shots, tol, seed): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1415,7 +1418,10 @@ def test_simple_qnode_expval_two_evolves(self, num_split_times, shots, tol, seed ham_y = qml.pulse.constant * qml.PauliX(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_x)(params[0], T_x) @@ -1444,7 +1450,10 @@ def test_simple_qnode_probs(self, num_split_times, shots, tol, seed): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1471,7 +1480,10 @@ def test_simple_qnode_probs_expval(self, num_split_times, shots, tol, seed): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1490,6 +1502,7 @@ def circuit(params): assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0) jax.clear_caches() + @pytest.mark.xfail # TODO: [sc-82874] @pytest.mark.parametrize("num_split_times", [1, 2]) @pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"]) def test_simple_qnode_jit(self, num_split_times, time_interface): @@ -1503,7 +1516,10 @@ def test_simple_qnode_jit(self, num_split_times, time_interface): ham_single_q_const = qml.pulse.constant * qml.PauliY(0) @qml.qnode( - dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times + dev, + interface="jax", + diff_method=stoch_pulse_grad, + gradient_kwargs={"num_split_times": num_split_times}, ) def circuit(params, T=None): qml.evolve(ham_single_q_const)(params, T) @@ -1542,8 +1558,7 @@ def ansatz(params): dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - sampler_seed=seed, + gradient_kwargs={"num_split_times": num_split_times, "sampler_seed": seed}, ) qnode_backprop = qml.QNode(ansatz, dev, interface="jax") @@ -1574,8 +1589,7 @@ def test_qnode_probs_expval_broadcasting(self, num_split_times, shots, tol, seed dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - use_broadcasting=True, + gradient_kwargs={"num_split_times": num_split_times, "use_broadcasting": True}, ) def circuit(params): qml.evolve(ham_single_q_const)(params, T) @@ -1619,18 +1633,22 @@ def ansatz(params): dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - use_broadcasting=True, - sampler_seed=seed, + gradient_kwargs={ + "num_split_times": num_split_times, + "use_broadcasting": True, + "sampler_seed": seed, + }, ) circuit_no_bc = qml.QNode( ansatz, dev, interface="jax", diff_method=stoch_pulse_grad, - num_split_times=num_split_times, - use_broadcasting=False, - sampler_seed=seed, + gradient_kwargs={ + "num_split_times": num_split_times, + "use_broadcasting": False, + "sampler_seed": seed, + }, ) params = [jnp.array(0.4)] jac_bc = jax.jacobian(circuit_bc)(params) @@ -1684,9 +1702,7 @@ def ansatz(params): dev, interface="jax", diff_method=qml.gradients.stoch_pulse_grad, - num_split_times=7, - use_broadcasting=True, - sampler_seed=seed, + gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True}, ) cost_jax = qml.QNode(ansatz, dev, interface="jax") params = (0.42,) @@ -1729,9 +1745,7 @@ def ansatz(params): dev, interface="jax", diff_method=qml.gradients.stoch_pulse_grad, - num_split_times=7, - use_broadcasting=True, - sampler_seed=seed, + gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True}, ) cost_jax = qml.QNode(ansatz, dev, interface="jax") diff --git a/tests/gradients/finite_diff/test_spsa_gradient.py b/tests/gradients/finite_diff/test_spsa_gradient.py index 1bd2a198bca..413bc8ff6ec 100644 --- a/tests/gradients/finite_diff/test_spsa_gradient.py +++ b/tests/gradients/finite_diff/test_spsa_gradient.py @@ -161,7 +161,7 @@ def test_invalid_sampler_rng(self): """Tests that if sampler_rng has an unexpected type, an error is raised.""" dev = qml.device("default.qubit", wires=1) - @qml.qnode(dev, diff_method="spsa", sampler_rng="foo") + @qml.qnode(dev, diff_method="spsa", gradient_kwargs={"sampler_rng": "foo"}) def circuit(param): qml.RX(param, wires=0) return qml.expval(qml.PauliZ(0)) diff --git a/tests/gradients/parameter_shift/test_cv_gradients.py b/tests/gradients/parameter_shift/test_cv_gradients.py index c9709753283..55955396ab7 100644 --- a/tests/gradients/parameter_shift/test_cv_gradients.py +++ b/tests/gradients/parameter_shift/test_cv_gradients.py @@ -268,7 +268,11 @@ def qf(x, y): grad_F = jax.grad(qf)(*par) - @qml.qnode(device=gaussian_dev, diff_method="parameter-shift", force_order2=True) + @qml.qnode( + device=gaussian_dev, + diff_method="parameter-shift", + gradient_kwargs={"force_order2": True}, + ) def qf2(x, y): qml.Displacement(0.5, 0, wires=[0]) qml.Squeezing(x, 0, wires=[0]) diff --git a/tests/gradients/parameter_shift/test_parameter_shift.py b/tests/gradients/parameter_shift/test_parameter_shift.py index f35a4b4fc95..4b741c5a66b 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift.py +++ b/tests/gradients/parameter_shift/test_parameter_shift.py @@ -3473,58 +3473,6 @@ def test_trainable_coeffs(self, tol, broadcast): assert np.allclose(res[0], expected[0], atol=tol, rtol=0) assert np.allclose(res[1], expected[1], atol=tol, rtol=0) - def test_multiple_hamiltonians(self, tol, broadcast): - """Test multiple trainable Hamiltonian coefficients""" - dev = qml.device("default.qubit", wires=2) - - obs = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)] - coeffs = np.array([0.1, 0.2, 0.3]) - a, b, c = coeffs - H1 = qml.Hamiltonian(coeffs, obs) - - obs = [qml.PauliZ(0)] - coeffs = np.array([0.7]) - d = coeffs[0] - H2 = qml.Hamiltonian(coeffs, obs) - - weights = np.array([0.4, 0.5]) - x, y = weights - - with qml.queuing.AnnotatedQueue() as q: - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - qml.expval(H1) - qml.expval(H2) - - tape = qml.tape.QuantumScript.from_queue(q) - tape.trainable_params = {0, 1, 2, 4, 5} - - res = dev.execute([tape]) - expected = [-c * np.sin(x) * np.sin(y) + np.cos(x) * (a + b * np.sin(y)), d * np.cos(x)] - assert np.allclose(res, expected, atol=tol, rtol=0) - - tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast) - # two shifts per rotation gate (in one batched tape if broadcasting), - # one circuit per trainable H term - assert len(tapes) == 2 * (1 if broadcast else 2) - - res = fn(dev.execute(tapes)) - assert isinstance(res, tuple) - assert len(res) == 2 - assert len(res[0]) == 2 - assert len(res[1]) == 2 - - expected = [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - ], - [-d * np.sin(x), 0], - ] - - assert np.allclose(np.stack(res), expected, atol=tol, rtol=0) - @staticmethod def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False): """Cost function for gradient tests""" @@ -3547,95 +3495,8 @@ def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False): jac = fn(dev.execute(tapes)) return jac - @staticmethod - def cost_fn_expected(weights, coeffs1, coeffs2): - """Analytic jacobian of cost_fn above""" - a, b, c = coeffs1 - d = coeffs2[0] - x, y = weights - return [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - ], - [-d * np.sin(x), 0], - ] - - @pytest.mark.autograd - def test_autograd(self, tol, broadcast): - """Test gradient of multiple trainable Hamiltonian coefficients - using autograd""" - coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True) - coeffs2 = np.array([0.7], requires_grad=True) - weights = np.array([0.4, 0.5], requires_grad=True) - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(res, np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # res = qml.jacobian(self.cost_fn)(weights, coeffs1, coeffs2, dev=dev) - # assert np.allclose(res[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(res[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.tf - def test_tf(self, tol, broadcast): - """Test gradient of multiple trainable Hamiltonian coefficients - using tf""" - import tensorflow as tf - - coeffs1 = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64) - coeffs2 = tf.Variable([0.7], dtype=tf.float64) - weights = tf.Variable([0.4, 0.5], dtype=tf.float64) - - dev = qml.device("default.qubit", wires=2) - - with tf.GradientTape() as _: - jac = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - - expected = self.cost_fn_expected(weights.numpy(), coeffs1.numpy(), coeffs2.numpy()) - assert np.allclose(jac[0], np.array(expected)[0], atol=tol, rtol=0) - assert np.allclose(jac[1], np.array(expected)[1], atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero. - # When activating the following, rename the GradientTape above from _ to t - # --- - # hess = t.jacobian(jac, [coeffs1, coeffs2]) - # assert np.allclose(hess[0][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[1][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.torch - def test_torch(self, tol, broadcast): - """Test gradient of multiple trainable Hamiltonian coefficients - using torch""" - import torch - - coeffs1 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64, requires_grad=True) - coeffs2 = torch.tensor([0.7], dtype=torch.float64, requires_grad=True) - weights = torch.tensor([0.4, 0.5], dtype=torch.float64, requires_grad=True) - - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected( - weights.detach().numpy(), coeffs1.detach().numpy(), coeffs2.detach().numpy() - ) - res = tuple(tuple(_r.detach() for _r in r) for r in res) - assert np.allclose(res, expected, atol=tol, rtol=0) - - # second derivative wrt to Hamiltonian coefficients should be zero - # hess = torch.autograd.functional.jacobian( - # lambda *args: self.cost_fn(*args, dev, broadcast), (weights, coeffs1, coeffs2) - # ) - # assert np.allclose(hess[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - @pytest.mark.jax - def test_jax(self, tol, broadcast): + def test_jax(self, broadcast): """Test gradient of multiple trainable Hamiltonian coefficients using JAX""" import jax @@ -3647,19 +3508,11 @@ def test_jax(self, tol, broadcast): weights = jnp.array([0.4, 0.5]) dev = qml.device("default.qubit", wires=2) - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(np.array(res)[:, :2], np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # second derivative wrt to Hamiltonian coefficients should be zero - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) + with pytest.warns( + qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated" + ): + with pytest.warns(UserWarning, match="Please use qml.gradients.split_to_single_terms"): + self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) @pytest.mark.autograd diff --git a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py index 34478ccecd0..e3151735e98 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py +++ b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py @@ -1970,12 +1970,12 @@ def expval(self, observable, **kwargs): dev = DeviceSupporingSpecialObservable(wires=1, shots=None) - @qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast}) def qnode(x): qml.RY(x, wires=0) return qml.expval(SpecialObservable(wires=0)) - @qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast}) def reference_qnode(x): qml.RY(x, wires=0) return qml.expval(qml.PauliZ(wires=0)) @@ -2128,221 +2128,6 @@ def test_trainable_coeffs(self, broadcast, tol): for r in res: assert qml.math.allclose(r, expected, atol=shot_vec_tol) - @pytest.mark.xfail(reason="TODO") - def test_multiple_hamiltonians(self, mocker, broadcast, tol): - """Test multiple trainable Hamiltonian coefficients""" - shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=2, shots=shot_vec) - spy = mocker.spy(qml.gradients, "hamiltonian_grad") - - obs = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)] - coeffs = np.array([0.1, 0.2, 0.3]) - a, b, c = coeffs - H1 = qml.Hamiltonian(coeffs, obs) - - obs = [qml.PauliZ(0)] - coeffs = np.array([0.7]) - d = coeffs[0] - H2 = qml.Hamiltonian(coeffs, obs) - - weights = np.array([0.4, 0.5]) - x, y = weights - - with qml.queuing.AnnotatedQueue() as q: - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - qml.expval(H1) - qml.expval(H2) - - tape = qml.tape.QuantumScript.from_queue(q, shots=shot_vec) - tape.trainable_params = {0, 1, 2, 4, 5} - - res = dev.execute([tape]) - expected = [-c * np.sin(x) * np.sin(y) + np.cos(x) * (a + b * np.sin(y)), d * np.cos(x)] - assert np.allclose(res, expected, atol=tol, rtol=0) - - if broadcast: - with pytest.raises( - NotImplementedError, match="Broadcasting with multiple measurements" - ): - qml.gradients.param_shift(tape, broadcast=broadcast) - return - - tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast) - # two shifts per rotation gate, one circuit per trainable H term - assert len(tapes) == 2 * 2 + 3 - spy.assert_called() - - res = fn(dev.execute(tapes)) - assert isinstance(res, tuple) - assert len(res) == 2 - assert len(res[0]) == 5 - assert len(res[1]) == 5 - - expected = [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - np.cos(x), - -(np.sin(x) * np.sin(y)), - 0, - ], - [-d * np.sin(x), 0, 0, 0, np.cos(x)], - ] - - assert np.allclose(np.stack(res), expected, atol=tol, rtol=0) - - @staticmethod - def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False): - """Cost function for gradient tests""" - obs1 = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)] - H1 = qml.Hamiltonian(coeffs1, obs1) - - obs2 = [qml.PauliZ(0)] - H2 = qml.Hamiltonian(coeffs2, obs2) - - with qml.queuing.AnnotatedQueue() as q: - qml.RX(weights[0], wires=0) - qml.RY(weights[1], wires=1) - qml.CNOT(wires=[0, 1]) - qml.expval(H1) - qml.expval(H2) - - tape = qml.tape.QuantumScript.from_queue(q, shots=dev.shots) - tape.trainable_params = {0, 1, 2, 3, 4, 5} - tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast) - return fn(dev.execute(tapes)) - - @staticmethod - def cost_fn_expected(weights, coeffs1, coeffs2): - """Analytic jacobian of cost_fn above""" - a, b, c = coeffs1 - d = coeffs2[0] - x, y = weights - return [ - [ - -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)), - b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x), - np.cos(x), - np.cos(x) * np.sin(y), - -(np.sin(x) * np.sin(y)), - 0, - ], - [-d * np.sin(x), 0, 0, 0, 0, np.cos(x)], - ] - - @pytest.mark.xfail(reason="TODO") - @pytest.mark.autograd - def test_autograd(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients - using autograd""" - coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True) - coeffs2 = np.array([0.7], requires_grad=True) - weights = np.array([0.4, 0.5], requires_grad=True) - shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=2, shots=shot_vec) - - if broadcast: - with pytest.raises( - NotImplementedError, match="Broadcasting with multiple measurements" - ): - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - return - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(res, np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # res = qml.jacobian(self.cost_fn)(weights, coeffs1, coeffs2, dev=dev) - # assert np.allclose(res[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(res[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.xfail(reason="TODO") - @pytest.mark.tf - def test_tf(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients using tf""" - import tensorflow as tf - - coeffs1 = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64) - coeffs2 = tf.Variable([0.7], dtype=tf.float64) - weights = tf.Variable([0.4, 0.5], dtype=tf.float64) - - shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=2, shots=shot_vec) - - with tf.GradientTape() as _: - jac = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - - expected = self.cost_fn_expected(weights.numpy(), coeffs1.numpy(), coeffs2.numpy()) - assert np.allclose(jac[0], np.array(expected)[0], atol=tol, rtol=0) - assert np.allclose(jac[1], np.array(expected)[1], atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # When activating the following, rename the GradientTape above from _ to t - # --- - # hess = t.jacobian(jac, [coeffs1, coeffs2]) - # assert np.allclose(hess[0][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[1][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.torch - def test_torch(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients - using torch""" - import torch - - coeffs1 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64, requires_grad=True) - coeffs2 = torch.tensor([0.7], dtype=torch.float64, requires_grad=True) - weights = torch.tensor([0.4, 0.5], dtype=torch.float64, requires_grad=True) - - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected( - weights.detach().numpy(), coeffs1.detach().numpy(), coeffs2.detach().numpy() - ) - for actual, _expected in zip(res, expected): - for val, exp_val in zip(actual, _expected): - assert qml.math.allclose(val.detach(), exp_val, atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # hess = torch.autograd.functional.jacobian( - # lambda *args: self.cost_fn(*args, dev, broadcast), (weights, coeffs1, coeffs2) - # ) - # assert np.allclose(hess[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - # assert np.allclose(hess[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - - @pytest.mark.jax - def test_jax(self, broadcast, tol): - """Test gradient of multiple trainable Hamiltonian coefficients - using JAX""" - import jax - - jnp = jax.numpy - - coeffs1 = jnp.array([0.1, 0.2, 0.3]) - coeffs2 = jnp.array([0.7]) - weights = jnp.array([0.4, 0.5]) - dev = qml.device("default.qubit", wires=2) - - res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast) - expected = self.cost_fn_expected(weights, coeffs1, coeffs2) - assert np.allclose(res, np.array(expected), atol=tol, rtol=0) - - # TODO: test when Hessians are supported with the new return types - # second derivative wrt to Hamiltonian coefficients should be zero - # --- - # second derivative wrt to Hamiltonian coefficients should be zero - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0) - - # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast) - # assert np.allclose(res[:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0) - pauliz = qml.PauliZ(wires=0) proj = qml.Projector([1], wires=0) diff --git a/tests/measurements/test_probs.py b/tests/measurements/test_probs.py index 07bd8fcd463..b5af8118b07 100644 --- a/tests/measurements/test_probs.py +++ b/tests/measurements/test_probs.py @@ -338,7 +338,7 @@ def circuit(): @pytest.mark.jax @pytest.mark.parametrize("shots", (None, 500)) @pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1))) - @pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2])) + @pytest.mark.parametrize("params", (np.pi / 2, [np.pi / 2, np.pi / 2, np.pi / 2])) def test_integration_jax(self, tol_stochastic, shots, obs, params, seed): """Test the probability is correct for a known state preparation when jitted with JAX.""" jax = pytest.importorskip("jax") @@ -359,7 +359,9 @@ def circuit(x): # expected probability, using [00, 01, 10, 11] # ordering, is [0.5, 0.5, 0, 0] - assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params)) + # TODO: [sc-82874] + # revert once we are able to jit end to end without extreme compilation overheads + assert "pure_callback" in str(jax.make_jaxpr(circuit)(params)) res = jax.jit(circuit)(params) expected = np.array([0.5, 0.5, 0, 0]) diff --git a/tests/ops/functions/test_matrix.py b/tests/ops/functions/test_matrix.py index 5adc8912cd2..a1f82e6077a 100644 --- a/tests/ops/functions/test_matrix.py +++ b/tests/ops/functions/test_matrix.py @@ -683,6 +683,7 @@ def circuit(theta): assert np.allclose(matrix, expected_matrix) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.catalyst @pytest.mark.external def test_catalyst(self): diff --git a/tests/ops/op_math/test_exp.py b/tests/ops/op_math/test_exp.py index 5abead84595..6ca4c68091b 100644 --- a/tests/ops/op_math/test_exp.py +++ b/tests/ops/op_math/test_exp.py @@ -768,6 +768,7 @@ def circ(phi): grad = jax.grad(circ)(phi) assert qml.math.allclose(grad, -jnp.sin(phi)) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.catalyst @pytest.mark.external def test_catalyst_qnode(self): diff --git a/tests/resource/test_specs.py b/tests/resource/test_specs.py index a02b35ef97b..5a5764e7153 100644 --- a/tests/resource/test_specs.py +++ b/tests/resource/test_specs.py @@ -33,10 +33,15 @@ class TestSpecsTransform: """Tests for the transform specs using the QNode""" def sample_circuit(self): + @qml.transforms.merge_rotations @qml.transforms.undo_swaps @qml.transforms.cancel_inverses - @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4) + @qml.qnode( + qml.device("default.qubit"), + diff_method="parameter-shift", + gradient_kwargs={"shifts": pnp.pi / 4}, + ) def circuit(x): qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1)) qml.RX(x, wires=0) @@ -222,7 +227,11 @@ def test_splitting_transforms(self): @qml.transforms.split_non_commuting @qml.transforms.merge_rotations - @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4) + @qml.qnode( + qml.device("default.qubit"), + diff_method="parameter-shift", + gradient_kwargs={"shifts": pnp.pi / 4}, + ) def circuit(x): qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1)) qml.RX(x, wires=0) diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py index 885def86e60..a63a2d20129 100644 --- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py +++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py @@ -431,7 +431,7 @@ def circuit(coeffs): qml.MottonenStatePreparation(coeffs, wires=[0, 1]) return qml.probs(wires=[0, 1]) - circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05) + circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", gradient_kwargs={"h": 0.05}) circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift") circuit_exact = qml.QNode(circuit, dev_no_shots) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 38d61cd86cf..6eb3fa851a4 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -76,6 +76,7 @@ def test_compiler(self): assert qml.compiler.available("catalyst") assert qml.compiler.available_compilers() == ["catalyst", "cuda_quantum"] + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_active_compiler(self): """Test `qml.compiler.active_compiler` inside a simple circuit""" dev = qml.device("lightning.qubit", wires=2) @@ -91,6 +92,7 @@ def circuit(phi, theta): assert jnp.allclose(circuit(jnp.pi, jnp.pi / 2), 1.0) assert jnp.allclose(qml.qjit(circuit)(jnp.pi, jnp.pi / 2), -1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_active(self): """Test `qml.compiler.active` inside a simple circuit""" dev = qml.device("lightning.qubit", wires=2) @@ -114,6 +116,7 @@ def test_jax_enable_x64(self, jax_enable_x64): qml.compiler.active() assert jax.config.jax_enable_x64 is jax_enable_x64 + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_circuit(self): """Test JIT compilation of a circuit with 2-qubit""" dev = qml.device("lightning.qubit", wires=2) @@ -128,6 +131,7 @@ def circuit(theta): assert jnp.allclose(circuit(0.5), 0.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_aot(self): """Test AOT compilation of a circuit with 2-qubit""" @@ -152,6 +156,7 @@ def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)): ) assert jnp.allclose(result, expected) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.parametrize( "_in,_out", [ @@ -196,6 +201,7 @@ def workflow1(params1, params2): result = workflow1(params1, params2) assert jnp.allclose(result, expected) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_return_value_dict(self): """Test pytree return values.""" dev = qml.device("lightning.qubit", wires=2) @@ -218,6 +224,7 @@ def circuit1(params): assert jnp.allclose(result["w0"], expected["w0"]) assert jnp.allclose(result["w1"], expected["w1"]) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_python_if(self): """Test JIT compilation with the autograph support""" dev = qml.device("lightning.qubit", wires=2) @@ -235,6 +242,7 @@ def circuit(x: int): assert jnp.allclose(circuit(3), 0.0) assert jnp.allclose(circuit(5), 1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_compilation_opt(self): """Test user-configurable compilation options""" dev = qml.device("lightning.qubit", wires=2) @@ -250,6 +258,7 @@ def circuit(x: float): result_header = "func.func public @circuit(%arg0: tensor) -> tensor" assert result_header in mlir_str + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_adjoint(self): """Test JIT compilation with adjoint support""" dev = qml.device("lightning.qubit", wires=2) @@ -273,6 +282,7 @@ def func(): assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1])) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_qjit_adjoint_lazy(self): """Test that the lazy kwarg is supported.""" dev = qml.device("lightning.qubit", wires=2) @@ -287,6 +297,7 @@ def workflow_pl(theta, wires): assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1])) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_control(self): """Test that control works with qjit.""" dev = qml.device("lightning.qubit", wires=2) @@ -317,6 +328,7 @@ def cond_fn(): class TestCatalystControlFlow: """Test ``qml.qjit`` with Catalyst's control-flow operations""" + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_alternating_while_loop(self): """Test simple while loop.""" dev = qml.device("lightning.qubit", wires=1) @@ -334,6 +346,7 @@ def loop(v): assert jnp.allclose(circuit(1), -1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_nested_while_loops(self): """Test nested while loops.""" dev = qml.device("lightning.qubit", wires=1) @@ -393,6 +406,7 @@ def loop(v): expected = [qml.PauliX(0) for i in range(4)] _ = [qml.assert_equal(i, j) for i, j in zip(tape.operations, expected)] + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_dynamic_wires_for_loops(self): """Test for loops with iteration index-dependant wires.""" dev = qml.device("lightning.qubit", wires=6) @@ -414,6 +428,7 @@ def loop_fn(i): assert jnp.allclose(circuit(6), expected) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_nested_for_loops(self): """Test nested for loops.""" dev = qml.device("lightning.qubit", wires=4) @@ -445,6 +460,7 @@ def inner(j): assert jnp.allclose(circuit(4), jnp.eye(2**4)[0]) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_for_loop_python_fallback(self): """Test that qml.for_loop fallsback to Python interpretation if Catalyst is not available""" @@ -496,6 +512,7 @@ def inner(j): _ = [qml.assert_equal(i, j) for i, j in zip(res, expected)] + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond(self): """Test condition with simple true_fn""" dev = qml.device("lightning.qubit", wires=1) @@ -514,6 +531,7 @@ def ansatz_true(): assert jnp.allclose(circuit(1.4), 1.0) assert jnp.allclose(circuit(1.6), 0.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond_with_else(self): """Test condition with simple true_fn and false_fn""" dev = qml.device("lightning.qubit", wires=1) @@ -535,6 +553,7 @@ def ansatz_false(): assert jnp.allclose(circuit(1.4), 0.16996714) assert jnp.allclose(circuit(1.6), 0.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond_with_elif(self): """Test condition with a simple elif branch""" dev = qml.device("lightning.qubit", wires=1) @@ -558,6 +577,7 @@ def false_fn(): assert jnp.allclose(circuit(1.2), 0.13042371) assert jnp.allclose(circuit(jnp.pi), -1.0) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_cond_with_elifs(self): """Test condition with multiple elif branches""" dev = qml.device("lightning.qubit", wires=1) @@ -630,6 +650,7 @@ def conditional_false_fn(): # pylint: disable=unused-variable class TestCatalystGrad: """Test ``qml.qjit`` with Catalyst's grad operations""" + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_grad_classical_preprocessing(self): """Test the grad transformation with classical preprocessing.""" @@ -647,6 +668,7 @@ def circuit(x): assert jnp.allclose(workflow(2.0), -jnp.pi) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_grad_with_postprocessing(self): """Test the grad transformation with classical preprocessing and postprocessing.""" dev = qml.device("lightning.qubit", wires=1) @@ -665,6 +687,7 @@ def loss(theta): assert jnp.allclose(workflow(1.0), 5.04324559) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_grad_with_multiple_qnodes(self): """Test the grad transformation with multiple QNodes with their own differentiation methods.""" dev = qml.device("lightning.qubit", wires=1) @@ -703,6 +726,7 @@ def dsquare(x: float): assert jnp.allclose(dsquare(2.3), 4.6) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_jacobian_diff_method(self): """Test the Jacobian transformation with the device diff_method.""" dev = qml.device("lightning.qubit", wires=1) @@ -721,6 +745,7 @@ def workflow(p: float): assert jnp.allclose(result, reference) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_jacobian_auto(self): """Test the Jacobian transformation with 'auto'.""" dev = qml.device("lightning.qubit", wires=1) @@ -740,6 +765,7 @@ def circuit(x): assert jnp.allclose(result, reference) + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") def test_jacobian_fd(self): """Test the Jacobian transformation with 'fd'.""" dev = qml.device("lightning.qubit", wires=1) @@ -838,6 +864,7 @@ def f(x): vjp(x, dy) +@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") class TestCatalystSample: """Test qml.sample with Catalyst.""" @@ -858,6 +885,7 @@ def circuit(x): assert circuit(jnp.pi) == 1 +@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") class TestCatalystMCMs: """Test dynamic_one_shot with Catalyst.""" diff --git a/tests/test_qnode.py b/tests/test_qnode.py index c60ae352e4b..5d270d1f9b8 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -36,6 +36,17 @@ def dummyfunc(): return None +def test_additional_kwargs_is_deprecated(): + """Test that passing gradient_kwargs as additional kwargs raises a deprecation warning.""" + dev = qml.device("default.qubit", wires=1) + + with pytest.warns( + qml.PennyLaneDeprecationWarning, + match=r"Specifying gradient keyword arguments \[\'atol\'\] is deprecated", + ): + QNode(dummyfunc, dev, atol=1) + + # pylint: disable=unused-argument class CustomDevice(qml.devices.Device): """A null device that just returns 0.""" @@ -145,24 +156,21 @@ def test_update_gradient_kwargs(self): """Test that gradient kwargs are updated correctly""" dev = qml.device("default.qubit") - @qml.qnode(dev, atol=1) + @qml.qnode(dev, gradient_kwargs={"atol": 1}) def circuit(x): qml.RZ(x, wires=0) qml.CNOT(wires=[0, 1]) qml.RY(x, wires=1) return qml.expval(qml.PauliZ(1)) - assert len(circuit.gradient_kwargs) == 1 - assert list(circuit.gradient_kwargs.keys()) == ["atol"] + assert set(circuit.gradient_kwargs.keys()) == {"atol"} - new_atol_circuit = circuit.update(atol=2) - assert len(new_atol_circuit.gradient_kwargs) == 1 - assert list(new_atol_circuit.gradient_kwargs.keys()) == ["atol"] + new_atol_circuit = circuit.update(gradient_kwargs={"atol": 2}) + assert set(new_atol_circuit.gradient_kwargs.keys()) == {"atol"} assert new_atol_circuit.gradient_kwargs["atol"] == 2 - new_kwarg_circuit = circuit.update(h=1) - assert len(new_kwarg_circuit.gradient_kwargs) == 2 - assert list(new_kwarg_circuit.gradient_kwargs.keys()) == ["atol", "h"] + new_kwarg_circuit = circuit.update(gradient_kwargs={"h": 1}) + assert set(new_kwarg_circuit.gradient_kwargs.keys()) == {"atol", "h"} assert new_kwarg_circuit.gradient_kwargs["atol"] == 1 assert new_kwarg_circuit.gradient_kwargs["h"] == 1 @@ -170,7 +178,7 @@ def circuit(x): UserWarning, match="Received gradient_kwarg blah, which is not included in the list of standard qnode gradient kwargs.", ): - circuit.update(blah=1) + circuit.update(gradient_kwargs={"blah": 1}) def test_update_multiple_arguments(self): """Test that multiple parameters can be updated at once.""" @@ -194,7 +202,7 @@ def test_update_transform_program(self): dev = qml.device("default.qubit", wires=2) @qml.transforms.combine_global_phases - @qml.qnode(dev, atol=1) + @qml.qnode(dev) def circuit(x): qml.RZ(x, wires=0) qml.GlobalPhase(phi=1) @@ -248,7 +256,7 @@ def circuit(return_type): def test_expansion_strategy_error(self): """Test that an error is raised if expansion_strategy is passed to the qnode.""" - with pytest.raises(ValueError, match=r"'expansion_strategy' is no longer"): + with pytest.raises(ValueError, match="'expansion_strategy' is no longer"): @qml.qnode(qml.device("default.qubit"), expansion_strategy="device") def _(): @@ -453,7 +461,7 @@ def test_unrecognized_kwargs_raise_warning(self): with warnings.catch_warnings(record=True) as w: - @qml.qnode(dev, random_kwarg=qml.gradients.finite_diff) + @qml.qnode(dev, gradient_kwargs={"random_kwarg": qml.gradients.finite_diff}) def circuit(params): qml.RX(params[0], wires=0) return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0)) @@ -846,7 +854,7 @@ def test_single_expectation_value_with_argnum_one(self, diff_method, tol): y = pnp.array(-0.654, requires_grad=True) @qnode( - dev, diff_method=diff_method, argnum=[1] + dev, diff_method=diff_method, gradient_kwargs={"argnum": [1]} ) # <--- we only choose one trainable parameter def circuit(x, y): qml.RX(x, wires=[0]) @@ -1320,11 +1328,11 @@ def ansatz0(): return qml.expval(qml.X(0)) with pytest.raises(ValueError, match="'shots' is not a valid gradient_kwarg."): - qml.QNode(ansatz0, dev, shots=100) + qml.QNode(ansatz0, dev, gradient_kwargs={"shots": 100}) with pytest.raises(ValueError, match="'shots' is not a valid gradient_kwarg."): - @qml.qnode(dev, shots=100) + @qml.qnode(dev, gradient_kwargs={"shots": 100}) def _(): return qml.expval(qml.X(0)) @@ -1918,14 +1926,17 @@ def circuit(x, mp): return mp(qml.PauliZ(0)) _ = circuit(1.8, qml.expval, shots=10) - assert circuit.execute_kwargs["mcm_config"] == original_config + assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode + assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method if mcm_method != "one-shot": _ = circuit(1.8, qml.expval) - assert circuit.execute_kwargs["mcm_config"] == original_config + assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode + assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method _ = circuit(1.8, qml.expval, shots=10) - assert circuit.execute_kwargs["mcm_config"] == original_config + assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode + assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method class TestTapeExpansion: diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py index 26c46687934..5a8c63c7d22 100644 --- a/tests/test_qnode_legacy.py +++ b/tests/test_qnode_legacy.py @@ -201,7 +201,7 @@ def test_unrecognized_kwargs_raise_warning(self): with warnings.catch_warnings(record=True) as w: - @qml.qnode(dev, random_kwarg=qml.gradients.finite_diff) + @qml.qnode(dev, gradient_kwargs={"random_kwarg": qml.gradients.finite_diff}) def circuit(params): qml.RX(params[0], wires=0) return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0)) @@ -627,7 +627,7 @@ def test_single_expectation_value_with_argnum_one(self, diff_method, tol): y = pnp.array(-0.654, requires_grad=True) @qnode( - dev, diff_method=diff_method, argnum=[1] + dev, diff_method=diff_method, gradient_kwargs={"argnum": [1]} ) # <--- we only choose one trainable parameter def circuit(x, y): qml.RX(x, wires=[0]) diff --git a/tests/transforms/core/test_transform_dispatcher.py b/tests/transforms/core/test_transform_dispatcher.py index 4a24627c739..6784c3b0a6e 100644 --- a/tests/transforms/core/test_transform_dispatcher.py +++ b/tests/transforms/core/test_transform_dispatcher.py @@ -216,6 +216,7 @@ def test_the_transform_container_attributes(self): class TestTransformDispatcher: # pylint: disable=too-many-public-methods """Test the transform function (validate and dispatch).""" + @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452") @pytest.mark.catalyst @pytest.mark.external def test_error_on_qjit(self): diff --git a/tests/transforms/test_add_noise.py b/tests/transforms/test_add_noise.py index 779dcf91e36..3beab611105 100644 --- a/tests/transforms/test_add_noise.py +++ b/tests/transforms/test_add_noise.py @@ -414,7 +414,7 @@ def test_add_noise_level(self, level1, level2): @qml.transforms.undo_swaps @qml.transforms.merge_rotations @qml.transforms.cancel_inverses - @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4}) def f(w, x, y, z): qml.RX(w, wires=0) qml.RY(x, wires=1) @@ -447,7 +447,7 @@ def test_add_noise_level_with_final(self): @qml.transforms.undo_swaps @qml.transforms.merge_rotations @qml.transforms.cancel_inverses - @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4}) def f(w, x, y, z): qml.RX(w, wires=0) qml.RY(x, wires=1) diff --git a/tests/workflow/interfaces/execute/test_execute.py b/tests/workflow/interfaces/execute/test_execute.py index 8bea335ff7e..4a253727261 100644 --- a/tests/workflow/interfaces/execute/test_execute.py +++ b/tests/workflow/interfaces/execute/test_execute.py @@ -57,6 +57,28 @@ def test_execute_legacy_device(): assert qml.math.allclose(res[0], np.cos(0.1)) +def test_mcm_config_deprecation(mocker): + """Test that mcm_config argument has been deprecated.""" + + tape = qml.tape.QuantumScript( + [qml.RX(qml.numpy.array(1.0), 0)], [qml.expval(qml.Z(0))], shots=10 + ) + dev = qml.device("default.qubit") + + with dev.tracker: + with pytest.warns( + qml.PennyLaneDeprecationWarning, + match="The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.", + ): + spy = mocker.spy(qml.dynamic_one_shot, "_transform") + qml.execute( + (tape,), + dev, + mcm_config=qml.devices.MCMConfig(mcm_method="one-shot", postselect_mode=None), + ) + spy.assert_called_once() + + def test_config_deprecation(): """Test that the config argument has been deprecated.""" diff --git a/tests/workflow/interfaces/qnode/test_autograd_qnode.py b/tests/workflow/interfaces/qnode/test_autograd_qnode.py index ac30592926b..18e5f70cb83 100644 --- a/tests/workflow/interfaces/qnode/test_autograd_qnode.py +++ b/tests/workflow/interfaces/qnode/test_autograd_qnode.py @@ -137,14 +137,15 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, tol, dev device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA a = np.array(0.1, requires_grad=True) b = np.array(0.2, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -183,14 +184,15 @@ def test_jacobian_no_evaluate( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA a = np.array(0.1, requires_grad=True) b = np.array(0.2, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -222,13 +224,14 @@ def test_jacobian_options(self, interface, dev, diff_method, grad_on_execution, a = np.array([0.1, 0.2], requires_grad=True) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( dev, interface=interface, - h=1e-8, - order=2, grad_on_execution=grad_on_execution, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -408,10 +411,10 @@ def test_differentiable_expand( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 10 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 10 tol = TOL_FOR_SPSA # pylint: disable=too-few-public-methods @@ -429,7 +432,7 @@ def decomposition(self): a = np.array(0.1, requires_grad=False) p = np.array([0.1, 0.2, 0.3], requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -555,15 +558,15 @@ def test_probability_differentiation( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -593,14 +596,15 @@ def test_multiple_probability_differentiation( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -660,14 +664,15 @@ def test_ragged_differentiation( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -711,8 +716,9 @@ def test_ragged_differentiation_variance( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard gradient does not support variances.") @@ -720,7 +726,7 @@ def test_ragged_differentiation_variance( x = np.array(0.543, requires_grad=True) y = np.array(-0.654, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -826,12 +832,13 @@ def test_chained_gradient_value( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA dev1 = qml.device("default.qubit") - @qnode(dev1, **kwargs) + @qnode(dev1, **kwargs, gradient_kwargs=gradient_kwargs) def circuit1(a, b, c): qml.RX(a, wires=0) qml.RX(b, wires=1) @@ -1350,8 +1357,9 @@ def test_projector( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard gradient does not support variances.") @@ -1359,7 +1367,7 @@ def test_projector( P = np.array(state, requires_grad=False) x, y = np.array([0.765, -0.654], requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=0) qml.RY(y, wires=1) @@ -1498,16 +1506,17 @@ def test_hamiltonian_expansion_analytic( device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method in ["adjoint", "hadamard"]: pytest.skip("The diff method requested does not yet support Hamiltonians") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 10 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 10 tol = TOL_FOR_SPSA obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) qml.templates.AngleEmbedding(data, wires=[0, 1]) @@ -1584,7 +1593,7 @@ def test_hamiltonian_finite_shots( grad_on_execution=grad_on_execution, max_diff=max_diff, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) diff --git a/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py index 87e534fdb52..54072ae442e 100644 --- a/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py @@ -52,7 +52,7 @@ def test_jac_single_measurement_param( """For one measurement and one param, the gradient is a float.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -74,7 +74,7 @@ def test_jac_single_measurement_multiple_param( """For one measurement and multiple param, the gradient is a tuple of arrays.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -100,7 +100,7 @@ def test_jacobian_single_measurement_multiple_param_array( """For one measurement and multiple param as a single array params, the gradient is an array.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -123,7 +123,7 @@ def test_jacobian_single_measurement_param_probs( dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -146,7 +146,7 @@ def test_jacobian_single_measurement_probs_multiple_param( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -173,7 +173,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -198,7 +198,7 @@ def test_jacobian_expval_expval_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=1, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=1, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -223,7 +223,7 @@ def test_jacobian_expval_expval_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -250,7 +250,7 @@ def test_jacobian_var_var_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -275,7 +275,7 @@ def test_jacobian_var_var_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -299,7 +299,7 @@ def test_jacobian_multiple_measurement_single_param( """The jacobian of multiple measurements with a single params return an array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -322,7 +322,7 @@ def test_jacobian_multiple_measurement_multiple_param( """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -349,7 +349,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -382,7 +382,7 @@ def test_hessian_expval_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -412,7 +412,7 @@ def test_hessian_expval_multiple_param_array( params = np.array([0.1, 0.2]) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -440,7 +440,7 @@ def test_hessian_var_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -470,7 +470,7 @@ def test_hessian_var_multiple_param_array( params = np.array([0.1, 0.2]) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -500,7 +500,7 @@ def test_hessian_probs_expval_multiple_params( par_0 = np.array(0.1) par_1 = np.array(0.2) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -533,7 +533,7 @@ def test_hessian_expval_probs_multiple_param_array( params = np.array([0.1, 0.2]) - @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -571,7 +571,7 @@ def test_single_expectation_value( x = np.array(0.543) y = np.array(-0.654) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -604,7 +604,7 @@ def test_prob_expectation_values( x = np.array(0.543) y = np.array(-0.654) - @qnode(dev, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) diff --git a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py index 862806999c6..02a586be118 100644 --- a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py +++ b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py @@ -230,7 +230,7 @@ def decomposition(self): interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, p): qml.RX(a, wires=0) @@ -266,14 +266,15 @@ def test_jacobian_options( a = np.array([0.1, 0.2], requires_grad=True) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( get_device(dev_name, wires=1, seed=seed), interface=interface, diff_method="finite-diff", - h=1e-8, - approx_order=2, grad_on_execution=grad_on_execution, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -323,7 +324,7 @@ def test_diff_expval_expval( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -389,7 +390,7 @@ def test_jacobian_no_evaluate( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -456,7 +457,7 @@ def test_diff_single_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -512,7 +513,7 @@ def test_diff_multi_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -601,7 +602,7 @@ def test_diff_expval_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -679,7 +680,7 @@ def test_diff_expval_probs_sub_argnums( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -740,7 +741,7 @@ def test_diff_var_probs( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -1130,7 +1131,7 @@ def test_second_derivative( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1183,7 +1184,7 @@ def test_hessian( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1237,7 +1238,7 @@ def test_hessian_vector_valued( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1300,7 +1301,7 @@ def test_hessian_vector_valued_postprocessing( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RX(x[0], wires=0) @@ -1366,7 +1367,7 @@ def test_hessian_vector_valued_separate_args( grad_on_execution=grad_on_execution, max_diff=2, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -1478,7 +1479,7 @@ def test_projector( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=0) @@ -1593,7 +1594,7 @@ def test_hamiltonian_expansion_analytic( grad_on_execution=grad_on_execution, max_diff=max_diff, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) @@ -1663,7 +1664,7 @@ def test_hamiltonian_finite_shots( grad_on_execution=grad_on_execution, max_diff=max_diff, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) @@ -1862,7 +1863,7 @@ def test_gradient( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -2013,7 +2014,7 @@ def test_gradient_scalar_cost_vector_valued_qnode( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -3062,7 +3063,7 @@ def test_single_measurement( grad_on_execution=grad_on_execution, cache=False, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -3123,7 +3124,7 @@ def test_multi_measurements( diff_method=diff_method, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(a, b): qml.RY(a, wires=0) diff --git a/tests/workflow/interfaces/qnode/test_jax_qnode.py b/tests/workflow/interfaces/qnode/test_jax_qnode.py index ae1cb53a398..ae7ea130193 100644 --- a/tests/workflow/interfaces/qnode/test_jax_qnode.py +++ b/tests/workflow/interfaces/qnode/test_jax_qnode.py @@ -188,10 +188,10 @@ def test_differentiable_expand( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 10 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 10 tol = TOL_FOR_SPSA class U3(qml.U3): # pylint:disable=too-few-public-methods @@ -206,7 +206,7 @@ def decomposition(self): a = jax.numpy.array(0.1) p = jax.numpy.array([0.1, 0.2, 0.3]) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -240,12 +240,13 @@ def test_jacobian_options( a = jax.numpy.array([0.1, 0.2]) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( get_device(dev_name, wires=1, seed=seed), interface=interface, diff_method="finite-diff", - h=1e-8, - approx_order=2, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -273,9 +274,9 @@ def test_diff_expval_expval( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } - + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -283,7 +284,7 @@ def test_diff_expval_expval( a = jax.numpy.array(0.1) b = jax.numpy.array(0.2) - @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -334,14 +335,15 @@ def test_jacobian_no_evaluate( if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA a = jax.numpy.array(0.1) b = jax.numpy.array(0.2) - @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -387,8 +389,9 @@ def test_diff_single_probs( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -396,7 +399,7 @@ def test_diff_single_probs( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -436,8 +439,9 @@ def test_diff_multi_probs( "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -445,7 +449,7 @@ def test_diff_multi_probs( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -517,8 +521,9 @@ def test_diff_expval_probs( "grad_on_execution": grad_on_execution, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") @@ -526,7 +531,7 @@ def test_diff_expval_probs( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -599,7 +604,7 @@ def test_diff_expval_probs_sub_argnums( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -646,18 +651,19 @@ def test_diff_var_probs( "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "hadamard": pytest.skip("Hadamard does not support var") if "lightning" in dev_name: pytest.xfail("lightning device_vjp not compatible with jax.jacobian.") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -968,13 +974,14 @@ def test_second_derivative( "max_diff": 2, } + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("Adjoint does not second derivative.") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA - @qnode(get_device(dev_name, wires=0, seed=seed), **kwargs) + @qnode(get_device(dev_name, wires=0, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x): qml.RY(x[0], wires=0) qml.RX(x[1], wires=0) @@ -1024,7 +1031,7 @@ def test_hessian( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1079,7 +1086,7 @@ def test_hessian_vector_valued( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1142,7 +1149,7 @@ def test_hessian_vector_valued_postprocessing( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x): qml.RX(x[0], wires=0) @@ -1208,7 +1215,7 @@ def test_hessian_vector_valued_separate_args( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -1316,7 +1323,7 @@ def test_projector( interface=interface, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=0) @@ -1425,7 +1432,7 @@ def test_split_non_commuting_analytic( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=max_diff, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) @@ -1513,7 +1520,7 @@ def test_hamiltonian_finite_shots( grad_on_execution=grad_on_execution, device_vjp=device_vjp, max_diff=max_diff, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = weights.reshape(1, -1) diff --git a/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py index 697a1e90223..210ba601586 100644 --- a/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py @@ -60,7 +60,7 @@ def test_jac_single_measurement_param( """For one measurement and one param, the gradient is a float.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -86,7 +86,7 @@ def test_jac_single_measurement_multiple_param( """For one measurement and multiple param, the gradient is a tuple of arrays.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -115,7 +115,7 @@ def test_jacobian_single_measurement_multiple_param_array( """For one measurement and multiple param as a single array params, the gradient is an array.""" dev = qml.device(dev_name, wires=1, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -141,7 +141,7 @@ def test_jacobian_single_measurement_param_probs( dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -167,7 +167,7 @@ def test_jacobian_single_measurement_probs_multiple_param( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -199,7 +199,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( the correct dimension""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -227,7 +227,13 @@ def test_jacobian_expval_expval_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=1, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=1, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -264,7 +270,7 @@ def test_jacobian_expval_expval_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -298,7 +304,7 @@ def test_jacobian_var_var_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -336,7 +342,7 @@ def test_jacobian_var_var_multiple_params_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -367,7 +373,7 @@ def test_jacobian_multiple_measurement_single_param( """The jacobian of multiple measurements with a single params return an array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -398,7 +404,7 @@ def test_jacobian_multiple_measurement_multiple_param( """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -438,7 +444,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( """The jacobian of multiple measurements with a multiple params array return a single array.""" dev = qml.device(dev_name, wires=2, shots=shots) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -471,7 +477,13 @@ def test_hessian_expval_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -508,7 +520,13 @@ def test_hessian_expval_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -534,7 +552,13 @@ def test_hessian_var_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -571,7 +595,13 @@ def test_hessian_var_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -597,7 +627,13 @@ def test_hessian_probs_expval_multiple_params( par_0 = jax.numpy.array(0.1) par_1 = jax.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -655,7 +691,13 @@ def test_hessian_expval_probs_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -687,7 +729,13 @@ def test_hessian_probs_var_multiple_params( par_0 = qml.numpy.array(0.1) par_1 = qml.numpy.array(0.2) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -742,7 +790,13 @@ def test_hessian_var_probs_multiple_param_array( params = jax.numpy.array([0.1, 0.2]) - @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs) + @qnode( + dev, + interface=interface, + diff_method=diff_method, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -792,7 +846,7 @@ def test_single_expectation_value( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -827,7 +881,7 @@ def test_prob_expectation_values( x = jax.numpy.array(0.543) y = jax.numpy.array(-0.654) - @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs) + @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py index f24eda4b382..6a20f264ec6 100644 --- a/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py @@ -73,7 +73,7 @@ def test_jac_single_measurement_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a, wires=0) @@ -101,7 +101,7 @@ def test_jac_single_measurement_multiple_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b, **_): qml.RY(a, wires=0) @@ -133,7 +133,7 @@ def test_jacobian_single_measurement_multiple_param_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -162,7 +162,7 @@ def test_jacobian_single_measurement_param_probs( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a, wires=0) @@ -191,7 +191,7 @@ def test_jacobian_single_measurement_probs_multiple_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b, **_): qml.RY(a, wires=0) @@ -224,7 +224,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -255,7 +255,7 @@ def test_jacobian_expval_expval_multiple_params( diff_method=diff_method, interface=interface, max_diff=1, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) @@ -285,7 +285,7 @@ def test_jacobian_expval_expval_multiple_params_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -314,7 +314,7 @@ def test_jacobian_multiple_measurement_single_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a, wires=0) @@ -342,7 +342,7 @@ def test_jacobian_multiple_measurement_multiple_param( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, b, **_): qml.RY(a, wires=0) @@ -374,7 +374,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(a, **_): qml.RY(a[0], wires=0) @@ -421,7 +421,7 @@ def test_hessian_expval_multiple_params( diff_method=diff_method, interface=interface, max_diff=2, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) @@ -471,7 +471,7 @@ def test_single_expectation_value( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) @@ -509,7 +509,7 @@ def test_prob_expectation_values( qml.device(dev_name, seed=seed), diff_method=diff_method, interface=interface, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y, **_): qml.RX(x, wires=[0]) diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py b/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py index 85a39cc588f..79f19bd2f4d 100644 --- a/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py +++ b/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py @@ -158,6 +158,7 @@ def circuit(p1, p2=y, **kwargs): def test_jacobian(self, dev, diff_method, grad_on_execution, device_vjp, tol, interface, seed): """Test jacobian calculation""" + gradient_kwargs = {} kwargs = { "diff_method": diff_method, "grad_on_execution": grad_on_execution, @@ -165,14 +166,14 @@ def test_jacobian(self, dev, diff_method, grad_on_execution, device_vjp, tol, in "device_vjp": device_vjp, } if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA a = tf.Variable(0.1, dtype=tf.float64) b = tf.Variable(0.2, dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -200,14 +201,15 @@ def test_jacobian_options(self, dev, diff_method, grad_on_execution, device_vjp, a = tf.Variable([0.1, 0.2]) + gradient_kwargs = {"approx_order": 2, "h": 1e-8} + @qnode( dev, interface=interface, - h=1e-8, - approx_order=2, diff_method=diff_method, grad_on_execution=grad_on_execution, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -240,7 +242,7 @@ def test_changing_trainability( diff_method=diff_method, grad_on_execution=grad_on_execution, device_vjp=device_vjp, - **diff_kwargs, + gradient_kwargs=diff_kwargs, ) def circuit(a, b): qml.RY(a, wires=0) @@ -370,6 +372,7 @@ def test_differentiable_expand( ): """Test that operation and nested tapes expansion is differentiable""" + gradient_kwargs = {} kwargs = { "diff_method": diff_method, "grad_on_execution": grad_on_execution, @@ -377,8 +380,8 @@ def test_differentiable_expand( "device_vjp": device_vjp, } if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA class U3(qml.U3): @@ -393,7 +396,7 @@ def decomposition(self): a = np.array(0.1) p = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -548,15 +551,16 @@ def test_probability_differentiation( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -605,15 +609,16 @@ def test_ragged_differentiation( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -943,13 +948,14 @@ def test_projector( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("adjoint supports either all expvals or all diagonal measurements.") if diff_method == "hadamard": pytest.skip("Variance not implemented yet.") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA if dev.name == "reference.qubit": pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)") @@ -959,7 +965,7 @@ def test_projector( x, y = 0.765, -0.654 weights = tf.Variable([x, y], dtype=tf.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(weights): qml.RX(weights[0], wires=0) qml.RY(weights[1], wires=1) @@ -1129,16 +1135,17 @@ def test_hamiltonian_expansion_analytic( "interface": interface, "device_vjp": device_vjp, } + gradient_kwargs = {} if diff_method in ["adjoint", "hadamard"]: pytest.skip("The adjoint/hadamard method does not yet support Hamiltonians") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(data, weights, coeffs): weights = tf.reshape(weights, [1, -1]) qml.templates.AngleEmbedding(data, wires=[0, 1]) @@ -1211,7 +1218,7 @@ def test_hamiltonian_finite_shots( max_diff=max_diff, interface=interface, device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = tf.reshape(weights, [1, -1]) diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py index 037469657bc..fafd85e7e1e 100644 --- a/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py +++ b/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py @@ -70,7 +70,7 @@ def test_jac_single_measurement_param( ): """For one measurement and one param, the gradient is a float.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -92,7 +92,7 @@ def test_jac_single_measurement_multiple_param( ): """For one measurement and multiple param, the gradient is a tuple of arrays.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -118,7 +118,7 @@ def test_jacobian_single_measurement_multiple_param_array( ): """For one measurement and multiple param as a single array params, the gradient is an array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -141,7 +141,7 @@ def test_jacobian_single_measurement_param_probs( """For a multi dimensional measurement (probs), check that a single array is returned with the correct dimension""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -164,7 +164,7 @@ def test_jacobian_single_measurement_probs_multiple_param( """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with the correct dimension""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -191,7 +191,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array( """For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with the correct dimension""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -216,7 +216,13 @@ def test_jacobian_expval_expval_multiple_params( par_0 = tf.Variable(0.1) par_1 = tf.Variable(0.2) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=1, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=1, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -240,7 +246,7 @@ def test_jacobian_expval_expval_multiple_params_array( ): """The jacobian of multiple measurements with a multiple params array return a single array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -263,7 +269,7 @@ def test_jacobian_multiple_measurement_single_param( ): """The jacobian of multiple measurements with a single params return an array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a, wires=0) qml.RX(0.2, wires=0) @@ -285,7 +291,7 @@ def test_jacobian_multiple_measurement_multiple_param( ): """The jacobian of multiple measurements with a multiple params return a tuple of arrays.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=0) @@ -311,7 +317,7 @@ def test_jacobian_multiple_measurement_multiple_param_array( ): """The jacobian of multiple measurements with a multiple params array return a single array.""" - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(a): qml.RY(a[0], wires=0) qml.RX(a[1], wires=0) @@ -343,7 +349,13 @@ def test_hessian_expval_multiple_params( par_0 = tf.Variable(0.1, dtype=tf.float64) par_1 = tf.Variable(0.2, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -373,7 +385,13 @@ def test_hessian_expval_multiple_param_array( params = tf.Variable([0.1, 0.2], dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -400,7 +418,13 @@ def test_hessian_probs_expval_multiple_params( par_0 = tf.Variable(0.1, dtype=tf.float64) par_1 = tf.Variable(0.2, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -430,7 +454,13 @@ def test_hessian_expval_probs_multiple_param_array( params = tf.Variable([0.1, 0.2], dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs) + @qnode( + dev, + diff_method=diff_method, + interface=interface, + max_diff=2, + gradient_kwargs=gradient_kwargs, + ) def circuit(x): qml.RX(x[0], wires=[0]) qml.RY(x[1], wires=[1]) @@ -467,7 +497,7 @@ def test_single_expectation_value( x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -500,7 +530,7 @@ def test_prob_expectation_values( x = tf.Variable(0.543, dtype=tf.float64) y = tf.Variable(-0.654, dtype=tf.float64) - @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs) + @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) diff --git a/tests/workflow/interfaces/qnode/test_torch_qnode.py b/tests/workflow/interfaces/qnode/test_torch_qnode.py index e9581898522..a2922a7ce62 100644 --- a/tests/workflow/interfaces/qnode/test_torch_qnode.py +++ b/tests/workflow/interfaces/qnode/test_torch_qnode.py @@ -173,9 +173,10 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, device_v interface=interface, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA a_val = 0.1 @@ -184,7 +185,7 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, device_v a = torch.tensor(a_val, dtype=torch.float64, requires_grad=True) b = torch.tensor(b_val, dtype=torch.float64, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, b): qml.RY(a, wires=0) qml.RX(b, wires=1) @@ -276,14 +277,15 @@ def test_jacobian_options( a = torch.tensor([0.1, 0.2], requires_grad=True) + gradient_kwargs = {"h": 1e-8, "approx_order": 2} + @qnode( dev, diff_method=diff_method, grad_on_execution=grad_on_execution, interface=interface, - h=1e-8, - approx_order=2, device_vjp=device_vjp, + gradient_kwargs=gradient_kwargs, ) def circuit(a): qml.RY(a[0], wires=0) @@ -483,9 +485,10 @@ def test_differentiable_expand( interface=interface, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA class U3(qml.U3): # pylint:disable=too-few-public-methods @@ -501,7 +504,7 @@ def decomposition(self): p_val = [0.1, 0.2, 0.3] p = torch.tensor(p_val, dtype=torch.float64, requires_grad=True) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(a, p): qml.RX(a, wires=0) U3(p[0], p[1], p[2], wires=0) @@ -644,9 +647,9 @@ def test_probability_differentiation( with prob and expval outputs""" if "lightning" in getattr(dev, "name", "").lower(): pytest.xfail("lightning does not support measureing probabilities with adjoint.") - kwargs = {} + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) tol = TOL_FOR_SPSA x_val = 0.543 @@ -660,7 +663,7 @@ def test_probability_differentiation( grad_on_execution=grad_on_execution, interface=interface, device_vjp=device_vjp, - **kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(x, y): qml.RX(x, wires=[0]) @@ -708,9 +711,10 @@ def test_ragged_differentiation( interface=interface, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA x_val = 0.543 @@ -718,7 +722,7 @@ def test_ragged_differentiation( x = torch.tensor(x_val, requires_grad=True, dtype=torch.float64) y = torch.tensor(y_val, requires_grad=True, dtype=torch.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=[0]) qml.RY(y, wires=[1]) @@ -813,7 +817,7 @@ def test_hessian(self, interface, dev, diff_method, grad_on_execution, device_vj max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RY(x[0], wires=0) @@ -866,7 +870,7 @@ def test_hessian_vector_valued( max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RY(x[0], wires=0) @@ -928,7 +932,7 @@ def test_hessian_ragged(self, interface, dev, diff_method, grad_on_execution, de max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RY(x[0], wires=0) @@ -1003,7 +1007,7 @@ def test_hessian_vector_valued_postprocessing( max_diff=2, interface=interface, device_vjp=device_vjp, - **options, + gradient_kwargs=options, ) def circuit(x): qml.RX(x[0], wires=0) @@ -1100,11 +1104,12 @@ def test_projector( grad_on_execution=grad_on_execution, device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("adjoint supports either all expvals or all diagonal measurements") if diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("Hadamard does not support variances.") @@ -1116,7 +1121,7 @@ def test_projector( x, y = 0.765, -0.654 weights = torch.tensor([x, y], requires_grad=True, dtype=torch.float64) - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(x, y): qml.RX(x, wires=0) qml.RY(y, wires=1) @@ -1285,18 +1290,19 @@ def test_hamiltonian_expansion_analytic( interface="torch", device_vjp=device_vjp, ) + gradient_kwargs = {} if diff_method == "adjoint": pytest.skip("The adjoint method does not yet support Hamiltonians") elif diff_method == "spsa": - kwargs["sampler_rng"] = np.random.default_rng(seed) - kwargs["num_directions"] = 20 + gradient_kwargs["sampler_rng"] = np.random.default_rng(seed) + gradient_kwargs["num_directions"] = 20 tol = TOL_FOR_SPSA elif diff_method == "hadamard": pytest.skip("The hadamard method does not yet support Hamiltonians") obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)] - @qnode(dev, **kwargs) + @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs) def circuit(data, weights, coeffs): weights = torch.reshape(weights, [1, -1]) qml.templates.AngleEmbedding(data, wires=[0, 1]) @@ -1389,7 +1395,7 @@ def test_hamiltonian_finite_shots( max_diff=max_diff, interface="torch", device_vjp=device_vjp, - **gradient_kwargs, + gradient_kwargs=gradient_kwargs, ) def circuit(data, weights, coeffs): weights = torch.reshape(weights, [1, -1]) diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py index f11eaee459c..d4288da65d9 100644 --- a/tests/workflow/test_construct_batch.py +++ b/tests/workflow/test_construct_batch.py @@ -72,7 +72,7 @@ def test_get_transform_program_diff_method_transform(self): @partial(qml.transforms.compile, num_passes=2) @partial(qml.transforms.merge_rotations, atol=1e-5) @qml.transforms.cancel_inverses - @qml.qnode(dev, diff_method="parameter-shift", shifts=2) + @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": 2}) def circuit(): return qml.expval(qml.PauliZ(0))