Skip to content

Commit a54a784

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][dicts] Consolidate dict(..) construction (pytorch#144342)
Pull Request resolved: pytorch#144342 Approved by: https://github.com/StrongerXi
1 parent 0373cd9 commit a54a784

File tree

5 files changed

+65
-84
lines changed

5 files changed

+65
-84
lines changed

test/dynamo/test_higher_order_ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,8 +1810,8 @@ def forward(self, L_x_ : torch.Tensor):
18101810
getitem_4 = map_impl[3]
18111811
getitem_5 = map_impl[4]
18121812
getitem_6 = map_impl[5]
1813-
getitem_7 = map_impl[6]; map_impl = None
1814-
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""",
1813+
value = map_impl[6]; map_impl = None
1814+
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, value)""",
18151815
)
18161816
self.assertExpectedInline(
18171817
body_graph,
@@ -2632,8 +2632,8 @@ def forward(self, L_x_: "f32[3]"):
26322632
26332633
wrap_body_0 = self.wrap_body_0
26342634
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
2635-
getitem: "f32[3]" = wrap[0]; wrap = None
2636-
return (getitem,)
2635+
value: "f32[3]" = wrap[0]; wrap = None
2636+
return (value,)
26372637
26382638
class wrap_body_0(torch.nn.Module):
26392639
def forward(self, l_x_: "f32[3]"):
@@ -4209,8 +4209,8 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
42094209
child_1: "f32[5]" = child.sin()
42104210
child_2: "f32[5]" = child.cos(); child = None
42114211
4212-
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
4213-
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
4212+
value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
4213+
value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
42144214
42154215
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
42164216
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
@@ -4219,7 +4219,7 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"):
42194219
42204220
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None
42214221
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
4222-
return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
4222+
return (value, value_1, getitem)
42234223
""",
42244224
)
42254225

test/functorch/test_control_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4088,9 +4088,9 @@ def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_py
40884088
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None
40894089
getitem = while_loop[0]
40904090
getitem_1 = while_loop[1]
4091-
getitem_2 = while_loop[2]
4092-
getitem_3 = while_loop[3]; while_loop = None
4093-
return (getitem, getitem_1, getitem_2, getitem_3)""", # noqa: B950
4091+
value = while_loop[2]
4092+
value_1 = while_loop[3]; while_loop = None
4093+
return (getitem, getitem_1, value, value_1)""", # noqa: B950
40944094
)
40954095

40964096
def _wrap_with_functionalize(self, fn, func_type):

torch/_dynamo/polyfills/__init__.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# mypy: allow-untyped-defs
1010

1111
from itertools import repeat as _repeat
12-
from typing import Any, Callable, List, Sequence, TYPE_CHECKING
12+
from typing import Any, Callable, List, MutableMapping, Sequence, TYPE_CHECKING
1313

1414
import torch
1515

@@ -146,6 +146,30 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
146146
return obj
147147

148148

149+
# Used with something like dict(obj)
150+
def construct_dict(cls, /, *args, **kwargs):
151+
dst = cls.__new__(cls)
152+
153+
if args:
154+
src = args[0]
155+
156+
# Ensure that the overridden __iter__ method is invoked
157+
if isinstance(src, (dict, MutableMapping)):
158+
for key in src:
159+
# This will inline the __getitem__ of the src object
160+
dst[key] = src[key]
161+
else:
162+
# likely a sequence like tuple of pairs
163+
for key, value in src:
164+
dst[key] = value
165+
166+
if kwargs:
167+
for key in kwargs:
168+
dst[key] = kwargs[key]
169+
170+
return dst
171+
172+
149173
def foreach_map_fn(*args):
150174
op = args[0]
151175
new_args: List[Any] = []

torch/_dynamo/variables/builtin.py

Lines changed: 11 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
import operator
1010
import types
1111
from collections import defaultdict, OrderedDict
12-
from collections.abc import KeysView, MutableMapping
12+
from collections.abc import KeysView
1313
from typing import Dict, List, TYPE_CHECKING
1414

1515
import torch
1616
from torch import sym_float, sym_int
1717
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
1818

19-
from .. import config, variables
19+
from .. import config, polyfills, variables
2020
from ..exc import (
2121
AttributeMutationError,
2222
unimplemented,
@@ -38,7 +38,6 @@
3838
check_numpy_ndarray_args,
3939
check_unspec_or_constant_args,
4040
check_unspec_python_args,
41-
does_not_override_dict_iter_methods,
4241
extract_fake_example_value,
4342
get_fake_value,
4443
guard_if_dyn,
@@ -1026,6 +1025,10 @@ def call_method(
10261025
return tx.output.side_effects.track_object_new_from_user_defined_class(
10271026
args[0]
10281027
)
1028+
if self.fn is dict and name == "__new__":
1029+
assert len(args) == 1
1030+
assert len(kwargs) == 0
1031+
return ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
10291032
if self.fn is dict and name == "fromkeys":
10301033
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
10311034
return super().call_method(tx, name, args, kwargs)
@@ -1370,73 +1373,11 @@ def call_dict(self, tx: "InstructionTranslator", *args, **kwargs):
13701373

13711374
@staticmethod
13721375
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
1373-
if not kwargs:
1374-
if not args:
1375-
args = ({},)
1376-
assert len(args) == 1
1377-
arg = args[0]
1378-
if isinstance(arg, dict):
1379-
return ConstDictVariable(
1380-
arg, user_cls, mutation_type=ValueMutationNew()
1381-
)
1382-
elif isinstance(arg, variables.ConstDictVariable):
1383-
return arg.clone(
1384-
user_cls=user_cls, source=None, mutation_type=ValueMutationNew()
1385-
)
1386-
elif isinstance(
1387-
arg,
1388-
(
1389-
ListVariable,
1390-
TupleVariable,
1391-
ListIteratorVariable,
1392-
variables.IteratorVariable,
1393-
),
1394-
):
1395-
items = dict(
1396-
x.force_unpack_var_sequence(tx)
1397-
for x in arg.force_unpack_var_sequence(tx)
1398-
)
1399-
return ConstDictVariable(
1400-
items, user_cls, mutation_type=ValueMutationNew()
1401-
)
1402-
elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping):
1403-
# This handles all other `MutableMapping` instances; for
1404-
# example, TensorDict which derives from MutableMapping.
1405-
#
1406-
# TODO(#142414) `hasattr(arg, 'value')` is a local workaround
1407-
# for lack of generall multiple inheritance in Dynamo. We can't
1408-
# use `isinstance(arg, MutableMappingVariable)` here because
1409-
# `arg` could be, e.g., a `UnspecializedNNModuleVariable` when
1410-
# `arg.value` has multiple inheritace.
1411-
if does_not_override_dict_iter_methods(type(arg.value)):
1412-
# In this case, `arg.value.items()` uses the default impls,
1413-
# which are implemented in C and cannot be traced, so we
1414-
# will have to manually construct the items. This is safe
1415-
# because we know they are side-effect free.
1416-
#
1417-
# Mutation tracked by Dynamo isn't reflected in `arg.value`,
1418-
# so we can't handle such cases by just calling
1419-
# `arg.value.items()`
1420-
if tx.output.side_effects.has_pending_mutation(arg):
1421-
unimplemented(
1422-
f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
1423-
)
1424-
new_dict = dict(arg.value.items())
1425-
return VariableTracker.build(tx, new_dict)
1426-
else:
1427-
func_var = arg.var_getattr(tx, "items")
1428-
if not isinstance(func_var, variables.UserFunctionVariable):
1429-
unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}")
1430-
out = tx.inline_user_function_return(func_var, args, kwargs)
1431-
if isinstance(out, ConstDictVariable):
1432-
return out
1433-
return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out)
1434-
elif not args and kwargs:
1435-
items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
1436-
return variables.ConstDictVariable(
1437-
items, user_cls=user_cls, mutation_type=ValueMutationNew()
1438-
)
1439-
unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
1376+
return tx.inline_user_function_return(
1377+
VariableTracker.build(tx, polyfills.construct_dict),
1378+
[VariableTracker.build(tx, user_cls), *args],
1379+
kwargs,
1380+
)
14401381

14411382
@staticmethod
14421383
def call_custom_dict_fromkeys(

torch/_dynamo/variables/user_defined.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
217217
source
218218
and not inspect.ismethoddescriptor(obj)
219219
and not is_wrapper_or_member_descriptor(obj)
220+
and obj is not dict.__new__
220221
):
221222
return VariableTracker.build(tx, obj, source)
222223

@@ -321,6 +322,12 @@ def call_method(
321322
return variables.ConstantVariable(self.value == args[0].value)
322323
elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
323324
return variables.ConstantVariable(self.value != args[0].value)
325+
elif name == "__new__" and self.value is collections.OrderedDict:
326+
assert len(args) == 1
327+
assert len(kwargs) == 0
328+
return variables.ConstDictVariable(
329+
{}, collections.OrderedDict, mutation_type=ValueMutationNew()
330+
)
324331

325332
return super().call_method(tx, name, args, kwargs)
326333

@@ -332,7 +339,6 @@ def call_function(
332339
) -> "VariableTracker":
333340
from ..side_effects import SideEffects
334341
from .builder import wrap_fx_proxy
335-
from .builtin import BuiltinVariable
336342

337343
constant_args = check_constant_args(args, kwargs)
338344

@@ -352,8 +358,10 @@ def call_function(
352358

353359
return NullContextVariable()
354360
elif self.value is collections.OrderedDict:
355-
return BuiltinVariable.call_custom_dict(
356-
tx, collections.OrderedDict, *args, **kwargs
361+
return tx.inline_user_function_return(
362+
VariableTracker.build(tx, polyfills.construct_dict),
363+
[self, *args],
364+
kwargs,
357365
)
358366
elif (
359367
self.value is collections.defaultdict
@@ -1418,6 +1426,14 @@ def call_method(
14181426
return self._dict_vt.call_method(tx, name, args, kwargs)
14191427
return super().call_method(tx, name, args, kwargs)
14201428

1429+
def unpack_var_sequence(self, tx):
1430+
if type(self.value).__iter__ in (
1431+
dict.__iter__,
1432+
collections.OrderedDict.__iter__,
1433+
):
1434+
return self._dict_vt.unpack_var_sequence(tx)
1435+
raise NotImplementedError
1436+
14211437

14221438
class MutableMappingVariable(UserDefinedObjectVariable):
14231439
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields

0 commit comments

Comments
 (0)