Skip to content

Commit

Permalink
Merge pull request #2401 from KhronosGroup/pr-2400
Browse files Browse the repository at this point in the history
Land PR 2400.
  • Loading branch information
HansKristian-Work authored Oct 30, 2024
2 parents f93223f + 492fa9c commit 2f08260
Show file tree
Hide file tree
Showing 9 changed files with 1,478 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}

object_data T& operator [] (size_t pos) object_data
{
return elements[pos];
}
constexpr const object_data T& operator [] (size_t pos) const object_data
{
return elements[pos];
}
};

void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)
{
if (gl_LocalInvocationIndex == 0)
{
spvMeshSizes.x = vertexCount;
spvMeshSizes.y = primitiveCount;
}
}

struct gl_MeshPerPrimitiveEXT
{
uint gl_PrimitiveID [[primitive_id]];
uint gl_Layer [[render_target_array_index]];
uint gl_ViewportIndex [[viewport_array_index]];
bool gl_CullPrimitiveEXT [[primitive_culled]];
};

struct gl_MeshPerVertexEXT
{
float4 gl_Position [[position]];
float gl_PointSize;
float gl_ClipDistance [[clip_distance]] [1];
};

struct BlockOut
{
float4 a;
float4 b;
};

struct BlockOutPrim
{
float4 a;
float4 b;
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u);

struct spvPerVertex
{
float4 gl_Position [[position]];
float gl_PointSize;
float gl_ClipDistance [[clip_distance]] [1];
float gl_ClipDistance_0 [[user(clip0)]];
float4 vOut [[user(locn0)]];
float4 outputs_a [[user(locn2)]];
float4 outputs_b [[user(locn3)]];
};

struct spvPerPrimitive
{
uint gl_PrimitiveID [[primitive_id]];
uint gl_Layer [[render_target_array_index]];
uint gl_ViewportIndex [[viewport_array_index]];
bool gl_CullPrimitiveEXT [[primitive_culled]];
float4 vPrim [[user(locn1)]];
float4 prim_outputs_a [[user(locn4)]];
float4 prim_outputs_b [[user(locn5)]];
};

using spvMesh_t = mesh<spvPerVertex, spvPerPrimitive, 24, 22, topology::line>;

static inline __attribute__((always_inline))
void _4(threadgroup spvUnsafeArray<uint2, 22>& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray<gl_MeshPerPrimitiveEXT, 22>& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT, threadgroup spvUnsafeArray<float4, 24>& vOut, threadgroup spvUnsafeArray<BlockOut, 24>& outputs, threadgroup spvUnsafeArray<float4, 22>& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray<BlockOutPrim, 22>& prim_outputs, threadgroup uint2& spvMeshSizes)
{
spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u);
float3 _158 = float3(gl_GlobalInvocationID);
float _159 = _158.x;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(_159, _158.yz, 1.0);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_PointSize = 2.0;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
vOut[gl_LocalInvocationIndex] = float4(_159, _158.yz, 2.0);
outputs[gl_LocalInvocationIndex].a = float4(5.0);
outputs[gl_LocalInvocationIndex].b = float4(6.0);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationIndex < 22u)
{
vPrim[gl_LocalInvocationIndex] = float4(float3(gl_WorkGroupID), 3.0);
prim_outputs[gl_LocalInvocationIndex].a = float4(0.0);
prim_outputs[gl_LocalInvocationIndex].b = float4(1.0);
gl_PrimitiveLineIndicesEXT[gl_LocalInvocationIndex] = uint2(0u, 1u) + uint2(gl_LocalInvocationIndex);
int _206 = int(gl_GlobalInvocationID.x);
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = _206;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = _206 + 1;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_ViewportIndex = _206 + 2;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_CullPrimitiveEXT = short((gl_GlobalInvocationID.x & 1u) != 0u);
}
}

