Skip to content

Commit

Permalink
[CIR][CIRGen] Complex unary increment and decrement operator (llvm#790)
Browse files Browse the repository at this point in the history
This PR adds CIRGen and LLVMIR lowering for unary increment and
decrement expressions of complex types.

Currently blocked by llvm#789 .
  • Loading branch information
Lancern authored and smeenai committed Oct 9, 2024
1 parent bb91807 commit 1dcbd27
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 8 deletions.
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
LValue LV = buildLValue(E->getSubExpr());

if (E->getType()->isAnyComplexType()) {
assert(0 && "not implemented");
buildComplexPrePostIncDec(E, LV, isInc, true /*isPre*/);
} else {
buildScalarPrePostIncDec(E, LV, isInc, isPre);
}
Expand Down
34 changes: 31 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {

// Operators.
mlir::Value VisitPrePostIncDec(const UnaryOperator *E, bool isInc,
bool isPre) {
llvm_unreachable("NYI");
}
bool isPre);
mlir::Value VisitUnaryPostDec(const UnaryOperator *E) {
return VisitPrePostIncDec(E, false, false);
}
Expand Down Expand Up @@ -524,6 +522,12 @@ mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *E) {
return CGF.buildCallExpr(E).getComplexVal();
}

mlir::Value ComplexExprEmitter::VisitPrePostIncDec(const UnaryOperator *E,
bool isInc, bool isPre) {
LValue LV = CGF.buildLValue(E->getSubExpr());
return CGF.buildComplexPrePostIncDec(E, LV, isInc, isPre);
}

mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *E,
QualType PromotionType) {
QualType promotionTy = PromotionType.isNull()
Expand Down Expand Up @@ -956,3 +960,27 @@ LValue CIRGenFunction::buildComplexCompoundAssignmentLValue(
RValue Val;
return ComplexExprEmitter(*this).buildCompoundAssignLValue(E, Op, Val);
}

mlir::Value CIRGenFunction::buildComplexPrePostIncDec(const UnaryOperator *E,
LValue LV, bool isInc,
bool isPre) {
mlir::Value InVal = buildLoadOfComplex(LV, E->getExprLoc());

auto Loc = getLoc(E->getExprLoc());
auto OpKind =
isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec;
mlir::Value IncVal = builder.createUnaryOp(Loc, OpKind, InVal);

// Store the updated result through the lvalue.
buildStoreOfComplex(Loc, IncVal, LV, /*init*/ false);
if (getLangOpts().OpenMP)
llvm_unreachable("NYI");

// If this is a postinc, return the value read from memory, otherwise use the
// updated value.
return isPre ? IncVal : InVal;
}

mlir::Value CIRGenFunction::buildLoadOfComplex(LValue src, SourceLocation loc) {
return ComplexExprEmitter(*this).buildLoadOfLValue(src, loc);
}
5 changes: 5 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::Value buildScalarPrePostIncDec(const UnaryOperator *E, LValue LV,
bool isInc, bool isPre);
mlir::Value buildComplexPrePostIncDec(const UnaryOperator *E, LValue LV,
bool isInc, bool isPre);

// Wrapper for function prototype sources. Wraps either a FunctionProtoType or
// an ObjCMethodDecl.
Expand Down Expand Up @@ -799,6 +801,9 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Value buildLoadOfScalar(LValue lvalue, clang::SourceLocation Loc);
mlir::Value buildLoadOfScalar(LValue lvalue, mlir::Location Loc);

/// Load a complex number from the specified l-value.
mlir::Value buildLoadOfComplex(LValue src, SourceLocation loc);

Address buildLoadOfReference(LValue RefLVal, mlir::Location Loc,
LValueBaseInfo *PointeeBaseInfo = nullptr);
LValue buildLoadOfReferenceLValue(LValue RefLVal, mlir::Location Loc);
Expand Down
10 changes: 6 additions & 4 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {

auto loc = op.getLoc();
auto opKind = op.getKind();
assert((opKind == mlir::cir::UnaryOpKind::Plus ||
opKind == mlir::cir::UnaryOpKind::Minus ||
opKind == mlir::cir::UnaryOpKind::Not) &&
"invalid unary op kind on complex numbers");

CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
Expand All @@ -372,6 +368,12 @@ void LoweringPreparePass::lowerUnaryOp(UnaryOp op) {
mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case mlir::cir::UnaryOpKind::Inc:
case mlir::cir::UnaryOpKind::Dec:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = operandImag;
break;

case mlir::cir::UnaryOpKind::Plus:
case mlir::cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
Expand Down
140 changes: 140 additions & 0 deletions clang/test/CIR/CodeGen/complex-arithmetic.c
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,143 @@ void builtin_conj() {
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#RESI]], 1

// CHECK: }

void pre_increment() {
++cd1;
++ci1;
}

// CLANG: @pre_increment
// CPPLANG: @_Z13pre_incrementv

// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void post_increment() {
cd1++;
ci1++;
}

// CLANG: @post_increment
// CPPLANG: @_Z14post_incrementv

// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(inc, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(inc, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double 1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = add i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void pre_decrement() {
--cd1;
--ci1;
}

// CLANG: @pre_decrement
// CPPLANG: @_Z13pre_decrementv

// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

void post_decrement() {
cd1--;
ci1--;
}

// CLANG: @post_decrement
// CPPLANG: @_Z14post_decrementv

// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!cir.double>, !cir.complex<!cir.double>
// CIRGEN: %{{.+}} = cir.unary(dec, %{{.+}}) : !cir.complex<!s32i>, !cir.complex<!s32i>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!cir.double> -> !cir.double
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !cir.double, !cir.double
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !cir.double -> !cir.complex<!cir.double>

// CIR: %[[#R:]] = cir.complex.real %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#I:]] = cir.complex.imag %{{.+}} : !cir.complex<!s32i> -> !s32i
// CIR-NEXT: %[[#IR:]] = cir.unary(dec, %[[#R]]) : !s32i, !s32i
// CIR-NEXT: %{{.+}} = cir.complex.create %[[#IR]], %[[#I]] : !s32i -> !cir.complex<!s32i>

// LLVM: %[[#R:]] = extractvalue { double, double } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { double, double } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = fadd double -1.000000e+00, %[[#R]]
// LLVM-NEXT: %[[#A:]] = insertvalue { double, double } undef, double %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { double, double } %[[#A]], double %[[#I]], 1

// LLVM: %[[#R:]] = extractvalue { i32, i32 } %{{.+}}, 0
// LLVM-NEXT: %[[#I:]] = extractvalue { i32, i32 } %{{.+}}, 1
// LLVM-NEXT: %[[#IR:]] = sub i32 %[[#R]], 1
// LLVM-NEXT: %[[#A:]] = insertvalue { i32, i32 } undef, i32 %[[#IR]], 0
// LLVM-NEXT: %{{.+}} = insertvalue { i32, i32 } %[[#A]], i32 %[[#I]], 1

// CHECK: }

0 comments on commit 1dcbd27

Please sign in to comment.