-
Notifications
You must be signed in to change notification settings - Fork 695
Description
Environment
Ubuntu 20.04
Python 3.8.8
torch 1.8.1
torch2trt 0.2.0
cuda 11.1
cudnn 8.1.1
TensorRT 7.2.2
Issue
I'm trying to convert YOLOR(https://github.com/WongKinYiu/yolor) implemented in PyTorch into TensorRT.
There are two kinds of layers used in YOLOR but not supported by torch2trt now.
torch.nn.functional.silutorch.Tensor.expand_as
silu
Thanks to #527, there is no problem here.
expand_as
Thanks to #487, a converter for torch.Tensor.expand is provided.
Since torch.Tensor.expand(other.size()) equals to torch.Tensor.expand_as(other) (https://pytorch.org/docs/stable/tensors.html),
I replaced all expand_as(other) with expand(other.size()) in yolor/utils/layers.py.
I made a script for conversion:
yolor/torch2trt_conversion.py
import argparse
import os
import sys
import time
import torch
import torch.jit
from torch2trt import torch2trt
from models.models import Darknet
from utils.torch_utils import select_device
def main():
out_path = opt.output
weights = opt.weights
img_size = opt.img_size
cfg = opt.cfg
device = select_device(opt.device)
# str2prec = {'int8': torch.int8, 'fp16': torch.float16, 'fp32': torch.float32}
precision = opt.precision
# load model
print('BEGIN loading weights')
model = Darknet(cfg, img_size).to(device)
model.load_state_dict(torch.load(weights, map_location=device)['model'])
model = model.eval()
print('END loading weights')
print('BEGIN conversion')
input_data = torch.randn((1, 3, img_size, img_size)).to(device)
if precision == 'int8':
model_trt = torch2trt(model, [input_data], int8_mode=True)
elif precision == 'fp16':
model_trt = torch2trt(model, [input_data], fp16_mode=True)
elif precision == 'fp32':
model_trt = torch2trt(model, [input_data])
else:
raise ValueError("Invalid precision")
print('END conversion')
# save
input_data = torch.empty([1, 3, img_size, img_size], dtype=precision).to(device)
result = model_trt(input_data)
torch.save(model_trt.state_dict(), out_path)
print('Successfully saved')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output', type=str, default='trt_darknet.pth', help='path to converted model')
parser.add_argument('--weights', nargs='+', type=str, default='yolor_p6.pt', help='model.pt path(s)')
parser.add_argument('--img-size', type=int, default=1280, help='inference size (pixels)')
parser.add_argument('--cfg', type=str, default='cfg/yolor_p6.cfg', help='*.cfg path')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--precision', type=str, default='fp16',
help='precision of inference for the output model [int8, fp16, fp32]')
opt = parser.parse_args()
print(opt)
main()
In the end, I got
/workspace/yolor# python torch2trt_conversion.py --img-size 640
Namespace(cfg='cfg/yolor_p6.cfg', device='0', img_size=640, output='trt_darknet.pth', precision='fp16', weights='yolor_p6.pt')
BEGIN loading weights
END loading weights
BEGIN conversion
Traceback (most recent call last):
File "torch2trt_conversion.py", line 62, in <module>
main()
File "torch2trt_conversion.py", line 34, in main
model_trt = torch2trt(model, [input_data], fp16_mode=True)
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.2.0-py3.8-linux-x86_64.egg/torch2trt/torch2trt.py", line 542, in torch2trt
outputs = module(*inputs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/workspace/yolor/models/models.py", line 550, in forward
return self.forward_once(x)
File "/workspace/yolor/models/models.py", line 601, in forward_once
x = module(x, out) # WeightedFeatureFusion(), FeatureConcat()
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/workspace/yolor/utils/layers.py", line 384, in forward
return a.expand(x.size()) + x
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.2.0-py3.8-linux-x86_64.egg/torch2trt/torch2trt.py", line 289, in wrapper
converter["converter"](ctx)
File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.2.0-py3.8-linux-x86_64.egg/torch2trt/converters/expand.py", line 17, in convert_expand
layer = ctx.network.add_slice(input._trt, start, shape, stride)
AttributeError: 'Parameter' object has no attribute '_trt'
'Parameter' object means 'torch.nn.parameter.Parameter'. This object has attribute data (torch.Tensor).
So, I replaced input with input.data, but it did not work.
What is the problem?