|
10 | 10 |
|
11 | 11 | import executorch.exir as exir |
12 | 12 | import torch |
13 | | -import torch.fx as fx |
14 | | -from executorch.exir import multi_method_program_to_executorch |
15 | | -from executorch.exir.backend.backend_api import ( |
16 | | - LoweredBackendModule, |
17 | | - to_backend, |
18 | | - to_backend_multiple, |
19 | | -) |
| 13 | +from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend |
20 | 14 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
21 | 15 | from executorch.exir.backend.partitioner import ( |
22 | 16 | DelegationSpec, |
@@ -1235,137 +1229,3 @@ def forward(self, x: List[torch.Tensor]): |
1235 | 1229 |
|
1236 | 1230 | gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() |
1237 | 1231 | gm(*inputs) |
1238 | | - |
1239 | | - def test_lower_multiple(self) -> None: |
1240 | | - class MultipleMethodModule(torch.nn.Module): |
1241 | | - def __init__(self) -> None: |
1242 | | - super().__init__() |
1243 | | - |
1244 | | - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
1245 | | - return x + y * y |
1246 | | - |
1247 | | - def method1(self, x: torch.Tensor) -> torch.Tensor: |
1248 | | - return x + x - x |
1249 | | - |
1250 | | - def method2( |
1251 | | - self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor |
1252 | | - ) -> torch.Tensor: |
1253 | | - return x + y - z |
1254 | | - |
1255 | | - module = MultipleMethodModule() |
1256 | | - method_name_to_args = { |
1257 | | - "forward": (torch.rand(2, 2), torch.rand(2, 2)), |
1258 | | - "method1": (torch.rand(2, 2),), |
1259 | | - "method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)), |
1260 | | - } |
1261 | | - |
1262 | | - multi_method_prog = exir.capture_multiple( |
1263 | | - module, method_name_to_args, exir.CaptureConfig() |
1264 | | - ).to_edge() |
1265 | | - |
1266 | | - lowered_multi_method_prog = to_backend_multiple( |
1267 | | - multi_method_prog, AddMulPartitionerDemo |
1268 | | - ) |
1269 | | - |
1270 | | - for method_name, args in method_name_to_args.items(): |
1271 | | - exported_prog = lowered_multi_method_prog.find_method(method_name) |
1272 | | - self.assertIsNotNone(exported_prog) |
1273 | | - exported_gm = exported_prog.exported_program.graph_module |
1274 | | - self.assertIsInstance(exported_gm, fx.GraphModule) |
1275 | | - |
1276 | | - eager_method = getattr(module, method_name) |
1277 | | - eager_results = eager_method(*args) |
1278 | | - exported_results = exported_gm(*args) |
1279 | | - self.assertTrue(torch.allclose(eager_results, exported_results[0])) |
1280 | | - |
1281 | | - add_nodes = [ |
1282 | | - node |
1283 | | - for node in exported_gm.graph.nodes |
1284 | | - if node.op == "call_function" |
1285 | | - and node.target == exir_ops.edge.aten.add.Tensor |
1286 | | - ] |
1287 | | - self.assertEqual(len(add_nodes), 0) |
1288 | | - |
1289 | | - lowered_submods = get_lowered_submodules(exported_gm) |
1290 | | - self.assertEqual(len(lowered_submods), 1) |
1291 | | - |
1292 | | - _ = multi_method_program_to_executorch(lowered_multi_method_prog) |
1293 | | - |
1294 | | - def test_lower_multiple_selective(self) -> None: |
1295 | | - class MultipleMethodModule(torch.nn.Module): |
1296 | | - def __init__(self) -> None: |
1297 | | - super().__init__() |
1298 | | - |
1299 | | - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
1300 | | - return x + y * y |
1301 | | - |
1302 | | - def method1(self, x: torch.Tensor) -> torch.Tensor: |
1303 | | - return x + x - x |
1304 | | - |
1305 | | - def method2( |
1306 | | - self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor |
1307 | | - ) -> torch.Tensor: |
1308 | | - return x + y - z |
1309 | | - |
1310 | | - module = MultipleMethodModule() |
1311 | | - method_name_to_args = { |
1312 | | - "forward": (torch.rand(2, 2), torch.rand(2, 2)), |
1313 | | - "method1": (torch.rand(2, 2),), |
1314 | | - "method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)), |
1315 | | - } |
1316 | | - |
1317 | | - multi_method_prog = exir.capture_multiple( |
1318 | | - module, method_name_to_args, exir.CaptureConfig() |
1319 | | - ).to_edge() |
1320 | | - |
1321 | | - method_name_to_partitioners = { |
1322 | | - "forward": AddMulPartitionerDemo, |
1323 | | - "method1": AddMulPartitionerDemo, |
1324 | | - } |
1325 | | - lowered_multi_method_prog = to_backend_multiple( |
1326 | | - multi_method_prog, method_name_to_partitioners |
1327 | | - ) |
1328 | | - |
1329 | | - for method_name, args in method_name_to_args.items(): |
1330 | | - if method_name == "method2": |
1331 | | - break |
1332 | | - |
1333 | | - exported_prog = lowered_multi_method_prog.find_method(method_name) |
1334 | | - self.assertIsNotNone(exported_prog) |
1335 | | - exported_gm = exported_prog.exported_program.graph_module |
1336 | | - self.assertIsInstance(exported_gm, fx.GraphModule) |
1337 | | - |
1338 | | - eager_method = getattr(module, method_name) |
1339 | | - eager_results = eager_method(*args) |
1340 | | - exported_results = exported_gm(*args) |
1341 | | - self.assertTrue(torch.allclose(eager_results, exported_results[0])) |
1342 | | - |
1343 | | - add_nodes = [ |
1344 | | - node |
1345 | | - for node in exported_gm.graph.nodes |
1346 | | - if node.op == "call_function" |
1347 | | - and node.target == exir_ops.edge.aten.add.Tensor |
1348 | | - ] |
1349 | | - self.assertEqual(len(add_nodes), 0) |
1350 | | - |
1351 | | - lowered_submods = get_lowered_submodules(exported_gm) |
1352 | | - self.assertEqual(len(lowered_submods), 1) |
1353 | | - |
1354 | | - # Check that method2 had nothing lowered |
1355 | | - method2_prog = lowered_multi_method_prog.find_method("method2") |
1356 | | - self.assertIsNotNone(method2_prog) |
1357 | | - method2_gm = method2_prog.exported_program.graph_module |
1358 | | - self.assertIsInstance(method2_gm, fx.GraphModule) |
1359 | | - add_nodes = [ |
1360 | | - node |
1361 | | - for node in method2_gm.graph.nodes |
1362 | | - if node.op == "call_function" |
1363 | | - and node.target == exir_ops.edge.aten.add.Tensor |
1364 | | - ] |
1365 | | - self.assertEqual(len(add_nodes), 1) |
1366 | | - |
1367 | | - lowered_submods = get_lowered_submodules(method2_gm) |
1368 | | - self.assertEqual(len(lowered_submods), 0) |
1369 | | - |
1370 | | - # Check we can export to executorch properly |
1371 | | - _ = multi_method_program_to_executorch(lowered_multi_method_prog) |
0 commit comments