Skip to content

Commit

Permalink
Merge pull request #2234 from KhronosGroup/fix-2226
Browse files Browse the repository at this point in the history
MSL: Support std140 half matrices and arrays.
  • Loading branch information
HansKristian-Work authored Nov 27, 2023
2 parents 61bbcb2 + 2e022db commit 3717660
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template <typename T>
struct spvPaddedStd140 { alignas(16) T data; };
template <typename T, int n>
using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];

struct Foo
{
spvPaddedStd140Matrix<half3, 2> c23;
spvPaddedStd140Matrix<half2, 3> c32;
spvPaddedStd140Matrix<half2, 3> r23;
spvPaddedStd140Matrix<half3, 2> r32;
spvPaddedStd140<half> h1[6];
spvPaddedStd140<half2> h2[6];
spvPaddedStd140<half3> h3[6];
spvPaddedStd140<half4> h4[6];
};

struct main0_out
{
float4 FragColor [[color(0)]];
};

fragment main0_out main0(device Foo& _20 [[buffer(0)]])
{
main0_out out = {};
((device half*)&_20.c23[1].data)[2u] = half(1.0);
((device half*)&_20.c32[2].data)[1u] = half(2.0);
((device half*)&_20.r23[2u])[1] = half(3.0);
((device half*)&_20.r32[1u])[2] = half(4.0);
_20.c23[1].data = half3(half(0.0), half(1.0), half(2.0));
_20.c32[1].data = half2(half(0.0), half(1.0));
((device half*)&_20.r23[0])[1] = half3(half(0.0), half(1.0), half(2.0)).x;
((device half*)&_20.r23[1])[1] = half3(half(0.0), half(1.0), half(2.0)).y;
((device half*)&_20.r23[2])[1] = half3(half(0.0), half(1.0), half(2.0)).z;
((device half*)&_20.r32[0])[1] = half2(half(0.0), half(1.0)).x;
((device half*)&_20.r32[1])[1] = half2(half(0.0), half(1.0)).y;
(device half3&)_20.c23[0] = half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0];
(device half3&)_20.c23[1] = half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1];
(device half2&)_20.c32[0] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0];
(device half2&)_20.c32[1] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1];
(device half2&)_20.c32[2] = half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2];
(device half2&)_20.r23[0] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][0], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][0]);
(device half2&)_20.r23[1] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][1], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][1]);
(device half2&)_20.r23[2] = half2(half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[0][2], half2x3(half3(half(1.0), half(2.0), half(3.0)), half3(half(4.0), half(5.0), half(6.0)))[1][2]);
(device half3&)_20.r32[0] = half3(half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0][0], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1][0], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2][0]);
(device half3&)_20.r32[1] = half3(half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[0][1], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[1][1], half3x2(half2(half(1.0), half(2.0)), half2(half(3.0), half(4.0)), half2(half(5.0), half(6.0)))[2][1]);
_20.h1[5].data = half(1.0);
_20.h2[5].data = half2(half(1.0), half(2.0));
_20.h3[5].data = half3(half(1.0), half(2.0), half(3.0));
_20.h4[5].data = half4(half(1.0), half(2.0), half(3.0), half(4.0));
((device half*)&_20.h2[5].data)[1u] = half(10.0);
((device half*)&_20.h3[5].data)[2u] = half(11.0);
((device half*)&_20.h4[5].data)[3u] = half(12.0);
out.FragColor = float4(1.0);
return out;
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template <typename T>
struct spvPaddedStd140 { alignas(16) T data; };
template <typename T, int n>
using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];

struct Foo
{
spvPaddedStd140Matrix<half2, 2> c22;
spvPaddedStd140Matrix<half2, 2> c22arr[3];
spvPaddedStd140Matrix<half3, 2> c23;
spvPaddedStd140Matrix<half4, 2> c24;
spvPaddedStd140Matrix<half2, 3> c32;
spvPaddedStd140Matrix<half3, 3> c33;
spvPaddedStd140Matrix<half4, 3> c34;
spvPaddedStd140Matrix<half2, 4> c42;
spvPaddedStd140Matrix<half3, 4> c43;
spvPaddedStd140Matrix<half4, 4> c44;
spvPaddedStd140Matrix<half2, 2> r22;
spvPaddedStd140Matrix<half2, 2> r22arr[3];
spvPaddedStd140Matrix<half2, 3> r23;
spvPaddedStd140Matrix<half2, 4> r24;
spvPaddedStd140Matrix<half3, 2> r32;
spvPaddedStd140Matrix<half3, 3> r33;
spvPaddedStd140Matrix<half3, 4> r34;
spvPaddedStd140Matrix<half4, 2> r42;
spvPaddedStd140Matrix<half4, 3> r43;
spvPaddedStd140Matrix<half4, 4> r44;
spvPaddedStd140<half> h1[6];
spvPaddedStd140<half2> h2[6];
spvPaddedStd140<half3> h3[6];
spvPaddedStd140<half4> h4[6];
};

