Skip to content

Commit

Permalink
[CIR] shufflevector and convertvector built-ins (llvm#530)
Browse files Browse the repository at this point in the history
Implement `__builtin_shufflevector` and `__builtin_convertvector` in
ClangIR. This change contributes to the implemention of issue llvm#284.

`__builtin_convertvector` is implemented as a cast. LLVM IR uses the
same instructions for arithmetic conversions of both individual scalars
and entire vectors. So ClangIR does the same. The code for handling
conversions, in both CodeGen and Lowering, is cleaned up to correctly
handle vector types. To simplify the lowering code and avoid `if
(type.isa<VectorType>())` statements everywhere, the utility function
`elementTypeIfVector` was added to `LowerToLLVM.cpp`.

`__builtin_shufflevector` has two forms, only one of which appears to be
documented.

The documented form, which takes a variable-sized list of integer
constants for the indices, is implemented with the new ClangIR operation
`cir.vec.shuffle.ints`. This operation is lowered to the
`llvm.shufflevector` op.

The undocumented form, which gets the indices from a vector operand, is
implemented with the new ClangIR operation `cir.vec.shuffle.vec`. LLVM
IR does not have an instruction for this, so it gets lowered to a long
series of `llvm.extractelement` and `llvm.insertelement` operations.
  • Loading branch information
dkolsen-pgi authored and lanza committed Apr 17, 2024
1 parent 80599ac commit b7827f2
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 53 deletions.
63 changes: 63 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,69 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecShuffle
//===----------------------------------------------------------------------===//

// TODO: Create an interface that both VecShuffleOp and VecShuffleDynamicOp
// implement. This could be useful for passes that don't care how the vector
// shuffle was specified.

def VecShuffleOp : CIR_Op<"vec.shuffle",
[Pure, AllTypesMatch<["vec1", "vec2"]>]> {
let summary = "Combine two vectors using indices passed as constant integers";
let description = [{
The `cir.vec.shuffle` operation implements the documented form of Clang's
__builtin_shufflevector, where the indices of the shuffled result are
integer constants.

The two input vectors, which must have the same type, are concatenated.
Each of the integer constant arguments is interpreted as an index into that
concatenated vector, with a value of -1 meaning that the result value
doesn't matter. The result vector, which must have the same element type as
the input vectors and the same number of elements as the list of integer
constant indices, is constructed by taking the elements at the given
indices from the concatenated vector. The size of the result vector does
not have to match the size of the individual input vectors or of the
concatenated vector.
}];
let arguments = (ins CIR_VectorType:$vec1, CIR_VectorType:$vec2,
ArrayAttr:$indices);
let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
`(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:`
qualified(type($result)) attr-dict
}];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecShuffleDynamic
//===----------------------------------------------------------------------===//

def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
[Pure, AllTypesMatch<["vec", "result"]>]> {
let summary = "Shuffle a vector using indices in another vector";
let description = [{
The `cir.vec.shuffle.dynamic` operation implements the undocumented form of
Clang's __builtin_shufflevector, where the indices of the shuffled result
can be runtime values.

There are two input vectors, which must have the same number of elements.
The second input vector must have an integral element type. The elements of
the second vector are interpreted as indices into the first vector. The
result vector is constructed by taking the elements from the first input
vector from the indices indicated by the elements of the second vector.
}];
let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
$vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))
attr-dict
}];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// BaseClassAddr
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 46 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,37 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
llvm_unreachable("NYI");
}
mlir::Value VisitShuffleVectorExpr(ShuffleVectorExpr *E) {
llvm_unreachable("NYI");
if (E->getNumSubExprs() == 2) {
// The undocumented form of __builtin_shufflevector.
mlir::Value InputVec = Visit(E->getExpr(0));
mlir::Value IndexVec = Visit(E->getExpr(1));
return CGF.builder.create<mlir::cir::VecShuffleDynamicOp>(
CGF.getLoc(E->getSourceRange()), InputVec, IndexVec);
} else {
// The documented form of __builtin_shufflevector, where the indices are
// a variable number of integer constants. The constants will be stored
// in an ArrayAttr.
mlir::Value Vec1 = Visit(E->getExpr(0));
mlir::Value Vec2 = Visit(E->getExpr(1));
SmallVector<mlir::Attribute, 8> Indices;
for (unsigned i = 2; i < E->getNumSubExprs(); ++i) {
Indices.push_back(mlir::cir::IntAttr::get(
CGF.builder.getSInt64Ty(),
E->getExpr(i)
->EvaluateKnownConstInt(CGF.getContext())
.getSExtValue()));
}
return CGF.builder.create<mlir::cir::VecShuffleOp>(
CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()), Vec1,
Vec2, CGF.builder.getArrayAttr(Indices));
}
}
mlir::Value VisitConvertVectorExpr(ConvertVectorExpr *E) {
llvm_unreachable("NYI");
// __builtin_convertvector is an element-wise cast, and is implemented as a
// regular cast. The back end handles casts of vectors correctly.
return buildScalarConversion(Visit(E->getSrcExpr()),
E->getSrcExpr()->getType(), E->getType(),
E->getSourceRange().getBegin());
}
mlir::Value VisitMemberExpr(MemberExpr *E);
mlir::Value VisitExtVectorelementExpr(Expr *E) { llvm_unreachable("NYI"); }
Expand Down Expand Up @@ -1725,9 +1752,9 @@ mlir::Value ScalarExprEmitter::VisitUnaryLNot(const UnaryOperator *E) {
}

