Skip to content

Commit

Permalink
[CIR] GNU vector type cleanup (#531)
Browse files Browse the repository at this point in the history
This is the final commit for issue #284. Vector types other than GNU
vector types will be covered by other yet-to-be-created issues.

Now that GNU vector types (the ones defined via the vector_size
attribute) are implemented, do a final cleanup of the assertions and
other checks related to vector types.

Remove `UnimplementedFeature::cirVectorType()`. Deal with the remaining
calls to that function. When the that is not yet implemented has to do
with Arm SVE vectors, the assert was changed to
`UnimplementedFeature::scalableVectors()` instead. The assertion was
removed in cases where the code correctly handle GNU vector types.

While cleaning up the assertion checks, I noticed that BinOp handling of
vector types wasn't quite complete. Any special handling for integer or
floating-point types wasn't happening when the operands were vector
types. To fix this, split `BinOpInfo::Ty` into two fields, `FullType`
and `CompType`. `FullType` is the type of the operands. `CompType` is
normally the same as `FullType`, but is the element type when `FullType`
is a vector type.
  • Loading branch information
dkolsen-pgi authored and lanza committed Apr 6, 2024
1 parent 84851f8 commit 4b49152
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 51 deletions.
94 changes: 47 additions & 47 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ struct BinOpInfo {
mlir::Value LHS;
mlir::Value RHS;
SourceRange Loc;
QualType Ty; // Computation Type.
QualType FullType; // Type of operands and result
QualType CompType; // Type used for computations. Element type
// for vectors, otherwise same as FullType.
BinaryOperator::Opcode Opcode; // Opcode of BinOp to perform
FPOptions FPFeatures;
const Expr *E; // Entire expr, for error unsupported. May not be binop.
Expand Down Expand Up @@ -749,7 +751,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
BinOpInfo Result;
Result.LHS = Visit(E->getLHS());
Result.RHS = Visit(E->getRHS());
Result.Ty = E->getType();
Result.FullType = E->getType();
Result.CompType = E->getType();
if (auto VecType = dyn_cast_or_null<VectorType>(E->getType())) {
Result.CompType = VecType->getElementType();
}
Result.Opcode = E->getOpcode();
Result.Loc = E->getSourceRange();
// TODO: Result.FPFeatures
Expand Down Expand Up @@ -850,7 +856,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
// a vector.
mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode());
return Builder.create<mlir::cir::VecCmpOp>(
CGF.getLoc(BOInfo.Loc), CGF.getCIRType(BOInfo.Ty), Kind,
CGF.getLoc(BOInfo.Loc), CGF.getCIRType(BOInfo.FullType), Kind,
BOInfo.LHS, BOInfo.RHS);
}
}
Expand All @@ -869,15 +875,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {

mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode());
return Builder.create<mlir::cir::CmpOp>(CGF.getLoc(BOInfo.Loc),
CGF.getCIRType(BOInfo.Ty), Kind,
BOInfo.LHS, BOInfo.RHS);
CGF.getCIRType(BOInfo.FullType),
Kind, BOInfo.LHS, BOInfo.RHS);
}

// If this is a vector comparison, sign extend the result to the
// appropriate vector integer type and return it (don't convert to
// bool).
if (LHSTy->isVectorType())
assert(0 && "not implemented");
} else { // Complex Comparison: can only be an equality comparison.
assert(0 && "not implemented");
}
Expand Down Expand Up @@ -994,10 +994,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
assert(!SrcType->isMatrixType() && !DstType->isMatrixType() &&
"Internal error: conversion between matrix type and scalar type");

// TODO(CIR): Support VectorTypes
assert(!UnimplementedFeature::cirVectorType() && "NYI: vector cast");

// Finally, we have the arithmetic types: real int/float.
// Finally, we have the arithmetic types or vectors of arithmetic types.
mlir::Value Res = nullptr;
mlir::Type ResTy = DstTy;

Expand Down Expand Up @@ -1214,18 +1211,18 @@ static mlir::Value buildPointerArithmetic(CIRGenFunction &CGF,

mlir::Value ScalarExprEmitter::buildMul(const BinOpInfo &Ops) {
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Mul,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Mul, Ops.LHS, Ops.RHS);
}
mlir::Value ScalarExprEmitter::buildDiv(const BinOpInfo &Ops) {
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Div,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Div, Ops.LHS, Ops.RHS);
}
mlir::Value ScalarExprEmitter::buildRem(const BinOpInfo &Ops) {
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Rem,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Rem, Ops.LHS, Ops.RHS);
}

mlir::Value ScalarExprEmitter::buildAdd(const BinOpInfo &Ops) {
Expand All @@ -1234,25 +1231,25 @@ mlir::Value ScalarExprEmitter::buildAdd(const BinOpInfo &Ops) {
return buildPointerArithmetic(CGF, Ops, /*isSubtraction=*/false);

return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Add,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Add, Ops.LHS, Ops.RHS);
}

mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
// The LHS is always a pointer if either side is.
if (!Ops.LHS.getType().isa<mlir::cir::PointerType>()) {
if (Ops.Ty->isSignedIntegerOrEnumerationType()) {
if (Ops.CompType->isSignedIntegerOrEnumerationType()) {
switch (CGF.getLangOpts().getSignedOverflowBehavior()) {
case LangOptions::SOB_Defined: {
llvm_unreachable("NYI");
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty),
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS);
}
case LangOptions::SOB_Undefined:
if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow))
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty),
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS);
[[fallthrough]];
case LangOptions::SOB_Trapping:
Expand All @@ -1262,17 +1259,16 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
}
}

