|
80 | 80 | FALLBACK_ALLOW_LIST, |
81 | 81 | fallback_handler, |
82 | 82 | fallback_node_due_to_unsupported_type, |
| 83 | + get_layout_constraint_tag, |
83 | 84 | lowerings, |
84 | 85 | make_fallback, |
85 | 86 | maybe_layout_constraints, |
86 | 87 | needs_realized_inputs, |
87 | 88 | require_contiguous, |
| 89 | + tag_to_layout_constraint, |
88 | 90 | unsupported_output_tensor, |
89 | 91 | ) |
90 | 92 | from .runtime import autotune_cache |
@@ -244,6 +246,14 @@ def _get_overload_packet( |
244 | 246 | cur.meta["dislike_padding"] = True |
245 | 247 | continue |
246 | 248 |
|
| 249 | + if ( |
| 250 | + isinstance(cur.target, torch._ops.OpOverload) |
| 251 | + and get_layout_constraint_tag(cur.target) |
| 252 | + == torch._C.Tag.needs_exact_strides |
| 253 | + ): |
| 254 | + cur.meta["dislike_padding"] = True |
| 255 | + continue |
| 256 | + |
247 | 257 | op = _get_overload_packet(cur) |
248 | 258 | if not op: |
249 | 259 | continue |
@@ -1150,34 +1160,26 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> |
1150 | 1160 | error.operator_str(target, args, kwargs), |
1151 | 1161 | ) |
1152 | 1162 |
|
1153 | | - # use contiguous unless the (custom) op asks something else |
1154 | | - # explicitly |
1155 | | - if torch._C.Tag.needs_exact_strides in target.tags: |
1156 | | - decided_constraint = constrain_to_fake_tensors # type: ignore[assignment] |
1157 | | - elif torch._C.Tag.needs_fixed_stride_order in target.tags: |
1158 | | - decided_constraint = constrain_to_fx_strides # type: ignore[assignment] |
1159 | | - elif torch._C.Tag.flexible_layout in target.tags: |
1160 | | - decided_constraint = None # type: ignore[assignment] |
1161 | | - else: |
1162 | | - # If there are no tags, we do different things depending on |
1163 | | - # if it's a builtin ATen/prim ops or custom ops. |
1164 | | - # For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452 |
1165 | | - # For custom ops, we constrain_to_fx_strides to maintain the |
1166 | | - # behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356 |
| 1163 | + tag = get_layout_constraint_tag(target, with_default=False) |
| 1164 | + if ( |
| 1165 | + tag is None |
| 1166 | + and torch._library.utils.is_builtin(target) |
| 1167 | + and self.is_backward |
| 1168 | + ): |
| 1169 | + # for implicit fallback ATen ops during backward, if there |
| 1170 | + # is no layout constraint tag, we conservatively require contiguous |
| 1171 | + # input since some eager kernels do not |
| 1172 | + # support non-contiguous inputs. Otherwise they may silently cause |
| 1173 | + # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 |
| 1174 | + # We only do this For ATen ops and for backward. |
1167 | 1175 | # |
1168 | | - # For ATen ops, only apply the constraint for backward |
1169 | | - # ops since fwd ops should work for any strides. |
1170 | | - if torch._library.utils.is_builtin(target) and self.is_backward: |
1171 | | - decided_constraint = require_contiguous # type: ignore[assignment] |
1172 | | - else: |
1173 | | - # maybe_layout_constraints will decide the layout constraint for the custom op |
1174 | | - # lazily |
1175 | | - decided_constraint = None # type: ignore[assignment] |
1176 | | - |
1177 | | - # for implicitly fallback ops, we conservatively requires |
1178 | | - # contiguous input since some eager kernels does not |
1179 | | - # support non-contiguous inputs. They may silently cause |
1180 | | - # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 |
| 1176 | + # TODO: should really switch to "needs_fixed_stride" constraint on these |
| 1177 | + # and identify them one by one. |
| 1178 | + decided_constraint = require_contiguous # type: ignore[assignment] |
| 1179 | + else: |
| 1180 | + tag = get_layout_constraint_tag(target, with_default=True) |
| 1181 | + decided_constraint = tag_to_layout_constraint(tag) |
| 1182 | + |
1181 | 1183 | make_fallback(target, layout_constraint=decided_constraint) |
1182 | 1184 |
|
1183 | 1185 | elif get_decompositions([target]): |
|
0 commit comments