-
Notifications
You must be signed in to change notification settings - Fork 370
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
The code below results in the following error:
RuntimeError: upper bound and larger bound inconsistent with step sign
Note that if one doesn't freeze the script module before trying to compile, a different error occurs:
RuntimeError: Method 'forward' is not defined.
To Reproduce
Steps to reproduce the behavior:
Try running the following code:
import torch
import torch_tensorrt
import torch.nn as nn
class TestEmbedding(nn.Module):
def __init__(self, N_freqs):
super(TestEmbedding, self).__init__()
self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
def forward2(self, x):
out = [x]
for freq in self.freq_bands:
out += [torch.sin(freq * x), torch.cos(freq * x)]
return torch.cat(out, -1)
class TestNeRF(nn.Module):
def __init__(self, D=8, W=256, N_freq_xyz=10, skips=[4], rgb_dim=3):
super(TestNeRF, self).__init__()
self.D = D
self.W = W
self.embedding_xyz = TestEmbedding(N_freq_xyz)
in_channels_xyz = 3 + 3 * N_freq_xyz * 2
self.skips = skips
xyz_encodings = []
# xyz encoding layers
for i in range(D):
if i == 0:
layer = nn.Linear(in_channels_xyz, W)
elif i in skips:
layer = nn.Linear(W + in_channels_xyz, W)
else:
layer = nn.Linear(W, W)
layer = nn.Sequential(layer, nn.ReLU(True))
xyz_encodings.append(layer)
self.xyz_encodings = nn.ModuleList(xyz_encodings)
# output layers
self.sigma = nn.Linear(W, 1)
self.rgb = nn.Linear(W, rgb_dim)
def forward(self, x):
input_xyz = self.embedding_xyz.forward2(x[:, :3])
xyz_ = input_xyz
for i, xyz_encoding in enumerate(self.xyz_encodings):
if i in self.skips:
xyz_ = torch.cat([input_xyz, xyz_], -1)
xyz_ = xyz_encoding(xyz_)
sigma = self.sigma(xyz_)
rgb = self.rgb(xyz_)
return torch.cat([rgb, sigma], -1)
sub_model = torch.jit.trace(TestNeRF(), torch.ones(1, 3))
print('before compile', sub_model(torch.ones(1, 3)))
trt_ts_module = torch_tensorrt.compile(sub_model,
inputs=[torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[1, 3, 512, 512],
max_shape=[1, 3, 1024, 1024],
dtype=torch.float)
],
enabled_precisions={torch.float}
)
print('after compile', trt_ts_module(torch.ones(1, 3)))
Expected behavior
Code compiles
Environment
Build information about the TRTorch compiler can be found by turning on debug messages
- TRTorch Version (e.g. 0.2.0): 1.0.0
- PyTorch Version (e.g. 1.0): 1.10.0
- CPU Architecture:
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): conda - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.8
- CUDA version: 11.1
- GPU models and configuration: 3090
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working