Skip to content

AttributeError: 'Parameter' object has no attribute '_trt' #565

@maronuu

Description

@maronuu

Environment

Ubuntu 20.04
Python 3.8.8
torch 1.8.1
torch2trt 0.2.0
cuda 11.1
cudnn 8.1.1
TensorRT 7.2.2

Issue

I'm trying to convert YOLOR(https://github.com/WongKinYiu/yolor) implemented in PyTorch into TensorRT.

There are two kinds of layers used in YOLOR but not supported by torch2trt now.

  • torch.nn.functional.silu
  • torch.Tensor.expand_as

silu

Thanks to #527, there is no problem here.

expand_as

Thanks to #487, a converter for torch.Tensor.expand is provided.
Since torch.Tensor.expand(other.size()) equals to torch.Tensor.expand_as(other) (https://pytorch.org/docs/stable/tensors.html),
I replaced all expand_as(other) with expand(other.size()) in yolor/utils/layers.py.

I made a script for conversion:

yolor/torch2trt_conversion.py

import argparse
import os
import sys
import time

import torch
import torch.jit
from torch2trt import torch2trt

from models.models import Darknet
from utils.torch_utils import select_device


def main():
    out_path = opt.output
    weights = opt.weights
    img_size = opt.img_size
    cfg = opt.cfg
    device = select_device(opt.device)
    # str2prec = {'int8': torch.int8, 'fp16': torch.float16, 'fp32': torch.float32}
    precision = opt.precision

    # load model
    print('BEGIN loading weights')
    model = Darknet(cfg, img_size).to(device)
    model.load_state_dict(torch.load(weights, map_location=device)['model'])
    model = model.eval()
    print('END loading weights')
    print('BEGIN conversion')
    input_data = torch.randn((1, 3, img_size, img_size)).to(device)
    if precision == 'int8':
        model_trt = torch2trt(model, [input_data], int8_mode=True)
    elif precision == 'fp16':
        model_trt = torch2trt(model, [input_data], fp16_mode=True)
    elif precision == 'fp32':
        model_trt = torch2trt(model, [input_data])
    else:
        raise ValueError("Invalid precision")

    print('END conversion')
    

    # save
    input_data = torch.empty([1, 3, img_size, img_size], dtype=precision).to(device)
    result = model_trt(input_data)
    torch.save(model_trt.state_dict(), out_path)
    print('Successfully saved')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output', type=str, default='trt_darknet.pth', help='path to converted model')
    parser.add_argument('--weights', nargs='+', type=str, default='yolor_p6.pt', help='model.pt path(s)')
    parser.add_argument('--img-size', type=int, default=1280, help='inference size (pixels)')
    parser.add_argument('--cfg', type=str, default='cfg/yolor_p6.cfg', help='*.cfg path')
    parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--precision', type=str, default='fp16', 
        help='precision of inference for the output model [int8, fp16, fp32]')
    opt = parser.parse_args()
    print(opt)

    main()

In the end, I got

/workspace/yolor# python torch2trt_conversion.py --img-size 640
Namespace(cfg='cfg/yolor_p6.cfg', device='0', img_size=640, output='trt_darknet.pth', precision='fp16', weights='yolor_p6.pt')
BEGIN loading weights
END loading weights
BEGIN conversion
Traceback (most recent call last):
  File "torch2trt_conversion.py", line 62, in <module>
    main()
  File "torch2trt_conversion.py", line 34, in main
    model_trt = torch2trt(model, [input_data], fp16_mode=True)
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.2.0-py3.8-linux-x86_64.egg/torch2trt/torch2trt.py", line 542, in torch2trt
    outputs = module(*inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/workspace/yolor/models/models.py", line 550, in forward
    return self.forward_once(x)
  File "/workspace/yolor/models/models.py", line 601, in forward_once
    x = module(x, out)  # WeightedFeatureFusion(), FeatureConcat()
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/workspace/yolor/utils/layers.py", line 384, in forward
    return a.expand(x.size()) + x
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.2.0-py3.8-linux-x86_64.egg/torch2trt/torch2trt.py", line 289, in wrapper
    converter["converter"](ctx)
  File "/opt/conda/lib/python3.8/site-packages/torch2trt-0.2.0-py3.8-linux-x86_64.egg/torch2trt/converters/expand.py", line 17, in convert_expand
    layer = ctx.network.add_slice(input._trt, start, shape, stride)
AttributeError: 'Parameter' object has no attribute '_trt'

'Parameter' object means 'torch.nn.parameter.Parameter'. This object has attribute data (torch.Tensor).
So, I replaced input with input.data, but it did not work.

What is the problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions