Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def trace(
tuple(torch_arg_inputs),
kwargs=torch_kwarg_inputs,
dynamic_shapes=dynamic_shapes,
strict=kwargs.get("strict", False),
)

return exp_program
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def fake_tensorrt_execute_engine(
output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val)
# Update var to val (hint)
output_sym_int_shape_env = output_sym_int.node.shape_env
output_sym_int_shape_env.add_var_to_val(
output_sym_int_shape_env.set_unbacked_var_to_val(
output_sym_int.node.expr, opt_val
)
output_shape.append(output_sym_int)
Expand Down Expand Up @@ -152,7 +152,7 @@ def __getstate__(self) -> Any:
pass


@torch.library.custom_op(
@torch.library.custom_op( # type: ignore
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
)
def no_op_placeholder_for_execute_engine(
Expand Down
13 changes: 8 additions & 5 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,10 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
# https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy
# expr.xreplace replaces the symbolic variables with their current values and computes the expression.
var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(expr)
var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace(
shape_env.var_to_val
var_val = (
shape_env.var_to_val.get(expr, None)
or shape_env.unbacked_var_to_val.get(expr, None)
or expr.xreplace(shape_env.var_to_val)
)
assert var_range, var_val
min_val, max_val = int(var_range.lower), int(var_range.upper)
Expand All @@ -385,8 +387,9 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
min_max_opt = {}
min_max_opt["min"] = min_val
min_max_opt["max"] = max_val
if isinstance(var_val, sympy.core.numbers.Integer):
if isinstance(var_val, (sympy.core.numbers.Integer, int)):
min_max_opt["opt"] = int(var_val)

return min_max_opt


Expand Down Expand Up @@ -447,9 +450,9 @@ def get_graph_io_attrs(
metadata = node.meta["val"]
if isinstance(metadata, (tuple, list)):
for tensor in metadata:
graph_io_attrs.append(attr_fn(tensor)) # type: ignore
graph_io_attrs.append(attr_fn(tensor))
else:
graph_io_attrs.append(attr_fn(metadata)) # type: ignore
graph_io_attrs.append(attr_fn(metadata))

return graph_io_attrs

Expand Down
13 changes: 9 additions & 4 deletions tests/py/dynamo/models/test_reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_resnet18_dynamic(ir):

dyn_batch = torch.export.Dim("batch", min=1, max=8)
exp_program = torch.export.export(
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},)
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False
)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

Expand Down Expand Up @@ -532,8 +532,9 @@ def test_resnet18_dynamic_fallback(ir):
}

dyn_batch = torch.export.Dim("batch", min=1, max=8)

exp_program = torch.export.export(
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},)
model, (input_bs2,), dynamic_shapes=({0: dyn_batch},), strict=False
)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

Expand Down Expand Up @@ -610,6 +611,7 @@ def forward(self, lhs_val, rhs_val):
model,
inputs_4,
dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}},
strict=False,
)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

Expand Down Expand Up @@ -699,13 +701,16 @@ def forward(self, x):

dyn_dim = torch.export.Dim("batch", min=1, max=64)
exp_program = torch.export.export(
model, torch_inputs_bs50, dynamic_shapes=({0: dyn_dim},)
model, torch_inputs_bs50, dynamic_shapes=({0: dyn_dim},), strict=False
)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

# Reexport with dynamic dimensions
trt_exp_program = torch.export.export(
trt_module, torch_inputs_bs50, strict=False, dynamic_shapes=({0: dyn_dim},)
trt_module,
torch_inputs_bs50,
strict=False,
dynamic_shapes=({0: dyn_dim},),
)
torch.export.save(trt_exp_program, trt_ep_path)

Expand Down
Loading