[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh)
{
threadgroup uint2 spvMeshSizes;
threadgroup spvUnsafeArray<uint2, 22> gl_PrimitiveLineIndicesEXT;
threadgroup spvUnsafeArray<gl_MeshPerPrimitiveEXT, 22> gl_MeshPrimitivesEXT;
threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24> gl_MeshVerticesEXT;
threadgroup spvUnsafeArray<float4, 24> vOut;
threadgroup spvUnsafeArray<BlockOut, 24> outputs;
threadgroup spvUnsafeArray<float4, 22> vPrim;
threadgroup spvUnsafeArray<BlockOutPrim, 22> prim_outputs;
threadgroup spvUnsafeArray<float, 16> shared_float;
if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;
_4(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, spvMeshSizes);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (spvMeshSizes.y == 0)
{
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
spvPerVertex spvV = {};
spvV.gl_Position = gl_MeshVerticesEXT[spvVI].gl_Position;
spvV.gl_PointSize = gl_MeshVerticesEXT[spvVI].gl_PointSize;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.outputs_a = outputs[spvVI].a;
spvV.outputs_b = outputs[spvVI].b;
spvMesh.set_vertex(spvVI, spvV);
}
const uint spvPI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.y)
{
spvMesh.set_index(spvPI * 2u + 0u, gl_PrimitiveLineIndicesEXT[spvPI].x);
spvMesh.set_index(spvPI * 2u + 1u, gl_PrimitiveLineIndicesEXT[spvPI].y);
spvPerPrimitive spvP = {};
spvP.gl_PrimitiveID = gl_MeshPrimitivesEXT[spvPI].gl_PrimitiveID;
spvP.gl_Layer = gl_MeshPrimitivesEXT[spvPI].gl_Layer;
spvP.gl_ViewportIndex = gl_MeshPrimitivesEXT[spvPI].gl_ViewportIndex;
spvP.gl_CullPrimitiveEXT = gl_MeshPrimitivesEXT[spvPI].gl_CullPrimitiveEXT;
spvP.vPrim = vPrim[spvPI];
spvP.prim_outputs_a = prim_outputs[spvPI].a;
spvP.prim_outputs_b = prim_outputs[spvPI].b;
spvMesh.set_primitive(spvPI, spvP);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}

object_data T& operator [] (size_t pos) object_data
{
return elements[pos];
}
constexpr const object_data T& operator [] (size_t pos) const object_data
{
return elements[pos];
}
};

void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)
{
if (gl_LocalInvocationIndex == 0)
{
spvMeshSizes.x = vertexCount;
spvMeshSizes.y = primitiveCount;
}
}

struct gl_MeshPerVertexEXT
{
float4 gl_Position [[position]];
float gl_PointSize;
float gl_ClipDistance [[clip_distance]] [1];
};

struct BlockOut
{
float4 a;
float4 b;
};

struct BlockOutPrim
{
float4 a;
float4 b;
};

struct gl_MeshPerPrimitiveEXT
{
uint gl_PrimitiveID [[primitive_id]];
uint gl_Layer [[render_target_array_index]];
uint gl_ViewportIndex [[viewport_array_index]];
bool gl_CullPrimitiveEXT [[primitive_culled]];
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u);

struct TaskPayload
{
float a;
float b;
int c;
};

struct spvPerVertex
{
float4 gl_Position [[position]];
float gl_PointSize;
float gl_ClipDistance [[clip_distance]] [1];
float gl_ClipDistance_0 [[user(clip0)]];
float4 vOut [[user(locn0)]];
float4 outputs_a [[user(locn2)]];
float4 outputs_b [[user(locn3)]];
};

struct spvPerPrimitive
{
float4 vPrim [[user(locn1)]];
float4 prim_outputs_a [[user(locn4)]];
float4 prim_outputs_b [[user(locn5)]];
uint gl_PrimitiveID [[primitive_id]];
uint gl_Layer [[render_target_array_index]];
uint gl_ViewportIndex [[viewport_array_index]];
bool gl_CullPrimitiveEXT [[primitive_culled]];
};

using spvMesh_t = mesh<spvPerVertex, spvPerPrimitive, 24, 22, topology::triangle>;

