|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#ifdef __OBJC__ |
| 12 | +#import <Foundation/Foundation.h> |
| 13 | +#import <Metal/Metal.h> |
| 14 | +#include <dispatch/dispatch.h> |
| 15 | +// Forward declarations for MetalPerformanceShadersGraph types |
| 16 | +@class MPSGraph; |
| 17 | +@class MPSCommandBuffer; |
| 18 | +// Metal type definitions for Objective-C compilation |
| 19 | +typedef id<MTLDevice> MTLDevice_t; |
| 20 | +typedef id<MTLCommandQueue> MTLCommandQueue_t; |
| 21 | +typedef id<MTLCommandBuffer> MTLCommandBuffer_t; |
| 22 | +typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t; |
| 23 | +typedef id<MTLComputePipelineState> MTLComputePipelineState_t; |
| 24 | +typedef id<MTLFunction> MTLFunction_t; |
| 25 | +typedef id<MTLLibrary> MTLLibrary_t; |
| 26 | +typedef id<MTLBuffer> MTLBuffer_t; |
| 27 | +typedef dispatch_queue_t dispatch_queue_t; |
| 28 | +typedef MPSGraph* MPSGraph_t; |
| 29 | +typedef MPSCommandBuffer* MPSCommandBuffer_t; |
| 30 | +typedef NSDictionary* NSDictionary_t; |
| 31 | +#else |
| 32 | +// Forward declarations for C++ compilation |
| 33 | +typedef void* MTLDevice_t; |
| 34 | +typedef void* MTLCommandQueue_t; |
| 35 | +typedef void* MTLCommandBuffer_t; |
| 36 | +typedef void* MTLComputeCommandEncoder_t; |
| 37 | +typedef void* MTLComputePipelineState_t; |
| 38 | +typedef void* MTLFunction_t; |
| 39 | +typedef void* MTLLibrary_t; |
| 40 | +typedef void* MTLBuffer_t; |
| 41 | +typedef void* dispatch_queue_t; |
| 42 | +typedef void* MPSGraph_t; |
| 43 | +typedef void* MPSCommandBuffer_t; |
| 44 | +typedef void* NSDictionary_t; |
| 45 | +#endif |
| 46 | + |
| 47 | +#include <functional> |
| 48 | +#include <memory> |
| 49 | +#include <string> |
| 50 | +#include <unordered_map> |
| 51 | +#include <vector> |
| 52 | + |
| 53 | +namespace executorch::runtime::etensor { |
| 54 | +class Tensor; |
| 55 | +} |
| 56 | + |
| 57 | +namespace executorch { |
| 58 | +namespace backends { |
| 59 | +namespace metal { |
| 60 | + |
| 61 | +// Forward declarations |
| 62 | +class ETMetalKernelFunction; |
| 63 | +class ETMetalStream; |
| 64 | + |
| 65 | +// ======================= |
| 66 | +// SyncType - Metal synchronization options |
| 67 | +// ======================= |
| 68 | +enum class SyncType { |
| 69 | + NONE, // no commit to command buffer |
| 70 | + COMMIT, // commit and flush the command buffer |
| 71 | + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish |
| 72 | + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command |
| 73 | + // buffer |
| 74 | + COMMIT_ADAPTIVE, // commit adaptively based on available memory |
| 75 | +}; |
| 76 | + |
| 77 | +// ======================= |
| 78 | +// ETMetalShaderLibrary - ExecuTorch Metal shader library management |
| 79 | +// ======================= |
| 80 | + |
| 81 | +/** |
| 82 | + * @class ETMetalShaderLibrary |
| 83 | + * @brief Manages Metal shader library compilation and kernel function |
| 84 | + * retrieval. |
| 85 | + * |
| 86 | + * This class provides a high-level interface for compiling Metal shading |
| 87 | + * language source code into a Metal library and creating compute pipeline |
| 88 | + * states for kernel functions. It handles the creation and caching of Metal |
| 89 | + * compute pipeline states and functions, which should be reused across multiple |
| 90 | + * kernel dispatches. |
| 91 | + * |
| 92 | + * The class automatically compiles the provided shader source code upon |
| 93 | + * construction and maintains an internal cache of compute pipeline states for |
| 94 | + * different kernel functions to avoid redundant compilation. |
| 95 | + * |
| 96 | + * Example usage: |
| 97 | + * @code |
| 98 | + * std::string shaderSource = R"( |
| 99 | + * #include <metal_stdlib> |
| 100 | + * using namespace metal; |
| 101 | + * kernel void my_kernel(device float* data [[buffer(0)]], |
| 102 | + * uint tid [[thread_position_in_grid]]) { |
| 103 | + * data[tid] = data[tid] * 2.0; |
| 104 | + * } |
| 105 | + * )"; |
| 106 | + * |
| 107 | + * ETMetalShaderLibrary library(shaderSource); |
| 108 | + * auto kernelFunction = library.getKernelFunction("my_kernel"); |
| 109 | + * @endcode |
| 110 | + */ |
| 111 | +class ETMetalShaderLibrary { |
| 112 | + public: |
| 113 | + ETMetalShaderLibrary(const std::string& source); |
| 114 | + ~ETMetalShaderLibrary(); |
| 115 | + |
| 116 | + std::shared_ptr<ETMetalKernelFunction> getKernelFunction( |
| 117 | + const std::string& name); |
| 118 | + |
| 119 | + private: |
| 120 | + void compileLibrary(); |
| 121 | + std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState( |
| 122 | + const std::string& functionName); |
| 123 | + |
| 124 | + friend class ETMetalKernelFunction; |
| 125 | + |
| 126 | + std::string shaderSource_; |
| 127 | + MTLLibrary_t library_; |
| 128 | + std::unordered_map< |
| 129 | + std::string, |
| 130 | + std::pair<MTLComputePipelineState_t, MTLFunction_t>> |
| 131 | + pipelineStates_; |
| 132 | +}; |
| 133 | + |
| 134 | +// ======================= |
| 135 | +// ETMetalKernelFunction - ExecuTorch Metal kernel function execution |
| 136 | +// ======================= |
| 137 | + |
| 138 | +/** |
| 139 | + * @class ETMetalKernelFunction |
| 140 | + * @brief Represents a Metal compute kernel function ready for execution. |
| 141 | + * |
| 142 | + * This class encapsulates a Metal compute pipeline state and function, |
| 143 | + * providing a high-level interface for setting kernel arguments and dispatching |
| 144 | + * compute work to the GPU. It handles the encoding of compute commands and |
| 145 | + * manages the interaction with Metal's compute command encoder. |
| 146 | + * |
| 147 | + * The class supports different dispatch patterns: |
| 148 | + * - Single-dimension dispatch for linear workloads |
| 149 | + * - Multi-dimensional dispatch for grid-based workloads |
| 150 | + * - Custom thread group sizes for performance optimization |
| 151 | + * |
| 152 | + * Kernel arguments can be set using tensors (which will be mapped to Metal |
| 153 | + * buffers) or scalar values. The class handles the encoding of these arguments |
| 154 | + * into the compute command encoder. |
| 155 | + * |
| 156 | + * Example usage: |
| 157 | + * @code |
| 158 | + * // Get kernel function from library |
| 159 | + * auto kernelFunction = library.getKernelFunction("vector_add"); |
| 160 | + * |
| 161 | + * // Start encoding commands |
| 162 | + * kernelFunction->startEncoding(); |
| 163 | + * |
| 164 | + * // Set tensor arguments |
| 165 | + * kernelFunction->setArg(0, inputTensorA); |
| 166 | + * kernelFunction->setArg(1, inputTensorB); |
| 167 | + * kernelFunction->setArg(2, outputTensor); |
| 168 | + * |
| 169 | + * // Set scalar argument |
| 170 | + * kernelFunction->setArg(3, static_cast<int64_t>(numElements)); |
| 171 | + * |
| 172 | + * // Dispatch for linear workload |
| 173 | + * kernelFunction->dispatchSingle(numElements); |
| 174 | + * @endcode |
| 175 | + */ |
| 176 | +class ETMetalKernelFunction { |
| 177 | + public: |
| 178 | + ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); |
| 179 | + ~ETMetalKernelFunction(); |
| 180 | + |
| 181 | + void startEncoding(); |
| 182 | + void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); |
| 183 | + void setArg(unsigned idx, int64_t val); |
| 184 | + |
| 185 | + void dispatchSingle(uint64_t length); |
| 186 | + void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); |
| 187 | + void dispatchArray(const uint64_t* length, size_t length_size); |
| 188 | + void dispatchArrayWithGroupSize( |
| 189 | + const uint64_t* length, |
| 190 | + size_t length_size, |
| 191 | + const uint64_t* group_size, |
| 192 | + size_t group_size_size); |
| 193 | + |
| 194 | + void runCommandBlock(std::function<void(void)> f); |
| 195 | + |
| 196 | + private: |
| 197 | + MTLComputePipelineState_t cps_; |
| 198 | + MTLFunction_t func_; |
| 199 | + MTLComputeCommandEncoder_t encoder_; |
| 200 | +}; |
| 201 | + |
| 202 | +// ======================= |
| 203 | +// ETMetalStream - Metal command buffer and synchronization management |
| 204 | +// ======================= |
| 205 | + |
| 206 | +/** |
| 207 | + * @class ETMetalStream |
| 208 | + * @brief Manages Metal compute command streams and provides GPU |
| 209 | + * synchronization. |
| 210 | + * |
| 211 | + * This class serves as the central management hub for Metal GPU operations, |
| 212 | + * providing a stream-based abstraction similar to CUDA streams. It handles |
| 213 | + * command buffer lifecycle, compute command encoder management, and various |
| 214 | + * synchronization patterns required for efficient GPU computation. |
| 215 | + * |
| 216 | + * Key features: |
| 217 | + * - Lazy command buffer and encoder creation for optimal resource usage |
| 218 | + * - Thread-safe operations using serial dispatch queues |
| 219 | + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, |
| 220 | + * COMMIT_AND_CONTINUE, etc.) |
| 221 | + * - Kernel coalescing to batch multiple operations efficiently |
| 222 | + * - MPSGraph integration for executing fall back operations (mm, conv, sdpa) |
| 223 | + * - Memory operations (copy, fill) with GPU acceleration via blit encoders |
| 224 | + * |
| 225 | + * The stream follows PyTorch's MPS stream design patterns, providing similar |
| 226 | + * semantics for command buffer management and synchronization. |
| 227 | + * |
| 228 | + * Example usage: |
| 229 | + * @code |
| 230 | + * // Get current stream (typically the default stream) |
| 231 | + * ETMetalStream* stream = getCurrentMetalStream(); |
| 232 | + * |
| 233 | + * // Execute kernel operations (handled automatically) |
| 234 | + * auto kernelFunction = library.getKernelFunction("my_kernel"); |
| 235 | + * kernelFunction->startEncoding(); |
| 236 | + * kernelFunction->setArg(0, inputTensor); |
| 237 | + * kernelFunction->dispatchSingle(numElements); |
| 238 | + * |
| 239 | + * // Synchronize to ensure completion |
| 240 | + * stream->synchronize(SyncType::COMMIT_AND_WAIT); |
| 241 | + * |
| 242 | + * // Copy between GPU buffers using blit encoder |
| 243 | + * stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT); |
| 244 | + * @endcode |
| 245 | + */ |
| 246 | +class ETMetalStream { |
| 247 | + public: |
| 248 | + ETMetalStream(); |
| 249 | + ~ETMetalStream(); |
| 250 | + |
| 251 | + // Get the default stream (singleton) |
| 252 | + static ETMetalStream* getDefaultStream(); |
| 253 | + |
| 254 | + // Device and queue access |
| 255 | + MTLDevice_t device() const { |
| 256 | + return device_; |
| 257 | + } |
| 258 | + MTLCommandQueue_t commandQueue() const { |
| 259 | + return commandQueue_; |
| 260 | + } |
| 261 | + dispatch_queue_t queue() const { |
| 262 | + return serialQueue_; |
| 263 | + } |
| 264 | + |
| 265 | + // Synchronization methods |
| 266 | + void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT); |
| 267 | + void synchronize(); // Overload for backward compatibility |
| 268 | + bool isEmpty() const; |
| 269 | + |
| 270 | + // Command buffer management with lazy creation |
| 271 | + MPSCommandBuffer_t commandBuffer(); |
| 272 | + MTLComputeCommandEncoder_t commandEncoder(); |
| 273 | + |
| 274 | + void endKernelCoalescing(); |
| 275 | + |
| 276 | + // MPSGraph execution |
| 277 | + void executeMPSGraph( |
| 278 | + MPSGraph_t mpsGraph, |
| 279 | + NSDictionary_t feeds, |
| 280 | + NSDictionary_t results, |
| 281 | + SyncType syncType = SyncType::COMMIT_ADAPTIVE); |
| 282 | + |
| 283 | + // Command buffer lifecycle management |
| 284 | + void commitCommandBuffer(MTLCommandBuffer_t commandBuffer); |
| 285 | + void flush(); |
| 286 | + |
| 287 | + // Memory operations |
| 288 | + void fill( |
| 289 | + MTLBuffer_t buffer, |
| 290 | + uint8_t value, |
| 291 | + size_t length, |
| 292 | + size_t offset, |
| 293 | + SyncType syncType = SyncType::NONE); |
| 294 | + void copy( |
| 295 | + MTLBuffer_t srcBuffer, |
| 296 | + MTLBuffer_t dstBuffer, |
| 297 | + size_t length, |
| 298 | + size_t srcOffset, |
| 299 | + size_t dstOffset, |
| 300 | + SyncType syncType = SyncType::NONE); |
| 301 | + |
| 302 | + private: |
| 303 | + // Private synchronization methods |
| 304 | + void commit(); |
| 305 | + void commitAndWait(); |
| 306 | + void commitAndContinue(); |
| 307 | + |
| 308 | + private: |
| 309 | + // Private members |
| 310 | + MTLDevice_t device_; |
| 311 | + MTLCommandQueue_t commandQueue_; |
| 312 | + MPSCommandBuffer_t commandBuffer_; |
| 313 | + MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern |
| 314 | + MTLComputeCommandEncoder_t commandEncoder_; |
| 315 | + dispatch_queue_t serialQueue_; // For thread safety |
| 316 | + |
| 317 | + // Configuration |
| 318 | + bool enableCommitAndContinue_; |
| 319 | + |
| 320 | + // Singleton instance |
| 321 | + static ETMetalStream* defaultStream_; |
| 322 | +}; |
| 323 | + |
| 324 | +// ======================= |
| 325 | +// Global storage management functions |
| 326 | +// ======================= |
| 327 | +void storeFunctionHandle( |
| 328 | + ETMetalKernelFunction* raw_function, |
| 329 | + std::shared_ptr<ETMetalKernelFunction> function_shared_ptr); |
| 330 | +void storeLibraryHandle( |
| 331 | + ETMetalShaderLibrary* raw_library, |
| 332 | + std::unique_ptr<ETMetalShaderLibrary> library); |
| 333 | +bool removeFunctionHandle(ETMetalKernelFunction* raw_function); |
| 334 | +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library); |
| 335 | + |
| 336 | +// ======================= |
| 337 | +// Global stream access functions |
| 338 | +// ======================= |
| 339 | +ETMetalStream* getCurrentMetalStream(); |
| 340 | +void setCurrentMetalStream(ETMetalStream* stream); |
| 341 | + |
| 342 | +// ======================= |
| 343 | +// Metal stream synchronization functions (C++ interface with exceptions) |
| 344 | +// ======================= |
| 345 | +void synchronize_metal_stream(); |
| 346 | +void synchronize_metal_stream_with_type(int sync_type); |
| 347 | + |
| 348 | +// ======================= |
| 349 | +// Metal helper functions (C interface) |
| 350 | +// ======================= |
| 351 | +#ifdef __cplusplus |
| 352 | +extern "C" { |
| 353 | +#endif |
| 354 | + |
| 355 | +// Memory management functions for Metal |
| 356 | +void* metal_allocate_buffer(long bytes); |
| 357 | +bool metal_is_device_pointer(void* ptr); |
| 358 | +int metal_copy_memory( |
| 359 | + void* dst, |
| 360 | + const void* src, |
| 361 | + size_t nbytes, |
| 362 | + bool src_is_device, |
| 363 | + bool dst_is_device); |
| 364 | +void metal_cleanup_resources(); |
| 365 | + |
| 366 | +// Helper functions to access Metal objects |
| 367 | +MTLDevice_t get_metal_device(); |
| 368 | +MTLCommandQueue_t get_metal_command_queue(); |
| 369 | + |
| 370 | +#ifdef __cplusplus |
| 371 | +} |
| 372 | + |
| 373 | +// C++ only - expose the Metal buffer mapping |
| 374 | +#ifdef __OBJC__ |
| 375 | +extern std::unordered_map<void*, MTLBuffer_t> ptr_to_mtl_buffer; |
| 376 | +#endif |
| 377 | + |
| 378 | +#endif |
| 379 | + |
| 380 | +} // namespace metal |
| 381 | +} // namespace backends |
| 382 | +} // namespace executorch |
0 commit comments