@@ -1065,6 +1065,9 @@ def forward(self, k: torch.Tensor) -> torch.Tensor:
10651065 self .check_tensor_buffer_loc (1 , execution_plan .values , 0 , 1 , 48 )
10661066
10671067 def test_emit_prims (self ) -> None :
1068+ tensor_output = torch .rand (1 , 4 )
1069+ tensor_list_output = [torch .rand (1 , 4 ), torch .rand (1 , 4 )]
1070+
10681071 class Simple (torch .nn .Module ):
10691072 def __init__ (self ) -> None :
10701073 super ().__init__ ()
@@ -1078,6 +1081,12 @@ def get_ints(self) -> Tuple[int]:
10781081 def get_str (self ) -> str :
10791082 return "foo"
10801083
1084+ def get_tensor (self ) -> torch .Tensor :
1085+ return tensor_output
1086+
1087+ def get_tensor_list (self ) -> List [torch .Tensor ]:
1088+ return tensor_list_output
1089+
10811090 def forward (self , x : torch .Tensor ) -> torch .Tensor :
10821091 return torch .nn .functional .sigmoid (self .linear (x ))
10831092
@@ -1090,9 +1099,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10901099 getters = {}
10911100 getters ["get_ints" ] = model .get_ints ()
10921101 getters ["get_str" ] = model .get_str ()
1093- print (getters ["get_str" ])
1102+ getters ["get_tensor" ] = model .get_tensor ()
1103+ getters ["get_tensor_list" ] = model .get_tensor_list ()
1104+
10941105 merged_program = emit_program (exir_input , False , getters ).program
1095- self .assertEqual (len (merged_program .execution_plan ), 3 )
1106+
1107+ self .assertEqual (len (merged_program .execution_plan ), 5 )
10961108
10971109 self .assertEqual (
10981110 merged_program .execution_plan [0 ].name ,
@@ -1106,6 +1118,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11061118 merged_program .execution_plan [2 ].name ,
11071119 "get_str" ,
11081120 )
1121+ self .assertEqual (
1122+ merged_program .execution_plan [3 ].name ,
1123+ "get_tensor" ,
1124+ )
1125+ self .assertEqual (
1126+ merged_program .execution_plan [4 ].name ,
1127+ "get_tensor_list" ,
1128+ )
1129+
11091130 # no instructions in a getter
11101131 self .assertEqual (
11111132 len (merged_program .execution_plan [1 ].chains [0 ].instructions ),
@@ -1141,6 +1162,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11411162 merged_program .execution_plan [2 ].values [0 ].val .string_val ,
11421163 "foo" ,
11431164 )
1165+ self .assertEqual (len (merged_program .execution_plan [3 ].outputs ), 1 )
1166+ self .assertEqual (len (merged_program .execution_plan [4 ].outputs ), 2 )
1167+
1168+ merged_program = to_edge (
1169+ export (model , inputs ), constant_methods = getters
1170+ ).to_executorch ()
1171+ executorch_module = _load_for_executorch_from_buffer (merged_program .buffer )
1172+ torch .allclose (executorch_module .run_method ("get_tensor" , [])[0 ], tensor_output )
1173+ model_output = executorch_module .run_method ("get_tensor_list" , [])
1174+ for i in range (len (tensor_list_output )):
1175+ torch .allclose (model_output [i ], tensor_list_output [i ])
11441176
11451177 def test_emit_debug_handle_map (self ) -> None :
11461178 mul_model = Mul ()
0 commit comments