diff --git a/reference/shaders-msl-no-opt/asm/comp/ptr-access-chain-custom-array-stride.asm.msl23.comp b/reference/shaders-msl-no-opt/asm/comp/ptr-access-chain-custom-array-stride.asm.msl23.comp new file mode 100644 index 000000000..3cb0b3a9f --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/ptr-access-chain-custom-array-stride.asm.msl23.comp @@ -0,0 +1,22 @@ +#include +#include + +using namespace metal; + +struct Registers +{ + device float3* a; + device float3* b; + uint2 c; + uint2 d; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(64u, 1u, 1u); + +kernel void main0(constant Registers& _7 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + device float3* _41 = reinterpret_cast(as_type(_7.c)); + *reinterpret_cast(reinterpret_cast(_7.a) + gl_GlobalInvocationID.x * 12) = float3(*reinterpret_cast(reinterpret_cast(_7.a) + gl_GlobalInvocationID.x * 12)) + _7.b[gl_GlobalInvocationID.x]; + *reinterpret_cast(reinterpret_cast(_41) + gl_GlobalInvocationID.x * 12) = float3(*reinterpret_cast(reinterpret_cast(_41) + gl_GlobalInvocationID.x * 12)) + (reinterpret_cast(as_type(_7.d)))[gl_GlobalInvocationID.x]; +} + diff --git a/shaders-msl-no-opt/asm/comp/ptr-access-chain-custom-array-stride.asm.msl23.comp b/shaders-msl-no-opt/asm/comp/ptr-access-chain-custom-array-stride.asm.msl23.comp new file mode 100644 index 000000000..298b4e750 --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/ptr-access-chain-custom-array-stride.asm.msl23.comp @@ -0,0 +1,98 @@ +; SPIR-V +; Version: 1.0 +; Generator: Khronos Glslang Reference Front End; 11 +; Bound: 66 +; Schema: 0 + OpCapability Shader + OpCapability PhysicalStorageBufferAddresses + OpExtension "SPV_KHR_physical_storage_buffer" + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel PhysicalStorageBuffer64 GLSL450 + OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationID + OpExecutionMode %main LocalSize 64 1 1 + OpSource GLSL 450 + OpSourceExtension "GL_EXT_buffer_reference" + OpSourceExtension "GL_EXT_buffer_reference_uvec2" + OpSourceExtension "GL_EXT_scalar_block_layout" + OpName %main "main" + OpName %Registers "Registers" + OpMemberName %Registers 0 "a" + OpMemberName %Registers 1 "b" + OpMemberName %Registers 2 "c" + OpMemberName %Registers 3 "d" + OpName %_ "" + OpName %gl_GlobalInvocationID "gl_GlobalInvocationID" + OpMemberDecorate %Registers 0 Offset 0 + OpMemberDecorate %Registers 1 Offset 8 + OpMemberDecorate %Registers 2 Offset 16 + OpMemberDecorate %Registers 3 Offset 24 + OpDecorate %Registers Block + OpDecorate %v3float_stride12_ptr ArrayStride 12 + OpDecorate %v3float_stride16_ptr ArrayStride 16 + OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 +%_ptr_PhysicalStorageBuffer_v3float = OpTypePointer PhysicalStorageBuffer %v3float +%v3float_stride12_ptr = OpTypePointer PhysicalStorageBuffer %v3float +%v3float_stride16_ptr = OpTypePointer PhysicalStorageBuffer %v3float +%v3float_stride12_ptr_push = OpTypePointer PushConstant %v3float_stride12_ptr +%v3float_stride16_ptr_push = OpTypePointer PushConstant %v3float_stride16_ptr +%v2uint_ptr = OpTypePointer PushConstant %v2uint + %Registers = OpTypeStruct %v3float_stride12_ptr %v3float_stride16_ptr %v2uint %v2uint +%_ptr_PushConstant_Registers = OpTypePointer PushConstant %Registers + %_ = OpVariable %_ptr_PushConstant_Registers PushConstant + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input + %uint_0 = OpConstant %uint 0 +%_ptr_Input_uint = OpTypePointer Input %uint + %int_1 = OpConstant %int 1 + %int_2 = OpConstant %int 2 +%_ptr_PushConstant_v2uint = OpTypePointer PushConstant %v2uint + %int_3 = OpConstant %int 3 + %uint_64 = OpConstant %uint 64 + %uint_1 = OpConstant %uint 1 +%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_64 %uint_1 %uint_1 + %main = OpFunction %void None %3 + %5 = OpLabel + %29 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0 + %index = OpLoad %uint %29 + + %ptr_member_0 = OpAccessChain %v3float_stride12_ptr_push %_ %int_0 + %ptr0 = OpLoad %v3float_stride12_ptr %ptr_member_0 + + %ptr_member_1 = OpAccessChain %v3float_stride16_ptr_push %_ %int_1 + %ptr1 = OpLoad %v3float_stride16_ptr %ptr_member_1 + + %ptr_member_2 = OpAccessChain %v2uint_ptr %_ %int_2 + %ptr2v = OpLoad %v2uint %ptr_member_2 + %ptr2 = OpBitcast %v3float_stride12_ptr %ptr2v + + %ptr_member_3 = OpAccessChain %v2uint_ptr %_ %int_3 + %ptr3v = OpLoad %v2uint %ptr_member_3 + %ptr3 = OpBitcast %v3float_stride16_ptr %ptr3v + + %ptr0_chain = OpPtrAccessChain %v3float_stride12_ptr %ptr0 %index + %ptr1_chain = OpPtrAccessChain %v3float_stride16_ptr %ptr1 %index + %ptr2_chain = OpPtrAccessChain %v3float_stride12_ptr %ptr2 %index + %ptr3_chain = OpPtrAccessChain %v3float_stride16_ptr %ptr3 %index + + %loaded0 = OpLoad %v3float %ptr0_chain Aligned 4 + %loaded1 = OpLoad %v3float %ptr1_chain Aligned 16 + %loaded2 = OpLoad %v3float %ptr2_chain Aligned 4 + %loaded3 = OpLoad %v3float %ptr3_chain Aligned 16 + + %added0 = OpFAdd %v3float %loaded0 %loaded1 + %added1 = OpFAdd %v3float %loaded2 %loaded3 + OpStore %ptr0_chain %added0 Aligned 4 + OpStore %ptr2_chain %added1 Aligned 4 + + OpReturn + OpFunctionEnd diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 3f13febcc..fad1132e8 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read) { auto &type = expression_type(id); - if (type.pointer && should_dereference(id)) + + if (is_pointer(type) && should_dereference(id)) return dereference_expression(type, to_enclosed_expression(id, register_expression_read)); else return to_expression(id, register_expression_read); @@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read) { auto &type = expression_type(id); - if (type.pointer && expression_is_lvalue(id) && !should_dereference(id)) + if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id)) return address_of_expression(to_enclosed_expression(id, register_expression_read)); else return to_unpacked_expression(id, register_expression_read); @@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read) { auto &type = expression_type(id); - if (type.pointer && expression_is_lvalue(id) && !should_dereference(id)) + if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id)) return address_of_expression(to_enclosed_expression(id, register_expression_read)); else return to_enclosed_unpacked_expression(id, register_expression_read); @@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice } else { - append_index(index, is_literal, true); + if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT) + { + SPIRType tmp_type(OpTypeInt); + tmp_type.basetype = SPIRType::UInt64; + tmp_type.width = 64; + tmp_type.vecsize = 1; + tmp_type.columns = 1; + + TypeID ptr_type_id = expression_type_id(base); + const SPIRType &ptr_type = get(ptr_type_id); + const SPIRType &pointee_type = get_pointee_type(ptr_type); + + // This only runs in native pointer backends. + // Can replace reinterpret_cast with a backend string if ever needed. + // We expect this to count as a de-reference. + // This leaks some MSL details, but feels slightly overkill to + // add yet another virtual interface just for this. + auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")"); + intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ", + get_decoration(ptr_type_id, DecorationArrayStride)); + + if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT) + { + is_packed = true; + expr = join("*reinterpret_cast(", intptr_expr, ")"); + } + else + { + expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")"); + } + } + else + append_index(index, is_literal, true); } if (type->basetype == SPIRType::ControlPointArray) @@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP return ret; } +uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const +{ + SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support."); +} + string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type, AccessChainMeta *meta, bool ptr_chain) { @@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32 { AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT; if (ptr_chain) + { flags |= ACCESS_CHAIN_PTR_CHAIN_BIT; + // PtrAccessChain could get complicated. + TypeID type_id = expression_type_id(base); + if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride)) + { + // If there is a mismatch we have to go via 64-bit pointer arithmetic :'( + // Using packed hacks only gets us so far, and is not designed to deal with pointer to + // random values. It works for structs though. + auto &pointee_type = get_pointee_type(get(type_id)); + uint32_t physical_stride = get_physical_type_stride(pointee_type); + uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride); + if (physical_stride != requested_stride) + { + flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT; + if (is_vector(pointee_type)) + flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT; + } + } + } + return access_chain_internal(base, indices, count, flags, meta); } } diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index f3e545e9f..8a0026323 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -66,7 +66,9 @@ enum AccessChainFlagBits ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3, ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4, ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5, - ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6 + ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6, + ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7, + ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8 }; typedef uint32_t AccessChainFlags; @@ -753,6 +755,10 @@ class CompilerGLSL : public Compiler std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags, AccessChainMeta *meta); + // Only meaningful on backends with physical pointer support ala MSL. + // Relevant for PtrAccessChain / BDA. + virtual uint32_t get_physical_type_stride(const SPIRType &type) const; + spv::StorageClass get_expression_effective_storage_class(uint32_t ptr); virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base); diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 75d935d8a..ebf5ffe03 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -4803,7 +4803,7 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32 return false; } - if (!mbr_type.array.empty()) + if (is_array(mbr_type)) { // If we have an array type, array stride must match exactly with SPIR-V. @@ -17050,13 +17050,21 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, return msl_size; } +uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const +{ + // This should only be relevant for plain types such as scalars and vectors? + // If we're pointing to a struct, it will recursively pick up packed/row-major state. + return get_declared_type_size_msl(type, false, false); +} + // Returns the byte size of a struct member. uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const { // Pointers take 8 bytes each + // Match both pointer and array-of-pointer here. if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) { - uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize); + uint32_t type_size = 8; // Work our way through potentially layered arrays, // stopping when we hit a pointer that is not also an array. @@ -17131,9 +17139,10 @@ uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t // Returns the byte alignment of a type. uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const { - // Pointers aligns on multiples of 8 bytes + // Pointers align on multiples of 8 bytes. + // Deliberately ignore array-ness here. It's not relevant for alignment. if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) - return 8 * (type.vecsize == 3 ? 4 : type.vecsize); + return 8; switch (type.basetype) { diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 9a1715808..2d970c0da 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -1028,6 +1028,8 @@ class CompilerMSL : public CompilerGLSL uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const; + uint32_t get_physical_type_stride(const SPIRType &type) const override; + // MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output. // These values can change depending on various extended decorations which control packing rules. // We need to make these rules match up with SPIR-V declared rules.