Skip to content

Commit

Permalink
Merge pull request #2306 from KhronosGroup/pr-2292
Browse files Browse the repository at this point in the history
MSL: Implement support for EXT_mutable_descriptor_type and general aliasing with argument buffers
  • Loading branch information
HansKristian-Work authored Apr 3, 2024
2 parents f9393f4 + 061bf6b commit 0640756
Show file tree
Hide file tree
Showing 16 changed files with 770 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

using namespace metal;

template <typename ImageT>
void spvImageFence(ImageT img) { img.fence(); }

static inline __attribute__((always_inline))
void _main(thread const uint3& id, texture2d<float, access::read_write> TargetTexture)
{
TargetTexture.fence();
spvImageFence(TargetTexture);
float2 loaded = TargetTexture.read(uint2(id.xy)).xy;
float2 storeTemp = loaded + float2(1.0);
TargetTexture.write(storeTemp.xyyy, uint2((id.xy + uint2(1u))));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wincompatible-pointer-types-discards-qualifiers"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template<typename T>
struct spvDescriptor
{
T value;
};

template<typename T>
struct spvDescriptorArray
{
spvDescriptorArray(const device spvDescriptor<T>* ptr) : ptr(ptr)
{
}
const device T& operator [] (size_t i) const
{
return ptr[i].value;
}
const device spvDescriptor<T>* ptr;
};

template <typename ImageT>
void spvImageFence(ImageT img) { img.fence(); }

struct B10
{
float v;
};

struct B11
{
float v;
};

struct B20
{
float v;
};

struct B21
{
float v;
};

struct B30
{
uint i;
};

struct B31
{
float v;
};

struct B40
{
float v;
};

struct B41
{
float v;
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
constant float4 _477 = {};

struct spvDescriptorSetBuffer0
{
array<texture2d<float>, 8> t00 [[id(0)]];
// Overlapping binding: array<texture2d<uint>, 8> t01 [[id(0)]];
// Overlapping binding: array<texture2d<int>, 8> t02 [[id(0)]];
// Overlapping binding: array<texture_buffer<uint, access::read_write>, 8> u0 [[id(0)]];
// Overlapping binding: array<sampler, 8> s00 [[id(0)]];
};

struct spvDescriptorSetBuffer1
{
spvDescriptor<device B30 *> b30 [[id(0)]][1] /* unsized array hack */;
// Overlapping binding: spvDescriptor<constant B31 *> b31 [[id(0)]][1] /* unsized array hack */;
// Overlapping binding: spvDescriptor<texture2d<uint>> t31 [[id(0)]][1] /* unsized array hack */;
// Overlapping binding: spvDescriptor<texture2d<int>> t32 [[id(0)]][1] /* unsized array hack */;
// Overlapping binding: spvDescriptor<texture_buffer<uint, access::read_write>> u3 [[id(0)]][1] /* unsized array hack */;
};

struct spvDescriptorSetBuffer2
{
device B20* b20 [[id(0)]][8];
// Overlapping binding: constant B21* b21 [[id(0)]][8];
// Overlapping binding: array<texture2d<uint>, 8> t21 [[id(0)]];
// Overlapping binding: array<texture2d<int>, 8> t22 [[id(0)]];
// Overlapping binding: array<texture_buffer<uint, access::read_write>, 8> u2 [[id(0)]];
};

struct spvDescriptorSetBuffer3
{
device B10* b10 [[id(0)]][8];
// Overlapping binding: constant B11* b11 [[id(0)]][8];
// Overlapping binding: array<texture_buffer<uint, access::read_write>, 8> u1 [[id(0)]];
};

struct spvDescriptorSetBuffer4
{
device B40* b40 [[id(0)]];
// Overlapping binding: constant B41* b41 [[id(0)]];
// Overlapping binding: texture2d<uint> t41 [[id(0)]];
// Overlapping binding: texture2d<int> t42 [[id(0)]];
// Overlapping binding: texture_buffer<uint, access::read_write> u4 [[id(0)]];
};

kernel void main0(const device spvDescriptorSetBuffer0& spvDescriptorSet0 [[buffer(0)]], const device spvDescriptorSetBuffer1& spvDescriptorSet1 [[buffer(1)]], constant spvDescriptorSetBuffer2& spvDescriptorSet2 [[buffer(2)]], constant spvDescriptorSetBuffer3& spvDescriptorSet3 [[buffer(3)]], constant spvDescriptorSetBuffer4& spvDescriptorSet4 [[buffer(4)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
spvDescriptorArray<device B30*> b30 {spvDescriptorSet1.b30};
spvDescriptorArray<texture2d<uint>> t31 {reinterpret_cast<const device spvDescriptor<texture2d<uint>>*>(&spvDescriptorSet1.b30)};
spvDescriptorArray<texture2d<int>> t32 {reinterpret_cast<const device spvDescriptor<texture2d<int>>*>(&spvDescriptorSet1.b30)};
spvDescriptorArray<constant B31*> b31 {reinterpret_cast<spvDescriptor<constant B31 *> const device *>(&spvDescriptorSet1.b30)};
spvDescriptorArray<texture_buffer<uint, access::read_write>> u3 {reinterpret_cast<const device spvDescriptor<texture_buffer<uint, access::read_write>>*>(&spvDescriptorSet1.b30)};

const device auto &t01 = reinterpret_cast<const device array<texture2d<uint>, 8> &>(spvDescriptorSet0.t00);
const device auto &t02 = reinterpret_cast<const device array<texture2d<int>, 8> &>(spvDescriptorSet0.t00);
const device auto &u0 = reinterpret_cast<const device array<texture_buffer<uint, access::read_write>, 8> &>(spvDescriptorSet0.t00);
const device auto &s00 = reinterpret_cast<const device array<sampler, 8> &>(spvDescriptorSet0.t00);
constant auto &b21 = reinterpret_cast<constant B21* constant (&)[8]>(spvDescriptorSet2.b20);
constant auto &t21 = reinterpret_cast<constant array<texture2d<uint>, 8> &>(spvDescriptorSet2.b20);
constant auto &t22 = reinterpret_cast<constant array<texture2d<int>, 8> &>(spvDescriptorSet2.b20);
constant auto &u2 = reinterpret_cast<constant array<texture_buffer<uint, access::read_write>, 8> &>(spvDescriptorSet2.b20);
constant auto &b11 = reinterpret_cast<constant B11* constant (&)[8]>(spvDescriptorSet3.b10);
constant auto &u1 = reinterpret_cast<constant array<texture_buffer<uint, access::read_write>, 8> &>(spvDescriptorSet3.b10);
constant auto &b41 = *reinterpret_cast<constant B41* constant &>(spvDescriptorSet4.b40);
constant auto &t41 = reinterpret_cast<constant texture2d<uint> &>(spvDescriptorSet4.b40);
constant auto &t42 = reinterpret_cast<constant texture2d<int> &>(spvDescriptorSet4.b40);
constant auto &u4 = reinterpret_cast<constant texture_buffer<uint, access::read_write> &>(spvDescriptorSet4.b40);
float4 _292 = spvDescriptorSet0.t00[0].sample(s00[3], float2(0.0), level(0.0));
_292.x = as_type<float>(t01[1].read(uint2(int2(0)), 0).x);
_292.y = as_type<float>(t02[2].read(uint2(int2(0)), 0).x);
spvImageFence(u0[2]);
_292.z = as_type<float>(u0[2].read(uint(0)).x);
float4 _448;
_448.x = spvDescriptorSet3.b10[3]->v;
_448.y = b11[4]->v;
spvImageFence(u1[2]);
_448.z = as_type<float>(u1[2].read(uint(0)).x);
float _342 = spvDescriptorSet2.b20[3]->v;
spvImageFence(u2[2]);
uint _356 = b30[gl_WorkGroupID.x]->i;
uint _388 = _356 + 6u;
spvImageFence(u3[_388]);
float _410 = (*spvDescriptorSet4.b40).v;
spvImageFence(u4);
u0[0].write(as_type<uint4>(_292), uint(0));
u1[0].write(as_type<uint4>(_448), uint(0));
u2[0].write(as_type<uint4>(float4(as_type<float>(t21[1].read(uint2(int2(0)), 0).x), as_type<float>(t22[2].read(uint2(int2(0)), 0).x), _342 + as_type<float>(u2[2].read(uint(0)).x), b21[4]->v)), uint(0));
u3[0].write(as_type<uint4>(float4(as_type<float>(t31[_356 + 2u].read(uint2(int2(0)), 0).x), as_type<float>(t32[_356 + 3u].read(uint2(int2(0)), 0).x), b31[_356 + 5u]->v, as_type<float>(u3[_388].read(uint(0)).x))), uint(0));
u4.write(as_type<uint4>(float4(as_type<float>(t41.read(uint2(int2(0)), 0).x), as_type<float>(t42.read(uint2(int2(0)), 0).x), _410 + b41.v, as_type<float>(u4.read(uint(0)).x))), uint(0));
}

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma clang diagnostic ignored "-Wincompatible-pointer-types-discards-qualifiers"
#include <metal_stdlib>
#include <simd/simd.h>

Expand Down Expand Up @@ -87,25 +88,31 @@ constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(64u, 1u, 1u);

struct spvDescriptorSetBuffer0
{
device SSBO_A* ssbo_a [[id(0)]];
constant UBO_C* ubo_c [[id(1)]];
// Overlapping binding: constant UBO_D* ubo_d [[id(1)]];
device SSBO_As* ssbo_as [[id(2)]][4];
// Overlapping binding: device SSBO_Bs* ssbo_bs [[id(2)]][4];
// Overlapping binding: const device SSBO_BsRO* ssbo_bs_readonly [[id(2)]][4];
constant UBO_Cs* ubo_cs [[id(6)]][4];
};

kernel void main0(const device spvDescriptorSetBuffer0& spvDescriptorSet0 [[buffer(0)]], constant Registers& _42 [[buffer(1)]], device void* spvBufferAliasSet2Binding0 [[buffer(2)]], constant void* spvBufferAliasSet2Binding1 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
device auto& ssbo_e = *(device SSBO_E*)spvBufferAliasSet2Binding0;
constant auto& ubo_g = *(constant UBO_G*)spvBufferAliasSet2Binding1;
device auto& ssbo_f = *(device SSBO_F*)spvBufferAliasSet2Binding0;
constant auto& ubo_h = *(constant UBO_H*)spvBufferAliasSet2Binding1;
const device auto& ssbo_i = *(const device SSBO_I*)spvBufferAliasSet2Binding0;
device auto& ssbo_b = (device SSBO_B&)(*spvDescriptorSet0.ssbo_a);
constant auto& ubo_d = (constant UBO_D&)(*spvDescriptorSet0.ubo_c);
const device auto& ssbo_b_readonly = (const device SSBO_BRO&)(*spvDescriptorSet0.ssbo_a);
const device auto& ssbo_bs = (device SSBO_Bs* const device (&)[4])spvDescriptorSet0.ssbo_as;
const device auto& ubo_ds = (constant UBO_Ds* const device (&)[4])spvDescriptorSet0.ubo_cs;
const device auto& ssbo_bs_readonly = (const device SSBO_BsRO* const device (&)[4])spvDescriptorSet0.ssbo_as;
// Overlapping binding: constant UBO_Ds* ubo_ds [[id(6)]][4];
device SSBO_A* ssbo_a [[id(10)]];
// Overlapping binding: device SSBO_B* ssbo_b [[id(10)]];
// Overlapping binding: const device SSBO_BRO* ssbo_b_readonly [[id(10)]];
};

kernel void main0(const device spvDescriptorSetBuffer0& spvDescriptorSet0 [[buffer(0)]], constant Registers& _42 [[buffer(1)]], device void* spvBufferAliasSet2Binding11 [[buffer(11)]], constant void* spvBufferAliasSet2Binding12 [[buffer(12)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
device auto& ssbo_e = *(device SSBO_E*)spvBufferAliasSet2Binding11;
constant auto& ubo_g = *(constant UBO_G*)spvBufferAliasSet2Binding12;
device auto& ssbo_f = *(device SSBO_F*)spvBufferAliasSet2Binding11;
constant auto& ubo_h = *(constant UBO_H*)spvBufferAliasSet2Binding12;
const device auto& ssbo_i = *(const device SSBO_I*)spvBufferAliasSet2Binding11;
constant auto &ubo_d = *reinterpret_cast<constant UBO_D* const device &>(spvDescriptorSet0.ubo_c);
const device auto &ssbo_bs = reinterpret_cast<device SSBO_Bs* const device (&)[4]>(spvDescriptorSet0.ssbo_as);
const device auto &ssbo_bs_readonly = reinterpret_cast<const device SSBO_BsRO* const device (&)[4]>(spvDescriptorSet0.ssbo_as);
const device auto &ubo_ds = reinterpret_cast<constant UBO_Ds* const device (&)[4]>(spvDescriptorSet0.ubo_cs);
device auto &ssbo_b = *reinterpret_cast<device SSBO_B* const device &>(spvDescriptorSet0.ssbo_a);
const device auto &ssbo_b_readonly = *reinterpret_cast<const device SSBO_BRO* const device &>(spvDescriptorSet0.ssbo_a);
(*spvDescriptorSet0.ssbo_a).data[gl_GlobalInvocationID.x] = (*spvDescriptorSet0.ubo_c).data[gl_WorkGroupID.x].x + _42.reg;
ssbo_b.data[gl_GlobalInvocationID.x] = ubo_d.data[gl_WorkGroupID.y].xy + ssbo_b_readonly.data[gl_GlobalInvocationID.x];
spvDescriptorSet0.ssbo_as[gl_WorkGroupID.x]->data[gl_GlobalInvocationID.x] = spvDescriptorSet0.ubo_cs[gl_WorkGroupID.x]->data[0].x;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ struct spvDescriptorSetBuffer0
{
device SSBO_A* ssbo_a [[id(0)]];
constant UBO_C* ubo_c [[id(1)]];
device SSBO_As* ssbo_as [[id(2)]][4];
constant UBO_Cs* ubo_cs [[id(6)]][4];
device SSBO_B* ssbo_b [[id(2)]];
constant UBO_D* ubo_d [[id(3)]];
const device SSBO_BRO* ssbo_b_readonly [[id(4)]];
device SSBO_As* ssbo_as [[id(5)]][4];
constant UBO_Cs* ubo_cs [[id(9)]][4];
device SSBO_Bs* ssbo_bs [[id(13)]][4];
constant UBO_Ds* ubo_ds [[id(17)]][4];
const device SSBO_BsRO* ssbo_bs_readonly [[id(21)]][4];
};

kernel void main0(constant spvDescriptorSetBuffer0& spvDescriptorSet0 [[buffer(0)]], constant Registers& _42 [[buffer(1)]], device void* spvBufferAliasSet2Binding0 [[buffer(2)]], constant void* spvBufferAliasSet2Binding1 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
Expand All @@ -100,16 +106,10 @@ kernel void main0(constant spvDescriptorSetBuffer0& spvDescriptorSet0 [[buffer(0
device auto& ssbo_f = *(device SSBO_F*)spvBufferAliasSet2Binding0;
constant auto& ubo_h = *(constant UBO_H*)spvBufferAliasSet2Binding1;
const device auto& ssbo_i = *(const device SSBO_I*)spvBufferAliasSet2Binding0;
device auto& ssbo_b = (device SSBO_B&)(*spvDescriptorSet0.ssbo_a);
constant auto& ubo_d = (constant UBO_D&)(*spvDescriptorSet0.ubo_c);
const device auto& ssbo_b_readonly = (const device SSBO_BRO&)(*spvDescriptorSet0.ssbo_a);
constant auto& ssbo_bs = (device SSBO_Bs* constant (&)[4])spvDescriptorSet0.ssbo_as;
constant auto& ubo_ds = (constant UBO_Ds* constant (&)[4])spvDescriptorSet0.ubo_cs;
constant auto& ssbo_bs_readonly = (const device SSBO_BsRO* constant (&)[4])spvDescriptorSet0.ssbo_as;
(*spvDescriptorSet0.ssbo_a).data[gl_GlobalInvocationID.x] = (*spvDescriptorSet0.ubo_c).data[gl_WorkGroupID.x].x + _42.reg;
ssbo_b.data[gl_GlobalInvocationID.x] = ubo_d.data[gl_WorkGroupID.y].xy + ssbo_b_readonly.data[gl_GlobalInvocationID.x];
(*spvDescriptorSet0.ssbo_b).data[gl_GlobalInvocationID.x] = (*spvDescriptorSet0.ubo_d).data[gl_WorkGroupID.y].xy + (*spvDescriptorSet0.ssbo_b_readonly).data[gl_GlobalInvocationID.x];
spvDescriptorSet0.ssbo_as[gl_WorkGroupID.x]->data[gl_GlobalInvocationID.x] = spvDescriptorSet0.ubo_cs[gl_WorkGroupID.x]->data[0].x;
ssbo_bs[gl_WorkGroupID.x]->data[gl_GlobalInvocationID.x] = ubo_ds[gl_WorkGroupID.x]->data[0].xy + ssbo_bs_readonly[gl_WorkGroupID.x]->data[gl_GlobalInvocationID.x];
spvDescriptorSet0.ssbo_bs[gl_WorkGroupID.x]->data[gl_GlobalInvocationID.x] = spvDescriptorSet0.ubo_ds[gl_WorkGroupID.x]->data[0].xy + spvDescriptorSet0.ssbo_bs_readonly[gl_WorkGroupID.x]->data[gl_GlobalInvocationID.x];
ssbo_e.data[gl_GlobalInvocationID.x] = ubo_g.data[gl_WorkGroupID.x].x;
ssbo_f.data[gl_GlobalInvocationID.x] = ubo_h.data[gl_WorkGroupID.y].xy + ssbo_i.data[gl_GlobalInvocationID.x];
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template <typename ImageT>
void spvImageFence(ImageT img) { img.fence(); }

fragment void main0(texture2d_ms<float> uImageMS [[texture(0)]], texture2d_array<float, access::read_write> uImageArray [[texture(1)]], texture2d<float, access::write> uImage [[texture(2)]])
{
uImageArray.fence();
spvImageFence(uImageArray);
uImage.write(uImageMS.read(uint2(int2(1, 2)), 2), uint2(int2(2, 3)));
uImageArray.write(uImageArray.read(uint2(int3(1, 2, 4).xy), uint(int3(1, 2, 4).z)), uint2(int3(2, 3, 7).xy), uint(int3(2, 3, 7).z));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

using namespace metal;

template <typename ImageT>
void spvImageFence(ImageT img) { img.fence(); }

static inline __attribute__((always_inline))
void _main(thread const uint3& id, texture2d<float, access::read_write> TargetTexture)
{
TargetTexture.fence();
spvImageFence(TargetTexture);
float2 loaded = TargetTexture.read(uint2(id.xy)).xy;
float2 storeTemp = loaded + float2(1.0);
TargetTexture.write(storeTemp.xyyy, uint2((id.xy + uint2(1u))));
Expand Down
Loading

0 comments on commit 0640756

Please sign in to comment.