diff --git a/circuit_knitting/cutting/cutting_experiments.py b/circuit_knitting/cutting/cutting_experiments.py index ae434d2a2..4411d3687 100644 --- a/circuit_knitting/cutting/cutting_experiments.py +++ b/circuit_knitting/cutting/cutting_experiments.py @@ -19,12 +19,8 @@ import numpy as np from qiskit.circuit import QuantumCircuit, ClassicalRegister from qiskit.quantum_info import PauliList -from qiskit.transpiler import PassManager -from qiskit.transpiler.passes import RemoveResetInZeroState, DAGFixedPoint -from qiskit.passmanager.flow_controllers import DoWhileController from ..utils.iteration import strict_zip -from ..utils.transpiler_passes import RemoveFinalReset, ConsolidateResets from ..utils.observable_grouping import ObservableCollection, CommutingObservableGroup from .qpd import ( WeightType, @@ -62,12 +58,6 @@ def generate_cutting_experiments( The coefficients will always be returned as a 1D array -- one coefficient for each unique sample. - Note that this function also runs some transpiler passes on each generated - circuit, namely :class:`~qiskit.transpiler.passes.RemoveResetInZeroState`, - :class:`.RemoveFinalReset`, and :class:`.ConsolidateResets`, in order to - remove unnecessary :class:`~qiskit.circuit.library.Reset`\ s from the - circuit that are added by the subexperiment decompositions for cut wires. - Args: circuits: The circuit(s) to partition and separate observables: The observable(s) to evaluate for each unique sample @@ -172,20 +162,11 @@ def generate_cutting_experiments( # https://github.com/Qiskit-Extensions/circuit-knitting-toolbox/issues/452. # While we are at it, we also consolidate each run of multiple resets # (which can arise when re-using qubits) into a single reset. - pass_manager = PassManager() - passes = [ - RemoveResetInZeroState(), - RemoveFinalReset(), - ConsolidateResets(), - DAGFixedPoint(), - ] - pass_manager.append( - DoWhileController( - passes, do_while=lambda property_set: not property_set["dag_fixed_point"] - ) - ) - for label, subexperiments in subexperiments_dict.items(): - subexperiments_dict[label] = pass_manager.run(subexperiments) + for subexperiments in subexperiments_dict.values(): + for circ in subexperiments: + _remove_resets_in_zero_state(circ) + _remove_final_resets(circ) + _consolidate_resets(circ) # If the input was a single quantum circuit, return the subexperiments as a list subexperiments_out: list[QuantumCircuit] | dict[Hashable, list[QuantumCircuit]] = ( @@ -389,3 +370,84 @@ def _get_pauli_indices(cog: CommutingObservableGroup) -> list[int]: if not pauli_indices: pauli_indices = [0] return pauli_indices + + +def _consolidate_resets( + circuit: QuantumCircuit, inplace: bool = True +) -> QuantumCircuit: + """Consolidate redundant resets into a single reset.""" + if not inplace: # pragma: no cover + circuit = circuit.copy() + + # Keep up with whether the previous instruction on a given qubit was a reset + resets = [False] * circuit.num_qubits + + # Remove resets which are immediately following other resets + remove_ids = [] + for i, inst in enumerate(circuit.data): + qargs = [circuit.find_bit(q).index for q in inst.qubits] + if inst.operation.name == "reset": + if resets[qargs[0]]: + remove_ids.append(i) + else: + resets[qargs[0]] = True + else: + for q in qargs: + resets[q] = False + + for i in sorted(remove_ids, reverse=True): + del circuit.data[i] + + return circuit + + +def _remove_resets_in_zero_state( + circuit: QuantumCircuit, inplace: bool = True +) -> QuantumCircuit: + """Remove resets if they are the first instruction on a qubit.""" + if not inplace: # pragma: no cover + circuit = circuit.copy() + + # Keep up with which qubits have at least one non-reset instruction + active_qubits = set() + remove_ids = [] + for i, inst in enumerate(circuit.data): + qargs = [circuit.find_bit(q).index for q in inst.qubits] + if inst.operation.name == "reset": + if qargs[0] not in active_qubits: + remove_ids.append(i) + else: + for q in qargs: + active_qubits.add(q) + + for i in sorted(remove_ids, reverse=True): + del circuit.data[i] + + return circuit + + +def _remove_final_resets( + circuit: QuantumCircuit, inplace: bool = True +) -> QuantumCircuit: + """Remove resets if they are the final instruction on a qubit.""" + if not inplace: # pragma: no cover + circuit = circuit.copy() + + # Keep up with whether we are at the end of a qubit + # We iterate in reverse, so all qubits begin in the "end" state + qubit_ended = set(range(circuit.num_qubits)) + remove_ids = [] + num_inst = len(circuit.data) + for i, inst in enumerate(reversed(circuit.data)): + qargs = [circuit.find_bit(q).index for q in inst.qubits] + if inst.operation.name == "reset": + if qargs[0] in qubit_ended: + remove_ids.append(num_inst - 1 - i) + else: + for q in qargs: + qubit_ended.discard(q) + + for i in sorted(remove_ids, reverse=True): + del circuit.data[i] + + return circuit diff --git a/releasenotes/notes/subexperiment-gen-speedup-41a4e8679353d1d9.yaml b/releasenotes/notes/subexperiment-gen-speedup-41a4e8679353d1d9.yaml new file mode 100644 index 000000000..6891d972c --- /dev/null +++ b/releasenotes/notes/subexperiment-gen-speedup-41a4e8679353d1d9.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + The :func:`.generate_cutting_experiments` function has been optimized for faster execution. diff --git a/test/cutting/test_cutting_experiments.py b/test/cutting/test_cutting_experiments.py index 6d435f34a..4da5edca6 100644 --- a/test/cutting/test_cutting_experiments.py +++ b/test/cutting/test_cutting_experiments.py @@ -15,7 +15,7 @@ import pytest import numpy as np from qiskit.quantum_info import PauliList, Pauli -from qiskit.circuit import QuantumCircuit +from qiskit.circuit import QuantumCircuit, QuantumRegister from qiskit.circuit.library.standard_gates import CXGate from circuit_knitting.cutting.qpd import ( @@ -30,6 +30,9 @@ from circuit_knitting.cutting.cutting_experiments import ( _append_measurement_register, _append_measurement_circuit, + _remove_final_resets, + _consolidate_resets, + _remove_resets_in_zero_state, ) @@ -219,3 +222,104 @@ def test_append_measurement_circuit(self): e_info.value.args[0] == "Quantum circuit qubit count (2) does not match qubit count of observable(s) (1). Try providing `qubit_locations` explicitly." ) + + def test_consolidate_double_reset(self): + """Consolidate a pair of resets. + qr0:--|0>--|0>-- ==> qr0:--|0>-- + """ + qr = QuantumRegister(1, "qr") + circuit = QuantumCircuit(qr) + circuit.reset(qr) + circuit.reset(qr) + + expected = QuantumCircuit(qr) + expected.reset(qr) + + _consolidate_resets(circuit) + + self.assertEqual(expected, circuit) + + def test_two_resets(self): + """Remove two final resets + qr0:--[H]-|0>-|0>-- ==> qr0:--[H]-- + """ + qr = QuantumRegister(1, "qr") + circuit = QuantumCircuit(qr) + circuit.h(qr[0]) + circuit.reset(qr[0]) + circuit.reset(qr[0]) + + expected = QuantumCircuit(qr) + expected.h(qr[0]) + + _remove_final_resets(circuit) + + self.assertEqual(expected, circuit) + + def test_optimize_single_reset_in_diff_qubits(self): + """Remove a single final reset in different qubits + qr0:--[H]--|0>-- qr0:--[H]-- + ==> + qr1:--[X]--|0>-- qr1:--[X]---- + """ + qr = QuantumRegister(2, "qr") + circuit = QuantumCircuit(qr) + circuit.h(0) + circuit.x(1) + circuit.reset(qr) + + expected = QuantumCircuit(qr) + expected.h(0) + expected.x(1) + + _remove_final_resets(circuit) + self.assertEqual(expected, circuit) + + def test_optimize_single_reset(self): + """Remove a single final reset + qr0:--[H]--|0>-- ==> qr0:--[H]-- + """ + qr = QuantumRegister(1, "qr") + circuit = QuantumCircuit(qr) + circuit.h(0) + circuit.reset(qr) + + expected = QuantumCircuit(qr) + expected.h(0) + + _remove_final_resets(circuit) + + self.assertEqual(expected, circuit) + + def test_dont_optimize_non_final_reset(self): + """Do not remove reset if not final instruction + qr0:--|0>--[H]-- ==> qr0:--|0>--[H]-- + """ + qr = QuantumRegister(1, "qr") + circuit = QuantumCircuit(qr) + circuit.reset(qr) + circuit.h(qr) + + expected = QuantumCircuit(qr) + expected.reset(qr) + expected.h(qr) + + _remove_final_resets(circuit) + + self.assertEqual(expected, circuit) + + def test_remove_reset_in_zero_state(self): + """Remove reset if first instruction on qubit + qr0:--|0>--[H]-- ==> qr0:--|0>--[H]-- + """ + qr = QuantumRegister(1, "qr") + circuit = QuantumCircuit(qr) + circuit.reset(qr) + circuit.h(qr) + + expected = QuantumCircuit(qr) + expected.h(qr) + + _remove_resets_in_zero_state(circuit) + + self.assertEqual(expected, circuit)