8787 TensorSpec ,
8888)
8989from executorch .exir .types import LeafValueSpec , ValueSpec
90+ from torch ._subclasses .fake_tensor import FakeTensor
9091
9192from torch .export .exported_program import ExportedProgram
9293from torch .utils import _pytree as pytree
@@ -933,6 +934,35 @@ def _emit_argument(
933934 return arg
934935 return self ._emit_evalue (self ._constant_to_evalue (arg , arg_type ))
935936
937+ def _get_sym_ret (
938+ self ,
939+ val : Tuple [Union [torch .SymInt , torch .BoolType , torch .FloatType , FakeTensor ]],
940+ ) -> Optional [_AbstractValue ]:
941+ """
942+ Returns the emit ret for sym value.
943+ """
944+ ret = None
945+ if isinstance (val , torch .SymInt ):
946+ ret = self ._emit_evalue (EValue (Int (0 )))
947+ elif isinstance (val , torch .BoolType ):
948+ ret = self ._emit_evalue (EValue (Bool (False )))
949+ elif isinstance (val , torch .FloatType ):
950+ ret = self ._emit_evalue (EValue (Double (0 )))
951+ return ret
952+
953+ def _get_sym_and_fake_tensor_ret (
954+ self ,
955+ val : Tuple [Union [torch .SymInt , torch .BoolType , torch .FloatType , FakeTensor ]],
956+ spec : TensorSpec ,
957+ ) -> Union [List [_AbstractValue ], _AbstractValue , Tuple [_AbstractValue , ...]]:
958+ # Try to get the ret if it's a sym value.
959+ ret = self ._get_sym_ret (val )
960+ # If the ret is None, it means that the val is not a sym value, but a regular tensor
961+ if ret is None :
962+ ret = self ._emit_spec (spec )
963+ assert ret is not None , "Can't have a None ret"
964+ return ret
965+
936966 def _emit_delegate (
937967 self ,
938968 lowered_module : "LoweredBackendModule" , # noqa
@@ -944,7 +974,40 @@ def _emit_delegate(
944974 processed_bytes = lowered_module .processed_bytes
945975
946976 delegate_index = self .emitter_state .delegate_cache .get (processed_bytes )
947- delegate_ret = self ._emit_spec (self .node .meta ["spec" ])
977+ delegate_ret = None
978+
979+ if isinstance (self .node .meta ["spec" ], list ):
980+ delegate_ret = []
981+ for index , _ in enumerate (self .node .meta ["val" ]):
982+ ret = self ._get_sym_and_fake_tensor_ret (
983+ self .node .meta ["val" ][index ], self .node .meta ["spec" ][index ]
984+ )
985+ delegate_ret .append (ret )
986+ elif isinstance (self .node .meta ["spec" ], tuple ):
987+ if isinstance (self .node .meta ["val" ], FakeTensor ):
988+ # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor
989+ ret = self ._get_sym_and_fake_tensor_ret (
990+ self .node .meta ["val" ], self .node .meta ["spec" ][0 ]
991+ )
992+ delegate_ret = (ret ,)
993+ else :
994+ delegate_ret = []
995+ for index , _ in enumerate (self .node .meta ["val" ]):
996+ ret = self ._get_sym_and_fake_tensor_ret (
997+ self .node .meta ["val" ][index ], self .node .meta ["spec" ][index ]
998+ )
999+ delegate_ret .append (ret )
1000+ delegate_ret = tuple (delegate_ret )
1001+ elif isinstance (self .node .meta ["spec" ], TensorSpec ):
1002+ ret = self ._get_sym_and_fake_tensor_ret (
1003+ self .node .meta ["val" ], self .node .meta ["spec" ]
1004+ )
1005+ delegate_ret = ret
1006+ else :
1007+ raise NotImplementedError (
1008+ f"self.node.meta['spec'] { type (self .node .meta ['spec' ])} is not supported"
1009+ )
1010+ assert delegate_ret is not None , "Can't have a None delegate_ret"
9481011 if delegate_index is None :
9491012 # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
9501013 # present.
@@ -1062,13 +1125,8 @@ def _get_empty_tensor_evalue() -> EValue:
10621125 torch .BoolType ,
10631126 torch .NumberType ,
10641127 ), f"Only symbolic ops that return a Int Bool Float are supported currently got { type (target ._schema .returns [0 ].type )} ."
1065- if type (target ._schema .returns [0 ].type ) == torch .IntType :
1066- ret = self ._emit_evalue (EValue (Int (0 )))
1067- elif type (target ._schema .returns [0 ].type ) == torch .BoolType :
1068- ret = self ._emit_evalue (EValue (Bool (False )))
1069- elif type (target ._schema .returns [0 ].type ) == torch .FloatType :
1070- ret = self ._emit_evalue (EValue (Double (0 )))
1071- else : # type(target._schema.returns[0].type) == torch.NumberType:
1128+ ret = self ._get_sym_ret (target ._schema .returns [0 ])
1129+ if ret is None : # type(target._schema.returns[0].type) == torch.NumberType:
10721130 # Cant definitively say what type this is, the runtime operator just overrides the EValue completely
10731131 # though so we can just serialize whatever as a placeholder.
10741132 ret = self ._emit_evalue (EValue (Int (0 )))
0 commit comments