Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove quantum dialect components at the end of quantum-to-ion lowering #1482

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions mlir/include/Ion/IR/IonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ def PulseOp : Ion_Op<"pulse"> {

let arguments = (ins
AnyFloat: $time,
QubitType: $in_qubit,
AnyTypeOf<[IonType, QubitType]>: $in_qubit,
BeamAttr: $beam,
Builtin_FloatAttr: $phase
);

let assemblyFormat = [{
`(` $time `:` type($time) `)` $in_qubit attr-dict
`(` $time `:` type($time) `)` $in_qubit `:` type($in_qubit) attr-dict
}];
}

Expand All @@ -172,10 +172,10 @@ def ParallelProtocolOp : Ion_Op<"parallelprotocol", [SingleBlockImplicitTerminat
let summary = "Represent a parallel protocol of pulses.";

let arguments = (ins
Variadic<QubitType>: $in_qubits
Variadic<AnyTypeOf<[IonType, QubitType]>>: $in_qubits
);

let results = (outs Variadic<QubitType>:$out_qubits);
let results = (outs Variadic<AnyTypeOf<[IonType, QubitType]>>:$out_qubits);
let regions = (region SizedRegion<1>:$region);

let builders = [
Expand All @@ -192,15 +192,15 @@ def ParallelProtocolOp : Ion_Op<"parallelprotocol", [SingleBlockImplicitTerminat
}];

let assemblyFormat = [{
`(` $in_qubits `)` attr-dict `:` type($out_qubits) $region
`(` ($in_qubits^ `:` type($in_qubits))? `)` attr-dict `:` type($out_qubits) $region
}];
}

