Skip to content

Commit

Permalink
Merge pull request #2346 from KhronosGroup/fix-2336
Browse files Browse the repository at this point in the history
MSL: Handle OpPtrAccessChain with ArrayStride
  • Loading branch information
HansKristian-Work authored Jun 19, 2024
2 parents 5d127b9 + 7e469d0 commit d79ba7d
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct Registers
{
device float3* a;
device float3* b;
uint2 c;
uint2 d;
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(64u, 1u, 1u);

kernel void main0(constant Registers& _7 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
device float3* _41 = reinterpret_cast<device float3*>(as_type<ulong>(_7.c));
*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_7.a) + gl_GlobalInvocationID.x * 12) = float3(*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_7.a) + gl_GlobalInvocationID.x * 12)) + _7.b[gl_GlobalInvocationID.x];
*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_41) + gl_GlobalInvocationID.x * 12) = float3(*reinterpret_cast<device packed_float3 *>(reinterpret_cast<ulong>(_41) + gl_GlobalInvocationID.x * 12)) + (reinterpret_cast<device float3*>(as_type<ulong>(_7.d)))[gl_GlobalInvocationID.x];
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
; SPIR-V
; Version: 1.0
; Generator: Khronos Glslang Reference Front End; 11
; Bound: 66
; Schema: 0
OpCapability Shader
OpCapability PhysicalStorageBufferAddresses
OpExtension "SPV_KHR_physical_storage_buffer"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel PhysicalStorageBuffer64 GLSL450
OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID
OpExecutionMode %main LocalSize 64 1 1
OpSource GLSL 450
OpSourceExtension "GL_EXT_buffer_reference"
OpSourceExtension "GL_EXT_buffer_reference_uvec2"
OpSourceExtension "GL_EXT_scalar_block_layout"
OpName %main "main"
OpName %Registers "Registers"
OpMemberName %Registers 0 "a"
OpMemberName %Registers 1 "b"
OpMemberName %Registers 2 "c"
OpMemberName %Registers 3 "d"
OpName %_ ""
OpName %gl_GlobalInvocationID "gl_GlobalInvocationID"
OpMemberDecorate %Registers 0 Offset 0
OpMemberDecorate %Registers 1 Offset 8
OpMemberDecorate %Registers 2 Offset 16
OpMemberDecorate %Registers 3 Offset 24
OpDecorate %Registers Block
OpDecorate %v3float_stride12_ptr ArrayStride 12
OpDecorate %v3float_stride16_ptr ArrayStride 16
OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%v2uint = OpTypeVector %uint 2
%float = OpTypeFloat 32
%v3float = OpTypeVector %float 3
%_ptr_PhysicalStorageBuffer_v3float = OpTypePointer PhysicalStorageBuffer %v3float
%v3float_stride12_ptr = OpTypePointer PhysicalStorageBuffer %v3float
%v3float_stride16_ptr = OpTypePointer PhysicalStorageBuffer %v3float
%v3float_stride12_ptr_push = OpTypePointer PushConstant %v3float_stride12_ptr
%v3float_stride16_ptr_push = OpTypePointer PushConstant %v3float_stride16_ptr
%v2uint_ptr = OpTypePointer PushConstant %v2uint
%Registers = OpTypeStruct %v3float_stride12_ptr %v3float_stride16_ptr %v2uint %v2uint
%_ptr_PushConstant_Registers = OpTypePointer PushConstant %Registers
%_ = OpVariable %_ptr_PushConstant_Registers PushConstant
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
%uint_0 = OpConstant %uint 0
%_ptr_Input_uint = OpTypePointer Input %uint
%int_1 = OpConstant %int 1
%int_2 = OpConstant %int 2
%_ptr_PushConstant_v2uint = OpTypePointer PushConstant %v2uint
%int_3 = OpConstant %int 3
%uint_64 = OpConstant %uint 64
%uint_1 = OpConstant %uint 1
%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_64 %uint_1 %uint_1
%main = OpFunction %void None %3
%5 = OpLabel
%29 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0
%index = OpLoad %uint %29

%ptr_member_0 = OpAccessChain %v3float_stride12_ptr_push %_ %int_0
%ptr0 = OpLoad %v3float_stride12_ptr %ptr_member_0

%ptr_member_1 = OpAccessChain %v3float_stride16_ptr_push %_ %int_1
%ptr1 = OpLoad %v3float_stride16_ptr %ptr_member_1

%ptr_member_2 = OpAccessChain %v2uint_ptr %_ %int_2
%ptr2v = OpLoad %v2uint %ptr_member_2
%ptr2 = OpBitcast %v3float_stride12_ptr %ptr2v

%ptr_member_3 = OpAccessChain %v2uint_ptr %_ %int_3
%ptr3v = OpLoad %v2uint %ptr_member_3
%ptr3 = OpBitcast %v3float_stride16_ptr %ptr3v

