Skip to content

Commit 145fe68

Browse files
BoyuanFengdsashidh
authored andcommitted
[Graph Partition] interface for custom cg wrapper (pytorch#162207)
This PR adds an interface to allow users to specify custom cudagraph wrapper. User example: [vllm](vllm-project/vllm#24281) Pull Request resolved: pytorch#162207 Approved by: https://github.com/zou3519, https://github.com/eellison, https://github.com/ProExpertProg
1 parent 5dab669 commit 145fe68

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

torch/_inductor/output_code.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
)
4242
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
4343
from torch._inductor.utils import (
44+
_unstable_customized_partition_wrapper,
4445
align_inputs_from_check_idxs,
4546
BoxedBool,
47+
CUDAGraphWrapperMetadata,
4648
GraphPartitionMap,
4749
InputType,
4850
output_node,
@@ -628,6 +630,23 @@ def post_compile(
628630
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
629631
The results of this function are *not* saved in the cache itself.
630632
"""
633+
if config.graph_partition and _unstable_customized_partition_wrapper.wrapper:
634+
# Mechanically apply user-specified cudagraph wrappers without modification
635+
assert self.recursively_apply_fns is not None
636+
assert self.compiled_fn_runner is not None
637+
num_partitions = len(self.compiled_fn_runner.partitions)
638+
wrapper_metadatas = [
639+
CUDAGraphWrapperMetadata(num_partitions, i)
640+
for i in range(num_partitions)
641+
]
642+
customized_wrapper = _unstable_customized_partition_wrapper.wrapper
643+
customized_wrappers_with_metadata = [
644+
lambda f, m=metadata: customized_wrapper(f, m)
645+
for metadata in wrapper_metadatas
646+
]
647+
self.recursively_apply_fns(customized_wrappers_with_metadata)
648+
return
649+
631650
set_tracing_context_output_strides(example_inputs, self)
632651
assert graph_kwargs["cudagraphs"] is not None
633652
assert graph_kwargs["is_backward"] is not None

torch/_inductor/scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .runtime.runtime_utils import green_text, red_text
5555
from .sizevars import SimplifyIndexing
5656
from .utils import (
57+
_unstable_customized_partition_wrapper,
5758
cache_on_self,
5859
cmp,
5960
device_need_guard,
@@ -4472,7 +4473,10 @@ def should_partition(
44724473
# When not using cudagraphs, keep all kernels in the `call` function
44734474
# instead of graph partition functions, since graph partition only brings
44744475
# benefit to cudagraph
4475-
if not torch._inductor.config.triton.cudagraphs:
4476+
if (
4477+
not torch._inductor.config.triton.cudagraphs
4478+
and _unstable_customized_partition_wrapper.wrapper is None
4479+
):
44764480
return True
44774481

44784482
# avoid duplicating logs when should_partition is called multiple times

torch/_inductor/utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,8 +3395,8 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
33953395
def is_using_cudagraph_partition() -> bool:
33963396
return (
33973397
torch._inductor.config.triton.cudagraphs
3398-
and torch._inductor.config.graph_partition
3399-
)
3398+
or _unstable_customized_partition_wrapper.wrapper is not None
3399+
) and torch._inductor.config.graph_partition
34003400

34013401

34023402
def dtype_from_size(size: int) -> torch.dtype:
@@ -3621,3 +3621,48 @@ def python_subprocess_env() -> dict[str, str]:
36213621
env["PYTHONHOME"] = sysconfig.get_path("data")
36223622

36233623
return env
3624+
3625+
3626+
@dataclasses.dataclass(frozen=True)
3627+
class CUDAGraphWrapperMetadata:
3628+
"""
3629+
Metadata for Customized CUDAGraphWrapper.
3630+
3631+
Currently assumes there is 1 dynamo graph and will extend to
3632+
multiple graphs in the future.
3633+
"""
3634+
3635+
# The number of partitions that are cudagraphable.
3636+
num_partitions: int
3637+
3638+
# Index of the current partition.
3639+
partition_index: int
3640+
3641+
3642+
PartitionFnType = Callable[..., Any]
3643+
CUDAGraphWrapperType = Callable[
3644+
[PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
3645+
]
3646+
3647+
3648+
# only incremented by user call of mark_step_begin
3649+
class CUDAGraphWrapper:
3650+
wrapper: Optional[CUDAGraphWrapperType] = None
3651+
3652+
3653+
# A customized partition wrappers from users. Interface should be:
3654+
#
3655+
# def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
3656+
#
3657+
# Inductor generates N wrapper functions for N partition functions, and mechanically wrap
3658+
# each partition fn with the generated wrapper function. Users need to handle all details
3659+
# such as static inputs, dynamic shapes, etc.
3660+
# Users could customize the wrapper based on the metadata. One example is to have special
3661+
# handle for the first and last wrapper function.
3662+
#
3663+
# Warning: This API is unstable and may change in the future.
3664+
_unstable_customized_partition_wrapper = CUDAGraphWrapper()
3665+
3666+
3667+
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
3668+
_unstable_customized_partition_wrapper.wrapper = wrapper

0 commit comments

Comments
 (0)