// Conversion from bool, integral, or floating-point to integral or
// floating-point. Conversions involving other types are handled elsewhere.
// floating-point. Conversions involving other types are handled elsewhere.
// Conversion to bool is handled elsewhere because that's a comparison against
// zero, not a simple cast.
// zero, not a simple cast. This handles both individual scalars and vectors.
mlir::Value ScalarExprEmitter::buildScalarCast(
mlir::Value Src, QualType SrcType, QualType DstType, mlir::Type SrcTy,
mlir::Type DstTy, ScalarConversionOpts Opts) {
Expand All @@ -1736,9 +1763,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
if (SrcTy.isa<mlir::IntegerType>() || DstTy.isa<mlir::IntegerType>())
llvm_unreachable("Obsolete code. Don't use mlir::IntegerType with CIR.");

mlir::Type FullDstTy = DstTy;
if (SrcTy.isa<mlir::cir::VectorType>() &&
DstTy.isa<mlir::cir::VectorType>()) {
// Use the element types of the vectors to figure out the CastKind.
SrcTy = SrcTy.dyn_cast<mlir::cir::VectorType>().getEltType();
DstTy = DstTy.dyn_cast<mlir::cir::VectorType>().getEltType();
}
assert(!SrcTy.isa<mlir::cir::VectorType>() &&
!DstTy.isa<mlir::cir::VectorType>() &&
"buildScalarCast given a vector type and a non-vector type");

std::optional<mlir::cir::CastKind> CastKind;

if (SrcType->isBooleanType()) {
if (SrcTy.isa<mlir::cir::BoolType>()) {
if (Opts.TreatBooleanAsSigned)
llvm_unreachable("NYI: signed bool");
if (CGF.getBuilder().isInt(DstTy)) {
Expand Down Expand Up @@ -1768,7 +1806,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
CastKind = mlir::cir::CastKind::float_to_int;
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
// TODO: split this to createFPExt/createFPTrunc
return Builder.createFloatingCast(Src, DstTy);
return Builder.createFloatingCast(Src, FullDstTy);
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
Expand All @@ -1777,7 +1815,8 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
}

assert(CastKind.has_value() && "Internal error: CastKind not set.");
return Builder.create<mlir::cir::CastOp>(Src.getLoc(), DstTy, *CastKind, Src);
return Builder.create<mlir::cir::CastOp>(Src.getLoc(), FullDstTy, *CastKind,
Src);
}

LValue
Expand Down
61 changes: 56 additions & 5 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ LogicalResult CastOp::verify() {
auto resType = getResult().getType();
auto srcType = getSrc().getType();

if (srcType.isa<mlir::cir::VectorType>() &&
resType.isa<mlir::cir::VectorType>()) {
// Use the element type of the vector to verify the cast kind. (Except for
// bitcast, see below.)
srcType = srcType.dyn_cast<mlir::cir::VectorType>().getEltType();
resType = resType.dyn_cast<mlir::cir::VectorType>().getEltType();
}

switch (getKind()) {
case cir::CastKind::int_to_bool: {
if (!resType.isa<mlir::cir::BoolType>())
Expand Down Expand Up @@ -433,18 +441,20 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::bitcast: {
if ((!srcType.isa<mlir::cir::PointerType>() ||
!resType.isa<mlir::cir::PointerType>()) &&
(!srcType.isa<mlir::cir::VectorType>() ||
!resType.isa<mlir::cir::VectorType>()))
// This is the only cast kind where we don't want vector types to decay
// into the element type.
if ((!getSrc().getType().isa<mlir::cir::PointerType>() ||
!getResult().getType().isa<mlir::cir::PointerType>()) &&
(!getSrc().getType().isa<mlir::cir::VectorType>() ||
!getResult().getType().isa<mlir::cir::VectorType>()))
return emitOpError()
<< "requires !cir.ptr or !cir.vector type for source and result";
return success();
}
case cir::CastKind::floating: {
if (!srcType.isa<mlir::cir::CIRFPTypeInterface>() ||
!resType.isa<mlir::cir::CIRFPTypeInterface>())
return emitOpError() << "requries floating for source and result";
return emitOpError() << "requires floating for source and result";
return success();
}
case cir::CastKind::float_to_int: {
Expand Down Expand Up @@ -544,6 +554,47 @@ LogicalResult VecTernaryOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// VecShuffle
//===----------------------------------------------------------------------===//

LogicalResult VecShuffleOp::verify() {
// The number of elements in the indices array must match the number of
// elements in the result type.
if (getIndices().size() != getResult().getType().getSize()) {
return emitOpError() << ": the number of elements in " << getIndices()
<< " and " << getResult().getType() << " don't match";
}
// The element types of the two input vectors and of the result type must
// match.
if (getVec1().getType().getEltType() != getResult().getType().getEltType()) {
return emitOpError() << ": element types of " << getVec1().getType()
<< " and " << getResult().getType() << " don't match";
}
// The indices must all be integer constants
if (not std::all_of(getIndices().begin(), getIndices().end(),
[](mlir::Attribute attr) {
return attr.isa<mlir::cir::IntAttr>();
})) {
return emitOpError() << "all index values must be integers";
}
return success();
}

//===----------------------------------------------------------------------===//
// VecShuffleDynamic
//===----------------------------------------------------------------------===//

LogicalResult VecShuffleDynamicOp::verify() {
// The number of elements in the two input vectors must match.
if (getVec().getType().getSize() !=
getIndices().getType().cast<mlir::cir::VectorType>().getSize()) {
return emitOpError() << ": the number of elements in " << getVec().getType()
<< " and " << getIndices().getType() << " don't match";
}
return success();
}

//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit b7827f2

Please sign in to comment.