diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py index 7884c85b..53f01707 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py @@ -699,6 +699,7 @@ def fold_constants( size_threshold=None, should_exclude_node=None, recurse_functions=True, + ort_session_options=None, ): """ Folds constants in-place in the graph. The graph's nodes and functions must be topologically @@ -754,7 +755,8 @@ def fold_constants( recurse_functions (bool): Whether to fold constants in this graph's Functions. Defaults to True. - + ort_session_options (Optional[onnxruntime.SessionOptions]): + SessionOptions object to be used for ONNX Runtime sessions. Returns: self """ @@ -1172,6 +1174,7 @@ def get_out_node_ids(): sess = onnxrt.InferenceSession( export_onnx(part, do_type_check=False).SerializeToString(), + sess_options = ort_session_options, providers=ORT_PROVIDERS, ) values = sess.run(names, {}) @@ -1254,6 +1257,7 @@ def should_eval_foldable(tensor): export_onnx( graph_clone, do_type_check=False ).SerializeToString(), + sess_options = ort_session_options, providers=ORT_PROVIDERS, ) values = sess.run(names, {})