From 677fd382437a2c41757ec0515fee9e633050b352 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 25 Jun 2025 20:56:00 +0000 Subject: [PATCH 1/4] fix: Fix unbacked sym int not found issue --- py/torch_tensorrt/dynamo/_tracer.py | 4 ++++ .../dynamo/runtime/meta_ops/register_meta_ops.py | 4 ++-- py/torch_tensorrt/dynamo/utils.py | 13 ++++++++----- tests/py/dynamo/models/test_reexport.py | 14 ++++++++++---- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 5f4bdd0a8d..b25af0b5c1 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -70,6 +70,9 @@ def trace( if kwarg_inputs is None: kwarg_inputs = {} + # After validation above, arg_inputs is guaranteed to be non-None + assert arg_inputs is not None, "arg_inputs should not be None after validation" + device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) @@ -81,6 +84,7 @@ def trace( tuple(torch_arg_inputs), kwargs=torch_kwarg_inputs, dynamic_shapes=dynamic_shapes, + strict=kwargs.get("strict", False), ) return exp_program diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index 500a665688..b0d1c69916 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -60,7 +60,7 @@ def fake_tensorrt_execute_engine( output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val) # Update var to val (hint) output_sym_int_shape_env = output_sym_int.node.shape_env - output_sym_int_shape_env.add_var_to_val( + output_sym_int_shape_env.set_unbacked_var_to_val( output_sym_int.node.expr, opt_val ) output_shape.append(output_sym_int) @@ -152,7 +152,7 @@ def __getstate__(self) -> Any: pass -@torch.library.custom_op( +@torch.library.custom_op( # type: ignore "tensorrt::no_op_placeholder_for_execute_engine", mutates_args=() ) def no_op_placeholder_for_execute_engine( diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index e0b3af7e0b..0703fd1cb9 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -375,8 +375,10 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]: # https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy # expr.xreplace replaces the symbolic variables with their current values and computes the expression. var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(expr) - var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace( - shape_env.var_to_val + var_val = ( + shape_env.var_to_val.get(expr, None) + or shape_env.unbacked_var_to_val.get(expr, None) + or expr.xreplace(shape_env.var_to_val) ) assert var_range, var_val min_val, max_val = int(var_range.lower), int(var_range.upper) @@ -385,8 +387,9 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]: min_max_opt = {} min_max_opt["min"] = min_val min_max_opt["max"] = max_val - if isinstance(var_val, sympy.core.numbers.Integer): + if isinstance(var_val, (sympy.core.numbers.Integer, int)): min_max_opt["opt"] = int(var_val) + return min_max_opt @@ -447,9 +450,9 @@ def get_graph_io_attrs( metadata = node.meta["val"] if isinstance(metadata, (tuple, list)): for tensor in metadata: - graph_io_attrs.append(attr_fn(tensor)) # type: ignore + graph_io_attrs.append(attr_fn(tensor)) else: - graph_io_attrs.append(attr_fn(metadata)) # type: ignore + graph_io_attrs.append(attr_fn(metadata)) return graph_io_attrs diff --git a/tests/py/dynamo/models/test_reexport.py b/tests/py/dynamo/models/test_reexport.py index c4c7fb6787..755721605e 100644 --- a/tests/py/dynamo/models/test_reexport.py +++ b/tests/py/dynamo/models/test_reexport.py @@ -458,7 +458,7 @@ def test_resnet18_dynamic(ir): dyn_batch = torch.export.Dim("batch", min=1, max=8) exp_program = torch.export.export( - model, (input_bs2,), dynamic_shapes=({0: dyn_batch},) + model, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) @@ -532,8 +532,9 @@ def test_resnet18_dynamic_fallback(ir): } dyn_batch = torch.export.Dim("batch", min=1, max=8) + exp_program = torch.export.export( - model, (input_bs2,), dynamic_shapes=({0: dyn_batch},) + model, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) @@ -610,6 +611,7 @@ def forward(self, lhs_val, rhs_val): model, inputs_4, dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}}, + strict=False, ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) @@ -699,13 +701,17 @@ def forward(self, x): dyn_dim = torch.export.Dim("batch", min=1, max=64) exp_program = torch.export.export( - model, torch_inputs_bs50, dynamic_shapes=({0: dyn_dim},) + model, torch_inputs_bs50, dynamic_shapes=({0: dyn_dim},), strict=False ) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) # Reexport with dynamic dimensions trt_exp_program = torch.export.export( - trt_module, torch_inputs_bs50, strict=False, dynamic_shapes=({0: dyn_dim},) + trt_module, + torch_inputs_bs50, + strict=False, + dynamic_shapes=({0: dyn_dim},), + strict=False, ) torch.export.save(trt_exp_program, trt_ep_path) From 7a0eb638896e91515b6b241372407bbe8fd1a43d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 25 Jun 2025 21:29:45 +0000 Subject: [PATCH 2/4] chore: random --- py/torch_tensorrt/dynamo/_tracer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index b25af0b5c1..322fe35d6f 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -70,8 +70,8 @@ def trace( if kwarg_inputs is None: kwarg_inputs = {} - # After validation above, arg_inputs is guaranteed to be non-None - assert arg_inputs is not None, "arg_inputs should not be None after validation" + # This assertion is a workaround to fix the mypy type error. The validation above is enough to guarantee arg_inputs is not None. + # assert arg_inputs is not None, "arg_inputs should not be None after validation" device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) From 93518d21e7281466d3a0f049780afffae2267831 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 25 Jun 2025 21:30:19 +0000 Subject: [PATCH 3/4] chore: mypy related fixes --- py/torch_tensorrt/dynamo/_tracer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 322fe35d6f..888dddbc3c 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -70,9 +70,6 @@ def trace( if kwarg_inputs is None: kwarg_inputs = {} - # This assertion is a workaround to fix the mypy type error. The validation above is enough to guarantee arg_inputs is not None. - # assert arg_inputs is not None, "arg_inputs should not be None after validation" - device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) From 426ee36df96a13318b5e89a20184c6a7f5a3d3b5 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 25 Jun 2025 21:32:57 +0000 Subject: [PATCH 4/4] chore: remove duplicate strict=False --- tests/py/dynamo/models/test_reexport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/py/dynamo/models/test_reexport.py b/tests/py/dynamo/models/test_reexport.py index 755721605e..d86217bd51 100644 --- a/tests/py/dynamo/models/test_reexport.py +++ b/tests/py/dynamo/models/test_reexport.py @@ -711,7 +711,6 @@ def forward(self, x): torch_inputs_bs50, strict=False, dynamic_shapes=({0: dyn_dim},), - strict=False, ) torch.export.save(trt_exp_program, trt_ep_path)