%ptr0_chain = OpPtrAccessChain %v3float_stride12_ptr %ptr0 %index
%ptr1_chain = OpPtrAccessChain %v3float_stride16_ptr %ptr1 %index
%ptr2_chain = OpPtrAccessChain %v3float_stride12_ptr %ptr2 %index
%ptr3_chain = OpPtrAccessChain %v3float_stride16_ptr %ptr3 %index

%loaded0 = OpLoad %v3float %ptr0_chain Aligned 4
%loaded1 = OpLoad %v3float %ptr1_chain Aligned 16
%loaded2 = OpLoad %v3float %ptr2_chain Aligned 4
%loaded3 = OpLoad %v3float %ptr3_chain Aligned 16

%added0 = OpFAdd %v3float %loaded0 %loaded1
%added1 = OpFAdd %v3float %loaded2 %loaded3
OpStore %ptr0_chain %added0 Aligned 4
OpStore %ptr2_chain %added1 Aligned 4

OpReturn
OpFunctionEnd
67 changes: 63 additions & 4 deletions spirv_glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_
string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && should_dereference(id))

if (is_pointer(type) && should_dereference(id))
return dereference_expression(type, to_enclosed_expression(id, register_expression_read));
else
return to_expression(id, register_expression_read);
Expand All @@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre
string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
return address_of_expression(to_enclosed_expression(id, register_expression_read));
else
return to_unpacked_expression(id, register_expression_read);
Expand All @@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression
string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
return address_of_expression(to_enclosed_expression(id, register_expression_read));
else
return to_enclosed_unpacked_expression(id, register_expression_read);
Expand Down Expand Up @@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
}
else
{
append_index(index, is_literal, true);
if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
{
SPIRType tmp_type(OpTypeInt);
tmp_type.basetype = SPIRType::UInt64;
tmp_type.width = 64;
tmp_type.vecsize = 1;
tmp_type.columns = 1;

TypeID ptr_type_id = expression_type_id(base);
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
const SPIRType &pointee_type = get_pointee_type(ptr_type);

// This only runs in native pointer backends.
// Can replace reinterpret_cast with a backend string if ever needed.
// We expect this to count as a de-reference.
// This leaks some MSL details, but feels slightly overkill to
// add yet another virtual interface just for this.
auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
get_decoration(ptr_type_id, DecorationArrayStride));

if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
{
is_packed = true;
expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
" *>(", intptr_expr, ")");
}
else
{
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
}
}
else
append_index(index, is_literal, true);
}

if (type->basetype == SPIRType::ControlPointArray)
Expand Down Expand Up @@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
return ret;
}

uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
{
SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
}

string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
AccessChainMeta *meta, bool ptr_chain)
{
Expand Down Expand Up @@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
{
AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
if (ptr_chain)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
// PtrAccessChain could get complicated.
TypeID type_id = expression_type_id(base);
if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
{
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
// random values. It works for structs though.
auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
uint32_t physical_stride = get_physical_type_stride(pointee_type);
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
if (physical_stride != requested_stride)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
if (is_vector(pointee_type))
flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
}
}
}

return access_chain_internal(base, indices, count, flags, meta);
}
}
Expand Down
8 changes: 7 additions & 1 deletion spirv_glsl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ enum AccessChainFlagBits
ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
};
typedef uint32_t AccessChainFlags;

Expand Down Expand Up @@ -753,6 +755,10 @@ class CompilerGLSL : public Compiler
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
AccessChainMeta *meta);

// Only meaningful on backends with physical pointer support ala MSL.
// Relevant for PtrAccessChain / BDA.
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;

spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);

Expand Down
17 changes: 13 additions & 4 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4803,7 +4803,7 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32
return false;
}

if (!mbr_type.array.empty())
if (is_array(mbr_type))
{
// If we have an array type, array stride must match exactly with SPIR-V.

Expand Down Expand Up @@ -17050,13 +17050,21 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type,
return msl_size;
}

uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
{
// This should only be relevant for plain types such as scalars and vectors?
// If we're pointing to a struct, it will recursively pick up packed/row-major state.
return get_declared_type_size_msl(type, false, false);
}

// Returns the byte size of a struct member.
uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
{
// Pointers take 8 bytes each
// Match both pointer and array-of-pointer here.
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
{
uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize);
uint32_t type_size = 8;

// Work our way through potentially layered arrays,
// stopping when we hit a pointer that is not also an array.
Expand Down Expand Up @@ -17131,9 +17139,10 @@ uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t
// Returns the byte alignment of a type.
uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
{
// Pointers aligns on multiples of 8 bytes
// Pointers align on multiples of 8 bytes.
// Deliberately ignore array-ness here. It's not relevant for alignment.
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
return 8 * (type.vecsize == 3 ? 4 : type.vecsize);
return 8;

switch (type.basetype)
{
Expand Down
2 changes: 2 additions & 0 deletions spirv_msl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,8 @@ class CompilerMSL : public CompilerGLSL

uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const;

uint32_t get_physical_type_stride(const SPIRType &type) const override;

// MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output.
// These values can change depending on various extended decorations which control packing rules.
// We need to make these rules match up with SPIR-V declared rules.
Expand Down

0 comments on commit d79ba7d

Please sign in to comment.