From 04e6200e0e8375d20be6bb27319151aecf0613f5 Mon Sep 17 00:00:00 2001 From: David Olsen Date: Fri, 5 Apr 2024 18:26:14 -0700 Subject: [PATCH] [CIR] GNU vector type cleanup (#531) This is the final commit for issue #284. Vector types other than GNU vector types will be covered by other yet-to-be-created issues. Now that GNU vector types (the ones defined via the vector_size attribute) are implemented, do a final cleanup of the assertions and other checks related to vector types. Remove `UnimplementedFeature::cirVectorType()`. Deal with the remaining calls to that function. When the that is not yet implemented has to do with Arm SVE vectors, the assert was changed to `UnimplementedFeature::scalableVectors()` instead. The assertion was removed in cases where the code correctly handle GNU vector types. While cleaning up the assertion checks, I noticed that BinOp handling of vector types wasn't quite complete. Any special handling for integer or floating-point types wasn't happening when the operands were vector types. To fix this, split `BinOpInfo::Ty` into two fields, `FullType` and `CompType`. `FullType` is the type of the operands. `CompType` is normally the same as `FullType`, but is the element type when `FullType` is a vector type. --- clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 94 +++++++++---------- .../CodeGen/UnimplementedFeatureGuarding.h | 5 +- 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index e6093c52ca84..22c8ffdf96ff 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -37,7 +37,9 @@ struct BinOpInfo { mlir::Value LHS; mlir::Value RHS; SourceRange Loc; - QualType Ty; // Computation Type. + QualType FullType; // Type of operands and result + QualType CompType; // Type used for computations. Element type + // for vectors, otherwise same as FullType. BinaryOperator::Opcode Opcode; // Opcode of BinOp to perform FPOptions FPFeatures; const Expr *E; // Entire expr, for error unsupported. May not be binop. @@ -749,7 +751,11 @@ class ScalarExprEmitter : public StmtVisitor { BinOpInfo Result; Result.LHS = Visit(E->getLHS()); Result.RHS = Visit(E->getRHS()); - Result.Ty = E->getType(); + Result.FullType = E->getType(); + Result.CompType = E->getType(); + if (auto VecType = dyn_cast_or_null(E->getType())) { + Result.CompType = VecType->getElementType(); + } Result.Opcode = E->getOpcode(); Result.Loc = E->getSourceRange(); // TODO: Result.FPFeatures @@ -850,7 +856,7 @@ class ScalarExprEmitter : public StmtVisitor { // a vector. mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode()); return Builder.create( - CGF.getLoc(BOInfo.Loc), CGF.getCIRType(BOInfo.Ty), Kind, + CGF.getLoc(BOInfo.Loc), CGF.getCIRType(BOInfo.FullType), Kind, BOInfo.LHS, BOInfo.RHS); } } @@ -869,15 +875,9 @@ class ScalarExprEmitter : public StmtVisitor { mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode()); return Builder.create(CGF.getLoc(BOInfo.Loc), - CGF.getCIRType(BOInfo.Ty), Kind, - BOInfo.LHS, BOInfo.RHS); + CGF.getCIRType(BOInfo.FullType), + Kind, BOInfo.LHS, BOInfo.RHS); } - - // If this is a vector comparison, sign extend the result to the - // appropriate vector integer type and return it (don't convert to - // bool). - if (LHSTy->isVectorType()) - assert(0 && "not implemented"); } else { // Complex Comparison: can only be an equality comparison. assert(0 && "not implemented"); } @@ -994,10 +994,7 @@ class ScalarExprEmitter : public StmtVisitor { assert(!SrcType->isMatrixType() && !DstType->isMatrixType() && "Internal error: conversion between matrix type and scalar type"); - // TODO(CIR): Support VectorTypes - assert(!UnimplementedFeature::cirVectorType() && "NYI: vector cast"); - - // Finally, we have the arithmetic types: real int/float. + // Finally, we have the arithmetic types or vectors of arithmetic types. mlir::Value Res = nullptr; mlir::Type ResTy = DstTy; @@ -1214,18 +1211,18 @@ static mlir::Value buildPointerArithmetic(CIRGenFunction &CGF, mlir::Value ScalarExprEmitter::buildMul(const BinOpInfo &Ops) { return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Mul, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Mul, Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildDiv(const BinOpInfo &Ops) { return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Div, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Div, Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildRem(const BinOpInfo &Ops) { return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Rem, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Rem, Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildAdd(const BinOpInfo &Ops) { @@ -1234,25 +1231,25 @@ mlir::Value ScalarExprEmitter::buildAdd(const BinOpInfo &Ops) { return buildPointerArithmetic(CGF, Ops, /*isSubtraction=*/false); return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Add, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Add, Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) { // The LHS is always a pointer if either side is. if (!Ops.LHS.getType().isa()) { - if (Ops.Ty->isSignedIntegerOrEnumerationType()) { + if (Ops.CompType->isSignedIntegerOrEnumerationType()) { switch (CGF.getLangOpts().getSignedOverflowBehavior()) { case LangOptions::SOB_Defined: { llvm_unreachable("NYI"); return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS); } case LangOptions::SOB_Undefined: if (!CGF.SanOpts.has(SanitizerKind::SignedIntegerOverflow)) return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS); [[fallthrough]]; case LangOptions::SOB_Trapping: @@ -1262,17 +1259,16 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) { } } - if (Ops.Ty->isConstantMatrixType()) { + if (Ops.FullType->isConstantMatrixType()) { llvm_unreachable("NYI"); } - if (Ops.Ty->isUnsignedIntegerType() && + if (Ops.CompType->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) llvm_unreachable("NYI"); - assert(!UnimplementedFeature::cirVectorType()); - if (Ops.LHS.getType().isa()) { + if (Ops.CompType->isFloatingType()) { CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures); return Builder.createFSub(Ops.LHS, Ops.RHS); } @@ -1281,8 +1277,8 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) { llvm_unreachable("NYI"); return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Sub, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Sub, Ops.LHS, Ops.RHS); } // If the RHS is not a pointer, then we have normal pointer @@ -1313,12 +1309,12 @@ mlir::Value ScalarExprEmitter::buildShl(const BinOpInfo &Ops) { // promote or truncate the RHS to the same size as the LHS. bool SanitizeSignedBase = CGF.SanOpts.has(SanitizerKind::ShiftBase) && - Ops.Ty->hasSignedIntegerRepresentation() && + Ops.CompType->hasSignedIntegerRepresentation() && !CGF.getLangOpts().isSignedOverflowDefined() && !CGF.getLangOpts().CPlusPlus20; bool SanitizeUnsignedBase = CGF.SanOpts.has(SanitizerKind::UnsignedShiftBase) && - Ops.Ty->hasUnsignedIntegerRepresentation(); + Ops.CompType->hasUnsignedIntegerRepresentation(); bool SanitizeBase = SanitizeSignedBase || SanitizeUnsignedBase; bool SanitizeExponent = CGF.SanOpts.has(SanitizerKind::ShiftExponent); @@ -1331,7 +1327,7 @@ mlir::Value ScalarExprEmitter::buildShl(const BinOpInfo &Ops) { } return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), Ops.LHS, Ops.RHS, + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), Ops.LHS, Ops.RHS, CGF.getBuilder().getUnitAttr()); } @@ -1355,23 +1351,23 @@ mlir::Value ScalarExprEmitter::buildShr(const BinOpInfo &Ops) { // Note that we don't need to distinguish unsigned treatment at this // point since it will be handled later by LLVM lowering. return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildAnd(const BinOpInfo &Ops) { return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::And, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::And, Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildXor(const BinOpInfo &Ops) { return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Xor, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Xor, Ops.LHS, Ops.RHS); } mlir::Value ScalarExprEmitter::buildOr(const BinOpInfo &Ops) { return Builder.create( - CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.Ty), mlir::cir::BinOpKind::Or, - Ops.LHS, Ops.RHS); + CGF.getLoc(Ops.Loc), CGF.getCIRType(Ops.FullType), + mlir::cir::BinOpKind::Or, Ops.LHS, Ops.RHS); } // Emit code for an explicit or implicit cast. Implicit @@ -1410,7 +1406,6 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { auto Src = Visit(const_cast(E)); mlir::Type DstTy = CGF.convertType(DestTy); - assert(!UnimplementedFeature::cirVectorType()); assert(!UnimplementedFeature::addressSpace()); if (CGF.SanOpts.has(SanitizerKind::CFIUnrelatedCast)) { llvm_unreachable("NYI"); @@ -1426,12 +1421,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { // If Src is a fixed vector and Dst is a scalable vector, and both have the // same element type, use the llvm.vector.insert intrinsic to perform the // bitcast. - assert(!UnimplementedFeature::cirVectorType()); + assert(!UnimplementedFeature::scalableVectors()); // If Src is a scalable vector and Dst is a fixed vector, and both have the // same element type, use the llvm.vector.extract intrinsic to perform the // bitcast. - assert(!UnimplementedFeature::cirVectorType()); + assert(!UnimplementedFeature::scalableVectors()); // Perform VLAT <-> VLST bitcast through memory. // TODO: since the llvm.experimental.vector.{insert,extract} intrinsics @@ -1439,7 +1434,8 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { // need to keep this around for bitcasts between VLAT <-> VLST where // the element types of the vectors are not the same, until we figure // out a better way of doing these casts. - assert(!UnimplementedFeature::cirVectorType()); + assert(!UnimplementedFeature::scalableVectors()); + return CGF.getBuilder().createBitcast(CGF.getLoc(E->getSourceRange()), Src, DstTy); } @@ -1881,7 +1877,11 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue( // Emit the RHS first. __block variables need to have the rhs evaluated // first, plus this should improve codegen a little. OpInfo.RHS = Visit(E->getRHS()); - OpInfo.Ty = E->getComputationResultType(); + OpInfo.FullType = E->getComputationResultType(); + OpInfo.CompType = OpInfo.FullType; + if (auto VecType = dyn_cast_or_null(OpInfo.FullType)) { + OpInfo.CompType = VecType->getElementType(); + } OpInfo.Opcode = E->getOpcode(); OpInfo.FPFeatures = E->getFPFeaturesInEffect(CGF.getLangOpts()); OpInfo.E = E; diff --git a/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h b/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h index 06bd8201834c..1a8d1328f90c 100644 --- a/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h +++ b/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h @@ -23,10 +23,7 @@ struct UnimplementedFeature { static bool tbaa() { return false; } static bool cleanups() { return false; } - // cir::VectorType is in progress, so cirVectorType() will go away soon. - // Start adding feature flags for more advanced vector types and operations - // that will take longer to implement. - static bool cirVectorType() { return false; } + // GNU vectors are done, but other kinds of vectors haven't been implemented. static bool scalableVectors() { return false; } static bool vectorConstants() { return false; }