|
9 | 9 | import operator |
10 | 10 | import types |
11 | 11 | from collections import defaultdict, OrderedDict |
12 | | -from collections.abc import KeysView, MutableMapping |
| 12 | +from collections.abc import KeysView |
13 | 13 | from typing import Dict, List, TYPE_CHECKING |
14 | 14 |
|
15 | 15 | import torch |
16 | 16 | from torch import sym_float, sym_int |
17 | 17 | from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
18 | 18 |
|
19 | | -from .. import config, variables |
| 19 | +from .. import config, polyfills, variables |
20 | 20 | from ..exc import ( |
21 | 21 | AttributeMutationError, |
22 | 22 | unimplemented, |
|
38 | 38 | check_numpy_ndarray_args, |
39 | 39 | check_unspec_or_constant_args, |
40 | 40 | check_unspec_python_args, |
41 | | - does_not_override_dict_iter_methods, |
42 | 41 | extract_fake_example_value, |
43 | 42 | get_fake_value, |
44 | 43 | guard_if_dyn, |
@@ -1026,6 +1025,10 @@ def call_method( |
1026 | 1025 | return tx.output.side_effects.track_object_new_from_user_defined_class( |
1027 | 1026 | args[0] |
1028 | 1027 | ) |
| 1028 | + if self.fn is dict and name == "__new__": |
| 1029 | + assert len(args) == 1 |
| 1030 | + assert len(kwargs) == 0 |
| 1031 | + return ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) |
1029 | 1032 | if self.fn is dict and name == "fromkeys": |
1030 | 1033 | return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) |
1031 | 1034 | return super().call_method(tx, name, args, kwargs) |
@@ -1370,73 +1373,11 @@ def call_dict(self, tx: "InstructionTranslator", *args, **kwargs): |
1370 | 1373 |
|
1371 | 1374 | @staticmethod |
1372 | 1375 | def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): |
1373 | | - if not kwargs: |
1374 | | - if not args: |
1375 | | - args = ({},) |
1376 | | - assert len(args) == 1 |
1377 | | - arg = args[0] |
1378 | | - if isinstance(arg, dict): |
1379 | | - return ConstDictVariable( |
1380 | | - arg, user_cls, mutation_type=ValueMutationNew() |
1381 | | - ) |
1382 | | - elif isinstance(arg, variables.ConstDictVariable): |
1383 | | - return arg.clone( |
1384 | | - user_cls=user_cls, source=None, mutation_type=ValueMutationNew() |
1385 | | - ) |
1386 | | - elif isinstance( |
1387 | | - arg, |
1388 | | - ( |
1389 | | - ListVariable, |
1390 | | - TupleVariable, |
1391 | | - ListIteratorVariable, |
1392 | | - variables.IteratorVariable, |
1393 | | - ), |
1394 | | - ): |
1395 | | - items = dict( |
1396 | | - x.force_unpack_var_sequence(tx) |
1397 | | - for x in arg.force_unpack_var_sequence(tx) |
1398 | | - ) |
1399 | | - return ConstDictVariable( |
1400 | | - items, user_cls, mutation_type=ValueMutationNew() |
1401 | | - ) |
1402 | | - elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping): |
1403 | | - # This handles all other `MutableMapping` instances; for |
1404 | | - # example, TensorDict which derives from MutableMapping. |
1405 | | - # |
1406 | | - # TODO(#142414) `hasattr(arg, 'value')` is a local workaround |
1407 | | - # for lack of generall multiple inheritance in Dynamo. We can't |
1408 | | - # use `isinstance(arg, MutableMappingVariable)` here because |
1409 | | - # `arg` could be, e.g., a `UnspecializedNNModuleVariable` when |
1410 | | - # `arg.value` has multiple inheritace. |
1411 | | - if does_not_override_dict_iter_methods(type(arg.value)): |
1412 | | - # In this case, `arg.value.items()` uses the default impls, |
1413 | | - # which are implemented in C and cannot be traced, so we |
1414 | | - # will have to manually construct the items. This is safe |
1415 | | - # because we know they are side-effect free. |
1416 | | - # |
1417 | | - # Mutation tracked by Dynamo isn't reflected in `arg.value`, |
1418 | | - # so we can't handle such cases by just calling |
1419 | | - # `arg.value.items()` |
1420 | | - if tx.output.side_effects.has_pending_mutation(arg): |
1421 | | - unimplemented( |
1422 | | - f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated" |
1423 | | - ) |
1424 | | - new_dict = dict(arg.value.items()) |
1425 | | - return VariableTracker.build(tx, new_dict) |
1426 | | - else: |
1427 | | - func_var = arg.var_getattr(tx, "items") |
1428 | | - if not isinstance(func_var, variables.UserFunctionVariable): |
1429 | | - unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}") |
1430 | | - out = tx.inline_user_function_return(func_var, args, kwargs) |
1431 | | - if isinstance(out, ConstDictVariable): |
1432 | | - return out |
1433 | | - return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out) |
1434 | | - elif not args and kwargs: |
1435 | | - items = {ConstantVariable.create(k): v for k, v in kwargs.items()} |
1436 | | - return variables.ConstDictVariable( |
1437 | | - items, user_cls=user_cls, mutation_type=ValueMutationNew() |
1438 | | - ) |
1439 | | - unimplemented(f"{user_cls.__name__}(): {args} {kwargs}") |
| 1376 | + return tx.inline_user_function_return( |
| 1377 | + VariableTracker.build(tx, polyfills.construct_dict), |
| 1378 | + [VariableTracker.build(tx, user_cls), *args], |
| 1379 | + kwargs, |
| 1380 | + ) |
1440 | 1381 |
|
1441 | 1382 | @staticmethod |
1442 | 1383 | def call_custom_dict_fromkeys( |
|
0 commit comments