Skip to content

Commit

Permalink
Merge pull request #2310 from KhronosGroup/fix-2300
Browse files Browse the repository at this point in the history
MSL: Fix SUMulExtended for 64-bit inputs.
  • Loading branch information
HansKristian-Work authored Apr 15, 2024
2 parents 56f24d8 + eef4c2a commit 25b7772
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 27 deletions.
4 changes: 2 additions & 2 deletions reference/opt/shaders-msl/asm/comp/uint_smulextended.asm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ struct _20
kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
_20 _28;
_28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]));
_28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x])));
_28._m0 = int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]);
_28._m1 = mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x]));
_7._m0[gl_GlobalInvocationID.x] = _28._m0;
_8._m0[gl_GlobalInvocationID.x] = _28._m1;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct _4
{
ulong _m0[1];
};

struct _21
{
ulong _m0;
ulong _m1;
};

kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
_21 _29;
_29._m0 = long(_5._m0[gl_GlobalInvocationID.x]) * long(_6._m0[gl_GlobalInvocationID.x]);
_29._m1 = mulhi(long(_5._m0[gl_GlobalInvocationID.x]), long(_6._m0[gl_GlobalInvocationID.x]));
_7._m0[gl_GlobalInvocationID.x] = _29._m0;
_8._m0[gl_GlobalInvocationID.x] = _29._m1;
}

4 changes: 2 additions & 2 deletions reference/shaders-msl/asm/comp/uint_smulextended.asm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ struct _20
kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
_20 _28;
_28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]));
_28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x])));
_28._m0 = int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]);
_28._m1 = mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x]));
_7._m0[gl_GlobalInvocationID.x] = _28._m0;
_8._m0[gl_GlobalInvocationID.x] = _28._m1;
}
Expand Down
25 changes: 25 additions & 0 deletions reference/shaders-msl/asm/comp/ulong_smulextended.asm.msl23.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct _4
{
ulong _m0[1];
};

struct _21
{
ulong _m0;
ulong _m1;
};

kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
_21 _29;
_29._m0 = long(_5._m0[gl_GlobalInvocationID.x]) * long(_6._m0[gl_GlobalInvocationID.x]);
_29._m1 = mulhi(long(_5._m0[gl_GlobalInvocationID.x]), long(_6._m0[gl_GlobalInvocationID.x]));
_7._m0[gl_GlobalInvocationID.x] = _29._m0;
_8._m0[gl_GlobalInvocationID.x] = _29._m1;
}

63 changes: 63 additions & 0 deletions shaders-msl/asm/comp/ulong_smulextended.asm.msl23.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
OpCapability Shader
OpCapability Int64

OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationId
OpExecutionMode %main LocalSize 1 1 1

OpDecorate %gl_GlobalInvocationId BuiltIn GlobalInvocationId
OpDecorate %ra_ulong ArrayStride 8
OpDecorate %struct_ulong4 BufferBlock
OpMemberDecorate %struct_ulong4 0 Offset 0
OpDecorate %input0 DescriptorSet 0
OpDecorate %input0 Binding 0
OpDecorate %input1 DescriptorSet 0
OpDecorate %input1 Binding 1
OpDecorate %output0 DescriptorSet 0
OpDecorate %output0 Binding 2
OpDecorate %output1 DescriptorSet 0
OpDecorate %output1 Binding 3

%uint = OpTypeInt 32 0
%ulong = OpTypeInt 64 0
%ptr_ulong = OpTypePointer Uniform %ulong
%ptr_input_uint = OpTypePointer Input %uint
%uint3 = OpTypeVector %uint 3
%ptr_input_uint3 = OpTypePointer Input %uint3
%void = OpTypeVoid
%voidFn = OpTypeFunction %void

%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%ra_ulong = OpTypeRuntimeArray %ulong
%ulong4 = OpTypeVector %uint 4
%struct_ulong4 = OpTypeStruct %ra_ulong
%ptr_struct_ulong4 = OpTypePointer Uniform %struct_ulong4
%resulttype = OpTypeStruct %ulong %ulong
%gl_GlobalInvocationId = OpVariable %ptr_input_uint3 Input
%input0 = OpVariable %ptr_struct_ulong4 Uniform
%input1 = OpVariable %ptr_struct_ulong4 Uniform

%output0 = OpVariable %ptr_struct_ulong4 Uniform
%output1 = OpVariable %ptr_struct_ulong4 Uniform

%main = OpFunction %void None %voidFn
%mainStart = OpLabel
%index_ptr = OpAccessChain %ptr_input_uint %gl_GlobalInvocationId %uint_0
%index = OpLoad %uint %index_ptr
%in_ptr0 = OpAccessChain %ptr_ulong %input0 %uint_0 %index
%invalue0 = OpLoad %ulong %in_ptr0
%in_ptr1 = OpAccessChain %ptr_ulong %input1 %uint_0 %index
%invalue1 = OpLoad %ulong %in_ptr1

%outvalue = OpSMulExtended %resulttype %invalue0 %invalue1
%outvalue0 = OpCompositeExtract %ulong %outvalue 0
%out_ptr0 = OpAccessChain %ptr_ulong %output0 %uint_0 %index
OpStore %out_ptr0 %outvalue0
%outvalue1 = OpCompositeExtract %ulong %outvalue 1
%out_ptr1 = OpAccessChain %ptr_ulong %output1 %uint_0 %index
OpStore %out_ptr1 %outvalue1


OpReturn
OpFunctionEnd
4 changes: 4 additions & 0 deletions spirv_glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11610,6 +11610,10 @@ uint32_t CompilerGLSL::get_integer_width_for_instruction(const Instruction &inst
case OpUGreaterThanEqual:
return expression_type(ops[2]).width;

case OpSMulExtended:
case OpUMulExtended:
return get<SPIRType>(get<SPIRType>(ops[0]).member_types[0]).width;

default:
{
// We can look at result type which is more robust.
Expand Down
26 changes: 3 additions & 23 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9406,32 +9406,12 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
uint32_t op1 = ops[3];
auto &type = get<SPIRType>(result_type);
auto input_type = opcode == OpSMulExtended ? int_type : uint_type;
auto &output_type = get_type(result_type);
string cast_op0, cast_op1;

auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false);

binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false);
emit_uninitialized_temporary_expression(result_type, result_id);

string mullo_expr, mulhi_expr;
mullo_expr = join(cast_op0, " * ", cast_op1);
mulhi_expr = join("mulhi(", cast_op0, ", ", cast_op1, ")");

auto &low_type = get_type(output_type.member_types[0]);
auto &high_type = get_type(output_type.member_types[1]);
if (low_type.basetype != input_type)
{
expected_type.basetype = input_type;
mullo_expr = join(bitcast_glsl_op(low_type, expected_type), "(", mullo_expr, ")");
}
if (high_type.basetype != input_type)
{
expected_type.basetype = input_type;
mulhi_expr = join(bitcast_glsl_op(high_type, expected_type), "(", mulhi_expr, ")");
}

statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", mullo_expr, ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = ", mulhi_expr, ";");
statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", cast_op0, " * ", cast_op1, ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", cast_op0, ", ", cast_op1, ");");
break;
}

Expand Down

0 comments on commit 25b7772

Please sign in to comment.