Skip to content

Commit 2e7c9d3

Browse files
zou3519pytorchmergebot
authored andcommitted
Refactor layout constraint selection logic (pytorch#148104)
This PR: - cleans up some existing comments that don't make sense anymore - hooks up the "custom_op_default_layout_constraint" back (that seems to have broken) - cleans up the "lazy registration path" which seems to never get hit anymore - adds dislike_padding to nodes that require exact strides Test Plan: - tests + CI disable padding Pull Request resolved: pytorch#148104 Approved by: https://github.com/shunting314, https://github.com/eellison ghstack dependencies: pytorch#150495
1 parent 44deb67 commit 2e7c9d3

File tree

4 files changed

+57
-50
lines changed

4 files changed

+57
-50
lines changed

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def prologue_fusion_enabled() -> bool:
126126
# If the custom op does not have a layout constraint tag already
127127
# then we assume the following applies.
128128
custom_op_default_layout_constraint: Literal[
129-
"needs_fixed_stride_order", "flexible_layout"
129+
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
130130
] = "needs_fixed_stride_order"
131131

132132
# The default layout constraint for user-defined triton kernels.

torch/_inductor/graph.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@
8080
FALLBACK_ALLOW_LIST,
8181
fallback_handler,
8282
fallback_node_due_to_unsupported_type,
83+
get_layout_constraint_tag,
8384
lowerings,
8485
make_fallback,
8586
maybe_layout_constraints,
8687
needs_realized_inputs,
8788
require_contiguous,
89+
tag_to_layout_constraint,
8890
unsupported_output_tensor,
8991
)
9092
from .runtime import autotune_cache
@@ -244,6 +246,14 @@ def _get_overload_packet(
244246
cur.meta["dislike_padding"] = True
245247
continue
246248

249+
if (
250+
isinstance(cur.target, torch._ops.OpOverload)
251+
and get_layout_constraint_tag(cur.target)
252+
== torch._C.Tag.needs_exact_strides
253+
):
254+
cur.meta["dislike_padding"] = True
255+
continue
256+
247257
op = _get_overload_packet(cur)
248258
if not op:
249259
continue
@@ -1150,34 +1160,26 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) ->
11501160
error.operator_str(target, args, kwargs),
11511161
)
11521162

1153-
# use contiguous unless the (custom) op asks something else
1154-
# explicitly
1155-
if torch._C.Tag.needs_exact_strides in target.tags:
1156-
decided_constraint = constrain_to_fake_tensors # type: ignore[assignment]
1157-
elif torch._C.Tag.needs_fixed_stride_order in target.tags:
1158-
decided_constraint = constrain_to_fx_strides # type: ignore[assignment]
1159-
elif torch._C.Tag.flexible_layout in target.tags:
1160-
decided_constraint = None # type: ignore[assignment]
1161-
else:
1162-
# If there are no tags, we do different things depending on
1163-
# if it's a builtin ATen/prim ops or custom ops.
1164-
# For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452
1165-
# For custom ops, we constrain_to_fx_strides to maintain the
1166-
# behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356
1163+
tag = get_layout_constraint_tag(target, with_default=False)
1164+
if (
1165+
tag is None
1166+
and torch._library.utils.is_builtin(target)
1167+
and self.is_backward
1168+
):
1169+
# for implicit fallback ATen ops during backward, if there
1170+
# is no layout constraint tag, we conservatively require contiguous
1171+
# input since some eager kernels do not
1172+
# support non-contiguous inputs. Otherwise they may silently cause
1173+
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
1174+
# We only do this For ATen ops and for backward.
11671175
#
1168-
# For ATen ops, only apply the constraint for backward
1169-
# ops since fwd ops should work for any strides.
1170-
if torch._library.utils.is_builtin(target) and self.is_backward:
1171-
decided_constraint = require_contiguous # type: ignore[assignment]
1172-
else:
1173-
# maybe_layout_constraints will decide the layout constraint for the custom op
1174-
# lazily
1175-
decided_constraint = None # type: ignore[assignment]
1176-
1177-
# for implicitly fallback ops, we conservatively requires
1178-
# contiguous input since some eager kernels does not
1179-
# support non-contiguous inputs. They may silently cause
1180-
# accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
1176+
# TODO: should really switch to "needs_fixed_stride" constraint on these
1177+
# and identify them one by one.
1178+
decided_constraint = require_contiguous # type: ignore[assignment]
1179+
else:
1180+
tag = get_layout_constraint_tag(target, with_default=True)
1181+
decided_constraint = tag_to_layout_constraint(tag)
1182+
11811183
make_fallback(target, layout_constraint=decided_constraint)
11821184

11831185
elif get_decompositions([target]):

torch/_inductor/lowering.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -157,37 +157,40 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
157157
return None
158158
if fn in _maybe_layout_constraints:
159159
return _maybe_layout_constraints[fn]
160-
# OpOverload with custom lowerings override tag-based layout constraints
161-
if fn in lowerings:
162-
_maybe_layout_constraints[fn] = None
163-
return None
164-
# We lazily register tag-based layout constraints.
165-
166-
def handle_layout_constraint_tag(tag):
167-
if tag is torch._C.Tag.needs_fixed_stride_order:
168-
_maybe_layout_constraints[fn] = constrain_to_fx_strides
169-
return _maybe_layout_constraints[fn]
170-
elif tag is torch._C.Tag.flexible_layout:
171-
_maybe_layout_constraints[fn] = None
172-
return None
173-
else:
174-
raise AssertionError(f"Unknown layout constraint tag: {tag}")
160+
return None
161+
175162

176-
tag = get_layout_constraint_tag(fn)
177-
return handle_layout_constraint_tag(tag)
163+
tags_by_priority = [
164+
torch._C.Tag.needs_exact_strides,
165+
torch._C.Tag.needs_fixed_stride_order,
166+
torch._C.Tag.flexible_layout,
167+
]
178168

179169

180-
def get_layout_constraint_tag(fn):
170+
def get_layout_constraint_tag(fn, *, with_default=True):
181171
tags_by_priority = [
172+
torch._C.Tag.needs_exact_strides,
182173
torch._C.Tag.needs_fixed_stride_order,
183174
torch._C.Tag.flexible_layout,
184175
]
185176
for tag in tags_by_priority:
186177
if tag in fn.tags:
187178
return tag
188-
if torch._library.utils.is_builtin(fn):
189-
return torch._C.Tag.flexible_layout
190-
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
179+
if with_default:
180+
if torch._library.utils.is_builtin(fn):
181+
return torch._C.Tag.flexible_layout
182+
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
183+
return None
184+
185+
186+
def tag_to_layout_constraint(tag):
187+
if tag == torch._C.Tag.needs_exact_strides:
188+
return constrain_to_fake_tensors
189+
if tag == torch._C.Tag.needs_fixed_stride_order:
190+
return constrain_to_fx_strides
191+
if tag == torch._C.Tag.flexible_layout:
192+
return None
193+
raise AssertionError(f"Unknown layout constraint tag: {tag}")
191194

192195

193196
def assert_nyi(cond, msg):

torch/fx/experimental/proxy_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,9 @@ def _should_save_eager_input_vals(
11691169
f"propagate the FakeTensor vals. Please file an issue."
11701170
)
11711171
if isinstance(target, torch._ops.OpOverload):
1172-
return torch._C.Tag.needs_exact_strides in target.tags
1172+
from torch._inductor.lowering import get_layout_constraint_tag
1173+
1174+
return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
11731175
return False
11741176

11751177

0 commit comments

Comments
 (0)