Skip to content

Commit

Permalink
[CIR][CIRGen] Add support for __atomic_add_fetch
Browse files Browse the repository at this point in the history
This introduces CIRGen and LLVM lowering for the first of a bunch
of these atomic operations, incremental work should generelize the
current constructs.
  • Loading branch information
bcardosolopes authored and lanza committed Apr 3, 2024
1 parent d18777b commit e58a597
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 35 deletions.
15 changes: 14 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace impl {
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
LogicalResult verifySameFirstOperandAndResultType(Operation *op);
LogicalResult verifySameSecondOperandAndResultType(Operation *op);
LogicalResult verifySameFirstSecondOperandAndResultType(Operation *op);
} // namespace impl

Expand All @@ -59,7 +60,19 @@ class SameFirstOperandAndResultType
};

/// This class provides verification for ops that are known to have the same
/// first operand and result type.
/// second operand and result type.
///
template <typename ConcreteType>
class SameSecondOperandAndResultType
: public TraitBase<ConcreteType, SameSecondOperandAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameSecondOperandAndResultType(op);
}
};

/// This class provides verification for ops that are known to have the same
/// first, second operand and result type.
///
template <typename ConcreteType>
class SameFirstSecondOperandAndResultType
Expand Down
58 changes: 52 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ include "mlir/IR/SymbolInterfaces.td"
class CIR_Op<string mnemonic, list<Trait> traits = []> :
Op<CIR_Dialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// CIR Op Traits
//===----------------------------------------------------------------------===//

