|
5 | 5 | import torch |
6 | 6 |
|
7 | 7 | from vllm.logger import init_logger |
| 8 | +from vllm.utils import _is_torch_equal |
8 | 9 |
|
9 | 10 | logger = init_logger(__name__) |
10 | 11 |
|
|
21 | 22 | os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" |
22 | 23 | # see https://github.com/vllm-project/vllm/issues/10619 |
23 | 24 | torch._inductor.config.compile_threads = 1 |
| 25 | + |
| 26 | +# =================================================== |
| 27 | +# torch 2.9 Inductor PythonWrapperCodegen monkeypatch |
| 28 | +# =================================================== |
| 29 | +# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around |
| 30 | +# a test failure for test_multi_graph_piecewise_compile_outputs_equal. |
| 31 | +# For more context, see https://github.com/pytorch/pytorch/pull/165514. |
| 32 | + |
| 33 | + |
| 34 | +def memory_plan_reuse_patched(self): |
| 35 | + import torch._inductor.ir as ir |
| 36 | + from torch._inductor.codegen.wrapper import ( |
| 37 | + EnterSubgraphLine, |
| 38 | + ExitSubgraphLine, |
| 39 | + MemoryPlanningLine, |
| 40 | + MemoryPlanningState, |
| 41 | + SubgraphPythonWrapperCodegen, |
| 42 | + ) |
| 43 | + from torch._inductor.virtualized import V |
| 44 | + |
| 45 | + def get_output_names(graph_outputs) -> list[str]: |
| 46 | + import itertools |
| 47 | + |
| 48 | + names = [] |
| 49 | + shape_counter = itertools.count(0) |
| 50 | + none_counter = itertools.count(0) |
| 51 | + for node in graph_outputs: |
| 52 | + if isinstance(node, ir.NoneAsConstantBuffer): |
| 53 | + names.append(f"{V.graph.name}_none{next(none_counter)}") |
| 54 | + elif isinstance(node, ir.ShapeAsConstantBuffer): |
| 55 | + names.append(f"{V.graph.name}_shape{next(shape_counter)}") |
| 56 | + else: |
| 57 | + names.append(node.get_name()) |
| 58 | + return names |
| 59 | + |
| 60 | + if ( |
| 61 | + isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen) |
| 62 | + and V.graph.wrapper_code.partition_signatures is not None |
| 63 | + ): |
| 64 | + out_names = get_output_names( |
| 65 | + V.graph.wrapper_code.partition_signatures.output_nodes |
| 66 | + ) |
| 67 | + else: |
| 68 | + out_names = V.graph.get_output_names() |
| 69 | + |
| 70 | + while ( |
| 71 | + self.lines |
| 72 | + and isinstance(self.lines[-1], MemoryPlanningLine) |
| 73 | + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] |
| 74 | + ): |
| 75 | + # these lines will be pointless |
| 76 | + self.lines.pop() |
| 77 | + |
| 78 | + # codegen allocations in two passes |
| 79 | + planning_states = [MemoryPlanningState()] |
| 80 | + past_planning_states = [] |
| 81 | + for i in range(len(self.lines)): |
| 82 | + line = self.lines[i] |
| 83 | + if isinstance(line, MemoryPlanningLine): |
| 84 | + self.lines[i] = line.plan(planning_states[-1]) |
| 85 | + elif isinstance(line, EnterSubgraphLine): |
| 86 | + planning_states.append(MemoryPlanningState()) |
| 87 | + elif isinstance(line, ExitSubgraphLine): |
| 88 | + past_planning_states.append(planning_states.pop()) |
| 89 | + past_planning_states.append(planning_states.pop()) |
| 90 | + assert len(planning_states) == 0 |
| 91 | + |
| 92 | + |
| 93 | +if _is_torch_equal("2.9.0"): |
| 94 | + from torch._inductor.codegen.wrapper import PythonWrapperCodegen |
| 95 | + |
| 96 | + PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched |
0 commit comments