struct main0_out
{
float4 FragColor [[color(0)]];
};

fragment main0_out main0(constant Foo& u [[buffer(0)]])
{
main0_out out = {};
half2 c2 = half2(u.c22[0].data) + half2(u.c22[1].data);
c2 = half2(u.c22arr[2][0].data) + half2(u.c22arr[2][1].data);
half3 c3 = half3(u.c23[0].data) + half3(u.c23[1].data);
half4 c4 = half4(u.c24[0].data) + half4(u.c24[1].data);
c2 = (half2(u.c32[0].data) + half2(u.c32[1].data)) + half2(u.c32[2].data);
c3 = (half3(u.c33[0].data) + half3(u.c33[1].data)) + half3(u.c33[2].data);
c4 = (half4(u.c34[0].data) + half4(u.c34[1].data)) + half4(u.c34[2].data);
c2 = ((half2(u.c42[0].data) + half2(u.c42[1].data)) + half2(u.c42[2].data)) + half2(u.c42[3].data);
c3 = ((half3(u.c43[0].data) + half3(u.c43[1].data)) + half3(u.c43[2].data)) + half3(u.c43[3].data);
c4 = ((half4(u.c44[0].data) + half4(u.c44[1].data)) + half4(u.c44[2].data)) + half4(u.c44[3].data);
half c = ((u.c22[0].data.x + u.c22[0].data.y) + u.c22[1].data.x) + u.c22[1].data.y;
c = ((u.c22arr[2][0].data.x + u.c22arr[2][0].data.y) + u.c22arr[2][1].data.x) + u.c22arr[2][1].data.y;
half2x2 c22 = half2x2(u.c22[0].data.xy, u.c22[1].data.xy);
c22 = half2x2(u.c22arr[2][0].data.xy, u.c22arr[2][1].data.xy);
half2x3 c23 = half2x3(u.c23[0].data.xyz, u.c23[1].data.xyz);
half2x4 c24 = half2x4(u.c24[0].data, u.c24[1].data);
half3x2 c32 = half3x2(u.c32[0].data.xy, u.c32[1].data.xy, u.c32[2].data.xy);
half3x3 c33 = half3x3(u.c33[0].data.xyz, u.c33[1].data.xyz, u.c33[2].data.xyz);
half3x4 c34 = half3x4(u.c34[0].data, u.c34[1].data, u.c34[2].data);
half4x2 c42 = half4x2(u.c42[0].data.xy, u.c42[1].data.xy, u.c42[2].data.xy, u.c42[3].data.xy);
half4x3 c43 = half4x3(u.c43[0].data.xyz, u.c43[1].data.xyz, u.c43[2].data.xyz, u.c43[3].data.xyz);
half4x4 c44 = half4x4(u.c44[0].data, u.c44[1].data, u.c44[2].data, u.c44[3].data);
half2 r2 = half2(u.r22[0].data[0], u.r22[1].data[0]) + half2(u.r22[0].data[1], u.r22[1].data[1]);
r2 = half2(u.r22arr[2][0].data[0], u.r22arr[2][1].data[0]) + half2(u.r22arr[2][0].data[1], u.r22arr[2][1].data[1]);
half3 r3 = half3(u.r23[0].data[0], u.r23[1].data[0], u.r23[2].data[0]) + half3(u.r23[0].data[1], u.r23[1].data[1], u.r23[2].data[1]);
half4 r4 = half4(u.r24[0].data[0], u.r24[1].data[0], u.r24[2].data[0], u.r24[3].data[0]) + half4(u.r24[0].data[1], u.r24[1].data[1], u.r24[2].data[1], u.r24[3].data[1]);
r2 = (half2(u.r32[0].data[0], u.r32[1].data[0]) + half2(u.r32[0].data[1], u.r32[1].data[1])) + half2(u.r32[0].data[2], u.r32[1].data[2]);
r3 = (half3(u.r33[0].data[0], u.r33[1].data[0], u.r33[2].data[0]) + half3(u.r33[0].data[1], u.r33[1].data[1], u.r33[2].data[1])) + half3(u.r33[0].data[2], u.r33[1].data[2], u.r33[2].data[2]);
r4 = (half4(u.r34[0].data[0], u.r34[1].data[0], u.r34[2].data[0], u.r34[3].data[0]) + half4(u.r34[0].data[1], u.r34[1].data[1], u.r34[2].data[1], u.r34[3].data[1])) + half4(u.r34[0].data[2], u.r34[1].data[2], u.r34[2].data[2], u.r34[3].data[2]);
r2 = ((half2(u.r42[0].data[0], u.r42[1].data[0]) + half2(u.r42[0].data[1], u.r42[1].data[1])) + half2(u.r42[0].data[2], u.r42[1].data[2])) + half2(u.r42[0].data[3], u.r42[1].data[3]);
r3 = ((half3(u.r43[0].data[0], u.r43[1].data[0], u.r43[2].data[0]) + half3(u.r43[0].data[1], u.r43[1].data[1], u.r43[2].data[1])) + half3(u.r43[0].data[2], u.r43[1].data[2], u.r43[2].data[2])) + half3(u.r43[0].data[3], u.r43[1].data[3], u.r43[2].data[3]);
r4 = ((half4(u.r44[0].data[0], u.r44[1].data[0], u.r44[2].data[0], u.r44[3].data[0]) + half4(u.r44[0].data[1], u.r44[1].data[1], u.r44[2].data[1], u.r44[3].data[1])) + half4(u.r44[0].data[2], u.r44[1].data[2], u.r44[2].data[2], u.r44[3].data[2])) + half4(u.r44[0].data[3], u.r44[1].data[3], u.r44[2].data[3], u.r44[3].data[3]);
half r = ((u.r22[0u].data[0] + u.r22[1u].data[0]) + u.r22[0u].data[1]) + u.r22[1u].data[1];
half2x2 r22 = transpose(half2x2(u.r22[0].data.xy, u.r22[1].data.xy));
half2x3 r23 = transpose(half3x2(u.r23[0].data.xy, u.r23[1].data.xy, u.r23[2].data.xy));
half2x4 r24 = transpose(half4x2(u.r24[0].data.xy, u.r24[1].data.xy, u.r24[2].data.xy, u.r24[3].data.xy));
half3x2 r32 = transpose(half2x3(u.r32[0].data.xyz, u.r32[1].data.xyz));
half3x3 r33 = transpose(half3x3(u.r33[0].data.xyz, u.r33[1].data.xyz, u.r33[2].data.xyz));
half3x4 r34 = transpose(half4x3(u.r34[0].data.xyz, u.r34[1].data.xyz, u.r34[2].data.xyz, u.r34[3].data.xyz));
half4x2 r42 = transpose(half2x4(u.r42[0].data, u.r42[1].data));
half4x3 r43 = transpose(half3x4(u.r43[0].data, u.r43[1].data, u.r43[2].data));
half4x4 r44 = transpose(half4x4(u.r44[0].data, u.r44[1].data, u.r44[2].data, u.r44[3].data));
half h1 = half(u.h1[5].data);
half2 h2 = half2(u.h2[5].data);
half3 h3 = half3(u.h3[5].data);
half4 h4 = half4(u.h4[5].data);
out.FragColor = float4(1.0);
return out;
}

