diff --git a/qujax/densitytensor_observable.py b/qujax/densitytensor_observable.py index a28179a..9e0ca62 100644 --- a/qujax/densitytensor_observable.py +++ b/qujax/densitytensor_observable.py @@ -4,15 +4,10 @@ from jax import numpy as jnp from jax import random -from jax.lax import fori_loop from qujax.densitytensor import _kraus_single, partial_trace -from qujax.statetensor_observable import _get_tensor_to_expectation_func -from qujax.utils import ( - bitstrings_to_integers, - sample_integers, - statetensor_to_densitytensor, -) +from qujax.statetensor_observable import _get_tensor_to_expectation_func, sample_probs +from qujax.utils import bitstrings_to_integers, check_hermitian def densitytensor_to_single_expectation( @@ -79,6 +74,15 @@ def get_densitytensor_to_sampled_expectation_func( Converts strings (or arrays) representing Hermitian matrices, qubit indices and coefficients into a function that converts a densitytensor into a sampled expected value. + On a quantum device, measurements are always taken in the computational basis, as such + sampled expectation values should be taken with respect to an observable that commutes + with the Pauli Z - a warning will be raised if it does not. + + qujax applies an importance sampling heuristic for sampled expectation values that only + reflects the physical notion of measurement in the case that the observable commutes with Z. + In the case that it does not, the expectation value will still be asymptotically unbiased + but not representative of an experiment on a real quantum device. + Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). @@ -95,14 +99,18 @@ def get_densitytensor_to_sampled_expectation_func( hermitian_seq_seq, qubits_seq_seq, coefficients ) + for hermitian_seq in hermitian_seq_seq: + for h in hermitian_seq: + check_hermitian(h, check_z_commutes=True) + def densitytensor_to_sampled_expectation_func( - statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int + densitytensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int ) -> float: """ - Maps statetensor to sampled expected value. + Maps densitytensor to sampled expected value. Args: - statetensor: Input statetensor. + densitytensor: Input densitytensor. random_key: JAX random key n_samps: Number of samples contributing to sampled expectation. @@ -110,19 +118,14 @@ def densitytensor_to_sampled_expectation_func( Sampled expected value (float). """ - sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop( - 0, - n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size), - ) - - sampled_probs /= n_samps - sampled_dt = statetensor_to_densitytensor( - jnp.sqrt(sampled_probs).reshape(statetensor.shape) + n_qubits = densitytensor.ndim // 2 + dm = densitytensor.reshape((2**n_qubits, 2**n_qubits)) + measure_probs = jnp.diag(dm).real + sampled_probs = sample_probs(measure_probs, random_key, n_samps) + iweights = jnp.sqrt(sampled_probs / measure_probs) + return densitytensor_to_expectation_func( + densitytensor * jnp.outer(iweights, iweights).reshape(densitytensor.shape) ) - return densitytensor_to_expectation_func(sampled_dt) return densitytensor_to_sampled_expectation_func diff --git a/qujax/statetensor_observable.py b/qujax/statetensor_observable.py index 44ada59..4bce068 100644 --- a/qujax/statetensor_observable.py +++ b/qujax/statetensor_observable.py @@ -7,7 +7,7 @@ from jax.lax import fori_loop from qujax.statetensor import apply_gate -from qujax.utils import check_hermitian, paulis, sample_integers +from qujax.utils import check_hermitian, paulis def statetensor_to_single_expectation( @@ -89,12 +89,12 @@ def _get_tensor_to_expectation_func( hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] - def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: + def tensor_to_expectation_func(tensor: jnp.ndarray) -> float: """ - Maps statetensor to expected value. + Maps tensor to expected value. Args: - statetensor: Input statetensor. + tensor: Input tensor. Returns: Expected value (float). @@ -103,10 +103,10 @@ def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: for hermitian, qubit_inds, coeff in zip( hermitian_tensors, qubits_seq_seq, coefficients ): - out += coeff * contraction_function(statetensor, hermitian, qubit_inds) + out += coeff * contraction_function(tensor, hermitian, qubit_inds) return out - return statetensor_to_expectation_func + return tensor_to_expectation_func def get_statetensor_to_expectation_func( @@ -149,6 +149,15 @@ def get_statetensor_to_sampled_expectation_func( Converts strings (or arrays) representing Hermitian matrices, qubit indices and coefficients into a function that converts a statetensor into a sampled expected value. + On a quantum device, measurements are always taken in the computational basis, as such + sampled expectation values should be taken with respect to an observable that commutes + with the Pauli Z - a warning will be raised if it does not. + + qujax applies an importance sampling heuristic for sampled expectation values that only + reflects the physical notion of measurement in the case that the observable commutes with Z. + In the case that it does not, the expectation value will still be asymptotically unbiased + but not representative of an experiment on a real quantum device. + Args: hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). @@ -165,6 +174,10 @@ def get_statetensor_to_sampled_expectation_func( hermitian_seq_seq, qubits_seq_seq, coefficients ) + for hermitian_seq in hermitian_seq_seq: + for h in hermitian_seq: + check_hermitian(h, check_z_commutes=True) + def statetensor_to_sampled_expectation_func( statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int ) -> float: @@ -179,16 +192,39 @@ def statetensor_to_sampled_expectation_func( Returns: Sampled expected value (float). """ - sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop( - 0, - n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size), - ) - - sampled_probs /= n_samps - sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape) - return statetensor_to_expectation_func(sampled_st) + measure_probs = jnp.abs(statetensor) ** 2 + sampled_probs = sample_probs(measure_probs, random_key, n_samps) + iweights = jnp.sqrt(sampled_probs / measure_probs) + return statetensor_to_expectation_func(statetensor * iweights) return statetensor_to_sampled_expectation_func + + +def sample_probs( + measure_probs: jnp.ndarray, random_key: random.PRNGKeyArray, n_samps: int +): + """ + Generate an empirical distribution from a probability distribution. + + Args: + measure_probs: Probability distribution. + random_key: JAX random key + n_samps: Number of samples contributing to empirical distribution. + + Returns: + Empirical distribution (jnp.ndarray). + """ + measure_probs_flat = measure_probs.flatten() + sampled_integers = random.choice( + random_key, + a=jnp.arange(measure_probs.size), + shape=(n_samps,), + p=measure_probs_flat, + ) + sampled_probs = fori_loop( + 0, + n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1 / n_samps), + jnp.zeros_like(measure_probs_flat), + ) + return sampled_probs.reshape(measure_probs.shape) diff --git a/qujax/utils.py b/qujax/utils.py index 36b34c4..459a88a 100644 --- a/qujax/utils.py +++ b/qujax/utils.py @@ -3,6 +3,7 @@ import collections.abc from inspect import signature from typing import Callable, Iterable, List, Optional, Protocol, Sequence, Tuple, Union +from warnings import warn from jax import numpy as jnp from jax import random @@ -69,12 +70,13 @@ def check_unitary(gate: Gate): raise TypeError(f"Gate not unitary: {gate}") -def check_hermitian(hermitian: Union[str, jnp.ndarray]): +def check_hermitian(hermitian: Union[str, jnp.ndarray], check_z_commutes: bool = False): """ Checks whether a matrix or tensor is Hermitian. Args: hermitian: array containing potentially Hermitian matrix or tensor + check_z_commutes: boolean on whether to check if the matrix commutes with Z """ if isinstance(hermitian, str): @@ -83,12 +85,27 @@ def check_hermitian(hermitian: Union[str, jnp.ndarray]): f"qujax only accepts {tuple(paulis.keys())} as Hermitian strings," "received: {hermitian}" ) + n_qubits = 1 + hermitian_mat = paulis[hermitian] + else: n_qubits = hermitian.ndim // 2 hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): raise TypeError(f"Array not Hermitian: {hermitian}") + if check_z_commutes: + big_z = jnp.diag(jnp.where(jnp.arange(2**n_qubits) % 2 == 0, 1, -1)) + z_commutes = jnp.allclose(hermitian_mat @ big_z, big_z @ hermitian_mat) + if not z_commutes: + warn( + "Hermitian matrix does not commute with Z. \n" + "For sampled expectation values, this may lead to unexpected results, " + "measurements on a quantum device are always taken in the computational basis. " + "Additional gates can be applied in the circuit to change the basis such " + "that an observable that commutes with Z can be measured." + ) + def _arrayify_inds( param_inds_seq: Sequence[Union[None, Sequence[int]]] diff --git a/qujax/version.py b/qujax/version.py index e19434e..334b899 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = "0.3.3" +__version__ = "0.3.4" diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 3581219..b77cdd9 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -22,8 +22,8 @@ def test_single_expectation(): dt2 = qujax.statetensor_to_densitytensor(st2) ZZ = jnp.kron(Z, Z).reshape(2, 2, 2, 2) - est1 = qujax.statetensor_to_single_expectation(dt1, ZZ, [0, 1]) - est2 = qujax.statetensor_to_single_expectation(dt2, ZZ, [0, 1]) + est1 = qujax.statetensor_to_single_expectation(st1, ZZ, [0, 1]) + est2 = qujax.statetensor_to_single_expectation(st2, ZZ, [0, 1]) edt1 = qujax.densitytensor_to_single_expectation(dt1, ZZ, [0, 1]) edt2 = qujax.densitytensor_to_single_expectation(dt2, ZZ, [0, 1]) @@ -91,29 +91,28 @@ def brute_force_param_to_exp(p): assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5) -def test_ZZ_Y(): - config.update("jax_enable_x64", True) # Run this test with 64 bit precision +def _test_hermitian_observable( + hermitian_str_seq_seq, qubit_inds_seq, coefs, st_in=None +): + n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - n_qubits = 4 + if st_in is None: + state = ( + random.uniform(random.PRNGKey(2), shape=(2**n_qubits,)) * 2 + + 1.0j * random.uniform(random.PRNGKey(1), shape=(2**n_qubits,)) * 2 + ) + state /= jnp.linalg.norm(state) + st_in = state.reshape((2,) * n_qubits) - hermitian_str_seq_seq = [["Z", "Z"]] * (n_qubits - 1) + [["Y"]] * n_qubits - coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) + dt_in = qujax.statetensor_to_densitytensor(st_in) - qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [ - [i] for i in range(n_qubits) - ] st_to_exp = qujax.get_statetensor_to_expectation_func( hermitian_str_seq_seq, qubit_inds_seq, coefs ) - dt_to_exp = qujax.get_statetensor_to_expectation_func( + dt_to_exp = qujax.get_densitytensor_to_expectation_func( hermitian_str_seq_seq, qubit_inds_seq, coefs ) - state = random.uniform(random.PRNGKey(0), shape=(2**n_qubits,)) * 2 - state /= jnp.linalg.norm(state) - st_in = state.reshape((2,) * n_qubits) - dt_in = qujax.statetensor_to_densitytensor(st_in) - def big_hermitian_matrix(hermitian_str_seq, qubit_inds): qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq] hermitian_arrs = [] @@ -139,7 +138,7 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): assert jnp.allclose(sum_big_hs, sum_big_hs.conj().T) sv = st_in.flatten() - true_exp = jnp.dot(sv, sum_big_hs @ sv.conj()).real + true_exp = jnp.dot(sv.conj(), sum_big_hs @ sv).real qujax_exp = st_to_exp(st_in) qujax_dt_exp = dt_to_exp(dt_in) @@ -156,16 +155,16 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func( hermitian_str_seq_seq, qubit_inds_seq, coefs ) - dt_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func( + dt_to_samp_exp = qujax.get_densitytensor_to_sampled_expectation_func( hermitian_str_seq_seq, qubit_inds_seq, coefs ) qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)( - st_in, random.PRNGKey(2), 1000000 + st_in, random.PRNGKey(1), 1000000 ) - qujax_samp_exp_dt = dt_to_samp_exp(st_in, random.PRNGKey(1), 1000000) + qujax_samp_exp_dt = dt_to_samp_exp(dt_in, random.PRNGKey(1), 1000000) qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)( - st_in, random.PRNGKey(2), 1000000 + dt_in, random.PRNGKey(1), 1000000 ) assert jnp.array(qujax_samp_exp).shape == () assert jnp.array(qujax_samp_exp).dtype.name[:5] == "float" @@ -175,6 +174,63 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): assert jnp.isclose(true_exp, qujax_samp_exp_dt_jit, rtol=1e-2) +def test_X(): + hermitian_str_seq_seq = ["X"] + qubit_inds_seq = [[0]] + coefs = [1] + + gates = ["H", "Rz"] + qubit = [[0], [0]] + param_ind = [[], [0]] + st_in = qujax.get_params_to_statetensor_func(gates, qubit, param_ind)(0.3) + + _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs, st_in) + + +def test_Y(): + n_qubits = 1 + + hermitian_str_seq_seq = ["Y"] * n_qubits + qubit_inds_seq = [[i] for i in range(n_qubits)] + coefs = jnp.ones(len(hermitian_str_seq_seq)) + + _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) + + +def test_Z(): + n_qubits = 1 + + hermitian_str_seq_seq = ["Z"] * n_qubits + qubit_inds_seq = [[i] for i in range(n_qubits)] + coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) + + _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) + + +def test_XYZ(): + n_qubits = 1 + + hermitian_str_seq_seq = ["X", "Y", "Z"] * n_qubits + qubit_inds_seq = [[i] for _ in range(3) for i in range(n_qubits)] + coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) + + _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) + + +def test_ZZ_Y(): + config.update("jax_enable_x64", True) # Run this test with 64 bit precision + + n_qubits = 4 + + hermitian_str_seq_seq = [["Z", "Z"]] * (n_qubits - 1) + [["Y"]] * n_qubits + qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [ + [i] for i in range(n_qubits) + ] + coefs = random.normal(random.PRNGKey(1), shape=(len(hermitian_str_seq_seq),)) + + _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) + + def test_sampling(): target_pmf = jnp.array([0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]) target_pmf /= target_pmf.sum()