-
Notifications
You must be signed in to change notification settings - Fork 369
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
To Reproduce
The code comes from the official documentation:
https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html#custom-dynamic-shape-constraints
import torch
import torch_tensorrt
class MatMul(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, query, key):
attn_weight = torch.matmul(query, key.transpose(-1, -2))
return attn_weight
model = MatMul().eval().cuda()
inputs = [torch.randn(1, 12, 7, 64).cuda(), torch.randn(1, 12, 7, 64).cuda()]
seq_len = torch.export.Dim("seq_len", min=1, max=10)
dynamic_shapes=({2: seq_len}, {2: seq_len})
# Export the model first with custom dynamic shape constraints
exp_program = torch.export.export(model, tuple(inputs), dynamic_shapes=dynamic_shapes)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs])
# Run inference
trt_gm(inputs)
Expected behavior
run successfully
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 1 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 7909 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 81 timing cache entries
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.846883
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 123972 bytes of Memory
WARNING: [Torch-TensorRT] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
Traceback (most recent call last):
File "/mnt/bn/hukongtao-infer-speed/mlx/users/kongtao.hu/codebase/EasyGuard_0617/speed_test.py", line 18, in <module>
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs])
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 227, in compile
trt_gm = compile_module(gm, inputs, settings)
File "/usr/local/lib/python3.9/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 421, in compile_module
sample_outputs = gm(
File "/usr/local/lib/python3.9/dist-packages/torch/fx/graph_module.py", line 737, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/fx/graph_module.py", line 317, in __call__
raise e
File "/usr/local/lib/python3.9/dist-packages/torch/fx/graph_module.py", line 304, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/export/_unlift.py", line 25, in _check_input_constraints_pre_hook
raise ValueError( # noqa: TRY200
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [*,
*]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(list, None, [*,
*])]),
TreeSpec(dict, [], [])])
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working