Skip to content

Commit d395902

Browse files
authored
Add WorkgroupCount function (#7734)
Fixes #7733 Copy gl_NumWorkGroups into hlsl.meta.slang as WorkgroupCount function so that it can be used for GLSL and SPIR-V targets without GLSL syntax. Also change WorkgroupSize function to allow use with mesh shading capability. Update pipeline/rasterization/mesh/task-simple.slang to test it in task and mesh stages.
1 parent e16b5ca commit d395902

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

source/slang/glsl.meta.slang

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,25 +141,18 @@ public property int gl_DeviceIndex
141141
public property uint3 gl_NumWorkGroups
142142
{
143143
[require(glsl_spirv, GLSL_430_SPIRV_1_0_compute)]
144+
[require(glsl_spirv, meshshading)]
144145
get
145146
{
146-
__target_switch
147-
{
148-
case glsl:
149-
__intrinsic_asm "(gl_NumWorkGroups)";
150-
case spirv:
151-
return spirv_asm {
152-
result:$$uint3 = OpLoad builtin(NumWorkgroups:uint3);
153-
};
154-
}
147+
return WorkgroupCount();
155148
}
156149
}
157150

158-
[require(compute)]
159151
public property uint3 gl_WorkGroupSize
160152
{
161153
[__unsafeForceInlineEarly]
162154
[require(compute)]
155+
[require(meshshading)]
163156
get
164157
{
165158
return WorkgroupSize();

source/slang/hlsl.meta.slang

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6619,9 +6619,26 @@ void AllMemoryBarrierWithGroupSync()
66196619

66206620
// Returns the workgroup size of the calling entry point.
66216621
[require(compute)]
6622+
[require(meshshading)]
66226623
__intrinsic_op($(kIROp_GetWorkGroupSize))
66236624
int3 WorkgroupSize();
66246625

6626+
// Returns number of workgroups that have been dispatched to a GLSL or SPIR-V compute shader
6627+
[require(glsl_spirv, GLSL_430_SPIRV_1_0_compute)]
6628+
[require(glsl_spirv, meshshading)]
6629+
uint3 WorkgroupCount()
6630+
{
6631+
__target_switch
6632+
{
6633+
case glsl:
6634+
__intrinsic_asm "(gl_NumWorkGroups)";
6635+
case spirv:
6636+
return spirv_asm {
6637+
result:$$uint3 = OpLoad builtin(NumWorkgroups:uint3);
6638+
};
6639+
}
6640+
}
6641+
66256642
// Test if any components is non-zero.
66266643

66276644
__generic<T : __BuiltinType>

tests/pipeline/rasterization/mesh/task-simple.slang

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ struct MeshPayload
3434
int exponent;
3535
};
3636

37-
[numthreads(1, 1, 1)]
37+
const static uint AMPLIFICATION_NUM_THREADS_X = 1;
38+
39+
[numthreads(AMPLIFICATION_NUM_THREADS_X, 1, 1)]
3840
[shader("amplification")]
3941
void taskMain(in uint tig : SV_GroupIndex)
4042
{
4143
MeshPayload p;
42-
p.exponent = 3;
44+
p.exponent = select(AMPLIFICATION_NUM_THREADS_X == WorkgroupSize().x, 3, 0);
4345
DispatchMesh(1,1,1,p);
4446
}
4547

@@ -71,8 +73,10 @@ struct Vertex
7173
const static uint MAX_VERTS = 12;
7274
const static uint MAX_PRIMS = 4;
7375

76+
const static uint MESH_NUM_THREADS_X = 12;
77+
7478
[outputtopology("triangle")]
75-
[numthreads(12, 1, 1)]
79+
[numthreads(MESH_NUM_THREADS_X, 1, 1)]
7680
void meshMain(
7781
in uint tig : SV_GroupIndex,
7882
in payload MeshPayload meshPayload,
@@ -88,7 +92,7 @@ void meshMain(
8892

8993
if(tig < numVertices)
9094
{
91-
const int tri = tig / 3;
95+
const int tri = select(WorkgroupSize().x == MESH_NUM_THREADS_X, tig / 3, -1);
9296
verts[tig] = {float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent))};
9397
}
9498

0 commit comments

Comments
 (0)