Skip to content

Commit d0cf78c

Browse files
[hal/metal] Mesh Shaders (#8139)
Co-authored-by: Connor Fitzgerald <[email protected]> Co-authored-by: Magnus <[email protected]>
1 parent 92fa99a commit d0cf78c

File tree

13 files changed

+900
-383
lines changed

13 files changed

+900
-383
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206
125125
- `util::StagingBelt` now takes a `Device` when it is created instead of when it is used. By @kpreid in [#8462](https://github.com/gfx-rs/wgpu/pull/8462).
126126
- `wgpu_hal::vulkan::Device::texture_from_raw` now takes an `external_memory` argument. By @s-ol in [#8512](https://github.com/gfx-rs/wgpu/pull/8512)
127127

128+
#### Metal
129+
- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139)
130+
128131
#### Naga
129132

130133
- Prevent UB with invalid ray query calls on spirv. By @Vecvec in [#8390](https://github.com/gfx-rs/wgpu/pull/8390).

examples/features/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ fn all_tests() -> Vec<wgpu_test::GpuTestInitializer> {
4949
cube::TEST,
5050
cube::TEST_LINES,
5151
hello_synchronization::tests::SYNC,
52+
mesh_shader::TEST,
5253
mipmap::TEST,
5354
mipmap::TEST_QUERY,
5455
msaa_line::TEST,

examples/features/src/mesh_shader/mod.rs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::Sh
6161
}
6262
}
6363

64+
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
65+
unsafe {
66+
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
67+
entry_point: entry.to_owned(),
68+
label: None,
69+
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
70+
num_workgroups: (1, 1, 1),
71+
..Default::default()
72+
})
73+
}
74+
}
75+
6476
pub struct Example {
6577
pipeline: wgpu::RenderPipeline,
6678
}
@@ -71,20 +83,23 @@ impl crate::framework::Example for Example {
7183
device: &wgpu::Device,
7284
_queue: &wgpu::Queue,
7385
) -> Self {
74-
let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Vulkan {
75-
(
86+
let (ts, ms, fs) = match adapter.get_info().backend {
87+
wgpu::Backend::Vulkan => (
7688
compile_glsl(device, "task"),
7789
compile_glsl(device, "mesh"),
7890
compile_glsl(device, "frag"),
79-
)
80-
} else if adapter.get_info().backend == wgpu::Backend::Dx12 {
81-
(
91+
),
92+
wgpu::Backend::Dx12 => (
8293
compile_hlsl(device, "Task", "as"),
8394
compile_hlsl(device, "Mesh", "ms"),
8495
compile_hlsl(device, "Frag", "ps"),
85-
)
86-
} else {
87-
panic!("Example can only run on vulkan or dx12");
96+
),
97+
wgpu::Backend::Metal => (
98+
compile_msl(device, "taskShader"),
99+
compile_msl(device, "meshShader"),
100+
compile_msl(device, "fragShader"),
101+
),
102+
_ => panic!("Example can currently only run on vulkan, dx12 or metal"),
88103
};
89104
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
90105
label: None,
@@ -179,3 +194,21 @@ impl crate::framework::Example for Example {
179194
pub fn main() {
180195
crate::framework::run::<Example>("mesh_shader");
181196
}
197+
198+
#[cfg(test)]
199+
#[wgpu_test::gpu_test]
200+
pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams {
201+
name: "mesh_shader",
202+
image_path: "/examples/features/src/mesh_shader/screenshot.png",
203+
width: 1024,
204+
height: 768,
205+
optional_features: wgpu::Features::default(),
206+
base_test_parameters: wgpu_test::TestParameters::default()
207+
.features(
208+
wgpu::Features::EXPERIMENTAL_MESH_SHADER
209+
| wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS,
210+
)
211+
.limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()),
212+
comparisons: &[wgpu_test::ComparisonType::Mean(0.01)],
213+
_phantom: std::marker::PhantomData::<Example>,
214+
};
33.5 KB
Loading
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using namespace metal;
2+
3+
struct OutVertex {
4+
float4 Position [[position]];
5+
float4 Color [[user(locn0)]];
6+
};
7+
8+
struct OutPrimitive {
9+
float4 ColorMask [[flat]] [[user(locn1)]];
10+
bool CullPrimitive [[primitive_culled]];
11+
};
12+
13+
struct InVertex {
14+
};
15+
16+
struct InPrimitive {
17+
float4 ColorMask [[flat]] [[user(locn1)]];
18+
};
19+
20+
struct FragmentIn {
21+
float4 Color [[user(locn0)]];
22+
float4 ColorMask [[flat]] [[user(locn1)]];
23+
};
24+
25+
struct PayloadData {
26+
float4 ColorMask;
27+
bool Visible;
28+
};
29+
30+
using Meshlet = metal::mesh<OutVertex, OutPrimitive, 3, 1, topology::triangle>;
31+
32+
33+
constant float4 positions[3] = {
34+
float4(0.0, 1.0, 0.0, 1.0),
35+
float4(-1.0, -1.0, 0.0, 1.0),
36+
float4(1.0, -1.0, 0.0, 1.0)
37+
};
38+
39+
constant float4 colors[3] = {
40+
float4(0.0, 1.0, 0.0, 1.0),
41+
float4(0.0, 0.0, 1.0, 1.0),
42+
float4(1.0, 0.0, 0.0, 1.0)
43+
};
44+
45+
46+
[[object]]
47+
void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) {
48+
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
49+
outPayload.Visible = true;
50+
grid.set_threadgroups_per_grid(uint3(3, 1, 1));
51+
}
52+
53+
[[mesh]]
54+
void meshShader(
55+
object_data PayloadData const& payload [[payload]],
56+
Meshlet out
57+
)
58+
{
59+
out.set_primitive_count(1);
60+
61+
for(int i = 0;i < 3;i++) {
62+
OutVertex vert;
63+
vert.Position = positions[i];
64+
vert.Color = colors[i] * payload.ColorMask;
65+
out.set_vertex(i, vert);
66+
out.set_index(i, i);
67+
}
68+
69+
OutPrimitive prim;
70+
prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0);
71+
prim.CullPrimitive = !payload.Visible;
72+
out.set_primitive(0, prim);
73+
}
74+
75+
fragment float4 fragShader(FragmentIn data [[stage_in]]) {
76+
return data.Color * data.ColorMask;
77+
}

tests/tests/wgpu-gpu/mesh_shader/mod.rs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,11 @@ use std::{
33
process::Stdio,
44
};
55

6-
use wgpu::{util::DeviceExt, Backends};
6+
use wgpu::util::DeviceExt;
77
use wgpu_test::{
8-
fail, gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters,
9-
TestingContext,
8+
fail, gpu_test, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext,
109
};
1110

12-
/// Backends that support mesh shaders
13-
const MESH_SHADER_BACKENDS: Backends = Backends::DX12.union(Backends::VULKAN);
14-
1511
pub fn all_tests(tests: &mut Vec<GpuTestInitializer>) {
1612
tests.extend([
1713
MESH_PIPELINE_BASIC_MESH,
@@ -98,6 +94,18 @@ fn compile_hlsl(
9894
}
9995
}
10096

97+
fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule {
98+
unsafe {
99+
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
100+
entry_point: entry.to_owned(),
101+
label: None,
102+
msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))),
103+
num_workgroups: (1, 1, 1),
104+
..Default::default()
105+
})
106+
}
107+
}
108+
101109
fn get_shaders(
102110
device: &wgpu::Device,
103111
backend: wgpu::Backend,
@@ -114,18 +122,17 @@ fn get_shaders(
114122
// (In the case that the platform does support mesh shaders, the dummy
115123
// shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.)
116124
let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl"));
117-
if backend == wgpu::Backend::Vulkan {
118-
(
125+
match backend {
126+
wgpu::Backend::Vulkan => (
119127
info.use_task.then(|| compile_glsl(device, "task")),
120128
if info.use_mesh {
121129
compile_glsl(device, "mesh")
122130
} else {
123131
dummy_shader
124132
},
125133
info.use_frag.then(|| compile_glsl(device, "frag")),
126-
)
127-
} else if backend == wgpu::Backend::Dx12 {
128-
(
134+
),
135+
wgpu::Backend::Dx12 => (
129136
info.use_task
130137
.then(|| compile_hlsl(device, "Task", "as", test_name)),
131138
if info.use_mesh {
@@ -135,11 +142,20 @@ fn get_shaders(
135142
},
136143
info.use_frag
137144
.then(|| compile_hlsl(device, "Frag", "ps", test_name)),
138-
)
139-
} else {
140-
assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend)));
141-
assert!(!info.use_task && !info.use_mesh && !info.use_frag);
142-
(None, dummy_shader, None)
145+
),
146+
wgpu::Backend::Metal => (
147+
info.use_task.then(|| compile_msl(device, "taskShader")),
148+
if info.use_mesh {
149+
compile_msl(device, "meshShader")
150+
} else {
151+
dummy_shader
152+
},
153+
info.use_frag.then(|| compile_msl(device, "fragShader")),
154+
),
155+
_ => {
156+
assert!(!info.use_task && !info.use_mesh && !info.use_frag);
157+
(None, dummy_shader, None)
158+
}
143159
}
144160
}
145161

