diff --git a/.github/stable/all_interfaces.txt b/.github/stable/all_interfaces.txt
index b53f5569d44..84489804731 100644
--- a/.github/stable/all_interfaces.txt
+++ b/.github/stable/all_interfaces.txt
@@ -38,7 +38,7 @@ isort==5.13.2
jax==0.4.28
jaxlib==0.4.28
Jinja2==3.1.5
-keras==3.7.0
+keras==3.8.0
kiwisolver==1.4.8
lazy-object-proxy==1.10.0
libclang==18.1.1
@@ -68,12 +68,12 @@ nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.1.105
opt_einsum==3.4.0
-optree==0.13.1
+optree==0.14.0
osqp==0.6.7.post3
packaging==24.2
pandas==2.2.3
pathspec==0.12.1
-PennyLane_Lightning==0.40.0
+PennyLane_Lightning==0.41.0
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
@@ -83,7 +83,7 @@ protobuf==4.25.5
py==1.11.0
py-cpuinfo==9.0.0
pydot==3.0.4
-Pygments==2.19.0
+Pygments==2.19.1
pylint==2.7.4
pyparsing==3.2.1
pytest==8.3.4
@@ -101,7 +101,8 @@ qdldl==0.1.7.post5
requests==2.32.3
rich==13.9.4
rustworkx==0.15.1
-scipy==1.15.0
+scipy==1.15.1
+scipy-openblas32==0.3.28.0.2
scs==3.2.7.post2
six==1.17.0
smmap==5.0.2
@@ -115,14 +116,14 @@ termcolor==2.5.0
tf_keras==2.16.0
toml==0.10.2
tomli==2.2.1
-tomli_w==1.1.0
+tomli_w==1.2.0
tomlkit==0.13.2
torch==2.3.0
triton==2.3.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.3.0
-virtualenv==20.28.1
+virtualenv==20.29.1
wcwidth==0.2.13
Werkzeug==3.1.3
wrapt==1.12.1
diff --git a/.github/stable/core.txt b/.github/stable/core.txt
index fc24ed2dff0..925553c4c2f 100644
--- a/.github/stable/core.txt
+++ b/.github/stable/core.txt
@@ -43,7 +43,7 @@ osqp==0.6.7.post3
packaging==24.2
pandas==2.2.3
pathspec==0.12.1
-PennyLane_Lightning==0.40.0
+PennyLane_Lightning==0.41.0
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
@@ -52,7 +52,7 @@ prompt_toolkit==3.0.48
py==1.11.0
py-cpuinfo==9.0.0
pydot==3.0.4
-Pygments==2.19.0
+Pygments==2.19.1
pylint==2.7.4
pyparsing==3.2.1
pytest==8.3.4
@@ -70,7 +70,8 @@ qdldl==0.1.7.post5
requests==2.32.3
rich==13.9.4
rustworkx==0.15.1
-scipy==1.15.0
+scipy==1.15.1
+scipy-openblas32==0.3.28.0.2
scs==3.2.7.post2
six==1.17.0
smmap==5.0.2
@@ -78,11 +79,11 @@ tach==0.13.1
termcolor==2.5.0
toml==0.10.2
tomli==2.2.1
-tomli_w==1.1.0
+tomli_w==1.2.0
tomlkit==0.13.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.3.0
-virtualenv==20.28.1
+virtualenv==20.29.1
wcwidth==0.2.13
wrapt==1.12.1
diff --git a/.github/stable/external.txt b/.github/stable/external.txt
index 5cb949a78d7..80052907bc0 100644
--- a/.github/stable/external.txt
+++ b/.github/stable/external.txt
@@ -27,14 +27,14 @@ clarabel==0.9.0
click==8.1.8
comm==0.2.2
contourpy==1.3.1
-cotengra==0.6.2
+cotengra==0.7.0
coverage==7.6.10
cryptography==44.0.0
cvxopt==1.3.2
cvxpy==1.6.0
cycler==0.12.1
cytoolz==1.0.1
-debugpy==1.8.11
+debugpy==1.8.12
decorator==5.1.1
defusedxml==0.7.1
diastatic-malt==2.15.2
@@ -60,8 +60,8 @@ h11==0.14.0
h5py==3.12.1
httpcore==1.0.7
httpx==0.28.1
-ibm-cloud-sdk-core==3.22.0
-ibm-platform-services==0.59.0
+ibm-cloud-sdk-core==3.22.1
+ibm-platform-services==0.59.1
identify==2.6.5
idna==3.10
iniconfig==2.0.0
@@ -89,7 +89,7 @@ jupyterlab==4.3.4
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==1.1.11
-keras==3.7.0
+keras==3.8.0
kiwisolver==1.4.8
lark==1.1.9
lazy-object-proxy==1.10.0
@@ -120,7 +120,7 @@ numba==0.60.0
numpy==1.26.4
opt_einsum==3.4.0
optax==0.2.4
-optree==0.13.1
+optree==0.14.0
osqp==0.6.7.post3
overrides==7.7.0
packaging==24.2
@@ -129,8 +129,8 @@ pandocfilters==1.5.1
parso==0.8.4
pathspec==0.12.1
pbr==6.1.0
-PennyLane-Catalyst==0.10.0.dev39
-PennyLane-qiskit @ git+https://github.com/PennyLaneAI/pennylane-qiskit.git@40a4d24f126e51e0e3e28a4cd737f883a6fd5ebc
+PennyLane-Catalyst==0.11.0.dev1
+PennyLane-qiskit @ git+https://github.com/PennyLaneAI/pennylane-qiskit.git@b46fbca3372979534bc33af701194df548cf4b16
PennyLane_Lightning==0.40.0
PennyLane_Lightning_Kokkos==0.40.0.dev41
pexpect==4.9.0
@@ -148,10 +148,10 @@ pure_eval==0.2.3
py==1.11.0
py-cpuinfo==9.0.0
pycparser==2.22
-pydantic==2.10.4
+pydantic==2.10.5
pydantic_core==2.27.2
pydot==3.0.4
-Pygments==2.19.0
+Pygments==2.19.1
PyJWT==2.10.1
pylint==2.7.4
pyparsing==3.2.1
@@ -173,11 +173,11 @@ pyzmq==26.2.0
pyzx==0.8.0
qdldl==0.1.7.post5
qiskit==1.2.4
-qiskit-aer==0.15.1
+qiskit-aer==0.16.0
qiskit-ibm-provider==0.11.0
qiskit-ibm-runtime==0.29.0
quimb==1.10.0
-referencing==0.35.1
+referencing==0.36.1
requests==2.32.3
requests_ntlm==1.3.0
rfc3339-validator==0.1.4
@@ -211,7 +211,7 @@ tf_keras==2.16.0
tinycss2==1.4.0
toml==0.10.2
tomli==2.2.1
-tomli_w==1.1.0
+tomli_w==1.2.0
tomlkit==0.13.2
toolz==1.0.0
tornado==6.4.2
@@ -222,12 +222,12 @@ typing_extensions==4.12.2
tzdata==2024.2
uri-template==1.3.0
urllib3==2.3.0
-virtualenv==20.28.1
+virtualenv==20.29.1
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
-websockets==14.1
+websockets==14.2
Werkzeug==3.1.3
widgetsnbextension==3.6.10
wrapt==1.12.1
diff --git a/.github/stable/jax.txt b/.github/stable/jax.txt
index ed2c0b32037..504bcff8e47 100644
--- a/.github/stable/jax.txt
+++ b/.github/stable/jax.txt
@@ -37,7 +37,7 @@ markdown-it-py==3.0.0
matplotlib==3.10.0
mccabe==0.6.1
mdurl==0.1.2
-ml_dtypes==0.5.0
+ml_dtypes==0.5.1
mypy-extensions==1.0.0
networkx==3.4.2
nodeenv==1.9.1
@@ -47,7 +47,7 @@ osqp==0.6.7.post3
packaging==24.2
pandas==2.2.3
pathspec==0.12.1
-PennyLane_Lightning==0.40.0
+PennyLane_Lightning==0.41.0
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
@@ -56,7 +56,7 @@ prompt_toolkit==3.0.48
py==1.11.0
py-cpuinfo==9.0.0
pydot==3.0.4
-Pygments==2.19.0
+Pygments==2.19.1
pylint==2.7.4
pyparsing==3.2.1
pytest==8.3.4
@@ -74,7 +74,8 @@ qdldl==0.1.7.post5
requests==2.32.3
rich==13.9.4
rustworkx==0.15.1
-scipy==1.15.0
+scipy==1.15.1
+scipy-openblas32==0.3.28.0.2
scs==3.2.7.post2
six==1.17.0
smmap==5.0.2
@@ -82,11 +83,11 @@ tach==0.13.1
termcolor==2.5.0
toml==0.10.2
tomli==2.2.1
-tomli_w==1.1.0
+tomli_w==1.2.0
tomlkit==0.13.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.3.0
-virtualenv==20.28.1
+virtualenv==20.29.1
wcwidth==0.2.13
wrapt==1.12.1
diff --git a/.github/stable/tf.txt b/.github/stable/tf.txt
index 634fbd09526..02d559768eb 100644
--- a/.github/stable/tf.txt
+++ b/.github/stable/tf.txt
@@ -34,7 +34,7 @@ identify==2.6.5
idna==3.10
iniconfig==2.0.0
isort==5.13.2
-keras==3.7.0
+keras==3.8.0
kiwisolver==1.4.8
lazy-object-proxy==1.10.0
libclang==18.1.1
@@ -51,12 +51,12 @@ networkx==3.4.2
nodeenv==1.9.1
numpy==1.26.4
opt_einsum==3.4.0
-optree==0.13.1
+optree==0.14.0
osqp==0.6.7.post3
packaging==24.2
pandas==2.2.3
pathspec==0.12.1
-PennyLane_Lightning==0.40.0
+PennyLane_Lightning==0.41.0
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
@@ -66,7 +66,7 @@ protobuf==4.25.5
py==1.11.0
py-cpuinfo==9.0.0
pydot==3.0.4
-Pygments==2.19.0
+Pygments==2.19.1
pylint==2.7.4
pyparsing==3.2.1
pytest==8.3.4
@@ -84,7 +84,8 @@ qdldl==0.1.7.post5
requests==2.32.3
rich==13.9.4
rustworkx==0.15.1
-scipy==1.15.0
+scipy==1.15.1
+scipy-openblas32==0.3.28.0.2
scs==3.2.7.post2
six==1.17.0
smmap==5.0.2
@@ -97,12 +98,12 @@ termcolor==2.5.0
tf_keras==2.16.0
toml==0.10.2
tomli==2.2.1
-tomli_w==1.1.0
+tomli_w==1.2.0
tomlkit==0.13.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.3.0
-virtualenv==20.28.1
+virtualenv==20.29.1
wcwidth==0.2.13
Werkzeug==3.1.3
wrapt==1.12.1
diff --git a/.github/stable/torch.txt b/.github/stable/torch.txt
index 3ab4c0e93f2..46b5f061df6 100644
--- a/.github/stable/torch.txt
+++ b/.github/stable/torch.txt
@@ -59,7 +59,7 @@ osqp==0.6.7.post3
packaging==24.2
pandas==2.2.3
pathspec==0.12.1
-PennyLane_Lightning==0.40.0
+PennyLane_Lightning==0.41.0
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
@@ -68,7 +68,7 @@ prompt_toolkit==3.0.48
py==1.11.0
py-cpuinfo==9.0.0
pydot==3.0.4
-Pygments==2.19.0
+Pygments==2.19.1
pylint==2.7.4
pyparsing==3.2.1
pytest==8.3.4
@@ -86,7 +86,8 @@ qdldl==0.1.7.post5
requests==2.32.3
rich==13.9.4
rustworkx==0.15.1
-scipy==1.15.0
+scipy==1.15.1
+scipy-openblas32==0.3.28.0.2
scs==3.2.7.post2
six==1.17.0
smmap==5.0.2
@@ -95,13 +96,13 @@ tach==0.13.1
termcolor==2.5.0
toml==0.10.2
tomli==2.2.1
-tomli_w==1.1.0
+tomli_w==1.2.0
tomlkit==0.13.2
torch==2.3.0
triton==2.3.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.3.0
-virtualenv==20.28.1
+virtualenv==20.29.1
wcwidth==0.2.13
wrapt==1.12.1
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 02b95edc395..5683d05880c 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -16,15 +16,18 @@
name: "Documentation check"
on:
+ merge_group:
+ types:
+ - checks_requested
pull_request:
types:
- opened
- reopened
- synchronize
- ready_for_review
- # Scheduled trigger on Wednesdays at 3:00am UTC
+ # Scheduled trigger on Monday at 2:47am UTC
schedule:
- - cron: "0 3 * * 3"
+ - cron: "47 2 * * 1"
permissions: write-all
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 0a29e842bf1..ba5d6ce7776 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -1,5 +1,8 @@
name: Formatting check
on:
+ merge_group:
+ types:
+ - checks_requested
pull_request:
types:
- opened
diff --git a/.github/workflows/interface-dependency-versions.yml b/.github/workflows/interface-dependency-versions.yml
index f567321b27c..d8520064dc1 100644
--- a/.github/workflows/interface-dependency-versions.yml
+++ b/.github/workflows/interface-dependency-versions.yml
@@ -26,7 +26,7 @@ on:
description: The version of PyTorch to use for testing
required: false
type: string
- default: '2.3.0'
+ default: '2.5.0'
outputs:
jax-version:
description: The version of JAX to use
@@ -72,6 +72,6 @@ jobs:
outputs:
jax-version: jax==${{ steps.jax.outputs.version }} jaxlib==${{ steps.jax.outputs.version }}
tensorflow-version: tensorflow~=${{ steps.tensorflow.outputs.version }} tf-keras~=${{ steps.tensorflow.outputs.version }}
- pytorch-version: torch==${{ steps.pytorch.outputs.version }}
+ pytorch-version: torch~=${{ steps.pytorch.outputs.version }}
catalyst-nightly: ${{ steps.catalyst.outputs.nightly }}
pennylane-lightning-latest: ${{ steps.pennylane-lightning.outputs.latest }}
diff --git a/.github/workflows/module-validation.yml b/.github/workflows/module-validation.yml
index 0db4f16faeb..66af8bc218e 100644
--- a/.github/workflows/module-validation.yml
+++ b/.github/workflows/module-validation.yml
@@ -1,6 +1,9 @@
name: Validate module imports
on:
+ merge_group:
+ types:
+ - checks_requested
pull_request:
types:
- opened
diff --git a/.github/workflows/tests-gpu.yml b/.github/workflows/tests-gpu.yml
index bd8cd5dae67..92a7bae8ac9 100644
--- a/.github/workflows/tests-gpu.yml
+++ b/.github/workflows/tests-gpu.yml
@@ -3,6 +3,9 @@ on:
push:
branches:
- master
+ merge_group:
+ types:
+ - checks_requested
pull_request:
types:
- opened
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 63c8074e524..620dc39ac57 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -14,6 +14,9 @@ on:
# Scheduled trigger on Monday at 2:47am UTC
schedule:
- cron: "47 2 * * 1"
+ merge_group:
+ types:
+ - checks_requested
concurrency:
group: unit-tests-${{ github.ref }}
diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst
index ff6048e18be..a61b8338222 100644
--- a/doc/development/deprecations.rst
+++ b/doc/development/deprecations.rst
@@ -9,6 +9,25 @@ deprecations are listed below.
Pending deprecations
--------------------
+* The ``mcm_config`` argument to ``qml.execute`` has been deprecated.
+ Instead, use the ``mcm_method`` and ``postselect_mode`` arguments.
+
+ - Deprecated in v0.41
+ - Will be removed in v0.42
+
+* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated
+ and will be removed in v0.42. The gradient keyword arguments should be passed to the new
+ keyword argument ``gradient_kwargs`` via an explicit dictionary, like ``gradient_kwargs={"h": 1e-4}``.
+
+ - Deprecated in v0.41
+ - Will be removed in v0.42
+
+* The `qml.gradients.hamiltonian_grad` function has been deprecated.
+ This gradient recipe is not required with the new operator arithmetic system.
+
+ - Deprecated in v0.41
+ - Will be removed in v0.42
+
* The ``inner_transform_program`` and ``config`` keyword arguments in ``qml.execute`` have been deprecated.
If more detailed control over the execution is required, use ``qml.workflow.run`` with these arguments instead.
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 89e21560387..88280392227 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -29,11 +29,11 @@
'parameter-shift'
```
-* Finite shot and parameter-shift executions on `default.qubit` can now
- be natively jitted end-to-end, leading to performance improvements.
- Devices can now configure whether or not ML framework data is sent to them
- via an `ExecutionConfig.convert_to_numpy` parameter.
+* Devices can now configure whether or not ML framework data is sent to them
+ via an `ExecutionConfig.convert_to_numpy` parameter. This is not used on
+ `default.qubit` due to compilation overheads when jitting.
[(#6788)](https://github.com/PennyLaneAI/pennylane/pull/6788)
+ [(#6869)](https://github.com/PennyLaneAI/pennylane/pull/6869)
* The coefficients of observables now have improved differentiability.
[(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598)
@@ -44,6 +44,9 @@
* An informative error is raised when a `QNode` with `diff_method=None` is differentiated.
[(#6770)](https://github.com/PennyLaneAI/pennylane/pull/6770)
+* The requested `diff_method` is now validated when program capture is enabled.
+ [(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852)
+
Breaking changes π
* `MultiControlledX` no longer accepts strings as control values.
@@ -51,6 +54,7 @@
* The input argument `control_wires` of `MultiControlledX` has been removed.
[(#6832)](https://github.com/PennyLaneAI/pennylane/pull/6832)
+ [(#6862)](https://github.com/PennyLaneAI/pennylane/pull/6862)
* `qml.execute` now has a collection of keyword-only arguments.
[(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598)
@@ -80,17 +84,40 @@
Deprecations π
+* The `mcm_method` keyword in `qml.execute` is deprecated. Instead, use the ``mcm_method`` and ``postselect_mode`` arguments.
+ [(#6807)](https://github.com/PennyLaneAI/pennylane/pull/6807)
+
+* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated
+ and will be removed in v0.42. The gradient keyword arguments should be passed to the new
+ keyword argument `gradient_kwargs` via an explicit dictionary. This change will improve qnode argument
+ validation.
+ [(#6828)](https://github.com/PennyLaneAI/pennylane/pull/6828)
+
+* The `qml.gradients.hamiltonian_grad` function has been deprecated.
+ This gradient recipe is not required with the new operator arithmetic system.
+ [(#6849)](https://github.com/PennyLaneAI/pennylane/pull/6849)
+
* The ``inner_transform_program`` and ``config`` keyword arguments in ``qml.execute`` have been deprecated.
If more detailed control over the execution is required, use ``qml.workflow.run`` with these arguments instead.
[(#6822)](https://github.com/PennyLaneAI/pennylane/pull/6822)
Internal changes βοΈ
+* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module.
+ This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives.
+ Consequently, `QmlPrimitive` is now used to define all PennyLane primitives.
+ [(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847)
+
Documentation π
-* Updated documentation for vibrational Hamiltonians
+* The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`,
+ `qml.qchem.localize_normal_modes`, and `qml.qchem.VibrationalPES` have been updated to include better
+ code examples.
[(#6717)](https://github.com/PennyLaneAI/pennylane/pull/6717)
+* Fixed a typo in the code example for `qml.labs.dla.lie_closure_dense`.
+ [(#6858)](https://github.com/PennyLaneAI/pennylane/pull/6858)
+
Bug fixes π
* `BasisState` now casts its input to integers.
@@ -101,8 +128,10 @@
This release contains contributions from (in alphabetical order):
Yushao Chen,
+Isaac De Vlugt,
Diksha Dhawan,
Pietropaolo Frisoni,
Marcus GisslΓ©n,
Christina Lee,
+Mudit Pandey,
Andrija Paurevic
diff --git a/pennylane/_version.py b/pennylane/_version.py
index 704c372802e..9295240495c 100644
--- a/pennylane/_version.py
+++ b/pennylane/_version.py
@@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""
-__version__ = "0.41.0-dev10"
+__version__ = "0.41.0-dev14"
diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py
index 4af11cb6198..2b7314f7c2e 100644
--- a/pennylane/capture/base_interpreter.py
+++ b/pennylane/capture/base_interpreter.py
@@ -25,8 +25,6 @@
from .flatfn import FlatFn
from .primitives import (
- AbstractMeasurement,
- AbstractOperator,
adjoint_transform_prim,
cond_prim,
ctrl_transform_prim,
@@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
self._env[constvar] = const
for eqn in jaxpr.eqns:
+ primitive = eqn.primitive
+ custom_handler = self._primitive_registrations.get(primitive, None)
- custom_handler = self._primitive_registrations.get(eqn.primitive, None)
if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
- elif isinstance(eqn.outvars[0].aval, AbstractOperator):
+ elif getattr(primitive, "prim_type", "") == "operator":
outvals = self.interpret_operation_eqn(eqn)
- elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
+ elif getattr(primitive, "prim_type", "") == "measurement":
outvals = self.interpret_measurement_eqn(eqn)
else:
invals = [self.read(invar) for invar in eqn.invars]
- outvals = eqn.primitive.bind(*invals, **eqn.params)
+ outvals = primitive.bind(*invals, **eqn.params)
- if not eqn.primitive.multiple_results:
+ if not primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval
diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py
index 482f692df69..ba7c3846693 100644
--- a/pennylane/capture/capture_diff.py
+++ b/pennylane/capture/capture_diff.py
@@ -24,34 +24,6 @@
has_jax = False
-@lru_cache
-def create_non_interpreted_prim():
- """Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace
- and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive.
- """
-
- if not has_jax: # pragma: no cover
- return None
-
- # pylint: disable=too-few-public-methods
- class NonInterpPrimitive(jax.core.Primitive):
- """A subclass to JAX's Primitive that works like a Python function
- when evaluating JVPTracers and BatchTracers."""
-
- def bind_with_trace(self, trace, args, params):
- """Bind the ``NonInterpPrimitive`` with a trace.
-
- If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
- Otherwise, the bind call of JAX's standard Primitive is used."""
- if isinstance(
- trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)
- ):
- return self.impl(*args, **params)
- return super().bind_with_trace(trace, args, params)
-
- return NonInterpPrimitive
-
-
@lru_cache
def _get_grad_prim():
"""Create a primitive for gradient computations.
@@ -60,8 +32,11 @@ def _get_grad_prim():
if not has_jax: # pragma: no cover
return None
- grad_prim = create_non_interpreted_prim()("grad")
+ from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
+
+ grad_prim = NonInterpPrimitive("grad")
grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init
+ grad_prim.prim_type = "higher_order"
# pylint: disable=too-many-arguments
@grad_prim.def_impl
@@ -91,8 +66,14 @@ def _get_jacobian_prim():
"""Create a primitive for Jacobian computations.
This primitive is used when capturing ``qml.jacobian``.
"""
- jacobian_prim = create_non_interpreted_prim()("jacobian")
+ if not has_jax: # pragma: no cover
+ return None
+
+ from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
+
+ jacobian_prim = NonInterpPrimitive("jacobian")
jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init
+ jacobian_prim.prim_type = "higher_order"
# pylint: disable=too-many-arguments
@jacobian_prim.def_impl
diff --git a/pennylane/capture/capture_measurements.py b/pennylane/capture/capture_measurements.py
index 59bf5490679..e23457bc7b7 100644
--- a/pennylane/capture/capture_measurements.py
+++ b/pennylane/capture/capture_measurements.py
@@ -128,7 +128,10 @@ def create_measurement_obs_primitive(
if not has_jax:
return None
- primitive = jax.core.Primitive(name + "_obs")
+ from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
+
+ primitive = NonInterpPrimitive(name + "_obs")
+ primitive.prim_type = "measurement"
@primitive.def_impl
def _(obs, **kwargs):
@@ -165,7 +168,10 @@ def create_measurement_mcm_primitive(
if not has_jax:
return None
- primitive = jax.core.Primitive(name + "_mcm")
+ from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
+
+ primitive = NonInterpPrimitive(name + "_mcm")
+ primitive.prim_type = "measurement"
@primitive.def_impl
def _(*mcms, single_mcm=True, **kwargs):
@@ -200,7 +206,10 @@ def create_measurement_wires_primitive(
if not has_jax:
return None
- primitive = jax.core.Primitive(name + "_wires")
+ from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
+
+ primitive = NonInterpPrimitive(name + "_wires")
+ primitive.prim_type = "measurement"
@primitive.def_impl
def _(*args, has_eigvals=False, **kwargs):
diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py
index 23c98f38944..2124b5b9fe4 100644
--- a/pennylane/capture/capture_operators.py
+++ b/pennylane/capture/capture_operators.py
@@ -20,8 +20,6 @@
import pennylane as qml
-from .capture_diff import create_non_interpreted_prim
-
has_jax = True
try:
import jax
@@ -103,7 +101,10 @@ def create_operator_primitive(
if not has_jax:
return None
- primitive = create_non_interpreted_prim()(operator_type.__name__)
+ from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
+
+ primitive = NonInterpPrimitive(operator_type.__name__)
+ primitive.prim_type = "operator"
@primitive.def_impl
def _(*args, **kwargs):
diff --git a/pennylane/capture/custom_primitives.py b/pennylane/capture/custom_primitives.py
new file mode 100644
index 00000000000..183ae05771b
--- /dev/null
+++ b/pennylane/capture/custom_primitives.py
@@ -0,0 +1,64 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This submodule offers custom primitives for the PennyLane capture module.
+"""
+from enum import Enum
+from typing import Union
+
+import jax
+
+
+class PrimitiveType(Enum):
+ """Enum to define valid set of primitive classes"""
+
+ DEFAULT = "default"
+ OPERATOR = "operator"
+ MEASUREMENT = "measurement"
+ HIGHER_ORDER = "higher_order"
+ TRANSFORM = "transform"
+
+
+# pylint: disable=too-few-public-methods,abstract-method
+class QmlPrimitive(jax.core.Primitive):
+ """A subclass for JAX's Primitive that differentiates between different
+ classes of primitives."""
+
+ _prim_type: PrimitiveType = PrimitiveType.DEFAULT
+
+ @property
+ def prim_type(self):
+ """Value of Enum representing the primitive type to differentiate between various
+ sets of PennyLane primitives."""
+ return self._prim_type.value
+
+ @prim_type.setter
+ def prim_type(self, value: Union[str, PrimitiveType]):
+ """Setter for QmlPrimitive.prim_type."""
+ self._prim_type = PrimitiveType(value)
+
+
+# pylint: disable=too-few-public-methods,abstract-method
+class NonInterpPrimitive(QmlPrimitive):
+ """A subclass to JAX's Primitive that works like a Python function
+ when evaluating JVPTracers and BatchTracers."""
+
+ def bind_with_trace(self, trace, args, params):
+ """Bind the ``NonInterpPrimitive`` with a trace.
+
+ If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
+ Otherwise, the bind call of JAX's standard Primitive is used."""
+ if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)):
+ return self.impl(*args, **params)
+ return super().bind_with_trace(trace, args, params)
diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md
index 84feef9786f..71033aac3ee 100644
--- a/pennylane/capture/explanations.md
+++ b/pennylane/capture/explanations.md
@@ -255,7 +255,7 @@ class MyClass(metaclass=MyMetaClass):
self.kwargs = kwargs
```
- Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}.
+ Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}.
And that we have set a class property `a`
@@ -272,7 +272,7 @@ But can we actually create instances of these classes?
```python
>> obj = MyClass(0.1, a=2)
>>> obj
-creating an instance of type with (0.1,), {'a': 2}.
+creating an instance of type with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```
@@ -294,7 +294,7 @@ class MyClass2(metaclass=MetaClass2):
self.args = args
```
-You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.
+You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.
Using a metaclass, we can hijack what happens when a type is called.
@@ -425,3 +425,103 @@ Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `A
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```
+
+# Non-interpreted primitives
+
+Some of the primitives in the capture module have a somewhat non-standard requirement for the
+behaviour under differentiation or batching: they should ignore that an input is a differentiation
+or batching tracer and just execute the standard implementation on them.
+
+We will look at an example to make the necessity for such a non-interpreted primitive clear.
+
+Consider a finite-difference differentiation routine together with some test function `fun`.
+
+```python
+def finite_diff_impl(x, fun, delta):
+ """Finite difference differentiation routine. Only supports differentiating
+ a function `fun` with a single scalar argument, for simplicity."""
+
+ out_plus = fun(x + delta)
+ out_minus = fun(x - delta)
+ return tuple((out_p - out_m) / (2 * delta) for out_p, out_m in zip(out_plus, out_minus))
+
+def fun(x):
+ return (x**2, 4 * x - 3, x**23)
+```
+
+Now suppose we want to turn this into a primitive. We could just promote it to a standard
+`jax.core.Primitive` as
+
+```python
+import jax
+
+fd_prim = jax.core.Primitive("finite_diff")
+fd_prim.multiple_results = True
+fd_prim.def_impl(finite_diff_impl)
+
+def finite_diff(x, fun, delta=1e-5):
+ return fd_prim.bind(x, fun, delta)
+```
+
+This allows us to use the forward pass as usual (to compute the first-order derivative):
+
+```pycon
+>>> finite_diff(1., fun, delta=1e-6)
+(2.000000000002, 3.999999999892978, 23.000000001216492)
+```
+
+Now if we want to make this primitive differentiable (with automatic
+differentiation/backprop, not by using a higher-order finite difference scheme),
+we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example
+that we could use to implement a finite difference scheme that is readily differentiable. This is
+somewhat beside the point because we did not identify the possibility of using any of those
+alternatives in the PennyLane code).
+
+However, the finite difference rule is just a standard
+algebraic function making use of calls to `fun` and some elementary operations, so ideally
+we would like to just use the chain rule as it is known to the automatic differentiation framework. A JVP rule would
+then just manually re-implement this chain rule, which we'd rather not do.
+
+Instead, we define a non-interpreted type of primitive and create such a primitive
+for our finite difference method. We also create the usual method that binds the
+primitive to inputs.
+
+```python
+class NonInterpPrimitive(jax.core.Primitive):
+ """A subclass to JAX's Primitive that works like a Python function
+ when evaluating JVPTracers."""
+
+ def bind_with_trace(self, trace, args, params):
+ """Bind the ``NonInterpPrimitive`` with a trace.
+ If the trace is a ``JVPTrace``, it falls back to a standard Python function call.
+ Otherwise, the bind call of JAX's standard Primitive is used."""
+ if isinstance(trace, jax.interpreters.ad.JVPTrace):
+ return self.impl(*args, **params)
+ return super().bind_with_trace(trace, args, params)
+
+fd_prim_2 = NonInterpPrimitive("finite_diff_2")
+fd_prim_2.multiple_results = True
+fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer
+
+def finite_diff_2(x, fun, delta=1e-5):
+ return fd_prim_2.bind(x, fun, delta)
+```
+
+Now we can use the primitive in a differentiable workflow, without defining a JVP rule
+that just repeats the chain rule:
+
+```pycon
+>>> # Forward execution of finite_diff_2 (-> first-order derivative)
+>>> finite_diff_2(1., fun, delta=1e-6)
+(2.000000000002, 3.999999999892978, 23.000000001216492)
+>>> # Differentiation of finite_diff_2 (-> second-order derivative)
+>>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6)
+(Array(1.9375, dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True), Array(498., dtype=float32, weak_type=True))
+```
+
+In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, quantum operators
+have non-interpreted primitives as well. This is because their differentiation is performed
+by the surrounding QNode primitive rather than through the standard chain rule that acts
+"locally" (in the circuit). In short, we only want gates to store their tracers (which will help
+determine the differentiability of gate arguments, for example), but not to do anything with them.
+
diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py
index 08d88988b79..797a7437abb 100644
--- a/pennylane/compiler/qjit_api.py
+++ b/pennylane/compiler/qjit_api.py
@@ -17,7 +17,6 @@
from collections.abc import Callable
import pennylane as qml
-from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.capture.flatfn import FlatFn
from .compiler import (
@@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable:
def _get_while_loop_qfunc_prim():
"""Get the while_loop primitive for quantum functions."""
- import jax # pylint: disable=import-outside-toplevel
+ # pylint: disable=import-outside-toplevel
+ import jax
- while_loop_prim = create_non_interpreted_prim()("while_loop")
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+ while_loop_prim = NonInterpPrimitive("while_loop")
while_loop_prim.multiple_results = True
+ while_loop_prim.prim_type = "higher_order"
@while_loop_prim.def_impl
def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):
@@ -626,10 +629,14 @@ def _decorator(body_fn):
def _get_for_loop_qfunc_prim():
"""Get the loop_for primitive for quantum functions."""
- import jax # pylint: disable=import-outside-toplevel
+ # pylint: disable=import-outside-toplevel
+ import jax
+
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
- for_loop_prim = create_non_interpreted_prim()("for_loop")
+ for_loop_prim = NonInterpPrimitive("for_loop")
for_loop_prim.multiple_results = True
+ for_loop_prim.prim_type = "higher_order"
# pylint: disable=too-many-arguments
@for_loop_prim.def_impl
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index 48f8ffbc686..38d5e1b34a1 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -591,13 +591,15 @@ def _setup_execution_config(self, execution_config: ExecutionConfig) -> Executio
"""
updated_values = {}
- jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
- updated_values["convert_to_numpy"] = (
- execution_config.interface not in jax_interfaces
- or execution_config.gradient_method == "adjoint"
- # need numpy to use caching, and need caching higher order derivatives
- or execution_config.derivative_order > 1
- )
+ # uncomment once compilation overhead with jitting improved
+ # TODO: [sc-82874]
+ # jax_interfaces = {qml.math.Interface.JAX, qml.math.Interface.JAX_JIT}
+ # updated_values["convert_to_numpy"] = (
+ # execution_config.interface not in jax_interfaces
+ # or execution_config.gradient_method == "adjoint"
+ # # need numpy to use caching, and need caching higher order derivatives
+ # or execution_config.derivative_order > 1
+ # )
for option in execution_config.device_options:
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")
@@ -643,6 +645,7 @@ def execute(
prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
if max_workers is None:
+
return tuple(
_simulate_wrapper(
c,
diff --git a/pennylane/gradients/hamiltonian_grad.py b/pennylane/gradients/hamiltonian_grad.py
index e83d942dcc9..95769bdc6d1 100644
--- a/pennylane/gradients/hamiltonian_grad.py
+++ b/pennylane/gradients/hamiltonian_grad.py
@@ -13,6 +13,8 @@
# limitations under the License.
"""Contains a gradient recipe for the coefficients of Hamiltonians."""
# pylint: disable=protected-access,unnecessary-lambda
+import warnings
+
import pennylane as qml
@@ -20,10 +22,19 @@ def hamiltonian_grad(tape, idx):
"""Computes the tapes necessary to get the gradient of a tape with respect to
a Hamiltonian observable's coefficients.
+ .. warning::
+ This function is deprecated and will be removed in v0.42. This gradient recipe is not
+ required for the new operator arithmetic of PennyLane.
+
Args:
tape (qml.tape.QuantumTape): tape with a single Hamiltonian expectation as measurement
idx (int): index of parameter that we differentiate with respect to
"""
+ warnings.warn(
+ "The 'hamiltonian_grad' function is deprecated and will be removed in v0.42. "
+ "This gradient recipe is not required for the new operator arithmetic system.",
+ qml.PennyLaneDeprecationWarning,
+ )
op, m_pos, p_idx = tape.get_operation(idx)
diff --git a/pennylane/gradients/parameter_shift.py b/pennylane/gradients/parameter_shift.py
index 6fcdd17df19..26a2b0e4a8b 100644
--- a/pennylane/gradients/parameter_shift.py
+++ b/pennylane/gradients/parameter_shift.py
@@ -15,6 +15,7 @@
This module contains functions for computing the parameter-shift gradient
of a qubit-based quantum tape.
"""
+import warnings
from functools import partial
import numpy as np
@@ -372,6 +373,12 @@ def expval_param_shift(
op, op_idx, _ = tape.get_operation(idx)
if op.name == "LinearCombination":
+ warnings.warn(
+ "Please use qml.gradients.split_to_single_terms so that the ML framework "
+ "can compute the gradients of the coefficients.",
+ UserWarning,
+ )
+
# operation is a Hamiltonian
if tape[op_idx].return_type is not qml.measurements.Expectation:
raise ValueError(
diff --git a/pennylane/labs/dla/lie_closure_dense.py b/pennylane/labs/dla/lie_closure_dense.py
index 71131cd1064..a2fb76d02e9 100644
--- a/pennylane/labs/dla/lie_closure_dense.py
+++ b/pennylane/labs/dla/lie_closure_dense.py
@@ -97,9 +97,11 @@ def lie_closure_dense(
Compute the Lie closure of the isotropic Heisenberg model with generators :math:`\{X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}\}_{i=0}^{n-1}`.
+ >>> from pennylane import X, Y, Z
+ >>> from pennylane.labs.dla import lie_closure_dense
>>> n = 5
>>> gens = [X(i) @ X(i+1) + Y(i) @ Y(i+1) + Z(i) @ Z(i+1) for i in range(n-1)]
- >>> g = lie_closure_mat(gens, n)
+ >>> g = lie_closure_dense(gens, n)
The result is a ``numpy`` array. We can turn the matrices back into PennyLane operators by employing :func:`~batched_pauli_decompose`.
diff --git a/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py b/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py
index 6d59fea1b96..440d053dee0 100644
--- a/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py
+++ b/pennylane/labs/tests/resource_estimation/ops/op_math/test_controlled_ops.py
@@ -643,22 +643,17 @@ class TestResourceMultiControlledX:
"""Test the ResourceMultiControlledX operation"""
res_ops = (
- re.ResourceMultiControlledX(control_wires=[0], wires=["t"], control_values=[1]),
- re.ResourceMultiControlledX(control_wires=[0, 1], wires=["t"], control_values=[1, 1]),
- re.ResourceMultiControlledX(control_wires=[0, 1, 2], wires=["t"], control_values=[1, 1, 1]),
+ re.ResourceMultiControlledX(wires=[0, "t"], control_values=[1]),
+ re.ResourceMultiControlledX(wires=[0, 1, "t"], control_values=[1, 1]),
+ re.ResourceMultiControlledX(wires=[0, 1, 2, "t"], control_values=[1, 1, 1]),
+ re.ResourceMultiControlledX(wires=[0, 1, 2, 3, 4, "t"], control_values=[1, 1, 1, 1, 1]),
+ re.ResourceMultiControlledX(wires=[0, "t"], control_values=[0], work_wires=["w1"]),
re.ResourceMultiControlledX(
- control_wires=[0, 1, 2, 3, 4], wires=["t"], control_values=[1, 1, 1, 1, 1]
+ wires=[0, 1, "t"], control_values=[1, 0], work_wires=["w1", "w2"]
),
+ re.ResourceMultiControlledX(wires=[0, 1, 2, "t"], control_values=[0, 0, 1]),
re.ResourceMultiControlledX(
- control_wires=[0], wires=["t"], control_values=[0], work_wires=["w1"]
- ),
- re.ResourceMultiControlledX(
- control_wires=[0, 1], wires=["t"], control_values=[1, 0], work_wires=["w1", "w2"]
- ),
- re.ResourceMultiControlledX(control_wires=[0, 1, 2], wires=["t"], control_values=[0, 0, 1]),
- re.ResourceMultiControlledX(
- control_wires=[0, 1, 2, 3, 4],
- wires=["t"],
+ wires=[0, 1, 2, 3, 4, "t"],
control_values=[1, 0, 0, 1, 0],
work_wires=["w1"],
),
@@ -732,8 +727,7 @@ def test_resource_params(self, op, params):
def test_resource_adjoint(self):
"""Test that the adjoint resources are as expected"""
op = re.ResourceMultiControlledX(
- control_wires=[0, 1, 2, 3, 4],
- wires=["t"],
+ wires=[0, 1, 2, 3, 4, "t"],
control_values=[1, 0, 0, 1, 0],
work_wires=["w1"],
)
@@ -777,7 +771,7 @@ def test_resource_adjoint(self):
)
def test_resource_controlled(self, ctrl_wires, ctrl_values, work_wires, expected_res):
"""Test that the controlled resources are as expected"""
- op = re.ResourceMultiControlledX(control_wires=[0], wires=["t"], control_values=[1])
+ op = re.ResourceMultiControlledX(wires=[0, "t"], control_values=[1])
num_ctrl_wires = len(ctrl_wires)
num_ctrl_values = len([v for v in ctrl_values if not v])
@@ -806,8 +800,7 @@ def test_resource_controlled(self, ctrl_wires, ctrl_values, work_wires, expected
def test_resource_pow(self, z, expected_res):
"""Test that the pow resources are as expected"""
op = re.ResourceMultiControlledX(
- control_wires=[0, 1, 2, 3, 4],
- wires=["t"],
+ wires=[0, 1, 2, 3, 4, "t"],
control_values=[1, 0, 0, 1, 0],
work_wires=["w1"],
)
diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py
index 5cdcd8cd708..3c9bdc8f1a8 100644
--- a/pennylane/measurements/mid_measure.py
+++ b/pennylane/measurements/mid_measure.py
@@ -243,9 +243,12 @@ def _create_mid_measure_primitive():
measurement.
"""
- import jax # pylint: disable=import-outside-toplevel
+ # pylint: disable=import-outside-toplevel
+ import jax
- mid_measure_p = jax.core.Primitive("measure")
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+ mid_measure_p = NonInterpPrimitive("measure")
@mid_measure_p.def_impl
def _(wires, reset=False, postselect=None):
diff --git a/pennylane/measurements/sample.py b/pennylane/measurements/sample.py
index cbc4d0a0bdb..18da1b9d284 100644
--- a/pennylane/measurements/sample.py
+++ b/pennylane/measurements/sample.py
@@ -228,6 +228,8 @@ def shape(self, shots: Optional[int] = None, num_device_wires: int = 0) -> tuple
)
if self.obs:
num_values_per_shot = 1 # one single eigenvalue
+ elif self.mv is not None:
+ num_values_per_shot = 1 if isinstance(self.mv, MeasurementValue) else len(self.mv)
else:
# one value per wire
num_values_per_shot = len(self.wires) if len(self.wires) > 0 else num_device_wires
diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py
index 400f2fc83c0..5bda04440d4 100644
--- a/pennylane/ops/op_math/adjoint.py
+++ b/pennylane/ops/op_math/adjoint.py
@@ -18,7 +18,6 @@
from typing import Callable, overload
import pennylane as qml
-from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.compiler import compiler
from pennylane.math import conj, moveaxis, transpose
from pennylane.operation import Observable, Operation, Operator
@@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy):
def _get_adjoint_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
- import jax # pylint: disable=import-outside-toplevel
+ # pylint: disable=import-outside-toplevel
+ import jax
+
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
- adjoint_prim = create_non_interpreted_prim()("adjoint_transform")
+ adjoint_prim = NonInterpPrimitive("adjoint_transform")
adjoint_prim.multiple_results = True
+ adjoint_prim.prim_type = "higher_order"
@adjoint_prim.def_impl
def _(*args, jaxpr, lazy, n_consts):
diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py
index deace92e73c..a15fdafff1d 100644
--- a/pennylane/ops/op_math/condition.py
+++ b/pennylane/ops/op_math/condition.py
@@ -20,7 +20,6 @@
import pennylane as qml
from pennylane import QueuingManager
-from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.capture.flatfn import FlatFn
from pennylane.compiler import compiler
from pennylane.measurements import MeasurementValue
@@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement
def _get_cond_qfunc_prim():
"""Get the cond primitive for quantum functions."""
- import jax # pylint: disable=import-outside-toplevel
+ # pylint: disable=import-outside-toplevel
+ import jax
- cond_prim = create_non_interpreted_prim()("cond")
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+ cond_prim = NonInterpPrimitive("cond")
cond_prim.multiple_results = True
+ cond_prim.prim_type = "higher_order"
@cond_prim.def_impl
def _(*all_args, jaxpr_branches, consts_slices, args_slice):
diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py
index d49209660c6..17e62323223 100644
--- a/pennylane/ops/op_math/controlled.py
+++ b/pennylane/ops/op_math/controlled.py
@@ -28,7 +28,6 @@
import pennylane as qml
from pennylane import math as qmlmath
from pennylane import operation
-from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.compiler import compiler
from pennylane.operation import Operator
from pennylane.wires import Wires, WiresLike
@@ -233,10 +232,15 @@ def wrapper(*args, **kwargs):
def _get_ctrl_qfunc_prim():
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
# if capture is enabled, jax should be installed
- import jax # pylint: disable=import-outside-toplevel
- ctrl_prim = create_non_interpreted_prim()("ctrl_transform")
+ # pylint: disable=import-outside-toplevel
+ import jax
+
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
+
+ ctrl_prim = NonInterpPrimitive("ctrl_transform")
ctrl_prim.multiple_results = True
+ ctrl_prim.prim_type = "higher_order"
@ctrl_prim.def_impl
def _(*args, n_control, jaxpr, control_values, work_wires, n_consts):
diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py
index e00bda09c8d..1cefb724cd2 100644
--- a/pennylane/transforms/core/transform_dispatcher.py
+++ b/pennylane/transforms/core/transform_dispatcher.py
@@ -540,12 +540,13 @@ def final_transform(self):
def _create_transform_primitive(name):
try:
# pylint: disable=import-outside-toplevel
- import jax
+ from pennylane.capture.custom_primitives import NonInterpPrimitive
except ImportError:
return None
- transform_prim = jax.core.Primitive(name + "_transform")
+ transform_prim = NonInterpPrimitive(name + "_transform")
transform_prim.multiple_results = True
+ transform_prim.prim_type = "transform"
@transform_prim.def_impl
def _(
diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py
index 418e68941a6..85dc0320fb7 100644
--- a/pennylane/transforms/optimization/cancel_inverses.py
+++ b/pennylane/transforms/optimization/cancel_inverses.py
@@ -70,7 +70,7 @@ def _get_plxpr_cancel_inverses(): # pylint: disable=missing-function-docstring,
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr
- from pennylane.capture import AbstractMeasurement, AbstractOperator, PlxprInterpreter
+ from pennylane.capture import PlxprInterpreter
from pennylane.operation import Operator
except ImportError: # pragma: no cover
return None, None
@@ -204,15 +204,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
self.interpret_all_previous_ops()
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
- elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractOperator):
+ elif getattr(eqn.primitive, "prim_type", "") == "operator":
outvals = self.interpret_operation_eqn(eqn)
- elif len(eqn.outvars) > 0 and isinstance(eqn.outvars[0].aval, AbstractMeasurement):
+ elif getattr(eqn.primitive, "prim_type", "") == "measurement":
self.interpret_all_previous_ops()
outvals = self.interpret_measurement_eqn(eqn)
else:
# Transform primitives don't have custom handlers, so we check for them here
# to purge the stored ops in self.previous_ops
- if eqn.primitive.name.endswith("_transform"):
+ if getattr(eqn.primitive, "prim_type", "") == "transform":
self.interpret_all_previous_ops()
invals = [self.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)
diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py
index 05f6b196440..2552770a159 100644
--- a/pennylane/workflow/_capture_qnode.py
+++ b/pennylane/workflow/_capture_qnode.py
@@ -107,7 +107,6 @@
"""
from copy import copy
-from dataclasses import asdict
from functools import partial
from numbers import Number
from warnings import warn
@@ -117,6 +116,7 @@
import pennylane as qml
from pennylane.capture import FlatFn
+from pennylane.capture.custom_primitives import QmlPrimitive
from pennylane.typing import TensorLike
@@ -177,8 +177,9 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(
return shapes
-qnode_prim = jax.core.Primitive("qnode")
+qnode_prim = QmlPrimitive("qnode")
qnode_prim.multiple_results = True
+qnode_prim.prim_type = "higher_order"
# pylint: disable=too-many-arguments, unused-argument
@@ -249,7 +250,6 @@ def _qnode_batching_rule(
"using parameter broadcasting to a quantum operation that supports batching.",
UserWarning,
)
-
# To resolve this ambiguity, we might add more properties to the AbstractOperator
# class to indicate which operators support batching and check them here.
# As above, at this stage we raise a warning and give the user full flexibility.
@@ -277,15 +277,43 @@ def _qnode_batching_rule(
return result, (0,) * len(result)
+### JVP CALCULATION #########################################################
+# This structure will change as we add more diff methods
+
+
def _make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan
-def _qnode_jvp(args, tangents, **impl_kwargs):
+def _backprop(args, tangents, **impl_kwargs):
tangents = tuple(map(_make_zero, tangents, args))
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)
+diff_method_map = {"backprop": _backprop}
+
+
+def _resolve_diff_method(diff_method: str, device) -> str:
+ # check if best is backprop
+ if diff_method == "best":
+ config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax")
+ diff_method = device.setup_execution_config(config).gradient_method
+
+ if diff_method not in diff_method_map:
+ raise NotImplementedError(f"diff_method {diff_method} not yet implemented.")
+
+ return diff_method
+
+
+def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs):
+ diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device)
+ return diff_method_map[diff_method](
+ args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs
+ )
+
+
+### END JVP CALCULATION #######################################################
+
ad.primitive_jvps[qnode_prim] = _qnode_jvp
batching.primitive_batchers[qnode_prim] = _qnode_batching_rule
@@ -375,8 +403,7 @@ def f(x):
qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args)
execute_kwargs = copy(qnode.execute_kwargs)
- mcm_config = asdict(execute_kwargs.pop("mcm_config"))
- qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
+ qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs}
flat_args = jax.tree_util.tree_leaves(args)
diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py
index cfe93f322d2..fb9ac5ef708 100644
--- a/pennylane/workflow/execution.py
+++ b/pennylane/workflow/execution.py
@@ -47,14 +47,16 @@ def execute(
diff_method: Optional[Union[Callable, SupportedDiffMethods, TransformDispatcher]] = None,
interface: Optional[InterfaceLike] = Interface.AUTO,
*,
+ transform_program: TransformProgram = None,
grad_on_execution: Literal[bool, "best"] = "best",
cache: Union[None, bool, dict, Cache] = True,
cachesize: int = 10000,
max_diff: int = 1,
device_vjp: Union[bool, None] = False,
+ postselect_mode=None,
+ mcm_method=None,
gradient_kwargs: dict = None,
- transform_program: TransformProgram = None,
- mcm_config: "qml.devices.MCMConfig" = None,
+ mcm_config: "qml.devices.MCMConfig" = "unset",
config="unset",
inner_transform="unset",
) -> ResultBatch:
@@ -86,10 +88,18 @@ def execute(
(classical) computational overhead during the backward pass.
device_vjp=False (Optional[bool]): whether or not to use the device-provided Jacobian
product if it is available.
- mcm_config (dict): Dictionary containing configuration options for handling
- mid-circuit measurements.
+ postselect_mode (str): Configuration for handling shots with mid-circuit measurement
+ postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
+ keep the same number of shots. Default is ``None``.
+ mcm_method (str): Strategy to use when executing circuits with mid-circuit measurements.
+ ``"deferred"`` is ignored. If mid-circuit measurements are found in the circuit,
+ the device will use ``"tree-traversal"`` if specified and the ``"one-shot"`` method
+ otherwise. For usage details, please refer to the
+ :doc:`dynamic quantum circuits page `.
gradient_kwargs (dict): dictionary of keyword arguments to pass when
determining the gradients of tapes.
+ mcm_config="unset": **DEPRECATED**. This keyword argument has been replaced by ``postselect_mode``
+ and ``mcm_method`` and will be removed in v0.42.
config="unset": **DEPRECATED**. This keyword argument has been deprecated and
will be removed in v0.42.
inner_transform="unset": **DEPRECATED**. This keyword argument has been deprecated
@@ -173,6 +183,14 @@ def cost_fn(params, x):
qml.PennyLaneDeprecationWarning,
)
+ if mcm_config != "unset":
+ warn(
+ "The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.",
+ qml.PennyLaneDeprecationWarning,
+ )
+ mcm_method = mcm_config.mcm_method
+ postselect_mode = mcm_config.postselect_mode
+
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
(
@@ -209,7 +227,7 @@ def cost_fn(params, x):
gradient_method=diff_method,
grad_on_execution=None if grad_on_execution == "best" else grad_on_execution,
use_device_jacobian_product=device_vjp,
- mcm_config=mcm_config or {},
+ mcm_config=qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method),
gradient_keyword_arguments=gradient_kwargs or {},
derivative_order=max_diff,
)
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index 29430c9a645..ea1a9e33900 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -107,6 +107,10 @@ def _to_qfunc_output_type(
return qml.pytrees.unflatten(results, qfunc_output_structure)
+def _validate_mcm_config(postselect_mode: str, mcm_method: str) -> None:
+ qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method)
+
+
def _validate_gradient_kwargs(gradient_kwargs: dict) -> None:
for kwarg in gradient_kwargs:
if kwarg == "expansion_strategy":
@@ -133,7 +137,8 @@ def _validate_gradient_kwargs(gradient_kwargs: dict) -> None:
elif kwarg not in qml.gradients.SUPPORTED_GRADIENT_KWARGS:
warnings.warn(
f"Received gradient_kwarg {kwarg}, which is not included in the list of "
- "standard qnode gradient kwargs."
+ "standard qnode gradient kwargs. Please specify all gradient kwargs through "
+ "the gradient_kwargs argument as a dictionary."
)
@@ -284,9 +289,7 @@ class QNode:
as the name suggests. If not provided,
the device will determine the best choice automatically. For usage details, please refer to the
:doc:`dynamic quantum circuits page `.
-
- Keyword Args:
- **kwargs: Any additional keyword arguments provided are passed to the differentiation
+ gradient_kwargs (dict): A dictionary of keyword arguments that are passed to the differentiation
method. Please refer to the :mod:`qml.gradients <.gradients>` module for details
on supported options for your chosen gradient transform.
@@ -505,10 +508,12 @@ def __init__(
device_vjp: Union[None, bool] = False,
postselect_mode: Literal[None, "hw-like", "fill-shots"] = None,
mcm_method: Literal[None, "deferred", "one-shot", "tree-traversal"] = None,
- **gradient_kwargs,
+ gradient_kwargs: Optional[dict] = None,
+ **kwargs,
):
self._init_args = locals()
del self._init_args["self"]
+ del self._init_args["kwargs"]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -536,7 +541,16 @@ def __init__(
if not isinstance(device, qml.devices.Device):
device = qml.devices.LegacyDeviceFacade(device)
+ gradient_kwargs = gradient_kwargs or {}
+ if kwargs:
+ if any(k in qml.gradients.SUPPORTED_GRADIENT_KWARGS for k in list(kwargs.keys())):
+ warnings.warn(
+ f"Specifying gradient keyword arguments {list(kwargs.keys())} is deprecated and will be removed in v0.42. Instead, please specify all arguments in the gradient_kwargs argument.",
+ qml.PennyLaneDeprecationWarning,
+ )
+ gradient_kwargs |= kwargs
_validate_gradient_kwargs(gradient_kwargs)
+
if "shots" in inspect.signature(func).parameters:
warnings.warn(
"Detected 'shots' as an argument to the given quantum function. "
@@ -553,17 +567,18 @@ def __init__(
self.device = device
self._interface = get_canonical_interface_name(interface)
self.diff_method = diff_method
- mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode)
cache = (max_diff > 1) if cache == "auto" else cache
# execution keyword arguments
+ _validate_mcm_config(postselect_mode, mcm_method)
self.execute_kwargs = {
"grad_on_execution": grad_on_execution,
"cache": cache,
"cachesize": cachesize,
"max_diff": max_diff,
"device_vjp": device_vjp,
- "mcm_config": mcm_config,
+ "postselect_mode": postselect_mode,
+ "mcm_method": mcm_method,
}
# internal data attributes
@@ -676,16 +691,19 @@ def circuit(x):
tensor(0.5403, dtype=torch.float64)
"""
if not kwargs:
- valid_params = (
- set(self._init_args.copy().pop("gradient_kwargs"))
- | qml.gradients.SUPPORTED_GRADIENT_KWARGS
- )
+ valid_params = set(self._init_args.copy()) | qml.gradients.SUPPORTED_GRADIENT_KWARGS
raise ValueError(
f"Must specify at least one configuration property to update. Valid properties are: {valid_params}."
)
original_init_args = self._init_args.copy()
- gradient_kwargs = original_init_args.pop("gradient_kwargs")
- original_init_args.update(gradient_kwargs)
+ # gradient_kwargs defaults to None
+ original_init_args["gradient_kwargs"] = original_init_args["gradient_kwargs"] or {}
+ # nested dictionary update
+ new_gradient_kwargs = kwargs.pop("gradient_kwargs", {})
+ old_gradient_kwargs = original_init_args.get("gradient_kwargs").copy()
+ old_gradient_kwargs.update(new_gradient_kwargs)
+ kwargs["gradient_kwargs"] = old_gradient_kwargs
+
original_init_args.update(kwargs)
updated_qn = QNode(**original_init_args)
# pylint: disable=protected-access
diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py
index 0ac083a768c..680d11f083e 100644
--- a/tests/capture/test_capture_qnode.py
+++ b/tests/capture/test_capture_qnode.py
@@ -14,7 +14,6 @@
"""
Tests for capturing a qnode into jaxpr.
"""
-from dataclasses import asdict
from functools import partial
# pylint: disable=protected-access
@@ -130,7 +129,6 @@ def circuit(x):
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best"}
expected_kwargs.update(circuit.execute_kwargs)
- expected_kwargs.update(asdict(expected_kwargs.pop("mcm_config")))
assert eqn0.params["qnode_kwargs"] == expected_kwargs
qfunc_jaxpr = eqn0.params["qfunc_jaxpr"]
@@ -339,18 +337,61 @@ def circuit(x):
assert list(out.keys()) == ["a", "b"]
-def test_qnode_jvp():
- """Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule."""
+class TestDifferentiation:
- @qml.qnode(qml.device("default.qubit", wires=1))
- def circuit(x):
- qml.RX(x, 0)
- return qml.expval(qml.Z(0))
+ def test_error_backprop_unsupported(self):
+ """Test an error is raised with backprop if the device does not support it."""
+
+ # pylint: disable=too-few-public-methods
+ class DummyDev(qml.devices.Device):
+
+ def execute(self, *_, **__):
+ return 0
+
+ with pytest.raises(qml.QuantumFunctionError, match="does not support backprop"):
+
+ @qml.qnode(DummyDev(wires=2), diff_method="backprop")
+ def _(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ def test_error_unsupported_diff_method(self):
+ """Test an error is raised for a non-backprop diff method."""
+
+ @qml.qnode(qml.device("default.qubit", wires=2), diff_method="parameter-shift")
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ with pytest.raises(
+ NotImplementedError, match="diff_method parameter-shift not yet implemented."
+ ):
+ jax.grad(circuit)(0.5)
+
+ @pytest.mark.parametrize("diff_method", ("best", "backprop"))
+ def test_default_qubit_backprop(self, diff_method):
+ """Test that JAX can compute the JVP of the QNode primitive via a registered JVP rule."""
+
+ @qml.qnode(qml.device("default.qubit", wires=1), diff_method=diff_method)
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
+
+ x = 0.9
+ xt = -0.6
+ jvp = jax.jvp(circuit, (x,), (xt,))
+ assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt))
+
+ def test_no_gradients_with_lightning(self):
+ """Test that we get an error if we try and differentiate a lightning execution."""
+
+ @qml.qnode(qml.device("lightning.qubit", wires=2))
+ def circuit(x):
+ qml.RX(x, 0)
+ return qml.expval(qml.Z(0))
- x = 0.9
- xt = -0.6
- jvp = jax.jvp(circuit, (x,), (xt,))
- assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt))
+ with pytest.raises(NotImplementedError, match=r"diff_method adjoint not yet implemented"):
+ jax.grad(circuit)(0.5)
def test_qnode_jit():
diff --git a/tests/capture/test_custom_primitives.py b/tests/capture/test_custom_primitives.py
new file mode 100644
index 00000000000..3d35e3e57e4
--- /dev/null
+++ b/tests/capture/test_custom_primitives.py
@@ -0,0 +1,48 @@
+# Copyright 2024 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Unit tests for PennyLane custom primitives.
+"""
+# pylint: disable=wrong-import-position
+import pytest
+
+jax = pytest.importorskip("jax")
+
+from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive
+
+pytestmark = pytest.mark.jax
+
+
+def test_qml_primitive_prim_type_default():
+ """Test that the default prim_type of a QmlPrimitive is set correctly."""
+ prim = QmlPrimitive("primitive")
+ assert prim._prim_type == PrimitiveType("default") # pylint: disable=protected-access
+ assert prim.prim_type == "default"
+
+
+@pytest.mark.parametrize("cast_in_enum", [True, False])
+@pytest.mark.parametrize("prim_type", ["operator", "measurement", "transform", "higher_order"])
+def test_qml_primitive_prim_type_setter(prim_type, cast_in_enum):
+ """Test that the QmlPrimitive.prim_type setter works correctly"""
+ prim = QmlPrimitive("primitive")
+ prim.prim_type = PrimitiveType(prim_type) if cast_in_enum else prim_type
+ assert prim._prim_type == PrimitiveType(prim_type) # pylint: disable=protected-access
+ assert prim.prim_type == prim_type
+
+
+def test_qml_primitive_prim_type_setter_invalid():
+ """Test that setting an invalid prim_type raises an error"""
+ prim = QmlPrimitive("primitive")
+ with pytest.raises(ValueError, match="not a valid PrimitiveType"):
+ prim.prim_type = "blah"
diff --git a/tests/capture/test_switches.py b/tests/capture/test_switches.py
index 52f50321740..72a170205e7 100644
--- a/tests/capture/test_switches.py
+++ b/tests/capture/test_switches.py
@@ -32,10 +32,15 @@ def test_switches_with_jax():
def test_switches_without_jax():
"""Test switches and status reporting function."""
-
- assert qml.capture.enabled() is False
- with pytest.raises(ImportError, match="plxpr requires JAX to be installed."):
- qml.capture.enable()
- assert qml.capture.enabled() is False
- assert qml.capture.disable() is None
- assert qml.capture.enabled() is False
+ # We want to skip the test if jax is installed
+ try:
+ # pylint: disable=import-outside-toplevel, unused-import
+ import jax
+ except ImportError:
+
+ assert qml.capture.enabled() is False
+ with pytest.raises(ImportError, match="plxpr requires JAX to be installed."):
+ qml.capture.enable()
+ assert qml.capture.enabled() is False
+ assert qml.capture.disable() is None
+ assert qml.capture.enabled() is False
diff --git a/tests/devices/default_qubit/test_default_qubit_native_mcm.py b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
index 42faaa10809..32367c455f4 100644
--- a/tests/devices/default_qubit/test_default_qubit_native_mcm.py
+++ b/tests/devices/default_qubit/test_default_qubit_native_mcm.py
@@ -389,7 +389,10 @@ def func(x, y, z):
results1 = func1(*params)
jaxpr = str(jax.make_jaxpr(func)(*params))
- assert "pure_callback" not in jaxpr
+ # will change once we solve the compilation overhead issue
+ # assert "pure_callback" not in jaxpr
+ # TODO: [sc-82874]
+ assert "pure_callback" in jaxpr
func2 = jax.jit(func)
results2 = func2(*params)
diff --git a/tests/devices/default_qubit/test_default_qubit_preprocessing.py b/tests/devices/default_qubit/test_default_qubit_preprocessing.py
index 59a9098f7ce..31a2c978745 100644
--- a/tests/devices/default_qubit/test_default_qubit_preprocessing.py
+++ b/tests/devices/default_qubit/test_default_qubit_preprocessing.py
@@ -142,15 +142,17 @@ def circuit(x):
assert dev.tracker.totals["execute_and_derivative_batches"] == 1
@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
- def test_not_convert_to_numpy_with_jax(self, interface):
+ def test_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""
-
+ # separate test so we can easily update it once we solve the
+ # compilation overhead issue
+ # TODO: [sc-82874]
dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
- assert not processed.convert_to_numpy
+ assert processed.convert_to_numpy
def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
diff --git a/tests/gradients/core/test_hamiltonian_gradient.py b/tests/gradients/core/test_hamiltonian_gradient.py
index 1bcb4bfc4fe..9474faf39a0 100644
--- a/tests/gradients/core/test_hamiltonian_gradient.py
+++ b/tests/gradients/core/test_hamiltonian_gradient.py
@@ -12,10 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the gradients.hamiltonian module."""
+import pytest
+
import pennylane as qml
from pennylane.gradients.hamiltonian_grad import hamiltonian_grad
+def test_hamiltonian_grad_deprecation():
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated"
+ ):
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.RY(0.3, wires=0)
+ qml.RX(0.5, wires=1)
+ qml.CNOT(wires=[0, 1])
+ qml.expval(qml.Hamiltonian([-1.5, 2.0], [qml.PauliZ(0), qml.PauliZ(1)]))
+
+ tape = qml.tape.QuantumScript.from_queue(q)
+ tape.trainable_params = {2, 3}
+ hamiltonian_grad(tape, idx=0)
+
+
def test_behaviour():
"""Test that the function behaves as expected."""
@@ -29,10 +46,16 @@ def test_behaviour():
tape = qml.tape.QuantumScript.from_queue(q)
tape.trainable_params = {2, 3}
- tapes, processing_fn = hamiltonian_grad(tape, idx=0)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated"
+ ):
+ tapes, processing_fn = hamiltonian_grad(tape, idx=0)
res1 = processing_fn(dev.execute(tapes))
- tapes, processing_fn = hamiltonian_grad(tape, idx=1)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated"
+ ):
+ tapes, processing_fn = hamiltonian_grad(tape, idx=1)
res2 = processing_fn(dev.execute(tapes))
with qml.queuing.AnnotatedQueue() as q1:
diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py
index c684d53b0c6..b06b7c711d0 100644
--- a/tests/gradients/core/test_pulse_gradient.py
+++ b/tests/gradients/core/test_pulse_gradient.py
@@ -1385,7 +1385,10 @@ def test_simple_qnode_expval(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
@qml.qnode(
- dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
+ dev,
+ interface="jax",
+ diff_method=stoch_pulse_grad,
+ gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
@@ -1415,7 +1418,10 @@ def test_simple_qnode_expval_two_evolves(self, num_split_times, shots, tol, seed
ham_y = qml.pulse.constant * qml.PauliX(0)
@qml.qnode(
- dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
+ dev,
+ interface="jax",
+ diff_method=stoch_pulse_grad,
+ gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_x)(params[0], T_x)
@@ -1444,7 +1450,10 @@ def test_simple_qnode_probs(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
@qml.qnode(
- dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
+ dev,
+ interface="jax",
+ diff_method=stoch_pulse_grad,
+ gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
@@ -1471,7 +1480,10 @@ def test_simple_qnode_probs_expval(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
@qml.qnode(
- dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
+ dev,
+ interface="jax",
+ diff_method=stoch_pulse_grad,
+ gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
@@ -1490,6 +1502,7 @@ def circuit(params):
assert qml.math.allclose(j[0], e, atol=tol, rtol=0.0)
jax.clear_caches()
+ @pytest.mark.xfail # TODO: [sc-82874]
@pytest.mark.parametrize("num_split_times", [1, 2])
@pytest.mark.parametrize("time_interface", ["python", "numpy", "jax"])
def test_simple_qnode_jit(self, num_split_times, time_interface):
@@ -1503,7 +1516,10 @@ def test_simple_qnode_jit(self, num_split_times, time_interface):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
@qml.qnode(
- dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
+ dev,
+ interface="jax",
+ diff_method=stoch_pulse_grad,
+ gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params, T=None):
qml.evolve(ham_single_q_const)(params, T)
@@ -1542,8 +1558,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=stoch_pulse_grad,
- num_split_times=num_split_times,
- sampler_seed=seed,
+ gradient_kwargs={"num_split_times": num_split_times, "sampler_seed": seed},
)
qnode_backprop = qml.QNode(ansatz, dev, interface="jax")
@@ -1574,8 +1589,7 @@ def test_qnode_probs_expval_broadcasting(self, num_split_times, shots, tol, seed
dev,
interface="jax",
diff_method=stoch_pulse_grad,
- num_split_times=num_split_times,
- use_broadcasting=True,
+ gradient_kwargs={"num_split_times": num_split_times, "use_broadcasting": True},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
@@ -1619,18 +1633,22 @@ def ansatz(params):
dev,
interface="jax",
diff_method=stoch_pulse_grad,
- num_split_times=num_split_times,
- use_broadcasting=True,
- sampler_seed=seed,
+ gradient_kwargs={
+ "num_split_times": num_split_times,
+ "use_broadcasting": True,
+ "sampler_seed": seed,
+ },
)
circuit_no_bc = qml.QNode(
ansatz,
dev,
interface="jax",
diff_method=stoch_pulse_grad,
- num_split_times=num_split_times,
- use_broadcasting=False,
- sampler_seed=seed,
+ gradient_kwargs={
+ "num_split_times": num_split_times,
+ "use_broadcasting": False,
+ "sampler_seed": seed,
+ },
)
params = [jnp.array(0.4)]
jac_bc = jax.jacobian(circuit_bc)(params)
@@ -1684,9 +1702,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
- num_split_times=7,
- use_broadcasting=True,
- sampler_seed=seed,
+ gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True},
)
cost_jax = qml.QNode(ansatz, dev, interface="jax")
params = (0.42,)
@@ -1729,9 +1745,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
- num_split_times=7,
- use_broadcasting=True,
- sampler_seed=seed,
+ gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True},
)
cost_jax = qml.QNode(ansatz, dev, interface="jax")
diff --git a/tests/gradients/finite_diff/test_spsa_gradient.py b/tests/gradients/finite_diff/test_spsa_gradient.py
index 1bd2a198bca..413bc8ff6ec 100644
--- a/tests/gradients/finite_diff/test_spsa_gradient.py
+++ b/tests/gradients/finite_diff/test_spsa_gradient.py
@@ -161,7 +161,7 @@ def test_invalid_sampler_rng(self):
"""Tests that if sampler_rng has an unexpected type, an error is raised."""
dev = qml.device("default.qubit", wires=1)
- @qml.qnode(dev, diff_method="spsa", sampler_rng="foo")
+ @qml.qnode(dev, diff_method="spsa", gradient_kwargs={"sampler_rng": "foo"})
def circuit(param):
qml.RX(param, wires=0)
return qml.expval(qml.PauliZ(0))
diff --git a/tests/gradients/parameter_shift/test_cv_gradients.py b/tests/gradients/parameter_shift/test_cv_gradients.py
index c9709753283..55955396ab7 100644
--- a/tests/gradients/parameter_shift/test_cv_gradients.py
+++ b/tests/gradients/parameter_shift/test_cv_gradients.py
@@ -268,7 +268,11 @@ def qf(x, y):
grad_F = jax.grad(qf)(*par)
- @qml.qnode(device=gaussian_dev, diff_method="parameter-shift", force_order2=True)
+ @qml.qnode(
+ device=gaussian_dev,
+ diff_method="parameter-shift",
+ gradient_kwargs={"force_order2": True},
+ )
def qf2(x, y):
qml.Displacement(0.5, 0, wires=[0])
qml.Squeezing(x, 0, wires=[0])
diff --git a/tests/gradients/parameter_shift/test_parameter_shift.py b/tests/gradients/parameter_shift/test_parameter_shift.py
index f35a4b4fc95..4b741c5a66b 100644
--- a/tests/gradients/parameter_shift/test_parameter_shift.py
+++ b/tests/gradients/parameter_shift/test_parameter_shift.py
@@ -3473,58 +3473,6 @@ def test_trainable_coeffs(self, tol, broadcast):
assert np.allclose(res[0], expected[0], atol=tol, rtol=0)
assert np.allclose(res[1], expected[1], atol=tol, rtol=0)
- def test_multiple_hamiltonians(self, tol, broadcast):
- """Test multiple trainable Hamiltonian coefficients"""
- dev = qml.device("default.qubit", wires=2)
-
- obs = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)]
- coeffs = np.array([0.1, 0.2, 0.3])
- a, b, c = coeffs
- H1 = qml.Hamiltonian(coeffs, obs)
-
- obs = [qml.PauliZ(0)]
- coeffs = np.array([0.7])
- d = coeffs[0]
- H2 = qml.Hamiltonian(coeffs, obs)
-
- weights = np.array([0.4, 0.5])
- x, y = weights
-
- with qml.queuing.AnnotatedQueue() as q:
- qml.RX(weights[0], wires=0)
- qml.RY(weights[1], wires=1)
- qml.CNOT(wires=[0, 1])
- qml.expval(H1)
- qml.expval(H2)
-
- tape = qml.tape.QuantumScript.from_queue(q)
- tape.trainable_params = {0, 1, 2, 4, 5}
-
- res = dev.execute([tape])
- expected = [-c * np.sin(x) * np.sin(y) + np.cos(x) * (a + b * np.sin(y)), d * np.cos(x)]
- assert np.allclose(res, expected, atol=tol, rtol=0)
-
- tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast)
- # two shifts per rotation gate (in one batched tape if broadcasting),
- # one circuit per trainable H term
- assert len(tapes) == 2 * (1 if broadcast else 2)
-
- res = fn(dev.execute(tapes))
- assert isinstance(res, tuple)
- assert len(res) == 2
- assert len(res[0]) == 2
- assert len(res[1]) == 2
-
- expected = [
- [
- -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)),
- b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x),
- ],
- [-d * np.sin(x), 0],
- ]
-
- assert np.allclose(np.stack(res), expected, atol=tol, rtol=0)
-
@staticmethod
def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False):
"""Cost function for gradient tests"""
@@ -3547,95 +3495,8 @@ def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False):
jac = fn(dev.execute(tapes))
return jac
- @staticmethod
- def cost_fn_expected(weights, coeffs1, coeffs2):
- """Analytic jacobian of cost_fn above"""
- a, b, c = coeffs1
- d = coeffs2[0]
- x, y = weights
- return [
- [
- -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)),
- b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x),
- ],
- [-d * np.sin(x), 0],
- ]
-
- @pytest.mark.autograd
- def test_autograd(self, tol, broadcast):
- """Test gradient of multiple trainable Hamiltonian coefficients
- using autograd"""
- coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True)
- coeffs2 = np.array([0.7], requires_grad=True)
- weights = np.array([0.4, 0.5], requires_grad=True)
- dev = qml.device("default.qubit", wires=2)
-
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- expected = self.cost_fn_expected(weights, coeffs1, coeffs2)
- assert np.allclose(res, np.array(expected), atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero
- # ---
- # res = qml.jacobian(self.cost_fn)(weights, coeffs1, coeffs2, dev=dev)
- # assert np.allclose(res[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
- # assert np.allclose(res[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
- @pytest.mark.tf
- def test_tf(self, tol, broadcast):
- """Test gradient of multiple trainable Hamiltonian coefficients
- using tf"""
- import tensorflow as tf
-
- coeffs1 = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64)
- coeffs2 = tf.Variable([0.7], dtype=tf.float64)
- weights = tf.Variable([0.4, 0.5], dtype=tf.float64)
-
- dev = qml.device("default.qubit", wires=2)
-
- with tf.GradientTape() as _:
- jac = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
-
- expected = self.cost_fn_expected(weights.numpy(), coeffs1.numpy(), coeffs2.numpy())
- assert np.allclose(jac[0], np.array(expected)[0], atol=tol, rtol=0)
- assert np.allclose(jac[1], np.array(expected)[1], atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero.
- # When activating the following, rename the GradientTape above from _ to t
- # ---
- # hess = t.jacobian(jac, [coeffs1, coeffs2])
- # assert np.allclose(hess[0][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
- # assert np.allclose(hess[1][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
- @pytest.mark.torch
- def test_torch(self, tol, broadcast):
- """Test gradient of multiple trainable Hamiltonian coefficients
- using torch"""
- import torch
-
- coeffs1 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64, requires_grad=True)
- coeffs2 = torch.tensor([0.7], dtype=torch.float64, requires_grad=True)
- weights = torch.tensor([0.4, 0.5], dtype=torch.float64, requires_grad=True)
-
- dev = qml.device("default.qubit", wires=2)
-
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- expected = self.cost_fn_expected(
- weights.detach().numpy(), coeffs1.detach().numpy(), coeffs2.detach().numpy()
- )
- res = tuple(tuple(_r.detach() for _r in r) for r in res)
- assert np.allclose(res, expected, atol=tol, rtol=0)
-
- # second derivative wrt to Hamiltonian coefficients should be zero
- # hess = torch.autograd.functional.jacobian(
- # lambda *args: self.cost_fn(*args, dev, broadcast), (weights, coeffs1, coeffs2)
- # )
- # assert np.allclose(hess[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
- # assert np.allclose(hess[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
@pytest.mark.jax
- def test_jax(self, tol, broadcast):
+ def test_jax(self, broadcast):
"""Test gradient of multiple trainable Hamiltonian coefficients
using JAX"""
import jax
@@ -3647,19 +3508,11 @@ def test_jax(self, tol, broadcast):
weights = jnp.array([0.4, 0.5])
dev = qml.device("default.qubit", wires=2)
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- expected = self.cost_fn_expected(weights, coeffs1, coeffs2)
- assert np.allclose(np.array(res)[:, :2], np.array(expected), atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero
- # ---
- # second derivative wrt to Hamiltonian coefficients should be zero
- # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast)
- # assert np.allclose(res[:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
-
- # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast)
- # assert np.allclose(res[:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning, match="The 'hamiltonian_grad' function is deprecated"
+ ):
+ with pytest.warns(UserWarning, match="Please use qml.gradients.split_to_single_terms"):
+ self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
@pytest.mark.autograd
diff --git a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py
index 34478ccecd0..e3151735e98 100644
--- a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py
+++ b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py
@@ -1970,12 +1970,12 @@ def expval(self, observable, **kwargs):
dev = DeviceSupporingSpecialObservable(wires=1, shots=None)
- @qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast)
+ @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast})
def qnode(x):
qml.RY(x, wires=0)
return qml.expval(SpecialObservable(wires=0))
- @qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast)
+ @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast})
def reference_qnode(x):
qml.RY(x, wires=0)
return qml.expval(qml.PauliZ(wires=0))
@@ -2128,221 +2128,6 @@ def test_trainable_coeffs(self, broadcast, tol):
for r in res:
assert qml.math.allclose(r, expected, atol=shot_vec_tol)
- @pytest.mark.xfail(reason="TODO")
- def test_multiple_hamiltonians(self, mocker, broadcast, tol):
- """Test multiple trainable Hamiltonian coefficients"""
- shot_vec = many_shots_shot_vector
- dev = qml.device("default.qubit", wires=2, shots=shot_vec)
- spy = mocker.spy(qml.gradients, "hamiltonian_grad")
-
- obs = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)]
- coeffs = np.array([0.1, 0.2, 0.3])
- a, b, c = coeffs
- H1 = qml.Hamiltonian(coeffs, obs)
-
- obs = [qml.PauliZ(0)]
- coeffs = np.array([0.7])
- d = coeffs[0]
- H2 = qml.Hamiltonian(coeffs, obs)
-
- weights = np.array([0.4, 0.5])
- x, y = weights
-
- with qml.queuing.AnnotatedQueue() as q:
- qml.RX(weights[0], wires=0)
- qml.RY(weights[1], wires=1)
- qml.CNOT(wires=[0, 1])
- qml.expval(H1)
- qml.expval(H2)
-
- tape = qml.tape.QuantumScript.from_queue(q, shots=shot_vec)
- tape.trainable_params = {0, 1, 2, 4, 5}
-
- res = dev.execute([tape])
- expected = [-c * np.sin(x) * np.sin(y) + np.cos(x) * (a + b * np.sin(y)), d * np.cos(x)]
- assert np.allclose(res, expected, atol=tol, rtol=0)
-
- if broadcast:
- with pytest.raises(
- NotImplementedError, match="Broadcasting with multiple measurements"
- ):
- qml.gradients.param_shift(tape, broadcast=broadcast)
- return
-
- tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast)
- # two shifts per rotation gate, one circuit per trainable H term
- assert len(tapes) == 2 * 2 + 3
- spy.assert_called()
-
- res = fn(dev.execute(tapes))
- assert isinstance(res, tuple)
- assert len(res) == 2
- assert len(res[0]) == 5
- assert len(res[1]) == 5
-
- expected = [
- [
- -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)),
- b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x),
- np.cos(x),
- -(np.sin(x) * np.sin(y)),
- 0,
- ],
- [-d * np.sin(x), 0, 0, 0, np.cos(x)],
- ]
-
- assert np.allclose(np.stack(res), expected, atol=tol, rtol=0)
-
- @staticmethod
- def cost_fn(weights, coeffs1, coeffs2, dev=None, broadcast=False):
- """Cost function for gradient tests"""
- obs1 = [qml.PauliZ(0), qml.PauliZ(0) @ qml.PauliX(1), qml.PauliY(0)]
- H1 = qml.Hamiltonian(coeffs1, obs1)
-
- obs2 = [qml.PauliZ(0)]
- H2 = qml.Hamiltonian(coeffs2, obs2)
-
- with qml.queuing.AnnotatedQueue() as q:
- qml.RX(weights[0], wires=0)
- qml.RY(weights[1], wires=1)
- qml.CNOT(wires=[0, 1])
- qml.expval(H1)
- qml.expval(H2)
-
- tape = qml.tape.QuantumScript.from_queue(q, shots=dev.shots)
- tape.trainable_params = {0, 1, 2, 3, 4, 5}
- tapes, fn = qml.gradients.param_shift(tape, broadcast=broadcast)
- return fn(dev.execute(tapes))
-
- @staticmethod
- def cost_fn_expected(weights, coeffs1, coeffs2):
- """Analytic jacobian of cost_fn above"""
- a, b, c = coeffs1
- d = coeffs2[0]
- x, y = weights
- return [
- [
- -c * np.cos(x) * np.sin(y) - np.sin(x) * (a + b * np.sin(y)),
- b * np.cos(x) * np.cos(y) - c * np.cos(y) * np.sin(x),
- np.cos(x),
- np.cos(x) * np.sin(y),
- -(np.sin(x) * np.sin(y)),
- 0,
- ],
- [-d * np.sin(x), 0, 0, 0, 0, np.cos(x)],
- ]
-
- @pytest.mark.xfail(reason="TODO")
- @pytest.mark.autograd
- def test_autograd(self, broadcast, tol):
- """Test gradient of multiple trainable Hamiltonian coefficients
- using autograd"""
- coeffs1 = np.array([0.1, 0.2, 0.3], requires_grad=True)
- coeffs2 = np.array([0.7], requires_grad=True)
- weights = np.array([0.4, 0.5], requires_grad=True)
- shot_vec = many_shots_shot_vector
- dev = qml.device("default.qubit", wires=2, shots=shot_vec)
-
- if broadcast:
- with pytest.raises(
- NotImplementedError, match="Broadcasting with multiple measurements"
- ):
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- return
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- expected = self.cost_fn_expected(weights, coeffs1, coeffs2)
- assert np.allclose(res, np.array(expected), atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero
- # ---
- # res = qml.jacobian(self.cost_fn)(weights, coeffs1, coeffs2, dev=dev)
- # assert np.allclose(res[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
- # assert np.allclose(res[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
- @pytest.mark.xfail(reason="TODO")
- @pytest.mark.tf
- def test_tf(self, broadcast, tol):
- """Test gradient of multiple trainable Hamiltonian coefficients using tf"""
- import tensorflow as tf
-
- coeffs1 = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64)
- coeffs2 = tf.Variable([0.7], dtype=tf.float64)
- weights = tf.Variable([0.4, 0.5], dtype=tf.float64)
-
- shot_vec = many_shots_shot_vector
- dev = qml.device("default.qubit", wires=2, shots=shot_vec)
-
- with tf.GradientTape() as _:
- jac = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
-
- expected = self.cost_fn_expected(weights.numpy(), coeffs1.numpy(), coeffs2.numpy())
- assert np.allclose(jac[0], np.array(expected)[0], atol=tol, rtol=0)
- assert np.allclose(jac[1], np.array(expected)[1], atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero
- # When activating the following, rename the GradientTape above from _ to t
- # ---
- # hess = t.jacobian(jac, [coeffs1, coeffs2])
- # assert np.allclose(hess[0][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
- # assert np.allclose(hess[1][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
- @pytest.mark.torch
- def test_torch(self, broadcast, tol):
- """Test gradient of multiple trainable Hamiltonian coefficients
- using torch"""
- import torch
-
- coeffs1 = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64, requires_grad=True)
- coeffs2 = torch.tensor([0.7], dtype=torch.float64, requires_grad=True)
- weights = torch.tensor([0.4, 0.5], dtype=torch.float64, requires_grad=True)
-
- dev = qml.device("default.qubit", wires=2)
-
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- expected = self.cost_fn_expected(
- weights.detach().numpy(), coeffs1.detach().numpy(), coeffs2.detach().numpy()
- )
- for actual, _expected in zip(res, expected):
- for val, exp_val in zip(actual, _expected):
- assert qml.math.allclose(val.detach(), exp_val, atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero
- # hess = torch.autograd.functional.jacobian(
- # lambda *args: self.cost_fn(*args, dev, broadcast), (weights, coeffs1, coeffs2)
- # )
- # assert np.allclose(hess[1][:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
- # assert np.allclose(hess[2][:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
- @pytest.mark.jax
- def test_jax(self, broadcast, tol):
- """Test gradient of multiple trainable Hamiltonian coefficients
- using JAX"""
- import jax
-
- jnp = jax.numpy
-
- coeffs1 = jnp.array([0.1, 0.2, 0.3])
- coeffs2 = jnp.array([0.7])
- weights = jnp.array([0.4, 0.5])
- dev = qml.device("default.qubit", wires=2)
-
- res = self.cost_fn(weights, coeffs1, coeffs2, dev, broadcast)
- expected = self.cost_fn_expected(weights, coeffs1, coeffs2)
- assert np.allclose(res, np.array(expected), atol=tol, rtol=0)
-
- # TODO: test when Hessians are supported with the new return types
- # second derivative wrt to Hamiltonian coefficients should be zero
- # ---
- # second derivative wrt to Hamiltonian coefficients should be zero
- # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast)
- # assert np.allclose(res[:, 2:5], np.zeros([2, 3, 3]), atol=tol, rtol=0)
-
- # res = jax.jacobian(self.cost_fn, argnums=1)(weights, coeffs1, coeffs2, dev, broadcast)
- # assert np.allclose(res[:, -1], np.zeros([2, 1, 1]), atol=tol, rtol=0)
-
pauliz = qml.PauliZ(wires=0)
proj = qml.Projector([1], wires=0)
diff --git a/tests/measurements/test_probs.py b/tests/measurements/test_probs.py
index 07bd8fcd463..b5af8118b07 100644
--- a/tests/measurements/test_probs.py
+++ b/tests/measurements/test_probs.py
@@ -338,7 +338,7 @@ def circuit():
@pytest.mark.jax
@pytest.mark.parametrize("shots", (None, 500))
@pytest.mark.parametrize("obs", ([0, 1], qml.PauliZ(0) @ qml.PauliZ(1)))
- @pytest.mark.parametrize("params", ([np.pi / 2], [np.pi / 2, np.pi / 2, np.pi / 2]))
+ @pytest.mark.parametrize("params", (np.pi / 2, [np.pi / 2, np.pi / 2, np.pi / 2]))
def test_integration_jax(self, tol_stochastic, shots, obs, params, seed):
"""Test the probability is correct for a known state preparation when jitted with JAX."""
jax = pytest.importorskip("jax")
@@ -359,7 +359,9 @@ def circuit(x):
# expected probability, using [00, 01, 10, 11]
# ordering, is [0.5, 0.5, 0, 0]
- assert "pure_callback" not in str(jax.make_jaxpr(circuit)(params))
+ # TODO: [sc-82874]
+ # revert once we are able to jit end to end without extreme compilation overheads
+ assert "pure_callback" in str(jax.make_jaxpr(circuit)(params))
res = jax.jit(circuit)(params)
expected = np.array([0.5, 0.5, 0, 0])
diff --git a/tests/ops/functions/test_matrix.py b/tests/ops/functions/test_matrix.py
index 5adc8912cd2..a1f82e6077a 100644
--- a/tests/ops/functions/test_matrix.py
+++ b/tests/ops/functions/test_matrix.py
@@ -683,6 +683,7 @@ def circuit(theta):
assert np.allclose(matrix, expected_matrix)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.catalyst
@pytest.mark.external
def test_catalyst(self):
diff --git a/tests/ops/op_math/test_exp.py b/tests/ops/op_math/test_exp.py
index 5abead84595..6ca4c68091b 100644
--- a/tests/ops/op_math/test_exp.py
+++ b/tests/ops/op_math/test_exp.py
@@ -768,6 +768,7 @@ def circ(phi):
grad = jax.grad(circ)(phi)
assert qml.math.allclose(grad, -jnp.sin(phi))
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.catalyst
@pytest.mark.external
def test_catalyst_qnode(self):
diff --git a/tests/resource/test_specs.py b/tests/resource/test_specs.py
index a02b35ef97b..5a5764e7153 100644
--- a/tests/resource/test_specs.py
+++ b/tests/resource/test_specs.py
@@ -33,10 +33,15 @@ class TestSpecsTransform:
"""Tests for the transform specs using the QNode"""
def sample_circuit(self):
+
@qml.transforms.merge_rotations
@qml.transforms.undo_swaps
@qml.transforms.cancel_inverses
- @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4)
+ @qml.qnode(
+ qml.device("default.qubit"),
+ diff_method="parameter-shift",
+ gradient_kwargs={"shifts": pnp.pi / 4},
+ )
def circuit(x):
qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1))
qml.RX(x, wires=0)
@@ -222,7 +227,11 @@ def test_splitting_transforms(self):
@qml.transforms.split_non_commuting
@qml.transforms.merge_rotations
- @qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4)
+ @qml.qnode(
+ qml.device("default.qubit"),
+ diff_method="parameter-shift",
+ gradient_kwargs={"shifts": pnp.pi / 4},
+ )
def circuit(x):
qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1))
qml.RX(x, wires=0)
diff --git a/tests/templates/test_state_preparations/test_mottonen_state_prep.py b/tests/templates/test_state_preparations/test_mottonen_state_prep.py
index 885def86e60..a63a2d20129 100644
--- a/tests/templates/test_state_preparations/test_mottonen_state_prep.py
+++ b/tests/templates/test_state_preparations/test_mottonen_state_prep.py
@@ -431,7 +431,7 @@ def circuit(coeffs):
qml.MottonenStatePreparation(coeffs, wires=[0, 1])
return qml.probs(wires=[0, 1])
- circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05)
+ circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", gradient_kwargs={"h": 0.05})
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)
diff --git a/tests/test_compiler.py b/tests/test_compiler.py
index 38d61cd86cf..6eb3fa851a4 100644
--- a/tests/test_compiler.py
+++ b/tests/test_compiler.py
@@ -76,6 +76,7 @@ def test_compiler(self):
assert qml.compiler.available("catalyst")
assert qml.compiler.available_compilers() == ["catalyst", "cuda_quantum"]
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_active_compiler(self):
"""Test `qml.compiler.active_compiler` inside a simple circuit"""
dev = qml.device("lightning.qubit", wires=2)
@@ -91,6 +92,7 @@ def circuit(phi, theta):
assert jnp.allclose(circuit(jnp.pi, jnp.pi / 2), 1.0)
assert jnp.allclose(qml.qjit(circuit)(jnp.pi, jnp.pi / 2), -1.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_active(self):
"""Test `qml.compiler.active` inside a simple circuit"""
dev = qml.device("lightning.qubit", wires=2)
@@ -114,6 +116,7 @@ def test_jax_enable_x64(self, jax_enable_x64):
qml.compiler.active()
assert jax.config.jax_enable_x64 is jax_enable_x64
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_circuit(self):
"""Test JIT compilation of a circuit with 2-qubit"""
dev = qml.device("lightning.qubit", wires=2)
@@ -128,6 +131,7 @@ def circuit(theta):
assert jnp.allclose(circuit(0.5), 0.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_aot(self):
"""Test AOT compilation of a circuit with 2-qubit"""
@@ -152,6 +156,7 @@ def circuit(x: complex, z: ShapedArray(shape=(3,), dtype=jnp.float64)):
)
assert jnp.allclose(result, expected)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.parametrize(
"_in,_out",
[
@@ -196,6 +201,7 @@ def workflow1(params1, params2):
result = workflow1(params1, params2)
assert jnp.allclose(result, expected)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_return_value_dict(self):
"""Test pytree return values."""
dev = qml.device("lightning.qubit", wires=2)
@@ -218,6 +224,7 @@ def circuit1(params):
assert jnp.allclose(result["w0"], expected["w0"])
assert jnp.allclose(result["w1"], expected["w1"])
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_python_if(self):
"""Test JIT compilation with the autograph support"""
dev = qml.device("lightning.qubit", wires=2)
@@ -235,6 +242,7 @@ def circuit(x: int):
assert jnp.allclose(circuit(3), 0.0)
assert jnp.allclose(circuit(5), 1.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_compilation_opt(self):
"""Test user-configurable compilation options"""
dev = qml.device("lightning.qubit", wires=2)
@@ -250,6 +258,7 @@ def circuit(x: float):
result_header = "func.func public @circuit(%arg0: tensor) -> tensor"
assert result_header in mlir_str
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_adjoint(self):
"""Test JIT compilation with adjoint support"""
dev = qml.device("lightning.qubit", wires=2)
@@ -273,6 +282,7 @@ def func():
assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1]))
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_qjit_adjoint_lazy(self):
"""Test that the lazy kwarg is supported."""
dev = qml.device("lightning.qubit", wires=2)
@@ -287,6 +297,7 @@ def workflow_pl(theta, wires):
assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1]))
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_control(self):
"""Test that control works with qjit."""
dev = qml.device("lightning.qubit", wires=2)
@@ -317,6 +328,7 @@ def cond_fn():
class TestCatalystControlFlow:
"""Test ``qml.qjit`` with Catalyst's control-flow operations"""
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_alternating_while_loop(self):
"""Test simple while loop."""
dev = qml.device("lightning.qubit", wires=1)
@@ -334,6 +346,7 @@ def loop(v):
assert jnp.allclose(circuit(1), -1.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_nested_while_loops(self):
"""Test nested while loops."""
dev = qml.device("lightning.qubit", wires=1)
@@ -393,6 +406,7 @@ def loop(v):
expected = [qml.PauliX(0) for i in range(4)]
_ = [qml.assert_equal(i, j) for i, j in zip(tape.operations, expected)]
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_dynamic_wires_for_loops(self):
"""Test for loops with iteration index-dependant wires."""
dev = qml.device("lightning.qubit", wires=6)
@@ -414,6 +428,7 @@ def loop_fn(i):
assert jnp.allclose(circuit(6), expected)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_nested_for_loops(self):
"""Test nested for loops."""
dev = qml.device("lightning.qubit", wires=4)
@@ -445,6 +460,7 @@ def inner(j):
assert jnp.allclose(circuit(4), jnp.eye(2**4)[0])
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_for_loop_python_fallback(self):
"""Test that qml.for_loop fallsback to Python
interpretation if Catalyst is not available"""
@@ -496,6 +512,7 @@ def inner(j):
_ = [qml.assert_equal(i, j) for i, j in zip(res, expected)]
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond(self):
"""Test condition with simple true_fn"""
dev = qml.device("lightning.qubit", wires=1)
@@ -514,6 +531,7 @@ def ansatz_true():
assert jnp.allclose(circuit(1.4), 1.0)
assert jnp.allclose(circuit(1.6), 0.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond_with_else(self):
"""Test condition with simple true_fn and false_fn"""
dev = qml.device("lightning.qubit", wires=1)
@@ -535,6 +553,7 @@ def ansatz_false():
assert jnp.allclose(circuit(1.4), 0.16996714)
assert jnp.allclose(circuit(1.6), 0.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond_with_elif(self):
"""Test condition with a simple elif branch"""
dev = qml.device("lightning.qubit", wires=1)
@@ -558,6 +577,7 @@ def false_fn():
assert jnp.allclose(circuit(1.2), 0.13042371)
assert jnp.allclose(circuit(jnp.pi), -1.0)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_cond_with_elifs(self):
"""Test condition with multiple elif branches"""
dev = qml.device("lightning.qubit", wires=1)
@@ -630,6 +650,7 @@ def conditional_false_fn(): # pylint: disable=unused-variable
class TestCatalystGrad:
"""Test ``qml.qjit`` with Catalyst's grad operations"""
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_grad_classical_preprocessing(self):
"""Test the grad transformation with classical preprocessing."""
@@ -647,6 +668,7 @@ def circuit(x):
assert jnp.allclose(workflow(2.0), -jnp.pi)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_grad_with_postprocessing(self):
"""Test the grad transformation with classical preprocessing and postprocessing."""
dev = qml.device("lightning.qubit", wires=1)
@@ -665,6 +687,7 @@ def loss(theta):
assert jnp.allclose(workflow(1.0), 5.04324559)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_grad_with_multiple_qnodes(self):
"""Test the grad transformation with multiple QNodes with their own differentiation methods."""
dev = qml.device("lightning.qubit", wires=1)
@@ -703,6 +726,7 @@ def dsquare(x: float):
assert jnp.allclose(dsquare(2.3), 4.6)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_jacobian_diff_method(self):
"""Test the Jacobian transformation with the device diff_method."""
dev = qml.device("lightning.qubit", wires=1)
@@ -721,6 +745,7 @@ def workflow(p: float):
assert jnp.allclose(result, reference)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_jacobian_auto(self):
"""Test the Jacobian transformation with 'auto'."""
dev = qml.device("lightning.qubit", wires=1)
@@ -740,6 +765,7 @@ def circuit(x):
assert jnp.allclose(result, reference)
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
def test_jacobian_fd(self):
"""Test the Jacobian transformation with 'fd'."""
dev = qml.device("lightning.qubit", wires=1)
@@ -838,6 +864,7 @@ def f(x):
vjp(x, dy)
+@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
class TestCatalystSample:
"""Test qml.sample with Catalyst."""
@@ -858,6 +885,7 @@ def circuit(x):
assert circuit(jnp.pi) == 1
+@pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
class TestCatalystMCMs:
"""Test dynamic_one_shot with Catalyst."""
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index c60ae352e4b..5d270d1f9b8 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -36,6 +36,17 @@ def dummyfunc():
return None
+def test_additional_kwargs_is_deprecated():
+ """Test that passing gradient_kwargs as additional kwargs raises a deprecation warning."""
+ dev = qml.device("default.qubit", wires=1)
+
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning,
+ match=r"Specifying gradient keyword arguments \[\'atol\'\] is deprecated",
+ ):
+ QNode(dummyfunc, dev, atol=1)
+
+
# pylint: disable=unused-argument
class CustomDevice(qml.devices.Device):
"""A null device that just returns 0."""
@@ -145,24 +156,21 @@ def test_update_gradient_kwargs(self):
"""Test that gradient kwargs are updated correctly"""
dev = qml.device("default.qubit")
- @qml.qnode(dev, atol=1)
+ @qml.qnode(dev, gradient_kwargs={"atol": 1})
def circuit(x):
qml.RZ(x, wires=0)
qml.CNOT(wires=[0, 1])
qml.RY(x, wires=1)
return qml.expval(qml.PauliZ(1))
- assert len(circuit.gradient_kwargs) == 1
- assert list(circuit.gradient_kwargs.keys()) == ["atol"]
+ assert set(circuit.gradient_kwargs.keys()) == {"atol"}
- new_atol_circuit = circuit.update(atol=2)
- assert len(new_atol_circuit.gradient_kwargs) == 1
- assert list(new_atol_circuit.gradient_kwargs.keys()) == ["atol"]
+ new_atol_circuit = circuit.update(gradient_kwargs={"atol": 2})
+ assert set(new_atol_circuit.gradient_kwargs.keys()) == {"atol"}
assert new_atol_circuit.gradient_kwargs["atol"] == 2
- new_kwarg_circuit = circuit.update(h=1)
- assert len(new_kwarg_circuit.gradient_kwargs) == 2
- assert list(new_kwarg_circuit.gradient_kwargs.keys()) == ["atol", "h"]
+ new_kwarg_circuit = circuit.update(gradient_kwargs={"h": 1})
+ assert set(new_kwarg_circuit.gradient_kwargs.keys()) == {"atol", "h"}
assert new_kwarg_circuit.gradient_kwargs["atol"] == 1
assert new_kwarg_circuit.gradient_kwargs["h"] == 1
@@ -170,7 +178,7 @@ def circuit(x):
UserWarning,
match="Received gradient_kwarg blah, which is not included in the list of standard qnode gradient kwargs.",
):
- circuit.update(blah=1)
+ circuit.update(gradient_kwargs={"blah": 1})
def test_update_multiple_arguments(self):
"""Test that multiple parameters can be updated at once."""
@@ -194,7 +202,7 @@ def test_update_transform_program(self):
dev = qml.device("default.qubit", wires=2)
@qml.transforms.combine_global_phases
- @qml.qnode(dev, atol=1)
+ @qml.qnode(dev)
def circuit(x):
qml.RZ(x, wires=0)
qml.GlobalPhase(phi=1)
@@ -248,7 +256,7 @@ def circuit(return_type):
def test_expansion_strategy_error(self):
"""Test that an error is raised if expansion_strategy is passed to the qnode."""
- with pytest.raises(ValueError, match=r"'expansion_strategy' is no longer"):
+ with pytest.raises(ValueError, match="'expansion_strategy' is no longer"):
@qml.qnode(qml.device("default.qubit"), expansion_strategy="device")
def _():
@@ -453,7 +461,7 @@ def test_unrecognized_kwargs_raise_warning(self):
with warnings.catch_warnings(record=True) as w:
- @qml.qnode(dev, random_kwarg=qml.gradients.finite_diff)
+ @qml.qnode(dev, gradient_kwargs={"random_kwarg": qml.gradients.finite_diff})
def circuit(params):
qml.RX(params[0], wires=0)
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0))
@@ -846,7 +854,7 @@ def test_single_expectation_value_with_argnum_one(self, diff_method, tol):
y = pnp.array(-0.654, requires_grad=True)
@qnode(
- dev, diff_method=diff_method, argnum=[1]
+ dev, diff_method=diff_method, gradient_kwargs={"argnum": [1]}
) # <--- we only choose one trainable parameter
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -1320,11 +1328,11 @@ def ansatz0():
return qml.expval(qml.X(0))
with pytest.raises(ValueError, match="'shots' is not a valid gradient_kwarg."):
- qml.QNode(ansatz0, dev, shots=100)
+ qml.QNode(ansatz0, dev, gradient_kwargs={"shots": 100})
with pytest.raises(ValueError, match="'shots' is not a valid gradient_kwarg."):
- @qml.qnode(dev, shots=100)
+ @qml.qnode(dev, gradient_kwargs={"shots": 100})
def _():
return qml.expval(qml.X(0))
@@ -1918,14 +1926,17 @@ def circuit(x, mp):
return mp(qml.PauliZ(0))
_ = circuit(1.8, qml.expval, shots=10)
- assert circuit.execute_kwargs["mcm_config"] == original_config
+ assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode
+ assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method
if mcm_method != "one-shot":
_ = circuit(1.8, qml.expval)
- assert circuit.execute_kwargs["mcm_config"] == original_config
+ assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode
+ assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method
_ = circuit(1.8, qml.expval, shots=10)
- assert circuit.execute_kwargs["mcm_config"] == original_config
+ assert circuit.execute_kwargs["postselect_mode"] == original_config.postselect_mode
+ assert circuit.execute_kwargs["mcm_method"] == original_config.mcm_method
class TestTapeExpansion:
diff --git a/tests/test_qnode_legacy.py b/tests/test_qnode_legacy.py
index 26c46687934..5a8c63c7d22 100644
--- a/tests/test_qnode_legacy.py
+++ b/tests/test_qnode_legacy.py
@@ -201,7 +201,7 @@ def test_unrecognized_kwargs_raise_warning(self):
with warnings.catch_warnings(record=True) as w:
- @qml.qnode(dev, random_kwarg=qml.gradients.finite_diff)
+ @qml.qnode(dev, gradient_kwargs={"random_kwarg": qml.gradients.finite_diff})
def circuit(params):
qml.RX(params[0], wires=0)
return qml.expval(qml.PauliZ(0)), qml.var(qml.PauliZ(0))
@@ -627,7 +627,7 @@ def test_single_expectation_value_with_argnum_one(self, diff_method, tol):
y = pnp.array(-0.654, requires_grad=True)
@qnode(
- dev, diff_method=diff_method, argnum=[1]
+ dev, diff_method=diff_method, gradient_kwargs={"argnum": [1]}
) # <--- we only choose one trainable parameter
def circuit(x, y):
qml.RX(x, wires=[0])
diff --git a/tests/transforms/core/test_transform_dispatcher.py b/tests/transforms/core/test_transform_dispatcher.py
index 4a24627c739..6784c3b0a6e 100644
--- a/tests/transforms/core/test_transform_dispatcher.py
+++ b/tests/transforms/core/test_transform_dispatcher.py
@@ -216,6 +216,7 @@ def test_the_transform_container_attributes(self):
class TestTransformDispatcher: # pylint: disable=too-many-public-methods
"""Test the transform function (validate and dispatch)."""
+ @pytest.mark.xfail(reason="https://github.com/PennyLaneAI/catalyst/pull/1452")
@pytest.mark.catalyst
@pytest.mark.external
def test_error_on_qjit(self):
diff --git a/tests/transforms/test_add_noise.py b/tests/transforms/test_add_noise.py
index 779dcf91e36..3beab611105 100644
--- a/tests/transforms/test_add_noise.py
+++ b/tests/transforms/test_add_noise.py
@@ -414,7 +414,7 @@ def test_add_noise_level(self, level1, level2):
@qml.transforms.undo_swaps
@qml.transforms.merge_rotations
@qml.transforms.cancel_inverses
- @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4)
+ @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4})
def f(w, x, y, z):
qml.RX(w, wires=0)
qml.RY(x, wires=1)
@@ -447,7 +447,7 @@ def test_add_noise_level_with_final(self):
@qml.transforms.undo_swaps
@qml.transforms.merge_rotations
@qml.transforms.cancel_inverses
- @qml.qnode(dev, diff_method="parameter-shift", shifts=np.pi / 4)
+ @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4})
def f(w, x, y, z):
qml.RX(w, wires=0)
qml.RY(x, wires=1)
diff --git a/tests/workflow/interfaces/execute/test_execute.py b/tests/workflow/interfaces/execute/test_execute.py
index 8bea335ff7e..4a253727261 100644
--- a/tests/workflow/interfaces/execute/test_execute.py
+++ b/tests/workflow/interfaces/execute/test_execute.py
@@ -57,6 +57,28 @@ def test_execute_legacy_device():
assert qml.math.allclose(res[0], np.cos(0.1))
+def test_mcm_config_deprecation(mocker):
+ """Test that mcm_config argument has been deprecated."""
+
+ tape = qml.tape.QuantumScript(
+ [qml.RX(qml.numpy.array(1.0), 0)], [qml.expval(qml.Z(0))], shots=10
+ )
+ dev = qml.device("default.qubit")
+
+ with dev.tracker:
+ with pytest.warns(
+ qml.PennyLaneDeprecationWarning,
+ match="The mcm_config argument is deprecated and will be removed in v0.42, use mcm_method and postselect_mode instead.",
+ ):
+ spy = mocker.spy(qml.dynamic_one_shot, "_transform")
+ qml.execute(
+ (tape,),
+ dev,
+ mcm_config=qml.devices.MCMConfig(mcm_method="one-shot", postselect_mode=None),
+ )
+ spy.assert_called_once()
+
+
def test_config_deprecation():
"""Test that the config argument has been deprecated."""
diff --git a/tests/workflow/interfaces/qnode/test_autograd_qnode.py b/tests/workflow/interfaces/qnode/test_autograd_qnode.py
index ac30592926b..18e5f70cb83 100644
--- a/tests/workflow/interfaces/qnode/test_autograd_qnode.py
+++ b/tests/workflow/interfaces/qnode/test_autograd_qnode.py
@@ -137,14 +137,15 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, tol, dev
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
a = np.array(0.1, requires_grad=True)
b = np.array(0.2, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=1)
@@ -183,14 +184,15 @@ def test_jacobian_no_evaluate(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
a = np.array(0.1, requires_grad=True)
b = np.array(0.2, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=1)
@@ -222,13 +224,14 @@ def test_jacobian_options(self, interface, dev, diff_method, grad_on_execution,
a = np.array([0.1, 0.2], requires_grad=True)
+ gradient_kwargs = {"h": 1e-8, "approx_order": 2}
+
@qnode(
dev,
interface=interface,
- h=1e-8,
- order=2,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a):
qml.RY(a[0], wires=0)
@@ -408,10 +411,10 @@ def test_differentiable_expand(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
-
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 10
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 10
tol = TOL_FOR_SPSA
# pylint: disable=too-few-public-methods
@@ -429,7 +432,7 @@ def decomposition(self):
a = np.array(0.1, requires_grad=False)
p = np.array([0.1, 0.2, 0.3], requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, p):
qml.RX(a, wires=0)
U3(p[0], p[1], p[2], wires=0)
@@ -555,15 +558,15 @@ def test_probability_differentiation(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
-
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
x = np.array(0.543, requires_grad=True)
y = np.array(-0.654, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -593,14 +596,15 @@ def test_multiple_probability_differentiation(
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
x = np.array(0.543, requires_grad=True)
y = np.array(-0.654, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -660,14 +664,15 @@ def test_ragged_differentiation(
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
x = np.array(0.543, requires_grad=True)
y = np.array(-0.654, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -711,8 +716,9 @@ def test_ragged_differentiation_variance(
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
elif diff_method == "hadamard":
pytest.skip("Hadamard gradient does not support variances.")
@@ -720,7 +726,7 @@ def test_ragged_differentiation_variance(
x = np.array(0.543, requires_grad=True)
y = np.array(-0.654, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -826,12 +832,13 @@ def test_chained_gradient_value(
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
dev1 = qml.device("default.qubit")
- @qnode(dev1, **kwargs)
+ @qnode(dev1, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit1(a, b, c):
qml.RX(a, wires=0)
qml.RX(b, wires=1)
@@ -1350,8 +1357,9 @@ def test_projector(
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
elif diff_method == "hadamard":
pytest.skip("Hadamard gradient does not support variances.")
@@ -1359,7 +1367,7 @@ def test_projector(
P = np.array(state, requires_grad=False)
x, y = np.array([0.765, -0.654], requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
@@ -1498,16 +1506,17 @@ def test_hamiltonian_expansion_analytic(
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method in ["adjoint", "hadamard"]:
pytest.skip("The diff method requested does not yet support Hamiltonians")
elif diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 10
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 10
tol = TOL_FOR_SPSA
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(data, weights, coeffs):
weights = weights.reshape(1, -1)
qml.templates.AngleEmbedding(data, wires=[0, 1])
@@ -1584,7 +1593,7 @@ def test_hamiltonian_finite_shots(
grad_on_execution=grad_on_execution,
max_diff=max_diff,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = weights.reshape(1, -1)
diff --git a/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py
index 87e534fdb52..54072ae442e 100644
--- a/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py
+++ b/tests/workflow/interfaces/qnode/test_autograd_qnode_shot_vector.py
@@ -52,7 +52,7 @@ def test_jac_single_measurement_param(
"""For one measurement and one param, the gradient is a float."""
dev = qml.device(dev_name, wires=1, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -74,7 +74,7 @@ def test_jac_single_measurement_multiple_param(
"""For one measurement and multiple param, the gradient is a tuple of arrays."""
dev = qml.device(dev_name, wires=1, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -100,7 +100,7 @@ def test_jacobian_single_measurement_multiple_param_array(
"""For one measurement and multiple param as a single array params, the gradient is an array."""
dev = qml.device(dev_name, wires=1, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -123,7 +123,7 @@ def test_jacobian_single_measurement_param_probs(
dimension"""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -146,7 +146,7 @@ def test_jacobian_single_measurement_probs_multiple_param(
the correct dimension"""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -173,7 +173,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array(
the correct dimension"""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -198,7 +198,7 @@ def test_jacobian_expval_expval_multiple_params(
par_0 = np.array(0.1)
par_1 = np.array(0.2)
- @qnode(dev, diff_method=diff_method, max_diff=1, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=1, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -223,7 +223,7 @@ def test_jacobian_expval_expval_multiple_params_array(
"""The jacobian of multiple measurements with a multiple params array return a single array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -250,7 +250,7 @@ def test_jacobian_var_var_multiple_params(
par_0 = np.array(0.1)
par_1 = np.array(0.2)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -275,7 +275,7 @@ def test_jacobian_var_var_multiple_params_array(
"""The jacobian of multiple measurements with a multiple params array return a single array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -299,7 +299,7 @@ def test_jacobian_multiple_measurement_single_param(
"""The jacobian of multiple measurements with a single params return an array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -322,7 +322,7 @@ def test_jacobian_multiple_measurement_multiple_param(
"""The jacobian of multiple measurements with a multiple params return a tuple of arrays."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -349,7 +349,7 @@ def test_jacobian_multiple_measurement_multiple_param_array(
"""The jacobian of multiple measurements with a multiple params array return a single array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -382,7 +382,7 @@ def test_hessian_expval_multiple_params(
par_0 = np.array(0.1)
par_1 = np.array(0.2)
- @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -412,7 +412,7 @@ def test_hessian_expval_multiple_param_array(
params = np.array([0.1, 0.2])
- @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs)
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -440,7 +440,7 @@ def test_hessian_var_multiple_params(
par_0 = np.array(0.1)
par_1 = np.array(0.2)
- @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -470,7 +470,7 @@ def test_hessian_var_multiple_param_array(
params = np.array([0.1, 0.2])
- @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs)
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -500,7 +500,7 @@ def test_hessian_probs_expval_multiple_params(
par_0 = np.array(0.1)
par_1 = np.array(0.2)
- @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -533,7 +533,7 @@ def test_hessian_expval_probs_multiple_param_array(
params = np.array([0.1, 0.2])
- @qnode(dev, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, max_diff=2, gradient_kwargs=gradient_kwargs)
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -571,7 +571,7 @@ def test_single_expectation_value(
x = np.array(0.543)
y = np.array(-0.654)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -604,7 +604,7 @@ def test_prob_expectation_values(
x = np.array(0.543)
y = np.array(-0.654)
- @qnode(dev, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
diff --git a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
index 862806999c6..02a586be118 100644
--- a/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
+++ b/tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
@@ -230,7 +230,7 @@ def decomposition(self):
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, p):
qml.RX(a, wires=0)
@@ -266,14 +266,15 @@ def test_jacobian_options(
a = np.array([0.1, 0.2], requires_grad=True)
+ gradient_kwargs = {"h": 1e-8, "approx_order": 2}
+
@qnode(
get_device(dev_name, wires=1, seed=seed),
interface=interface,
diff_method="finite-diff",
- h=1e-8,
- approx_order=2,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a):
qml.RY(a[0], wires=0)
@@ -323,7 +324,7 @@ def test_diff_expval_expval(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
@@ -389,7 +390,7 @@ def test_jacobian_no_evaluate(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
@@ -456,7 +457,7 @@ def test_diff_single_probs(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -512,7 +513,7 @@ def test_diff_multi_probs(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -601,7 +602,7 @@ def test_diff_expval_probs(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -679,7 +680,7 @@ def test_diff_expval_probs_sub_argnums(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **kwargs,
+ gradient_kwargs=kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -740,7 +741,7 @@ def test_diff_var_probs(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -1130,7 +1131,7 @@ def test_second_derivative(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -1183,7 +1184,7 @@ def test_hessian(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -1237,7 +1238,7 @@ def test_hessian_vector_valued(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -1300,7 +1301,7 @@ def test_hessian_vector_valued_postprocessing(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RX(x[0], wires=0)
@@ -1366,7 +1367,7 @@ def test_hessian_vector_valued_separate_args(
grad_on_execution=grad_on_execution,
max_diff=2,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
@@ -1478,7 +1479,7 @@ def test_projector(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=0)
@@ -1593,7 +1594,7 @@ def test_hamiltonian_expansion_analytic(
grad_on_execution=grad_on_execution,
max_diff=max_diff,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = weights.reshape(1, -1)
@@ -1663,7 +1664,7 @@ def test_hamiltonian_finite_shots(
grad_on_execution=grad_on_execution,
max_diff=max_diff,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = weights.reshape(1, -1)
@@ -1862,7 +1863,7 @@ def test_gradient(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -2013,7 +2014,7 @@ def test_gradient_scalar_cost_vector_valued_qnode(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -3062,7 +3063,7 @@ def test_single_measurement(
grad_on_execution=grad_on_execution,
cache=False,
device_vjp=device_vjp,
- **kwargs,
+ gradient_kwargs=kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
@@ -3123,7 +3124,7 @@ def test_multi_measurements(
diff_method=diff_method,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **kwargs,
+ gradient_kwargs=kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
diff --git a/tests/workflow/interfaces/qnode/test_jax_qnode.py b/tests/workflow/interfaces/qnode/test_jax_qnode.py
index ae1cb53a398..ae7ea130193 100644
--- a/tests/workflow/interfaces/qnode/test_jax_qnode.py
+++ b/tests/workflow/interfaces/qnode/test_jax_qnode.py
@@ -188,10 +188,10 @@ def test_differentiable_expand(
"grad_on_execution": grad_on_execution,
"device_vjp": device_vjp,
}
-
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 10
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 10
tol = TOL_FOR_SPSA
class U3(qml.U3): # pylint:disable=too-few-public-methods
@@ -206,7 +206,7 @@ def decomposition(self):
a = jax.numpy.array(0.1)
p = jax.numpy.array([0.1, 0.2, 0.3])
- @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, p):
qml.RX(a, wires=0)
U3(p[0], p[1], p[2], wires=0)
@@ -240,12 +240,13 @@ def test_jacobian_options(
a = jax.numpy.array([0.1, 0.2])
+ gradient_kwargs = {"h": 1e-8, "approx_order": 2}
+
@qnode(
get_device(dev_name, wires=1, seed=seed),
interface=interface,
diff_method="finite-diff",
- h=1e-8,
- approx_order=2,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a):
qml.RY(a[0], wires=0)
@@ -273,9 +274,9 @@ def test_diff_expval_expval(
"grad_on_execution": grad_on_execution,
"device_vjp": device_vjp,
}
-
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
if "lightning" in dev_name:
pytest.xfail("lightning device_vjp not compatible with jax.jacobian.")
@@ -283,7 +284,7 @@ def test_diff_expval_expval(
a = jax.numpy.array(0.1)
b = jax.numpy.array(0.2)
- @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=1)
@@ -334,14 +335,15 @@ def test_jacobian_no_evaluate(
if "lightning" in dev_name:
pytest.xfail("lightning device_vjp not compatible with jax.jacobian.")
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
a = jax.numpy.array(0.1)
b = jax.numpy.array(0.2)
- @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=1)
@@ -387,8 +389,9 @@ def test_diff_single_probs(
"grad_on_execution": grad_on_execution,
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
if "lightning" in dev_name:
pytest.xfail("lightning device_vjp not compatible with jax.jacobian.")
@@ -396,7 +399,7 @@ def test_diff_single_probs(
x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)
- @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=2, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -436,8 +439,9 @@ def test_diff_multi_probs(
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
if "lightning" in dev_name:
pytest.xfail("lightning device_vjp not compatible with jax.jacobian.")
@@ -445,7 +449,7 @@ def test_diff_multi_probs(
x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)
- @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -517,8 +521,9 @@ def test_diff_expval_probs(
"grad_on_execution": grad_on_execution,
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
if "lightning" in dev_name:
pytest.xfail("lightning device_vjp not compatible with jax.jacobian.")
@@ -526,7 +531,7 @@ def test_diff_expval_probs(
x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)
- @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -599,7 +604,7 @@ def test_diff_expval_probs_sub_argnums(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **kwargs,
+ gradient_kwargs=kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -646,18 +651,19 @@ def test_diff_var_probs(
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "hadamard":
pytest.skip("Hadamard does not support var")
if "lightning" in dev_name:
pytest.xfail("lightning device_vjp not compatible with jax.jacobian.")
elif diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)
- @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=1, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -968,13 +974,14 @@ def test_second_derivative(
"max_diff": 2,
}
+ gradient_kwargs = {}
if diff_method == "adjoint":
pytest.skip("Adjoint does not second derivative.")
elif diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
- @qnode(get_device(dev_name, wires=0, seed=seed), **kwargs)
+ @qnode(get_device(dev_name, wires=0, seed=seed), **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x):
qml.RY(x[0], wires=0)
qml.RX(x[1], wires=0)
@@ -1024,7 +1031,7 @@ def test_hessian(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -1079,7 +1086,7 @@ def test_hessian_vector_valued(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -1142,7 +1149,7 @@ def test_hessian_vector_valued_postprocessing(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x):
qml.RX(x[0], wires=0)
@@ -1208,7 +1215,7 @@ def test_hessian_vector_valued_separate_args(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
@@ -1316,7 +1323,7 @@ def test_projector(
interface=interface,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=0)
@@ -1425,7 +1432,7 @@ def test_split_non_commuting_analytic(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=max_diff,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = weights.reshape(1, -1)
@@ -1513,7 +1520,7 @@ def test_hamiltonian_finite_shots(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
max_diff=max_diff,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = weights.reshape(1, -1)
diff --git a/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py
index 697a1e90223..210ba601586 100644
--- a/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py
+++ b/tests/workflow/interfaces/qnode/test_jax_qnode_shot_vector.py
@@ -60,7 +60,7 @@ def test_jac_single_measurement_param(
"""For one measurement and one param, the gradient is a float."""
dev = qml.device(dev_name, wires=1, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -86,7 +86,7 @@ def test_jac_single_measurement_multiple_param(
"""For one measurement and multiple param, the gradient is a tuple of arrays."""
dev = qml.device(dev_name, wires=1, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -115,7 +115,7 @@ def test_jacobian_single_measurement_multiple_param_array(
"""For one measurement and multiple param as a single array params, the gradient is an array."""
dev = qml.device(dev_name, wires=1, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -141,7 +141,7 @@ def test_jacobian_single_measurement_param_probs(
dimension"""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -167,7 +167,7 @@ def test_jacobian_single_measurement_probs_multiple_param(
the correct dimension"""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -199,7 +199,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array(
the correct dimension"""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -227,7 +227,13 @@ def test_jacobian_expval_expval_multiple_params(
par_0 = jax.numpy.array(0.1)
par_1 = jax.numpy.array(0.2)
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=1, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=1,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -264,7 +270,7 @@ def test_jacobian_expval_expval_multiple_params_array(
"""The jacobian of multiple measurements with a multiple params array return a single array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -298,7 +304,7 @@ def test_jacobian_var_var_multiple_params(
par_0 = jax.numpy.array(0.1)
par_1 = jax.numpy.array(0.2)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -336,7 +342,7 @@ def test_jacobian_var_var_multiple_params_array(
"""The jacobian of multiple measurements with a multiple params array return a single array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -367,7 +373,7 @@ def test_jacobian_multiple_measurement_single_param(
"""The jacobian of multiple measurements with a single params return an array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -398,7 +404,7 @@ def test_jacobian_multiple_measurement_multiple_param(
"""The jacobian of multiple measurements with a multiple params return a tuple of arrays."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -438,7 +444,7 @@ def test_jacobian_multiple_measurement_multiple_param_array(
"""The jacobian of multiple measurements with a multiple params array return a single array."""
dev = qml.device(dev_name, wires=2, shots=shots)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -471,7 +477,13 @@ def test_hessian_expval_multiple_params(
par_0 = jax.numpy.array(0.1)
par_1 = jax.numpy.array(0.2)
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -508,7 +520,13 @@ def test_hessian_expval_multiple_param_array(
params = jax.numpy.array([0.1, 0.2])
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -534,7 +552,13 @@ def test_hessian_var_multiple_params(
par_0 = jax.numpy.array(0.1)
par_1 = jax.numpy.array(0.2)
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -571,7 +595,13 @@ def test_hessian_var_multiple_param_array(
params = jax.numpy.array([0.1, 0.2])
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -597,7 +627,13 @@ def test_hessian_probs_expval_multiple_params(
par_0 = jax.numpy.array(0.1)
par_1 = jax.numpy.array(0.2)
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -655,7 +691,13 @@ def test_hessian_expval_probs_multiple_param_array(
params = jax.numpy.array([0.1, 0.2])
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -687,7 +729,13 @@ def test_hessian_probs_var_multiple_params(
par_0 = qml.numpy.array(0.1)
par_1 = qml.numpy.array(0.2)
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -742,7 +790,13 @@ def test_hessian_var_probs_multiple_param_array(
params = jax.numpy.array([0.1, 0.2])
- @qnode(dev, interface=interface, diff_method=diff_method, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ interface=interface,
+ diff_method=diff_method,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -792,7 +846,7 @@ def test_single_expectation_value(
x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -827,7 +881,7 @@ def test_prob_expectation_values(
x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)
- @qnode(dev, interface=interface, diff_method=diff_method, **gradient_kwargs)
+ @qnode(dev, interface=interface, diff_method=diff_method, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py
index f24eda4b382..6a20f264ec6 100644
--- a/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py
+++ b/tests/workflow/interfaces/qnode/test_tensorflow_autograph_qnode_shot_vector.py
@@ -73,7 +73,7 @@ def test_jac_single_measurement_param(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a, wires=0)
@@ -101,7 +101,7 @@ def test_jac_single_measurement_multiple_param(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b, **_):
qml.RY(a, wires=0)
@@ -133,7 +133,7 @@ def test_jacobian_single_measurement_multiple_param_array(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a[0], wires=0)
@@ -162,7 +162,7 @@ def test_jacobian_single_measurement_param_probs(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a, wires=0)
@@ -191,7 +191,7 @@ def test_jacobian_single_measurement_probs_multiple_param(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b, **_):
qml.RY(a, wires=0)
@@ -224,7 +224,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a[0], wires=0)
@@ -255,7 +255,7 @@ def test_jacobian_expval_expval_multiple_params(
diff_method=diff_method,
interface=interface,
max_diff=1,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y, **_):
qml.RX(x, wires=[0])
@@ -285,7 +285,7 @@ def test_jacobian_expval_expval_multiple_params_array(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a[0], wires=0)
@@ -314,7 +314,7 @@ def test_jacobian_multiple_measurement_single_param(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a, wires=0)
@@ -342,7 +342,7 @@ def test_jacobian_multiple_measurement_multiple_param(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, b, **_):
qml.RY(a, wires=0)
@@ -374,7 +374,7 @@ def test_jacobian_multiple_measurement_multiple_param_array(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a, **_):
qml.RY(a[0], wires=0)
@@ -421,7 +421,7 @@ def test_hessian_expval_multiple_params(
diff_method=diff_method,
interface=interface,
max_diff=2,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y, **_):
qml.RX(x, wires=[0])
@@ -471,7 +471,7 @@ def test_single_expectation_value(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y, **_):
qml.RX(x, wires=[0])
@@ -509,7 +509,7 @@ def test_prob_expectation_values(
qml.device(dev_name, seed=seed),
diff_method=diff_method,
interface=interface,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y, **_):
qml.RX(x, wires=[0])
diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py b/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py
index 85a39cc588f..79f19bd2f4d 100644
--- a/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py
+++ b/tests/workflow/interfaces/qnode/test_tensorflow_qnode.py
@@ -158,6 +158,7 @@ def circuit(p1, p2=y, **kwargs):
def test_jacobian(self, dev, diff_method, grad_on_execution, device_vjp, tol, interface, seed):
"""Test jacobian calculation"""
+ gradient_kwargs = {}
kwargs = {
"diff_method": diff_method,
"grad_on_execution": grad_on_execution,
@@ -165,14 +166,14 @@ def test_jacobian(self, dev, diff_method, grad_on_execution, device_vjp, tol, in
"device_vjp": device_vjp,
}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
a = tf.Variable(0.1, dtype=tf.float64)
b = tf.Variable(0.2, dtype=tf.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=1)
@@ -200,14 +201,15 @@ def test_jacobian_options(self, dev, diff_method, grad_on_execution, device_vjp,
a = tf.Variable([0.1, 0.2])
+ gradient_kwargs = {"approx_order": 2, "h": 1e-8}
+
@qnode(
dev,
interface=interface,
- h=1e-8,
- approx_order=2,
diff_method=diff_method,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a):
qml.RY(a[0], wires=0)
@@ -240,7 +242,7 @@ def test_changing_trainability(
diff_method=diff_method,
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
- **diff_kwargs,
+ gradient_kwargs=diff_kwargs,
)
def circuit(a, b):
qml.RY(a, wires=0)
@@ -370,6 +372,7 @@ def test_differentiable_expand(
):
"""Test that operation and nested tapes expansion
is differentiable"""
+ gradient_kwargs = {}
kwargs = {
"diff_method": diff_method,
"grad_on_execution": grad_on_execution,
@@ -377,8 +380,8 @@ def test_differentiable_expand(
"device_vjp": device_vjp,
}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
class U3(qml.U3):
@@ -393,7 +396,7 @@ def decomposition(self):
a = np.array(0.1)
p = tf.Variable([0.1, 0.2, 0.3], dtype=tf.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, p):
qml.RX(a, wires=0)
U3(p[0], p[1], p[2], wires=0)
@@ -548,15 +551,16 @@ def test_probability_differentiation(
"interface": interface,
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
x = tf.Variable(0.543, dtype=tf.float64)
y = tf.Variable(-0.654, dtype=tf.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -605,15 +609,16 @@ def test_ragged_differentiation(
"interface": interface,
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
x = tf.Variable(0.543, dtype=tf.float64)
y = tf.Variable(-0.654, dtype=tf.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -943,13 +948,14 @@ def test_projector(
"interface": interface,
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method == "adjoint":
pytest.skip("adjoint supports either all expvals or all diagonal measurements.")
if diff_method == "hadamard":
pytest.skip("Variance not implemented yet.")
elif diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
if dev.name == "reference.qubit":
pytest.xfail("diagonalize_measurements do not support projectors (sc-72911)")
@@ -959,7 +965,7 @@ def test_projector(
x, y = 0.765, -0.654
weights = tf.Variable([x, y], dtype=tf.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(weights):
qml.RX(weights[0], wires=0)
qml.RY(weights[1], wires=1)
@@ -1129,16 +1135,17 @@ def test_hamiltonian_expansion_analytic(
"interface": interface,
"device_vjp": device_vjp,
}
+ gradient_kwargs = {}
if diff_method in ["adjoint", "hadamard"]:
pytest.skip("The adjoint/hadamard method does not yet support Hamiltonians")
elif diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(data, weights, coeffs):
weights = tf.reshape(weights, [1, -1])
qml.templates.AngleEmbedding(data, wires=[0, 1])
@@ -1211,7 +1218,7 @@ def test_hamiltonian_finite_shots(
max_diff=max_diff,
interface=interface,
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = tf.reshape(weights, [1, -1])
diff --git a/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py b/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py
index 037469657bc..fafd85e7e1e 100644
--- a/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py
+++ b/tests/workflow/interfaces/qnode/test_tensorflow_qnode_shot_vector.py
@@ -70,7 +70,7 @@ def test_jac_single_measurement_param(
):
"""For one measurement and one param, the gradient is a float."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -92,7 +92,7 @@ def test_jac_single_measurement_multiple_param(
):
"""For one measurement and multiple param, the gradient is a tuple of arrays."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -118,7 +118,7 @@ def test_jacobian_single_measurement_multiple_param_array(
):
"""For one measurement and multiple param as a single array params, the gradient is an array."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -141,7 +141,7 @@ def test_jacobian_single_measurement_param_probs(
"""For a multi dimensional measurement (probs), check that a single array is returned with the correct
dimension"""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -164,7 +164,7 @@ def test_jacobian_single_measurement_probs_multiple_param(
"""For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with
the correct dimension"""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -191,7 +191,7 @@ def test_jacobian_single_measurement_probs_multiple_param_single_array(
"""For a multi dimensional measurement (probs), check that a single tuple is returned containing arrays with
the correct dimension"""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -216,7 +216,13 @@ def test_jacobian_expval_expval_multiple_params(
par_0 = tf.Variable(0.1)
par_1 = tf.Variable(0.2)
- @qnode(dev, diff_method=diff_method, interface=interface, max_diff=1, **gradient_kwargs)
+ @qnode(
+ dev,
+ diff_method=diff_method,
+ interface=interface,
+ max_diff=1,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -240,7 +246,7 @@ def test_jacobian_expval_expval_multiple_params_array(
):
"""The jacobian of multiple measurements with a multiple params array return a single array."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -263,7 +269,7 @@ def test_jacobian_multiple_measurement_single_param(
):
"""The jacobian of multiple measurements with a single params return an array."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
@@ -285,7 +291,7 @@ def test_jacobian_multiple_measurement_multiple_param(
):
"""The jacobian of multiple measurements with a multiple params return a tuple of arrays."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=0)
@@ -311,7 +317,7 @@ def test_jacobian_multiple_measurement_multiple_param_array(
):
"""The jacobian of multiple measurements with a multiple params array return a single array."""
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(a):
qml.RY(a[0], wires=0)
qml.RX(a[1], wires=0)
@@ -343,7 +349,13 @@ def test_hessian_expval_multiple_params(
par_0 = tf.Variable(0.1, dtype=tf.float64)
par_1 = tf.Variable(0.2, dtype=tf.float64)
- @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ diff_method=diff_method,
+ interface=interface,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -373,7 +385,13 @@ def test_hessian_expval_multiple_param_array(
params = tf.Variable([0.1, 0.2], dtype=tf.float64)
- @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ diff_method=diff_method,
+ interface=interface,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -400,7 +418,13 @@ def test_hessian_probs_expval_multiple_params(
par_0 = tf.Variable(0.1, dtype=tf.float64)
par_1 = tf.Variable(0.2, dtype=tf.float64)
- @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ diff_method=diff_method,
+ interface=interface,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -430,7 +454,13 @@ def test_hessian_expval_probs_multiple_param_array(
params = tf.Variable([0.1, 0.2], dtype=tf.float64)
- @qnode(dev, diff_method=diff_method, interface=interface, max_diff=2, **gradient_kwargs)
+ @qnode(
+ dev,
+ diff_method=diff_method,
+ interface=interface,
+ max_diff=2,
+ gradient_kwargs=gradient_kwargs,
+ )
def circuit(x):
qml.RX(x[0], wires=[0])
qml.RY(x[1], wires=[1])
@@ -467,7 +497,7 @@ def test_single_expectation_value(
x = tf.Variable(0.543, dtype=tf.float64)
y = tf.Variable(-0.654, dtype=tf.float64)
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -500,7 +530,7 @@ def test_prob_expectation_values(
x = tf.Variable(0.543, dtype=tf.float64)
y = tf.Variable(-0.654, dtype=tf.float64)
- @qnode(dev, diff_method=diff_method, interface=interface, **gradient_kwargs)
+ @qnode(dev, diff_method=diff_method, interface=interface, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
diff --git a/tests/workflow/interfaces/qnode/test_torch_qnode.py b/tests/workflow/interfaces/qnode/test_torch_qnode.py
index e9581898522..a2922a7ce62 100644
--- a/tests/workflow/interfaces/qnode/test_torch_qnode.py
+++ b/tests/workflow/interfaces/qnode/test_torch_qnode.py
@@ -173,9 +173,10 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, device_v
interface=interface,
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
a_val = 0.1
@@ -184,7 +185,7 @@ def test_jacobian(self, interface, dev, diff_method, grad_on_execution, device_v
a = torch.tensor(a_val, dtype=torch.float64, requires_grad=True)
b = torch.tensor(b_val, dtype=torch.float64, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, b):
qml.RY(a, wires=0)
qml.RX(b, wires=1)
@@ -276,14 +277,15 @@ def test_jacobian_options(
a = torch.tensor([0.1, 0.2], requires_grad=True)
+ gradient_kwargs = {"h": 1e-8, "approx_order": 2}
+
@qnode(
dev,
diff_method=diff_method,
grad_on_execution=grad_on_execution,
interface=interface,
- h=1e-8,
- approx_order=2,
device_vjp=device_vjp,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(a):
qml.RY(a[0], wires=0)
@@ -483,9 +485,10 @@ def test_differentiable_expand(
interface=interface,
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
class U3(qml.U3): # pylint:disable=too-few-public-methods
@@ -501,7 +504,7 @@ def decomposition(self):
p_val = [0.1, 0.2, 0.3]
p = torch.tensor(p_val, dtype=torch.float64, requires_grad=True)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(a, p):
qml.RX(a, wires=0)
U3(p[0], p[1], p[2], wires=0)
@@ -644,9 +647,9 @@ def test_probability_differentiation(
with prob and expval outputs"""
if "lightning" in getattr(dev, "name", "").lower():
pytest.xfail("lightning does not support measureing probabilities with adjoint.")
- kwargs = {}
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
tol = TOL_FOR_SPSA
x_val = 0.543
@@ -660,7 +663,7 @@ def test_probability_differentiation(
grad_on_execution=grad_on_execution,
interface=interface,
device_vjp=device_vjp,
- **kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(x, y):
qml.RX(x, wires=[0])
@@ -708,9 +711,10 @@ def test_ragged_differentiation(
interface=interface,
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
x_val = 0.543
@@ -718,7 +722,7 @@ def test_ragged_differentiation(
x = torch.tensor(x_val, requires_grad=True, dtype=torch.float64)
y = torch.tensor(y_val, requires_grad=True, dtype=torch.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=[0])
qml.RY(y, wires=[1])
@@ -813,7 +817,7 @@ def test_hessian(self, interface, dev, diff_method, grad_on_execution, device_vj
max_diff=2,
interface=interface,
device_vjp=device_vjp,
- **options,
+ gradient_kwargs=options,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -866,7 +870,7 @@ def test_hessian_vector_valued(
max_diff=2,
interface=interface,
device_vjp=device_vjp,
- **options,
+ gradient_kwargs=options,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -928,7 +932,7 @@ def test_hessian_ragged(self, interface, dev, diff_method, grad_on_execution, de
max_diff=2,
interface=interface,
device_vjp=device_vjp,
- **options,
+ gradient_kwargs=options,
)
def circuit(x):
qml.RY(x[0], wires=0)
@@ -1003,7 +1007,7 @@ def test_hessian_vector_valued_postprocessing(
max_diff=2,
interface=interface,
device_vjp=device_vjp,
- **options,
+ gradient_kwargs=options,
)
def circuit(x):
qml.RX(x[0], wires=0)
@@ -1100,11 +1104,12 @@ def test_projector(
grad_on_execution=grad_on_execution,
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "adjoint":
pytest.skip("adjoint supports either all expvals or all diagonal measurements")
if diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
elif diff_method == "hadamard":
pytest.skip("Hadamard does not support variances.")
@@ -1116,7 +1121,7 @@ def test_projector(
x, y = 0.765, -0.654
weights = torch.tensor([x, y], requires_grad=True, dtype=torch.float64)
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
@@ -1285,18 +1290,19 @@ def test_hamiltonian_expansion_analytic(
interface="torch",
device_vjp=device_vjp,
)
+ gradient_kwargs = {}
if diff_method == "adjoint":
pytest.skip("The adjoint method does not yet support Hamiltonians")
elif diff_method == "spsa":
- kwargs["sampler_rng"] = np.random.default_rng(seed)
- kwargs["num_directions"] = 20
+ gradient_kwargs["sampler_rng"] = np.random.default_rng(seed)
+ gradient_kwargs["num_directions"] = 20
tol = TOL_FOR_SPSA
elif diff_method == "hadamard":
pytest.skip("The hadamard method does not yet support Hamiltonians")
obs = [qml.PauliX(0), qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.PauliZ(1)]
- @qnode(dev, **kwargs)
+ @qnode(dev, **kwargs, gradient_kwargs=gradient_kwargs)
def circuit(data, weights, coeffs):
weights = torch.reshape(weights, [1, -1])
qml.templates.AngleEmbedding(data, wires=[0, 1])
@@ -1389,7 +1395,7 @@ def test_hamiltonian_finite_shots(
max_diff=max_diff,
interface="torch",
device_vjp=device_vjp,
- **gradient_kwargs,
+ gradient_kwargs=gradient_kwargs,
)
def circuit(data, weights, coeffs):
weights = torch.reshape(weights, [1, -1])
diff --git a/tests/workflow/test_construct_batch.py b/tests/workflow/test_construct_batch.py
index f11eaee459c..d4288da65d9 100644
--- a/tests/workflow/test_construct_batch.py
+++ b/tests/workflow/test_construct_batch.py
@@ -72,7 +72,7 @@ def test_get_transform_program_diff_method_transform(self):
@partial(qml.transforms.compile, num_passes=2)
@partial(qml.transforms.merge_rotations, atol=1e-5)
@qml.transforms.cancel_inverses
- @qml.qnode(dev, diff_method="parameter-shift", shifts=2)
+ @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": 2})
def circuit():
return qml.expval(qml.PauliZ(0))