Skip to content

Commit

Permalink
[CIR][CodeGen] Fix address space of result pointer type of array deca…
Browse files Browse the repository at this point in the history
…y cast op (llvm#812)

There are two occurrences of `cir.cast(array_to_ptrdecay, ...)` that
drop address spaces unexpectedly for its result pointer type. This PR
fixes them with the source address space.

```mlir
// Before
%1 = cir.cast(array_to_ptrdecay, %0 : !cir.ptr<!cir.array<!s32i x 32>, addrspace(offload_local)>), !cir.ptr<!s32i>
// After
%1 = cir.cast(array_to_ptrdecay, %0 : !cir.ptr<!cir.array<!s32i x 32>, addrspace(offload_local)>), !cir.ptr<!s32i, addrspace(offload_local)>
```
  • Loading branch information
seven-mile authored and smeenai committed Oct 9, 2024
1 parent 5c13864 commit feecd53
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 3 deletions.
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ mlir::Value CIRGenBuilderTy::maybeBuildArrayDecay(mlir::Location loc,
::mlir::dyn_cast<::mlir::cir::ArrayType>(arrayPtrTy.getPointee());

if (arrayTy) {
auto addrSpace = ::mlir::cast_if_present<::mlir::cir::AddressSpaceAttr>(
arrayPtrTy.getAddrSpace());
mlir::cir::PointerType flatPtrTy =
mlir::cir::PointerType::get(getContext(), arrayTy.getEltType());
getPointerTo(arrayTy.getEltType(), addrSpace);
return create<mlir::cir::CastOp>(
loc, flatPtrTy, mlir::cir::CastKind::array_to_ptrdecay, arrayPtr);
}
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,10 @@ void AggExprEmitter::buildArrayInit(Address DestPtr, mlir::cir::ArrayType AType,
QualType elementPtrType = CGF.getContext().getPointerType(elementType);

auto cirElementType = CGF.convertType(elementType);
auto cirElementPtrType = mlir::cir::PointerType::get(
CGF.getBuilder().getContext(), cirElementType);
auto cirAddrSpace = mlir::cast_if_present<mlir::cir::AddressSpaceAttr>(
DestPtr.getType().getAddrSpace());
auto cirElementPtrType =
CGF.getBuilder().getPointerTo(cirElementType, cirAddrSpace);
auto loc = CGF.getLoc(ExprToVisit->getSourceRange());

// Cast from cir.ptr<cir.array<elementType> to cir.ptr<elementType>
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ LogicalResult CastOp::verify() {
if (!arrayPtrTy || !flatPtrTy)
return emitOpError() << "requires !cir.ptr type for source and result";

if (arrayPtrTy.getAddrSpace() != flatPtrTy.getAddrSpace()) {
return emitOpError()
<< "requires same address space for source and result";
}

auto arrayTy =
mlir::dyn_cast<mlir::cir::ArrayType>(arrayPtrTy.getPointee());
if (!arrayTy)
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/array-decay.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-cir -triple spirv64-unknown-unknown %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-llvm -triple spirv64-unknown-unknown %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM

// CIR: @func1
// LLVM: @func1
kernel void func1(global int *data) {
local int arr[32];

local int *ptr = arr;
// CIR: cir.cast(array_to_ptrdecay, %{{[0-9]+}} : !cir.ptr<!cir.array<!s32i x 32>, addrspace(offload_local)>), !cir.ptr<!s32i, addrspace(offload_local)>
// CIR-NEXT: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!s32i, addrspace(offload_local)>, !cir.ptr<!cir.ptr<!s32i, addrspace(offload_local)>, addrspace(offload_private)>

// LLVM: store ptr addrspace(3) @func1.arr, ptr %{{[0-9]+}}
}

// CIR: @func2
// LLVM: @func2
kernel void func2(global int *data) {
private int arr[32] = {data[2]};
// CIR: %{{[0-9]+}} = cir.cast(array_to_ptrdecay, %{{[0-9]+}} : !cir.ptr<!cir.array<!s32i x 32>, addrspace(offload_private)>), !cir.ptr<!s32i, addrspace(offload_private)>

// LLVM: %{{[0-9]+}} = getelementptr i32, ptr %3, i32 0
}
13 changes: 13 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,16 @@ module {
cir.return
}
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @array_to_ptrdecay_addrspace() {
%0 = cir.alloca !cir.array<!s32i x 32>, !cir.ptr<!cir.array<!s32i x 32>, addrspace(offload_private)>, ["x", init]
// expected-error@+1 {{requires same address space for source and result}}
%1 = cir.cast(array_to_ptrdecay, %0 : !cir.ptr<!cir.array<!s32i x 32>, addrspace(offload_private)>), !cir.ptr<!s32i>
cir.return
}
}

0 comments on commit feecd53

Please sign in to comment.