51 changes: 51 additions & 0 deletions shaders-msl-no-opt/packing/std140-half-matrix-and-array-write.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require

layout(set = 0, binding = 0, std140) buffer Foo
{
f16mat2x3 c23;
f16mat3x2 c32;
layout(row_major) f16mat2x3 r23;
layout(row_major) f16mat3x2 r32;

float16_t h1[6];
f16vec2 h2[6];
f16vec3 h3[6];
f16vec4 h4[6];
};

layout(location = 0) out vec4 FragColor;

void main()
{
// Store scalar
c23[1][2] = 1.0hf;
c32[2][1] = 2.0hf;
r23[1][2] = 3.0hf;
r32[2][1] = 4.0hf;

// Store vector
c23[1] = f16vec3(0, 1, 2);
c32[1] = f16vec2(0, 1);
r23[1] = f16vec3(0, 1, 2);
r32[1] = f16vec2(0, 1);

// Store matrix
c23 = f16mat2x3(1, 2, 3, 4, 5, 6);
c32 = f16mat3x2(1, 2, 3, 4, 5, 6);
r23 = f16mat2x3(1, 2, 3, 4, 5, 6);
r32 = f16mat3x2(1, 2, 3, 4, 5, 6);

// Store array
h1[5] = 1.0hf;
h2[5] = f16vec2(1, 2);
h3[5] = f16vec3(1, 2, 3);
h4[5] = f16vec4(1, 2, 3, 4);

// Store scalar in array
h2[5][1] = 10.0hf;
h3[5][2] = 11.0hf;
h4[5][3] = 12.0hf;

FragColor = vec4(1.0);
}
110 changes: 110 additions & 0 deletions shaders-msl-no-opt/packing/std140-half-matrix-and-array.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require