static inline __attribute__((always_inline))
void _4(threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24>& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray<float4, 24>& vOut, threadgroup spvUnsafeArray<BlockOut, 24>& outputs, threadgroup spvUnsafeArray<float4, 22>& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray<BlockOutPrim, 22>& prim_outputs, threadgroup spvUnsafeArray<uint3, 22>& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray<gl_MeshPerPrimitiveEXT, 22>& gl_MeshPrimitivesEXT, threadgroup uint2& spvMeshSizes)
{
spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u);
float3 _27 = float3(gl_GlobalInvocationID);
float _29 = _27.x;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(_29, _27.yz, 1.0);
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_PointSize = 2.0;
gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0;
vOut[gl_LocalInvocationIndex] = float4(_29, _27.yz, 2.0);
outputs[gl_LocalInvocationIndex].a = float4(5.0);
outputs[gl_LocalInvocationIndex].b = float4(6.0);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationIndex < 22u)
{
vPrim[gl_LocalInvocationIndex] = float4(float3(gl_WorkGroupID), 3.0);
prim_outputs[gl_LocalInvocationIndex].a = float4(0.0);
prim_outputs[gl_LocalInvocationIndex].b = float4(1.0);
gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uint3(0u, 1u, 2u) + uint3(gl_LocalInvocationIndex);
int _116 = int(gl_GlobalInvocationID.x);
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = _116;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = _116 + 1;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_ViewportIndex = _116 + 2;
gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_CullPrimitiveEXT = short((gl_GlobalInvocationID.x & 1u) != 0u);
}
}

[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh)
{
threadgroup uint2 spvMeshSizes;
threadgroup spvUnsafeArray<gl_MeshPerVertexEXT, 24> gl_MeshVerticesEXT;
threadgroup spvUnsafeArray<float4, 24> vOut;
threadgroup spvUnsafeArray<BlockOut, 24> outputs;
threadgroup spvUnsafeArray<float4, 22> vPrim;
threadgroup spvUnsafeArray<BlockOutPrim, 22> prim_outputs;
threadgroup spvUnsafeArray<uint3, 22> gl_PrimitiveTriangleIndicesEXT;
threadgroup spvUnsafeArray<gl_MeshPerPrimitiveEXT, 22> gl_MeshPrimitivesEXT;
threadgroup spvUnsafeArray<float, 16> shared_float;
if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;
_4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, gl_PrimitiveTriangleIndicesEXT, gl_MeshPrimitivesEXT, spvMeshSizes);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (spvMeshSizes.y == 0)
{
return;
}
spvMesh.set_primitive_count(spvMeshSizes.y);
const uint spvThreadCount = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
const uint spvVI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.x)
{
spvPerVertex spvV = {};
spvV.gl_Position = gl_MeshVerticesEXT[spvVI].gl_Position;
spvV.gl_PointSize = gl_MeshVerticesEXT[spvVI].gl_PointSize;
spvV.gl_ClipDistance[0] = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.gl_ClipDistance_0 = gl_MeshVerticesEXT[spvVI].gl_ClipDistance[0];
spvV.outputs_a = outputs[spvVI].a;
spvV.outputs_b = outputs[spvVI].b;
spvMesh.set_vertex(spvVI, spvV);
}
const uint spvPI = gl_LocalInvocationIndex;
if (gl_LocalInvocationIndex < spvMeshSizes.y)
{
spvMesh.set_index(spvPI * 3u + 0u, gl_PrimitiveTriangleIndicesEXT[spvPI].x);
spvMesh.set_index(spvPI * 3u + 1u, gl_PrimitiveTriangleIndicesEXT[spvPI].y);
spvMesh.set_index(spvPI * 3u + 2u, gl_PrimitiveTriangleIndicesEXT[spvPI].z);
spvPerPrimitive spvP = {};
spvP.vPrim = vPrim[spvPI];
spvP.prim_outputs_a = prim_outputs[spvPI].a;
spvP.prim_outputs_b = prim_outputs[spvPI].b;
spvP.gl_PrimitiveID = gl_MeshPrimitivesEXT[spvPI].gl_PrimitiveID;
spvP.gl_Layer = gl_MeshPrimitivesEXT[spvPI].gl_Layer;
spvP.gl_ViewportIndex = gl_MeshPrimitivesEXT[spvPI].gl_ViewportIndex;
spvP.gl_CullPrimitiveEXT = gl_MeshPrimitivesEXT[spvPI].gl_CullPrimitiveEXT;
spvMesh.set_primitive(spvPI, spvP);
}
}

Loading

0 comments on commit 2f08260

Please sign in to comment.