def YieldOp : Ion_Op<"yield", [Pure, ReturnLike, Terminator, ParentOneOf<["ParallelProtocolOp"]>]> {
let summary = "Return results from parallel protocol regions";

let arguments = (ins
Variadic<QubitType>:$results
Variadic<IonType>:$results
);

let assemblyFormat = [{
Expand Down
3 changes: 0 additions & 3 deletions mlir/include/Ion/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def QuantumToIonPass : Pass<"quantum-to-ion"> {
Option<"Gate2PulseDecompTomlLoc", "gate-to-pulse-toml-loc",
"std::string", /*default=*/"\"\"",
"Toml file location for the ion hardware gate-to-pulse decomposition parameters.">,
Option<"LoadIon", "load-ion",
"bool", /*default=*/"true",
"Whether to load the physical parameters for the ion (e.g. mass, charge, spin) into the IR.">,
];

let dependentDialects = [
Expand Down
17 changes: 14 additions & 3 deletions mlir/lib/Ion/IR/IonOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,26 @@ void ParallelProtocolOp::build(OpBuilder &builder, OperationState &result, Value
{
OpBuilder::InsertionGuard guard(builder);
Location loc = result.location;

Type ionType = IonType::get(result.getContext());

// The parallel protocol op can interact with the outside world by accepting
// either ion types or qubit types.
// We allow qubit types because during `quantum-to-ion`, during gate-to-pulse
// decomposition, we still need the core quantum dialect to track SSA def use
// chains of qubit values.
// After gate-to-pulse decomposition is finished, we change all parallel protocol
// ops to return ion types.
// Note that the body region is shielded from the outside, so it's block can
// have an ion type argument directly
result.addOperands(inQubits);
for (Value v : inQubits)
for (Value v : inQubits) {
result.addTypes(v.getType());
}

Region *bodyRegion = result.addRegion();
Block *bodyBlock = builder.createBlock(bodyRegion);
for (Value v : inQubits) {
bodyBlock->addArgument(v.getType(), v.getLoc());
bodyBlock->addArgument(ionType, v.getLoc());
}

builder.setInsertionPointToStart(bodyBlock);
Expand Down
46 changes: 37 additions & 9 deletions mlir/lib/Ion/Transforms/quantum_to_ion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ struct QuantumToIonPass : impl::QuantumToIonPassBase<QuantumToIonPass> {
void runOnOperation() final
{
func::FuncOp op = cast<func::FuncOp>(getOperation());
// auto module = getOperation();
auto &context = getContext();
ConversionTarget target(context);

Expand All @@ -78,12 +77,16 @@ struct QuantumToIonPass : impl::QuantumToIonPassBase<QuantumToIonPass> {

OQDDatabaseManager dataManager(DeviceTomlLoc, QubitTomlLoc, Gate2PulseDecompTomlLoc);

if (LoadIon) {
// FIXME(?): we only load Yb171 ion since the hardware ion species is unlikely to change
MLIRContext *ctx = op->getContext();
IRRewriter builder(ctx);
Ion ion = dataManager.getIonParams().at("Yb171");
// FIXME(?): we only load Yb171 ion since the hardware ion species is unlikely to change
MLIRContext *ctx = op->getContext();
IRRewriter builder(ctx);
Ion ion = dataManager.getIonParams().at("Yb171");

// First, we need to convert each qubit to an ion
// A qubit is initilized as an extract op from an alloc op in quantum dialect
llvm::DenseMap</*quantum*/ Value, /*ion*/ Value> qubitMap;

op->walk([&](quantum::ExtractOp qExtract) {
SmallVector<Attribute> levels, transitions;
for (const Level &level : ion.levels) {
levels.push_back(cast<Attribute>(getLevelAttr(ctx, builder, level)));
Expand All @@ -92,20 +95,45 @@ struct QuantumToIonPass : impl::QuantumToIonPassBase<QuantumToIonPass> {
transitions.push_back(cast<Attribute>(getTransitionAttr(ctx, builder, transition)));
}

builder.setInsertionPointToStart(&(op->getRegion(0).front()));
builder.create<ion::IonOp>(
builder.setInsertionPointAfter(qExtract);
ion::IonOp ionOp = builder.create<ion::IonOp>(
op->getLoc(), IonType::get(ctx), builder.getStringAttr(ion.name),
builder.getF64FloatAttr(ion.mass), builder.getF64FloatAttr(ion.charge),
builder.getI64VectorAttr(ion.position), builder.getArrayAttr(levels),
builder.getArrayAttr(transitions));
}

qubitMap.insert({qExtract.getQubit(), ionOp.getOutIon()});
});

// Then, we decompose the quantum gates on qubits to pulses on qubits
// We keep the quantum dialect at this stage since we still want the SSA def use
// chains from the quantum dialect.
RewritePatternSet ionPatterns(&getContext());
populateQuantumToIonPatterns(ionPatterns, dataManager);

if (failed(applyPartialConversion(op, target, std::move(ionPatterns)))) {
return signalPassFailure();
}

// Finally, to aid ion api stub generation, we eliminate all quantum dialect
// We replace uses and change all quantum.bit types to ion.ion types
for (auto [qubit, ion] : qubitMap) {
qubit.replaceAllUsesWith(ion);
qubit.getDefiningOp()->erase();
}

Type ionType = IonType::get(ctx);
op->walk([&](ion::ParallelProtocolOp ppOp) {
for (auto v : ppOp->getResults()) {
v.setType(ionType);
}
});

SmallVector<quantum::AllocOp> qAllocOps;
op->walk([&](quantum::AllocOp alloc) { qAllocOps.push_back(alloc); });
for (auto alloc : qAllocOps) {
alloc->erase();
}
}
};

Expand Down
40 changes: 20 additions & 20 deletions mlir/test/Ion/Dialect.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ func.func @example_pulse(%arg0: f64) -> !quantum.bit {
// CHECK: [[q0:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%1 = quantum.extract %0[0] : !quantum.reg -> !quantum.bit

// CHECK: ion.pulse(%arg0 : f64) [[q0]] {beam = #ion.beam<
// CHECK: ion.pulse(%arg0 : f64) [[q0]] : !quantum.bit {beam = #ion.beam<
// CHECK-SAME: transition_index = 0 : i64,
// CHECK-SAME: rabi = 1.010000e+01 : f64,
// CHECK-SAME: detuning = 1.111000e+01 : f64,
// CHECK-SAME: polarization = dense<[0, 1]> : tensor<2xi64>,
// CHECK-SAME: wavevector = dense<[0, 1]> : tensor<2xi64>>,
// CHECK-SAME: phase = 0.000000e+00 : f64}
ion.pulse(%arg0: f64) %1 {
ion.pulse(%arg0: f64) %1 : !quantum.bit {
beam=#ion.beam<
transition_index=0,
rabi=10.10,
Expand All @@ -49,17 +49,17 @@ func.func @example_parallel_protocol(%arg0: f64) -> !quantum.bit {
// CHECK: [[q0:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%1 = quantum.extract %0[0] : !quantum.reg -> !quantum.bit

// CHECK: [[paraproto:%.+]] = ion.parallelprotocol([[q0]]) : !quantum.bit {
%2 = ion.parallelprotocol(%1): !quantum.bit {
^bb0(%arg1: !quantum.bit):
// CHECK: ion.pulse(%arg0 : f64) %arg1 {beam = #ion.beam<
// CHECK: [[paraproto:%.+]] = ion.parallelprotocol([[q0]] : !quantum.bit) : !quantum.bit {
%2 = ion.parallelprotocol(%1 : !quantum.bit): !quantum.bit {
^bb0(%arg1: !ion.ion):
// CHECK: ion.pulse(%arg0 : f64) %arg1 : !ion.ion {beam = #ion.beam<
// CHECK-SAME: transition_index = 1 : i64,
// CHECK-SAME: rabi = 1.010000e+01 : f64,
// CHECK-SAME: detuning = 1.111000e+01 : f64,
// CHECK-SAME: polarization = dense<[0, 1]> : tensor<2xi64>,
// CHECK-SAME: wavevector = dense<[0, 1]> : tensor<2xi64>>,
// CHECK-SAME: phase = 0.000000e+00 : f64}
ion.pulse(%arg0: f64) %arg1 {
ion.pulse(%arg0: f64) %arg1 : !ion.ion {
beam=#ion.beam<
transition_index=1,
rabi=10.10,
Expand All @@ -69,14 +69,14 @@ func.func @example_parallel_protocol(%arg0: f64) -> !quantum.bit {
>,
phase=0.0
}
// CHECK: ion.pulse(%arg0 : f64) %arg1 {beam = #ion.beam<
// CHECK: ion.pulse(%arg0 : f64) %arg1 : !ion.ion {beam = #ion.beam<
// CHECK-SAME: transition_index = 0 : i64,
// CHECK-SAME: rabi = 1.010000e+01 : f64,
// CHECK-SAME: detuning = 1.111000e+01 : f64,
// CHECK-SAME: polarization = dense<[0, 1]> : tensor<2xi64>,
// CHECK-SAME: wavevector = dense<[0, 1]> : tensor<2xi64>>,
// CHECK-SAME: phase = 0.000000e+00 : f64}
ion.pulse(%arg0: f64) %arg1 {
ion.pulse(%arg0: f64) %arg1 : !ion.ion {
beam=#ion.beam<
transition_index=0,
rabi=10.10,
Expand All @@ -86,8 +86,8 @@ func.func @example_parallel_protocol(%arg0: f64) -> !quantum.bit {
>,
phase=0.0
}
// CHECK: ion.yield %arg1 : !quantum.bit
ion.yield %arg1: !quantum.bit
// CHECK: ion.yield %arg1 : !ion.ion
ion.yield %arg1: !ion.ion
}

// CHECK: return [[paraproto]] : !quantum.bit
Expand All @@ -103,17 +103,17 @@ func.func @example_parallel_protocol_two_qubits(%arg0: f64) -> (!quantum.bit, !q
// CHECK: [[q1:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %0[1] : !quantum.reg -> !quantum.bit

// CHECK: [[paraproto:%.+]]{{:2}} = ion.parallelprotocol([[q0]], [[q1]]) : !quantum.bit, !quantum.bit {
%3:2 = ion.parallelprotocol(%1, %2): !quantum.bit, !quantum.bit {
^bb0(%arg1: !quantum.bit, %arg2: !quantum.bit):
// CHECK: ion.pulse(%arg0 : f64) %arg1 {beam = #ion.beam<
// CHECK: [[paraproto:%.+]]{{:2}} = ion.parallelprotocol([[q0]], [[q1]] : !quantum.bit, !quantum.bit) : !quantum.bit, !quantum.bit {
%3:2 = ion.parallelprotocol(%1, %2 : !quantum.bit, !quantum.bit): !quantum.bit, !quantum.bit {
^bb0(%arg1: !ion.ion, %arg2: !ion.ion):
// CHECK: ion.pulse(%arg0 : f64) %arg1 : !ion.ion {beam = #ion.beam<
// CHECK-SAME: transition_index = 2 : i64,
// CHECK-SAME: rabi = 1.010000e+01 : f64,
// CHECK-SAME: detuning = 1.111000e+01 : f64,
// CHECK-SAME: polarization = dense<[0, 1]> : tensor<2xi64>,
// CHECK-SAME: wavevector = dense<[0, 1]> : tensor<2xi64>>,
// CHECK-SAME: phase = 0.000000e+00 : f64}
ion.pulse(%arg0: f64) %arg1 {
ion.pulse(%arg0: f64) %arg1 : !ion.ion {
beam=#ion.beam<
transition_index=2,
rabi=10.10,
Expand All @@ -123,14 +123,14 @@ func.func @example_parallel_protocol_two_qubits(%arg0: f64) -> (!quantum.bit, !q
>,
phase=0.0
}
// CHECK: ion.pulse(%arg0 : f64) %arg2 {beam = #ion.beam<
// CHECK: ion.pulse(%arg0 : f64) %arg2 : !ion.ion {beam = #ion.beam<
// CHECK-SAME: transition_index = 1 : i64,
// CHECK-SAME: rabi = 1.010000e+01 : f64,
// CHECK-SAME: detuning = 1.111000e+01 : f64,
// CHECK-SAME: polarization = dense<[0, 1]> : tensor<2xi64>,
// CHECK-SAME: wavevector = dense<[0, 1]> : tensor<2xi64>>,
// CHECK-SAME: phase = 0.000000e+00 : f64}
ion.pulse(%arg0: f64) %arg2 {
ion.pulse(%arg0: f64) %arg2 : !ion.ion {
beam=#ion.beam<
transition_index=1,
rabi=10.10,
Expand All @@ -140,8 +140,8 @@ func.func @example_parallel_protocol_two_qubits(%arg0: f64) -> (!quantum.bit, !q
>,
phase=0.0
}
// CHECK: ion.yield %arg1, %arg2 : !quantum.bit, !quantum.bit
ion.yield %arg1, %arg2: !quantum.bit, !quantum.bit
// CHECK: ion.yield %arg1, %arg2 : !ion.ion, !ion.ion
ion.yield %arg1, %arg2: !ion.ion, !ion.ion
}

// CHECK: return [[paraproto]]#0, [[paraproto]]#1 : !quantum.bit, !quantum.bit
Expand Down
Loading