layout(set = 0, binding = 0, std140) uniform Foo
{
f16mat2x2 c22;
f16mat2x2 c22arr[3];
f16mat2x3 c23;
f16mat2x4 c24;

f16mat3x2 c32;
f16mat3x3 c33;
f16mat3x4 c34;

f16mat4x2 c42;
f16mat4x3 c43;
f16mat4x4 c44;

layout(row_major) f16mat2x2 r22;
layout(row_major) f16mat2x2 r22arr[3];
layout(row_major) f16mat2x3 r23;
layout(row_major) f16mat2x4 r24;

layout(row_major) f16mat3x2 r32;
layout(row_major) f16mat3x3 r33;
layout(row_major) f16mat3x4 r34;

layout(row_major) f16mat4x2 r42;
layout(row_major) f16mat4x3 r43;
layout(row_major) f16mat4x4 r44;

float16_t h1[6];
f16vec2 h2[6];
f16vec3 h3[6];
f16vec4 h4[6];
} u;

layout(location = 0) out vec4 FragColor;

void main()
{
// Load vectors.
f16vec2 c2 = u.c22[0] + u.c22[1];
c2 = u.c22arr[2][0] + u.c22arr[2][1];
f16vec3 c3 = u.c23[0] + u.c23[1];
f16vec4 c4 = u.c24[0] + u.c24[1];

c2 = u.c32[0] + u.c32[1] + u.c32[2];
c3 = u.c33[0] + u.c33[1] + u.c33[2];
c4 = u.c34[0] + u.c34[1] + u.c34[2];

c2 = u.c42[0] + u.c42[1] + u.c42[2] + u.c42[3];
c3 = u.c43[0] + u.c43[1] + u.c43[2] + u.c43[3];
c4 = u.c44[0] + u.c44[1] + u.c44[2] + u.c44[3];

// Load scalars.
float16_t c = u.c22[0].x + u.c22[0].y + u.c22[1].x + u.c22[1].y;
c = u.c22arr[2][0].x + u.c22arr[2][0].y + u.c22arr[2][1].x + u.c22arr[2][1].y;

// Load full matrix.
f16mat2x2 c22 = u.c22;
c22 = u.c22arr[2];
f16mat2x3 c23 = u.c23;
f16mat2x4 c24 = u.c24;

f16mat3x2 c32 = u.c32;
f16mat3x3 c33 = u.c33;
f16mat3x4 c34 = u.c34;

f16mat4x2 c42 = u.c42;
f16mat4x3 c43 = u.c43;
f16mat4x4 c44 = u.c44;

// Same, but row-major.
f16vec2 r2 = u.r22[0] + u.r22[1];
r2 = u.r22arr[2][0] + u.r22arr[2][1];
f16vec3 r3 = u.r23[0] + u.r23[1];
f16vec4 r4 = u.r24[0] + u.r24[1];

r2 = u.r32[0] + u.r32[1] + u.r32[2];
r3 = u.r33[0] + u.r33[1] + u.r33[2];
r4 = u.r34[0] + u.r34[1] + u.r34[2];

r2 = u.r42[0] + u.r42[1] + u.r42[2] + u.r42[3];
r3 = u.r43[0] + u.r43[1] + u.r43[2] + u.r43[3];
r4 = u.r44[0] + u.r44[1] + u.r44[2] + u.r44[3];

// Load scalars.
float16_t r = u.r22[0].x + u.r22[0].y + u.r22[1].x + u.r22[1].y;

// Load full matrix.
f16mat2x2 r22 = u.r22;
f16mat2x3 r23 = u.r23;
f16mat2x4 r24 = u.r24;

f16mat3x2 r32 = u.r32;
f16mat3x3 r33 = u.r33;
f16mat3x4 r34 = u.r34;

f16mat4x2 r42 = u.r42;
f16mat4x3 r43 = u.r43;
f16mat4x4 r44 = u.r44;

float16_t h1 = u.h1[5];
f16vec2 h2 = u.h2[5];
f16vec3 h3 = u.h3[5];
f16vec4 h4 = u.h4[5];

FragColor = vec4(1.0);
}
Loading

0 comments on commit 3717660

Please sign in to comment.