def SameFirstOperandAndResultType :
NativeOpTrait<"SameFirstOperandAndResultType">;
def SameSecondOperandAndResultType :
NativeOpTrait<"SameSecondOperandAndResultType">;
def SameFirstSecondOperandAndResultType :
NativeOpTrait<"SameFirstSecondOperandAndResultType">;

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -109,6 +120,7 @@ def CastOp : CIR_Op<"cast", [Pure]> {

// The input and output types should match the cast kind.
let hasVerifier = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -183,9 +195,6 @@ def PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
// PtrStrideOp
//===----------------------------------------------------------------------===//

def SameFirstOperandAndResultType :
NativeOpTrait<"SameFirstOperandAndResultType">;

def PtrStrideOp : CIR_Op<"ptr_stride",
[Pure, SameFirstOperandAndResultType]> {
let summary = "Pointer access with stride";
Expand Down Expand Up @@ -2933,9 +2942,6 @@ def MemChrOp : CIR_Op<"libc.memchr"> {
// StdFindOp
//===----------------------------------------------------------------------===//

def SameFirstSecondOperandAndResultType :
NativeOpTrait<"SameFirstSecondOperandAndResultType">;

def StdFindOp : CIR_Op<"std.find", [SameFirstSecondOperandAndResultType]> {
let arguments = (ins FlatSymbolRefAttr:$original_fn,
CIR_AnyType:$first,
Expand Down Expand Up @@ -3412,6 +3418,46 @@ def IsConstantOp : CIR_Op<"is_constant", [Pure]> {
}];
}

//===----------------------------------------------------------------------===//
// Atomic operations
//===----------------------------------------------------------------------===//

def MemOrderRelaxed : I32EnumAttrCase<"Relaxed", 0, "relaxed">;
def MemOrderConsume : I32EnumAttrCase<"Consume", 1, "consume">;
def MemOrderAcquire : I32EnumAttrCase<"Acquire", 2, "acquire">;
def MemOrderRelease : I32EnumAttrCase<"Release", 3, "release">;
def MemOrderAcqRel : I32EnumAttrCase<"AcquireRelease", 4, "acq_rel">;
def MemOrderSeqCst : I32EnumAttrCase<"SequentiallyConsistent", 5, "seq_cst">;

def MemOrder : I32EnumAttr<
"MemOrder",
"Memory order according to C++11 memory model",
[MemOrderRelaxed, MemOrderConsume, MemOrderAcquire,
MemOrderRelease, MemOrderAcqRel, MemOrderSeqCst]> {
let cppNamespace = "::mlir::cir";
}

def AtomicAddFetch : CIR_Op<"atomic.add_fetch",
[Pure, SameSecondOperandAndResultType]> {
let summary = "Represents the __atomic_add_fetch builtin";
let description = [{}];
let results = (outs CIR_AnyIntOrFloat:$result);
let arguments = (ins IntOrFPPtr:$ptr, CIR_AnyIntOrFloat:$val,
Arg<MemOrder, "memory order">:$mem_order,
UnitAttr:$is_volatile);

let assemblyFormat = [{
`(`
$ptr `:` type($ptr) `,`
$val `:` type($val) `,`
$mem_order `)`
(`volatile` $is_volatile^)?
`:` type($result) attr-dict
}];

let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// Operations Lowered Directly to LLVM IR
//
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROpsEnums.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ LLVM_ATTRIBUTE_UNUSED static bool isValidLinkage(GlobalLinkageKind L) {
isLinkOnceLinkage(L);
}

bool operator<(mlir::cir::MemOrder, mlir::cir::MemOrder) = delete;
bool operator>(mlir::cir::MemOrder, mlir::cir::MemOrder) = delete;
bool operator<=(mlir::cir::MemOrder, mlir::cir::MemOrder) = delete;
bool operator>=(mlir::cir::MemOrder, mlir::cir::MemOrder) = delete;

// Validate an integral value which isn't known to fit within the enum's range
// is a valid AtomicOrderingCABI.
template <typename Int> inline bool isValidCIRAtomicOrderingCABI(Int I) {
return (Int)mlir::cir::MemOrder::Relaxed <= I &&
I <= (Int)mlir::cir::MemOrder::SequentiallyConsistent;
}

} // namespace cir
} // namespace mlir

Expand Down
11 changes: 11 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
// Constraints

def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;

//===----------------------------------------------------------------------===//
// PointerType
Expand Down Expand Up @@ -373,6 +374,16 @@ def VoidPtr : Type<
"mlir::cir::VoidType::get($_builder.getContext()))"> {
}

// Pointer to int, float or double
def IntOrFPPtr : Type<
And<[
CPred<"$_self.isa<::mlir::cir::PointerType>()">,
CPred<"$_self.cast<::mlir::cir::PointerType>()"
".getPointee().isa<::mlir::cir::IntType,"
"::mlir::cir::SingleType, ::mlir::cir::DoubleType>()">,
]>, "{int,void}*"> {
}

// Pointer to struct
def StructPtr : Type<
And<[
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/Address.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Address {
/// Return address with different element type, but same pointer and
/// alignment.
Address withElementType(mlir::Type ElemTy) const {
// TODO(cir): hasOffset() check
return Address(getPointer(), ElemTy, getAlignment(), isKnownNonNull());
}

Expand Down
72 changes: 54 additions & 18 deletions clang/lib/CIR/CodeGen/CIRGenAtomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ static Address buildValToTemp(CIRGenFunction &CGF, Expr *E) {
}

Address AtomicInfo::castToAtomicIntPointer(Address addr) const {
auto intTy = addr.getElementType().dyn_cast<mlir::cir::IntType>();
// Don't bother with int casts if the integer size is the same.
if (intTy && intTy.getWidth() == AtomicSizeInBits)
return addr;
auto ty = CGF.getBuilder().getUIntNTy(AtomicSizeInBits);
return addr.withElementType(ty);
}
Expand Down Expand Up @@ -314,10 +318,12 @@ static mlir::cir::IntAttr getConstOpIntAttr(mlir::Value v) {
static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,
Address Ptr, Address Val1, Address Val2,
mlir::Value IsWeak, mlir::Value FailureOrder,
uint64_t Size, llvm::AtomicOrdering Order,
uint64_t Size, mlir::cir::MemOrder Order,
uint8_t Scope) {
assert(!UnimplementedFeature::syncScopeID());
StringRef Op;
[[maybe_unused]] bool PostOpMinMax = false;
auto loc = CGF.getLoc(E->getSourceRange());

switch (E->getOp()) {
case AtomicExpr::AO__c11_atomic_init:
Expand Down Expand Up @@ -375,18 +381,19 @@ static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,

case AtomicExpr::AO__atomic_add_fetch:
case AtomicExpr::AO__scoped_atomic_add_fetch:
llvm_unreachable("NYI");
// In LLVM codegen, the post operation codegen is tracked here.
[[fallthrough]];
case AtomicExpr::AO__c11_atomic_fetch_add:
case AtomicExpr::AO__hip_atomic_fetch_add:
case AtomicExpr::AO__opencl_atomic_fetch_add:
case AtomicExpr::AO__atomic_fetch_add:
case AtomicExpr::AO__scoped_atomic_fetch_add:
llvm_unreachable("NYI");
Op = mlir::cir::AtomicAddFetch::getOperationName();
break;

case AtomicExpr::AO__atomic_sub_fetch:
case AtomicExpr::AO__scoped_atomic_sub_fetch:
// In LLVM codegen, the post operation codegen is tracked here.
llvm_unreachable("NYI");
[[fallthrough]];
case AtomicExpr::AO__c11_atomic_fetch_sub:
Expand Down Expand Up @@ -423,6 +430,7 @@ static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,

case AtomicExpr::AO__atomic_and_fetch:
case AtomicExpr::AO__scoped_atomic_and_fetch:
// In LLVM codegen, the post operation codegen is tracked here.
llvm_unreachable("NYI");
[[fallthrough]];
case AtomicExpr::AO__c11_atomic_fetch_and:
Expand All @@ -435,6 +443,7 @@ static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,

case AtomicExpr::AO__atomic_or_fetch:
case AtomicExpr::AO__scoped_atomic_or_fetch:
// In LLVM codegen, the post operation codegen is tracked here.
llvm_unreachable("NYI");
[[fallthrough]];
case AtomicExpr::AO__c11_atomic_fetch_or:
Expand All @@ -447,6 +456,7 @@ static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,

case AtomicExpr::AO__atomic_xor_fetch:
case AtomicExpr::AO__scoped_atomic_xor_fetch:
// In LLVM codegen, the post operation codegen is tracked here.
llvm_unreachable("NYI");
[[fallthrough]];
case AtomicExpr::AO__c11_atomic_fetch_xor:
Expand All @@ -459,6 +469,7 @@ static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,

case AtomicExpr::AO__atomic_nand_fetch:
case AtomicExpr::AO__scoped_atomic_nand_fetch:
// In LLVM codegen, the post operation codegen is tracked here.
llvm_unreachable("NYI");
[[fallthrough]];
case AtomicExpr::AO__c11_atomic_fetch_nand:
Expand All @@ -467,13 +478,38 @@ static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *E, Address Dest,
llvm_unreachable("NYI");
break;
}
llvm_unreachable("NYI");

assert(Op.size() && "expected operation name to build");
auto &builder = CGF.getBuilder();

auto LoadVal1 = builder.createLoad(loc, Val1);

SmallVector<mlir::Value> atomicOperands = {Ptr.getPointer(), LoadVal1};
SmallVector<mlir::Type> atomicResTys = {
Ptr.getPointer().getType().cast<mlir::cir::PointerType>().getPointee()};
auto orderAttr = mlir::cir::MemOrderAttr::get(builder.getContext(), Order);
auto RMWI = builder.create(loc, builder.getStringAttr(Op), atomicOperands,
atomicResTys, {});
RMWI->setAttr("mem_order", orderAttr);
if (E->isVolatile())
RMWI->setAttr("is_volatile", mlir::UnitAttr::get(builder.getContext()));
auto Result = RMWI->getResult(0);

if (PostOpMinMax)
llvm_unreachable("NYI");

// This should be handled in LowerToLLVM.cpp, still tracking here for now.
if (E->getOp() == AtomicExpr::AO__atomic_nand_fetch ||
E->getOp() == AtomicExpr::AO__scoped_atomic_nand_fetch)
llvm_unreachable("NYI");

builder.createStore(loc, Result, Dest);
}

static void buildAtomicOp(CIRGenFunction &CGF, AtomicExpr *Expr, Address Dest,
Address Ptr, Address Val1, Address Val2,
mlir::Value IsWeak, mlir::Value FailureOrder,
uint64_t Size, llvm::AtomicOrdering Order,
uint64_t Size, mlir::cir::MemOrder Order,
mlir::Value Scope) {
auto ScopeModel = Expr->getScopeModel();

Expand Down Expand Up @@ -1011,34 +1047,34 @@ RValue CIRGenFunction::buildAtomicExpr(AtomicExpr *E) {
// We should not ever get to a case where the ordering isn't a valid CABI
// value, but it's hard to enforce that in general.
auto ord = ordAttr.getUInt();
if (llvm::isValidAtomicOrderingCABI(ord)) {
switch ((llvm::AtomicOrderingCABI)ord) {
case llvm::AtomicOrderingCABI::relaxed:
if (mlir::cir::isValidCIRAtomicOrderingCABI(ord)) {
switch ((mlir::cir::MemOrder)ord) {
case mlir::cir::MemOrder::Relaxed:
buildAtomicOp(*this, E, Dest, Ptr, Val1, Val2, IsWeak, OrderFail, Size,
llvm::AtomicOrdering::Monotonic, Scope);
mlir::cir::MemOrder::Relaxed, Scope);
break;
case llvm::AtomicOrderingCABI::consume:
case llvm::AtomicOrderingCABI::acquire:
case mlir::cir::MemOrder::Consume:
case mlir::cir::MemOrder::Acquire:
if (IsStore)
break; // Avoid crashing on code with undefined behavior
buildAtomicOp(*this, E, Dest, Ptr, Val1, Val2, IsWeak, OrderFail, Size,
llvm::AtomicOrdering::Acquire, Scope);
mlir::cir::MemOrder::Acquire, Scope);
break;
case llvm::AtomicOrderingCABI::release:
case mlir::cir::MemOrder::Release:
if (IsLoad)
break; // Avoid crashing on code with undefined behavior
buildAtomicOp(*this, E, Dest, Ptr, Val1, Val2, IsWeak, OrderFail, Size,
llvm::AtomicOrdering::Release, Scope);
mlir::cir::MemOrder::Release, Scope);
break;
case llvm::AtomicOrderingCABI::acq_rel:
case mlir::cir::MemOrder::AcquireRelease:
if (IsLoad || IsStore)
break; // Avoid crashing on code with undefined behavior
buildAtomicOp(*this, E, Dest, Ptr, Val1, Val2, IsWeak, OrderFail, Size,
llvm::AtomicOrdering::AcquireRelease, Scope);
mlir::cir::MemOrder::AcquireRelease, Scope);
break;
case llvm::AtomicOrderingCABI::seq_cst:
case mlir::cir::MemOrder::SequentiallyConsistent:
buildAtomicOp(*this, E, Dest, Ptr, Val1, Val2, IsWeak, OrderFail, Size,
llvm::AtomicOrdering::SequentiallyConsistent, Scope);
mlir::cir::MemOrder::SequentiallyConsistent, Scope);
break;
}
}
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}

mlir::Value createLoad(mlir::Location loc, Address addr) {
return create<mlir::cir::LoadOp>(loc, addr.getElementType(),
addr.getPointer());
auto ptrTy = addr.getPointer().getType().dyn_cast<mlir::cir::PointerType>();
return create<mlir::cir::LoadOp>(
loc, addr.getElementType(),
createElementBitCast(loc, addr, ptrTy.getPointee()).getPointer());
}

mlir::Value createAlignedLoad(mlir::Location loc, mlir::Type ty,
Expand Down
29 changes: 29 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,20 @@ LogicalResult CastOp::verify() {
llvm_unreachable("Unknown CastOp kind?");
}

OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (getKind() != mlir::cir::CastKind::integral)
return {};
if (getSrc().getType() != getResult().getType())
return {};
// TODO: for sign differences, it's possible in certain conditions to
// create a new attributes that's capable or representing the source.
SmallVector<mlir::OpFoldResult, 1> foldResults;
auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
if (foldOrder.succeeded() && foldResults[0].is<mlir::Attribute>())
return foldResults[0].get<mlir::Attribute>();
return {};
}

//===----------------------------------------------------------------------===//
// VecCreateOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2373,6 +2387,21 @@ mlir::OpTrait::impl::verifySameFirstOperandAndResultType(Operation *op) {
return success();
}

LogicalResult
mlir::OpTrait::impl::verifySameSecondOperandAndResultType(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 2)) || failed(verifyOneResult(op)))
return failure();

auto type = op->getResult(0).getType();
auto opType = op->getOperand(1).getType();

if (type != opType)
return op->emitOpError()
<< "requires the same type for first operand and result";

return success();
}

LogicalResult
mlir::OpTrait::impl::verifySameFirstSecondOperandAndResultType(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 3)) || failed(verifyOneResult(op)))
Expand Down
Loading

0 comments on commit e58a597

Please sign in to comment.