@@ -3395,8 +3395,8 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
33953395def is_using_cudagraph_partition () -> bool :
33963396 return (
33973397 torch ._inductor .config .triton .cudagraphs
3398- and torch . _inductor . config . graph_partition
3399- )
3398+ or _unstable_customized_partition_wrapper . wrapper is not None
3399+ ) and torch . _inductor . config . graph_partition
34003400
34013401
34023402def dtype_from_size (size : int ) -> torch .dtype :
@@ -3621,3 +3621,48 @@ def python_subprocess_env() -> dict[str, str]:
36213621 env ["PYTHONHOME" ] = sysconfig .get_path ("data" )
36223622
36233623 return env
3624+
3625+
3626+ @dataclasses .dataclass (frozen = True )
3627+ class CUDAGraphWrapperMetadata :
3628+ """
3629+ Metadata for Customized CUDAGraphWrapper.
3630+
3631+ Currently assumes there is 1 dynamo graph and will extend to
3632+ multiple graphs in the future.
3633+ """
3634+
3635+ # The number of partitions that are cudagraphable.
3636+ num_partitions : int
3637+
3638+ # Index of the current partition.
3639+ partition_index : int
3640+
3641+
3642+ PartitionFnType = Callable [..., Any ]
3643+ CUDAGraphWrapperType = Callable [
3644+ [PartitionFnType , CUDAGraphWrapperMetadata ], PartitionFnType
3645+ ]
3646+
3647+
3648+ # only incremented by user call of mark_step_begin
3649+ class CUDAGraphWrapper :
3650+ wrapper : Optional [CUDAGraphWrapperType ] = None
3651+
3652+
3653+ # A customized partition wrappers from users. Interface should be:
3654+ #
3655+ # def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
3656+ #
3657+ # Inductor generates N wrapper functions for N partition functions, and mechanically wrap
3658+ # each partition fn with the generated wrapper function. Users need to handle all details
3659+ # such as static inputs, dynamic shapes, etc.
3660+ # Users could customize the wrapper based on the metadata. One example is to have special
3661+ # handle for the first and last wrapper function.
3662+ #
3663+ # Warning: This API is unstable and may change in the future.
3664+ _unstable_customized_partition_wrapper = CUDAGraphWrapper ()
3665+
3666+
3667+ def set_customized_partition_wrappers (wrapper : CUDAGraphWrapperType ) -> None :
3668+ _unstable_customized_partition_wrapper .wrapper = wrapper
0 commit comments