-
Notifications
You must be signed in to change notification settings - Fork 370
Closed
Labels
Description
Bug Description
There are issues where networks may be fully supported but for some reason partitioning starts running.
To Reproduce
import torch
from torchvision import models
import torch_tensorrt as torchtrt
classification_arches = [
models.alexnet,
models.convnext_base,
models.densenet121,
models.efficientnet_b0,
models.efficientnet_v2_s,
models.googlenet,
models.inception_v3,
models.mnasnet0_5,
models.mobilenet_v2,
models.mobilenet_v3_small,
models.regnet_y_400mf,
models.resnet18,
models.resnext50_32x4d,
models.shufflenet_v2_x0_5,
models.squeezenet1_0,
models.swin_t,
models.vgg11_bn,
models.vit_b_16,
models.wide_resnet50_2,
]
failures = []
for arch in []:
model = arch()
model = torch.jit.script(model)
model.eval().cuda()
try:
print(f"Running {arch.__name__}")
with torchtrt.logging.errors():
mod = torchtrt.ts.compile(
model,
inputs=[torchtrt.Input((1, 3, 300, 300))],
truncate_long_and_double=True,
torch_executed_ops=[
"prim::TupleConstruct",
]
)
x = torch.randn((1, 3, 300, 300)).cuda()
mod(x)
except:
failures.append(arch.__name__)
print(f"Classification Failures: {failures}")
segmentation_arches = [
models.segmentation.deeplabv3_mobilenet_v3_large,
models.segmentation.fcn_resnet50,
models.segmentation.lraspp_mobilenet_v3_large
]
failures = []
for arch in segmentation_arches:
model = arch()
model = torch.jit.script(model)
model.eval().cuda()
#try:
print(f"Running {arch.__name__}")
with torchtrt.logging.graphs():
mod = torchtrt.ts.compile(
model,
inputs=[torchtrt.Input((1, 3, 300, 300))],
truncate_long_and_double=True,
min_block_size=1,
torch_executed_ops=[
"aten::_set_item",
]
)
x = torch.randn((1, 3, 300, 300)).cuda()
mod(x)
#except:
#failures.append(arch.__name__)
print(f"Segmentation Failures: {failures}")
Expected behavior
These models should run end to end.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.2.0a0+master
- PyTorch Version (e.g. 1.0): 1.12
- CPU Architecture: x86
- OS (e.g., Linux): Linux
- How you installed
PyTorch
(conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source): python3 setup.py develop
- Are you using local sources or building from archives: archives
- Python version: 3.9
- CUDA version: 11.3
- GPU models and configuration: TITAN V
- Any other relevant information: