Skip to content

Commit 3e38cbe

Browse files
Add Metal backend core ETMetal runtime. (#15020)
This commit introduces the foundational Metal backend runtime. Key features: - ETMetalStream for managing Metal devices, command queues, buffers, and synchronization. - ETMetalShaderLibrary for compiling Metal shader source and caching pipeline states. - ETMetalKernelFunction for kernel argument binding, dispatching, and synchronization with stream-managed encoders. - Added global buffer management and pointer tracking between host and Metal buffers. - Added global stream management utilities and synchronization helpers This provides the necessary runtime primitives for executing compute shaders and MPSGraph workloads.
1 parent a393191 commit 3e38cbe

File tree

2 files changed

+1273
-0
lines changed

2 files changed

+1273
-0
lines changed
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef __OBJC__
12+
#import <Foundation/Foundation.h>
13+
#import <Metal/Metal.h>
14+
#include <dispatch/dispatch.h>
15+
// Forward declarations for MetalPerformanceShadersGraph types
16+
@class MPSGraph;
17+
@class MPSCommandBuffer;
18+
// Metal type definitions for Objective-C compilation
19+
typedef id<MTLDevice> MTLDevice_t;
20+
typedef id<MTLCommandQueue> MTLCommandQueue_t;
21+
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
22+
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
23+
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
24+
typedef id<MTLFunction> MTLFunction_t;
25+
typedef id<MTLLibrary> MTLLibrary_t;
26+
typedef id<MTLBuffer> MTLBuffer_t;
27+
typedef dispatch_queue_t dispatch_queue_t;
28+
typedef MPSGraph* MPSGraph_t;
29+
typedef MPSCommandBuffer* MPSCommandBuffer_t;
30+
typedef NSDictionary* NSDictionary_t;
31+
#else
32+
// Forward declarations for C++ compilation
33+
typedef void* MTLDevice_t;
34+
typedef void* MTLCommandQueue_t;
35+
typedef void* MTLCommandBuffer_t;
36+
typedef void* MTLComputeCommandEncoder_t;
37+
typedef void* MTLComputePipelineState_t;
38+
typedef void* MTLFunction_t;
39+
typedef void* MTLLibrary_t;
40+
typedef void* MTLBuffer_t;
41+
typedef void* dispatch_queue_t;
42+
typedef void* MPSGraph_t;
43+
typedef void* MPSCommandBuffer_t;
44+
typedef void* NSDictionary_t;
45+
#endif
46+
47+
#include <functional>
48+
#include <memory>
49+
#include <string>
50+
#include <unordered_map>
51+
#include <vector>
52+
53+
namespace executorch::runtime::etensor {
54+
class Tensor;
55+
}
56+
57+
namespace executorch {
58+
namespace backends {
59+
namespace metal {
60+
61+
// Forward declarations
62+
class ETMetalKernelFunction;
63+
class ETMetalStream;
64+
65+
// =======================
66+
// SyncType - Metal synchronization options
67+
// =======================
68+
enum class SyncType {
69+
NONE, // no commit to command buffer
70+
COMMIT, // commit and flush the command buffer
71+
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
72+
COMMIT_AND_CONTINUE, // commit and continue with a new underlying command
73+
// buffer
74+
COMMIT_ADAPTIVE, // commit adaptively based on available memory
75+
};
76+
77+
// =======================
78+
// ETMetalShaderLibrary - ExecuTorch Metal shader library management
79+
// =======================
80+
81+
/**
82+
* @class ETMetalShaderLibrary
83+
* @brief Manages Metal shader library compilation and kernel function
84+
* retrieval.
85+
*
86+
* This class provides a high-level interface for compiling Metal shading
87+
* language source code into a Metal library and creating compute pipeline
88+
* states for kernel functions. It handles the creation and caching of Metal
89+
* compute pipeline states and functions, which should be reused across multiple
90+
* kernel dispatches.
91+
*
92+
* The class automatically compiles the provided shader source code upon
93+
* construction and maintains an internal cache of compute pipeline states for
94+
* different kernel functions to avoid redundant compilation.
95+
*
96+
* Example usage:
97+
* @code
98+
* std::string shaderSource = R"(
99+
* #include <metal_stdlib>
100+
* using namespace metal;
101+
* kernel void my_kernel(device float* data [[buffer(0)]],
102+
* uint tid [[thread_position_in_grid]]) {
103+
* data[tid] = data[tid] * 2.0;
104+
* }
105+
* )";
106+
*
107+
* ETMetalShaderLibrary library(shaderSource);
108+
* auto kernelFunction = library.getKernelFunction("my_kernel");
109+
* @endcode
110+
*/
111+
class ETMetalShaderLibrary {
112+
public:
113+
ETMetalShaderLibrary(const std::string& source);
114+
~ETMetalShaderLibrary();
115+
116+
std::shared_ptr<ETMetalKernelFunction> getKernelFunction(
117+
const std::string& name);
118+
119+
private:
120+
void compileLibrary();
121+
std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
122+
const std::string& functionName);
123+
124+
friend class ETMetalKernelFunction;
125+
126+
std::string shaderSource_;
127+
MTLLibrary_t library_;
128+
std::unordered_map<
129+
std::string,
130+
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
131+
pipelineStates_;
132+
};
133+
134+
// =======================
135+
// ETMetalKernelFunction - ExecuTorch Metal kernel function execution
136+
// =======================
137+
138+
/**
139+
* @class ETMetalKernelFunction
140+
* @brief Represents a Metal compute kernel function ready for execution.
141+
*
142+
* This class encapsulates a Metal compute pipeline state and function,
143+
* providing a high-level interface for setting kernel arguments and dispatching
144+
* compute work to the GPU. It handles the encoding of compute commands and
145+
* manages the interaction with Metal's compute command encoder.
146+
*
147+
* The class supports different dispatch patterns:
148+
* - Single-dimension dispatch for linear workloads
149+
* - Multi-dimensional dispatch for grid-based workloads
150+
* - Custom thread group sizes for performance optimization
151+
*
152+
* Kernel arguments can be set using tensors (which will be mapped to Metal
153+
* buffers) or scalar values. The class handles the encoding of these arguments
154+
* into the compute command encoder.
155+
*
156+
* Example usage:
157+
* @code
158+
* // Get kernel function from library
159+
* auto kernelFunction = library.getKernelFunction("vector_add");
160+
*
161+
* // Start encoding commands
162+
* kernelFunction->startEncoding();
163+
*
164+
* // Set tensor arguments
165+
* kernelFunction->setArg(0, inputTensorA);
166+
* kernelFunction->setArg(1, inputTensorB);
167+
* kernelFunction->setArg(2, outputTensor);
168+
*
169+
* // Set scalar argument
170+
* kernelFunction->setArg(3, static_cast<int64_t>(numElements));
171+
*
172+
* // Dispatch for linear workload
173+
* kernelFunction->dispatchSingle(numElements);
174+
* @endcode
175+
*/
176+
class ETMetalKernelFunction {
177+
public:
178+
ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func);
179+
~ETMetalKernelFunction();
180+
181+
void startEncoding();
182+
void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor);
183+
void setArg(unsigned idx, int64_t val);
184+
185+
void dispatchSingle(uint64_t length);
186+
void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size);
187+
void dispatchArray(const uint64_t* length, size_t length_size);
188+
void dispatchArrayWithGroupSize(
189+
const uint64_t* length,
190+
size_t length_size,
191+
const uint64_t* group_size,
192+
size_t group_size_size);
193+
194+
void runCommandBlock(std::function<void(void)> f);
195+
196+
private:
197+
MTLComputePipelineState_t cps_;
198+
MTLFunction_t func_;
199+
MTLComputeCommandEncoder_t encoder_;
200+
};
201+
202+
// =======================
203+
// ETMetalStream - Metal command buffer and synchronization management
204+
// =======================
205+
206+
/**
207+
* @class ETMetalStream
208+
* @brief Manages Metal compute command streams and provides GPU
209+
* synchronization.
210+
*
211+
* This class serves as the central management hub for Metal GPU operations,
212+
* providing a stream-based abstraction similar to CUDA streams. It handles
213+
* command buffer lifecycle, compute command encoder management, and various
214+
* synchronization patterns required for efficient GPU computation.
215+
*
216+
* Key features:
217+
* - Lazy command buffer and encoder creation for optimal resource usage
218+
* - Thread-safe operations using serial dispatch queues
219+
* - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT,
220+
* COMMIT_AND_CONTINUE, etc.)
221+
* - Kernel coalescing to batch multiple operations efficiently
222+
* - MPSGraph integration for executing fall back operations (mm, conv, sdpa)
223+
* - Memory operations (copy, fill) with GPU acceleration via blit encoders
224+
*
225+
* The stream follows PyTorch's MPS stream design patterns, providing similar
226+
* semantics for command buffer management and synchronization.
227+
*
228+
* Example usage:
229+
* @code
230+
* // Get current stream (typically the default stream)
231+
* ETMetalStream* stream = getCurrentMetalStream();
232+
*
233+
* // Execute kernel operations (handled automatically)
234+
* auto kernelFunction = library.getKernelFunction("my_kernel");
235+
* kernelFunction->startEncoding();
236+
* kernelFunction->setArg(0, inputTensor);
237+
* kernelFunction->dispatchSingle(numElements);
238+
*
239+
* // Synchronize to ensure completion
240+
* stream->synchronize(SyncType::COMMIT_AND_WAIT);
241+
*
242+
* // Copy between GPU buffers using blit encoder
243+
* stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT);
244+
* @endcode
245+
*/
246+
class ETMetalStream {
247+
public:
248+
ETMetalStream();
249+
~ETMetalStream();
250+
251+
// Get the default stream (singleton)
252+
static ETMetalStream* getDefaultStream();
253+
254+
// Device and queue access
255+
MTLDevice_t device() const {
256+
return device_;
257+
}
258+
MTLCommandQueue_t commandQueue() const {
259+
return commandQueue_;
260+
}
261+
dispatch_queue_t queue() const {
262+
return serialQueue_;
263+
}
264+
265+
// Synchronization methods
266+
void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT);
267+
void synchronize(); // Overload for backward compatibility
268+
bool isEmpty() const;
269+
270+
// Command buffer management with lazy creation
271+
MPSCommandBuffer_t commandBuffer();
272+
MTLComputeCommandEncoder_t commandEncoder();
273+
274+
void endKernelCoalescing();
275+
276+
// MPSGraph execution
277+
void executeMPSGraph(
278+
MPSGraph_t mpsGraph,
279+
NSDictionary_t feeds,
280+
NSDictionary_t results,
281+
SyncType syncType = SyncType::COMMIT_ADAPTIVE);
282+
283+
// Command buffer lifecycle management
284+
void commitCommandBuffer(MTLCommandBuffer_t commandBuffer);
285+
void flush();
286+
287+
// Memory operations
288+
void fill(
289+
MTLBuffer_t buffer,
290+
uint8_t value,
291+
size_t length,
292+
size_t offset,
293+
SyncType syncType = SyncType::NONE);
294+
void copy(
295+
MTLBuffer_t srcBuffer,
296+
MTLBuffer_t dstBuffer,
297+
size_t length,
298+
size_t srcOffset,
299+
size_t dstOffset,
300+
SyncType syncType = SyncType::NONE);
301+
302+
private:
303+
// Private synchronization methods
304+
void commit();
305+
void commitAndWait();
306+
void commitAndContinue();
307+
308+
private:
309+
// Private members
310+
MTLDevice_t device_;
311+
MTLCommandQueue_t commandQueue_;
312+
MPSCommandBuffer_t commandBuffer_;
313+
MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern
314+
MTLComputeCommandEncoder_t commandEncoder_;
315+
dispatch_queue_t serialQueue_; // For thread safety
316+
317+
// Configuration
318+
bool enableCommitAndContinue_;
319+
320+
// Singleton instance
321+
static ETMetalStream* defaultStream_;
322+
};
323+
324+
// =======================
325+
// Global storage management functions
326+
// =======================
327+
void storeFunctionHandle(
328+
ETMetalKernelFunction* raw_function,
329+
std::shared_ptr<ETMetalKernelFunction> function_shared_ptr);
330+
void storeLibraryHandle(
331+
ETMetalShaderLibrary* raw_library,
332+
std::unique_ptr<ETMetalShaderLibrary> library);
333+
bool removeFunctionHandle(ETMetalKernelFunction* raw_function);
334+
bool removeLibraryHandle(ETMetalShaderLibrary* raw_library);
335+
336+
// =======================
337+
// Global stream access functions
338+
// =======================
339+
ETMetalStream* getCurrentMetalStream();
340+
void setCurrentMetalStream(ETMetalStream* stream);
341+
342+
// =======================
343+
// Metal stream synchronization functions (C++ interface with exceptions)
344+
// =======================
345+
void synchronize_metal_stream();
346+
void synchronize_metal_stream_with_type(int sync_type);
347+
348+
// =======================
349+
// Metal helper functions (C interface)
350+
// =======================
351+
#ifdef __cplusplus
352+
extern "C" {
353+
#endif
354+
355+
// Memory management functions for Metal
356+
void* metal_allocate_buffer(long bytes);
357+
bool metal_is_device_pointer(void* ptr);
358+
int metal_copy_memory(
359+
void* dst,
360+
const void* src,
361+
size_t nbytes,
362+
bool src_is_device,
363+
bool dst_is_device);
364+
void metal_cleanup_resources();
365+
366+
// Helper functions to access Metal objects
367+
MTLDevice_t get_metal_device();
368+
MTLCommandQueue_t get_metal_command_queue();
369+
370+
#ifdef __cplusplus
371+
}
372+
373+
// C++ only - expose the Metal buffer mapping
374+
#ifdef __OBJC__
375+
extern std::unordered_map<void*, MTLBuffer_t> ptr_to_mtl_buffer;
376+
#endif
377+
378+
#endif
379+
380+
} // namespace metal
381+
} // namespace backends
382+
} // namespace executorch

0 commit comments

Comments
 (0)