@@ -377,7 +393,6 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) {
377393
fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration {
378394
GpuTestConfiguration::new().parameters(
379395
TestParameters::default()
380-
.skip(FailureCase::backend(!MESH_SHADER_BACKENDS))
381396
.test_features_limits()
382397
.features(
383398
wgpu::Features::EXPERIMENTAL_MESH_SHADER
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using namespace metal;
2+
3+
struct OutVertex {
4+
float4 Position [[position]];
5+
float4 Color [[user(locn0)]];
6+
};
7+
8+
struct OutPrimitive {
9+
float4 ColorMask [[flat]] [[user(locn1)]];
10+
bool CullPrimitive [[primitive_culled]];
11+
};
12+
13+
struct InVertex {
14+
};
15+
16+
struct InPrimitive {
17+
float4 ColorMask [[flat]] [[user(locn1)]];
18+
};
19+
20+
struct FragmentIn {
21+
float4 Color [[user(locn0)]];
22+
float4 ColorMask [[flat]] [[user(locn1)]];
23+
};
24+
25+
struct PayloadData {
26+
float4 ColorMask;
27+
bool Visible;
28+
};
29+
30+
using Meshlet = metal::mesh<OutVertex, OutPrimitive, 3, 1, topology::triangle>;
31+
32+
33+
constant float4 positions[3] = {
34+
float4(0.0, 1.0, 0.0, 1.0),
35+
float4(-1.0, -1.0, 0.0, 1.0),
36+
float4(1.0, -1.0, 0.0, 1.0)
37+
};
38+
39+
constant float4 colors[3] = {
40+
float4(0.0, 1.0, 0.0, 1.0),
41+
float4(0.0, 0.0, 1.0, 1.0),
42+
float4(1.0, 0.0, 0.0, 1.0)
43+
};
44+
45+
46+
[[object]]
47+
void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) {
48+
outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0);
49+
outPayload.Visible = true;
50+
grid.set_threadgroups_per_grid(uint3(3, 1, 1));
51+
}
52+
53+
[[mesh]]
54+
void meshShader(
55+
object_data PayloadData const& payload [[payload]],
56+
Meshlet out
57+
)
58+
{
59+
out.set_primitive_count(1);
60+
61+
for(int i = 0;i < 3;i++) {
62+
OutVertex vert;
63+
vert.Position = positions[i];
64+
vert.Color = colors[i] * payload.ColorMask;
65+
out.set_vertex(i, vert);
66+
out.set_index(i, i);
67+
}
68+
69+
OutPrimitive prim;
70+
prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0);
71+
prim.CullPrimitive = !payload.Visible;
72+
out.set_primitive(0, prim);
73+
}
74+
75+
fragment float4 fragShader(FragmentIn data [[stage_in]]) {
76+
return data.Color * data.ColorMask;
77+
}

wgpu-hal/src/metal/adapter.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ impl super::PrivateCapabilities {
607607

608608
let argument_buffers = device.argument_buffers_support();
609609

610+
let is_virtual = device.name().to_lowercase().contains("virtual");
611+
610612
Self {
611613
family_check,
612614
msl_version: if os_is_xr || version.at_least((14, 0), (17, 0), os_is_mac) {
@@ -902,6 +904,12 @@ impl super::PrivateCapabilities {
902904
&& (device.supports_family(MTLGPUFamily::Apple7)
903905
|| device.supports_family(MTLGPUFamily::Mac2)),
904906
supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac),
907+
mesh_shaders: family_check
908+
&& (device.supports_family(MTLGPUFamily::Metal3)
909+
|| device.supports_family(MTLGPUFamily::Apple7)
910+
|| device.supports_family(MTLGPUFamily::Mac2))
911+
// Mesh shaders don't work on virtual devices even if they should be supported.
912+
&& !is_virtual,
905913
supported_vertex_amplification_factor: {
906914
let mut factor = 1;
907915
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=8
@@ -1023,6 +1031,8 @@ impl super::PrivateCapabilities {
10231031
features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER);
10241032
}
10251033

1034+
features.set(F::EXPERIMENTAL_MESH_SHADER, self.mesh_shaders);
1035+
10261036
if self.supported_vertex_amplification_factor > 1 {
10271037
features.insert(F::MULTIVIEW);
10281038
}
@@ -1102,10 +1112,11 @@ impl super::PrivateCapabilities {
11021112
max_buffer_size: self.max_buffer_size,
11031113
max_non_sampler_bindings: u32::MAX,
11041114

1105-
max_task_workgroup_total_count: 0,
1106-
max_task_workgroups_per_dimension: 0,
1115+
// See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, Maximum threadgroups per mesh shader grid
1116+
max_task_workgroup_total_count: 1024,
1117+
max_task_workgroups_per_dimension: 1024,
11071118
max_mesh_multiview_view_count: 0,
1108-
max_mesh_output_layers: 0,
1119+
max_mesh_output_layers: self.max_texture_layers as u32,
11091120

11101121
max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits
11111122
max_blas_geometry_count: 0, // When added: 2^24

0 commit comments

Comments
 (0)