if (Ops.Ty->isConstantMatrixType()) {
if (Ops.FullType->isConstantMatrixType()) {
llvm_unreachable("NYI");
}

if (Ops.Ty->isUnsignedIntegerType() &&
if (Ops.CompType->isUnsignedIntegerType() &&
CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
!CanElideOverflowCheck(CGF.getContext(), Ops))
llvm_unreachable("NYI");

assert(!UnimplementedFeature::cirVectorType());
if (Ops.LHS.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
if (Ops.CompType->isFloatingType()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand All @@ -1281,8 +1277,8 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
llvm_unreachable("NYI");

return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Sub,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS);
}

// If the RHS is not a pointer, then we have normal pointer
Expand Down Expand Up @@ -1313,12 +1309,12 @@ mlir::Value ScalarExprEmitter::buildShl(const BinOpInfo &Ops) {
// promote or truncate the RHS to the same size as the LHS.

bool SanitizeSignedBase = CGF.SanOpts.has(SanitizerKind::ShiftBase) &&
Ops.Ty->hasSignedIntegerRepresentation() &&
Ops.CompType->hasSignedIntegerRepresentation() &&
!CGF.getLangOpts().isSignedOverflowDefined() &&
!CGF.getLangOpts().CPlusPlus20;
bool SanitizeUnsignedBase =
CGF.SanOpts.has(SanitizerKind::UnsignedShiftBase) &&
Ops.Ty->hasUnsignedIntegerRepresentation();
Ops.CompType->hasUnsignedIntegerRepresentation();
bool SanitizeBase = SanitizeSignedBase || SanitizeUnsignedBase;
bool SanitizeExponent = CGF.SanOpts.has(SanitizerKind::ShiftExponent);

Expand All @@ -1331,7 +1327,7 @@ mlir::Value ScalarExprEmitter::buildShl(const BinOpInfo &Ops) {
}

return Builder.create<mlir::cir::ShiftOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), Ops.LHS, Ops.RHS,
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), Ops.LHS, Ops.RHS,
CGF.getBuilder().getUnitAttr());
}

Expand All @@ -1355,23 +1351,23 @@ mlir::Value ScalarExprEmitter::buildShr(const BinOpInfo &Ops) {
// Note that we don't need to distinguish unsigned treatment at this
// point since it will be handled later by LLVM lowering.
return Builder.create<mlir::cir::ShiftOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), Ops.LHS, Ops.RHS);
}

mlir::Value ScalarExprEmitter::buildAnd(const BinOpInfo &Ops) {
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::And,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::And, Ops.LHS, Ops.RHS);
}
mlir::Value ScalarExprEmitter::buildXor(const BinOpInfo &Ops) {
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Xor,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Xor, Ops.LHS, Ops.RHS);
}
mlir::Value ScalarExprEmitter::buildOr(const BinOpInfo &Ops) {
return Builder.create<mlir::cir::BinOp>(
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Or,
Ops.LHS, Ops.RHS);
CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType),
mlir::cir::BinOpKind::Or, Ops.LHS, Ops.RHS);
}

// Emit code for an explicit or implicit cast. Implicit
Expand Down Expand Up @@ -1410,7 +1406,6 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
auto Src = Visit(const_cast<Expr *>(E));
mlir::Type DstTy = CGF.convertType(DestTy);

assert(!UnimplementedFeature::cirVectorType());
assert(!UnimplementedFeature::addressSpace());
if (CGF.SanOpts.has(SanitizerKind::CFIUnrelatedCast)) {
llvm_unreachable("NYI");
Expand All @@ -1426,20 +1421,21 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
// If Src is a fixed vector and Dst is a scalable vector, and both have the
// same element type, use the llvm.vector.insert intrinsic to perform the
// bitcast.
assert(!UnimplementedFeature::cirVectorType());
assert(!UnimplementedFeature::scalableVectors());

// If Src is a scalable vector and Dst is a fixed vector, and both have the
// same element type, use the llvm.vector.extract intrinsic to perform the
// bitcast.
assert(!UnimplementedFeature::cirVectorType());
assert(!UnimplementedFeature::scalableVectors());

// Perform VLAT <-> VLST bitcast through memory.
// TODO: since the llvm.experimental.vector.{insert,extract} intrinsics
// require the element types of the vectors to be the same, we
// need to keep this around for bitcasts between VLAT <-> VLST where
// the element types of the vectors are not the same, until we figure
// out a better way of doing these casts.
assert(!UnimplementedFeature::cirVectorType());
assert(!UnimplementedFeature::scalableVectors());

return CGF.getBuilder().createBitcast(CGF.getLoc(E->getSourceRange()), Src,
DstTy);
}
Expand Down Expand Up @@ -1881,7 +1877,11 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
// Emit the RHS first. __block variables need to have the rhs evaluated
// first, plus this should improve codegen a little.
OpInfo.RHS = Visit(E->getRHS());
OpInfo.Ty = E->getComputationResultType();
OpInfo.FullType = E->getComputationResultType();
OpInfo.CompType = OpInfo.FullType;
if (auto VecType = dyn_cast_or_null<VectorType>(OpInfo.FullType)) {
OpInfo.CompType = VecType->getElementType();
}
OpInfo.Opcode = E->getOpcode();
OpInfo.FPFeatures = E->getFPFeaturesInEffect(CGF.getLangOpts());
OpInfo.E = E;
Expand Down
5 changes: 1 addition & 4 deletions clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ struct UnimplementedFeature {
static bool tbaa() { return false; }
static bool cleanups() { return false; }

// cir::VectorType is in progress, so cirVectorType() will go away soon.
// Start adding feature flags for more advanced vector types and operations
// that will take longer to implement.
static bool cirVectorType() { return false; }
// GNU vectors are done, but other kinds of vectors haven't been implemented.
static bool scalableVectors() { return false; }
static bool vectorConstants() { return false; }

Expand Down

0 comments on commit 4b49152

Please sign in to comment.