Skip to content

Commit

Permalink
Merge pull request #77 from CQCL/develop
Browse files Browse the repository at this point in the history
v0.3.4
  • Loading branch information
SamDuffield authored Feb 22, 2023
2 parents 7c2502f + 1f031c5 commit 55324b8
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 62 deletions.
47 changes: 25 additions & 22 deletions qujax/densitytensor_observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@

from jax import numpy as jnp
from jax import random
from jax.lax import fori_loop

from qujax.densitytensor import _kraus_single, partial_trace
from qujax.statetensor_observable import _get_tensor_to_expectation_func
from qujax.utils import (
bitstrings_to_integers,
sample_integers,
statetensor_to_densitytensor,
)
from qujax.statetensor_observable import _get_tensor_to_expectation_func, sample_probs
from qujax.utils import bitstrings_to_integers, check_hermitian


def densitytensor_to_single_expectation(
Expand Down Expand Up @@ -79,6 +74,15 @@ def get_densitytensor_to_sampled_expectation_func(
Converts strings (or arrays) representing Hermitian matrices, qubit indices and
coefficients into a function that converts a densitytensor into a sampled expected value.
On a quantum device, measurements are always taken in the computational basis, as such
sampled expectation values should be taken with respect to an observable that commutes
with the Pauli Z - a warning will be raised if it does not.
qujax applies an importance sampling heuristic for sampled expectation values that only
reflects the physical notion of measurement in the case that the observable commutes with Z.
In the case that it does not, the expectation value will still be asymptotically unbiased
but not representative of an experiment on a real quantum device.
Args:
hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z').
Expand All @@ -95,34 +99,33 @@ def get_densitytensor_to_sampled_expectation_func(
hermitian_seq_seq, qubits_seq_seq, coefficients
)

for hermitian_seq in hermitian_seq_seq:
for h in hermitian_seq:
check_hermitian(h, check_z_commutes=True)

def densitytensor_to_sampled_expectation_func(
statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int
densitytensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int
) -> float:
"""
Maps statetensor to sampled expected value.
Maps densitytensor to sampled expected value.
Args:
statetensor: Input statetensor.
densitytensor: Input densitytensor.
random_key: JAX random key
n_samps: Number of samples contributing to sampled expectation.
Returns:
Sampled expected value (float).
"""
sampled_integers = sample_integers(random_key, statetensor, n_samps)
sampled_probs = fori_loop(
0,
n_samps,
lambda i, sv: sv.at[sampled_integers[i]].add(1),
jnp.zeros(statetensor.size),
)

sampled_probs /= n_samps
sampled_dt = statetensor_to_densitytensor(
jnp.sqrt(sampled_probs).reshape(statetensor.shape)
n_qubits = densitytensor.ndim // 2
dm = densitytensor.reshape((2**n_qubits, 2**n_qubits))
measure_probs = jnp.diag(dm).real
sampled_probs = sample_probs(measure_probs, random_key, n_samps)
iweights = jnp.sqrt(sampled_probs / measure_probs)
return densitytensor_to_expectation_func(
densitytensor * jnp.outer(iweights, iweights).reshape(densitytensor.shape)
)
return densitytensor_to_expectation_func(sampled_dt)

return densitytensor_to_sampled_expectation_func

Expand Down
70 changes: 53 additions & 17 deletions qujax/statetensor_observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.lax import fori_loop

from qujax.statetensor import apply_gate
from qujax.utils import check_hermitian, paulis, sample_integers
from qujax.utils import check_hermitian, paulis


def statetensor_to_single_expectation(
Expand Down Expand Up @@ -89,12 +89,12 @@ def _get_tensor_to_expectation_func(

hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq]

def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float:
def tensor_to_expectation_func(tensor: jnp.ndarray) -> float:
"""
Maps statetensor to expected value.
Maps tensor to expected value.
Args:
statetensor: Input statetensor.
tensor: Input tensor.
Returns:
Expected value (float).
Expand All @@ -103,10 +103,10 @@ def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float:
for hermitian, qubit_inds, coeff in zip(
hermitian_tensors, qubits_seq_seq, coefficients
):
out += coeff * contraction_function(statetensor, hermitian, qubit_inds)
out += coeff * contraction_function(tensor, hermitian, qubit_inds)
return out

return statetensor_to_expectation_func
return tensor_to_expectation_func


def get_statetensor_to_expectation_func(
Expand Down Expand Up @@ -149,6 +149,15 @@ def get_statetensor_to_sampled_expectation_func(
Converts strings (or arrays) representing Hermitian matrices, qubit indices and
coefficients into a function that converts a statetensor into a sampled expected value.
On a quantum device, measurements are always taken in the computational basis, as such
sampled expectation values should be taken with respect to an observable that commutes
with the Pauli Z - a warning will be raised if it does not.
qujax applies an importance sampling heuristic for sampled expectation values that only
reflects the physical notion of measurement in the case that the observable commutes with Z.
In the case that it does not, the expectation value will still be asymptotically unbiased
but not representative of an experiment on a real quantum device.
Args:
hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z').
Expand All @@ -165,6 +174,10 @@ def get_statetensor_to_sampled_expectation_func(
hermitian_seq_seq, qubits_seq_seq, coefficients
)

for hermitian_seq in hermitian_seq_seq:
for h in hermitian_seq:
check_hermitian(h, check_z_commutes=True)

def statetensor_to_sampled_expectation_func(
statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int
) -> float:
Expand All @@ -179,16 +192,39 @@ def statetensor_to_sampled_expectation_func(
Returns:
Sampled expected value (float).
"""
sampled_integers = sample_integers(random_key, statetensor, n_samps)
sampled_probs = fori_loop(
0,
n_samps,
lambda i, sv: sv.at[sampled_integers[i]].add(1),
jnp.zeros(statetensor.size),
)

sampled_probs /= n_samps
sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape)
return statetensor_to_expectation_func(sampled_st)
measure_probs = jnp.abs(statetensor) ** 2
sampled_probs = sample_probs(measure_probs, random_key, n_samps)
iweights = jnp.sqrt(sampled_probs / measure_probs)
return statetensor_to_expectation_func(statetensor * iweights)

return statetensor_to_sampled_expectation_func


def sample_probs(
measure_probs: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int
):
"""
Generate an empirical distribution from a probability distribution.
Args:
measure_probs: Probability distribution.
random_key: JAX random key
n_samps: Number of samples contributing to empirical distribution.
Returns:
Empirical distribution (jnp.ndarray).
"""
measure_probs_flat = measure_probs.flatten()
sampled_integers = random.choice(
random_key,
a=jnp.arange(measure_probs.size),
shape=(n_samps,),
p=measure_probs_flat,
)
sampled_probs = fori_loop(
0,
n_samps,
lambda i, sv: sv.at[sampled_integers[i]].add(1 / n_samps),
jnp.zeros_like(measure_probs_flat),
)
return sampled_probs.reshape(measure_probs.shape)
19 changes: 18 additions & 1 deletion qujax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections.abc
from inspect import signature
from typing import Callable, Iterable, List, Optional, Protocol, Sequence, Tuple, Union
from warnings import warn

from jax import numpy as jnp
from jax import random
Expand Down Expand Up @@ -69,12 +70,13 @@ def check_unitary(gate: Gate):
raise TypeError(f"Gate not unitary: {gate}")


def check_hermitian(hermitian: Union[str, jnp.ndarray]):
def check_hermitian(hermitian: Union[str, jnp.ndarray], check_z_commutes: bool = False):
"""
Checks whether a matrix or tensor is Hermitian.
Args:
hermitian: array containing potentially Hermitian matrix or tensor
check_z_commutes: boolean on whether to check if the matrix commutes with Z
"""
if isinstance(hermitian, str):
Expand All @@ -83,12 +85,27 @@ def check_hermitian(hermitian: Union[str, jnp.ndarray]):
f"qujax only accepts {tuple(paulis.keys())} as Hermitian strings,"
"received: {hermitian}"
)
n_qubits = 1
hermitian_mat = paulis[hermitian]

else:
n_qubits = hermitian.ndim // 2
hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits)
if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()):
raise TypeError(f"Array not Hermitian: {hermitian}")

if check_z_commutes:
big_z = jnp.diag(jnp.where(jnp.arange(2**n_qubits) % 2 == 0, 1, -1))
z_commutes = jnp.allclose(hermitian_mat @ big_z, big_z @ hermitian_mat)
if not z_commutes:
warn(
"Hermitian matrix does not commute with Z. \n"
"For sampled expectation values, this may lead to unexpected results, "
"measurements on a quantum device are always taken in the computational basis. "
"Additional gates can be applied in the circuit to change the basis such "
"that an observable that commutes with Z can be measured."
)


def _arrayify_inds(
param_inds_seq: Sequence[Union[None, Sequence[int]]]
Expand Down
2 changes: 1 addition & 1 deletion qujax/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.3"
__version__ = "0.3.4"
Loading

0 comments on commit 55324b8

Please sign in to comment.