4747from lightning .fabric .utilities .apply_func import convert_to_tensors
4848from lightning .fabric .utilities .cloud_io import get_filesystem
4949from lightning .fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
50+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_5
5051from lightning .fabric .utilities .types import _MAP_LOCATION_TYPE , _PATH
5152from lightning .fabric .wrappers import _FabricOptimizer
5253from lightning .pytorch .callbacks .callback import Callback
6061from lightning .pytorch .trainer .connectors .logger_connector .result import _get_default_dtype
6162from lightning .pytorch .utilities import GradClipAlgorithmType
6263from lightning .pytorch .utilities .exceptions import MisconfigurationException
63- from lightning .pytorch .utilities .imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
64+ from lightning .pytorch .utilities .imports import _TORCH_GREATER_EQUAL_2_6 , _TORCHMETRICS_GREATER_EQUAL_0_9_1
6465from lightning .pytorch .utilities .model_helpers import _restricted_classmethod
6566from lightning .pytorch .utilities .rank_zero import WarningCache , rank_zero_warn
6667from lightning .pytorch .utilities .signature_utils import is_param_in_hook_signature
7273 OptimizerLRScheduler ,
7374)
7475
76+ _ONNX_AVAILABLE = RequirementCache ("onnx" )
77+ _ONNXSCRIPT_AVAILABLE = RequirementCache ("onnxscript" )
78+
7579if TYPE_CHECKING :
7680 from torch .distributed .device_mesh import DeviceMesh
7781
78- _ONNX_AVAILABLE = RequirementCache ("onnx" )
82+ if _TORCH_GREATER_EQUAL_2_5 :
83+ if _TORCH_GREATER_EQUAL_2_6 :
84+ from torch .onnx import ONNXProgram
85+ else :
86+ from torch .onnx ._internal .exporter import ONNXProgram # type: ignore[no-redef]
7987
8088warning_cache = WarningCache ()
8189log = logging .getLogger (__name__ )
@@ -1416,12 +1424,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
14161424 )
14171425
14181426 @torch .no_grad ()
1419- def to_onnx (self , file_path : Union [str , Path , BytesIO ], input_sample : Optional [Any ] = None , ** kwargs : Any ) -> None :
1427+ def to_onnx (
1428+ self ,
1429+ file_path : Union [str , Path , BytesIO , None ] = None ,
1430+ input_sample : Optional [Any ] = None ,
1431+ ** kwargs : Any ,
1432+ ) -> Optional ["ONNXProgram" ]:
14201433 """Saves the model in ONNX format.
14211434
14221435 Args:
1423- file_path: The path of the file the onnx model should be saved to.
1436+ file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
14241437 input_sample: An input for tracing. Default: None (Use self.example_input_array)
1438+
14251439 **kwargs: Will be passed to torch.onnx.export function.
14261440
14271441 Example::
@@ -1442,6 +1456,12 @@ def forward(self, x):
14421456 if not _ONNX_AVAILABLE :
14431457 raise ModuleNotFoundError (f"`{ type (self ).__name__ } .to_onnx()` requires `onnx` to be installed." )
14441458
1459+ if kwargs .get ("dynamo" , False ) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5 ):
1460+ raise ModuleNotFoundError (
1461+ f"`{ type (self ).__name__ } .to_onnx(dynamo=True)` "
1462+ "requires `onnxscript` and `torch>=2.5.0` to be installed."
1463+ )
1464+
14451465 mode = self .training
14461466
14471467 if input_sample is None :
@@ -1458,8 +1478,9 @@ def forward(self, x):
14581478 file_path = str (file_path ) if isinstance (file_path , Path ) else file_path
14591479 # PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
14601480 # BytesIO does work, too.
1461- torch .onnx .export (self , input_sample , file_path , ** kwargs ) # type: ignore
1481+ ret = torch .onnx .export (self , input_sample , file_path , ** kwargs ) # type: ignore
14621482 self .train (mode )
1483+ return ret
14631484
14641485 @torch .no_grad ()
14651486 def to_torchscript (
0 commit comments