Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import tensorrt as trt
from copy import copy
import copy
import numpy as np
import io
from collections import defaultdict
import importlib

from .calibration import (
TensorBatchDataset,
Expand Down Expand Up @@ -297,30 +298,24 @@ def wrapper(*args, **kwargs):
class ConversionHook(object):
"""Attaches TensorRT converter to PyTorch method call"""

def __init__(self, ctx, method, converter):
def __init__(self, ctx, key, converter):
self.ctx = ctx
self.method_str = method
self.key = key
self.converter = converter

def _set_method(self, method):
exec("%s = method" % self.method_str)
module = self.converter['module']
exec('module.%s = method' % self.converter['qual_name'])

def __enter__(self):
try:
self.method_impl = eval(self.method_str)
except AttributeError:
self.method_impl = None

if self.method_impl:
self._set_method(
attach_converter(
self.ctx, self.method_impl, self.converter, self.method_str
)
self._set_method(
attach_converter(
self.ctx, self.converter['method_impl'], self.converter, self.converter['method_str']
)
)

def __exit__(self, type, val, tb):
if self.method_impl:
self._set_method(self.method_impl)
self._set_method(self.converter['method_impl'])

def default_input_names(num_inputs):
return ["input_%d" % i for i in range(num_inputs)]
Expand Down Expand Up @@ -369,8 +364,8 @@ def __init__(self, network, converters=CONVERTERS):
self.method_kwargs = None
self.method_return = None
self.hooks = [
ConversionHook(self, method, converter)
for method, converter in converters.items()
ConversionHook(self, key, converter)
for key, converter in converters.items()
]

def __enter__(self):
Expand Down Expand Up @@ -569,11 +564,40 @@ def torch2trt(module,

# DEFINE ALL CONVERSION FUNCTIONS

def get_module_qualname(name):
s = name.split('.')

for i in range(len(s)):
idx = len(s) - i - 1
modulename, qualname = ".".join(s[:idx]), ".".join(s[idx:])
try:
module = importlib.import_module(modulename)
return module, modulename, qualname
except:
pass

raise RuntimeError("Could not import module")


def tensorrt_converter(method, is_real=True, enabled=True):

def tensorrt_converter(method, is_real=True, enabled=True, imports=[]):

if isinstance(method, str):
module, module_name, qual_name = get_module_qualname(method)
else:
module, module_name, qual_name = importlib.import_module(method.__module__), method.__module__, method.__qualname__

method_impl = eval('copy.deepcopy(module.%s)' % qual_name)

def register_converter(converter):
CONVERTERS[method] = {"converter": converter, "is_real": is_real}
CONVERTERS[method] = {
"converter": converter,
"is_real": is_real,
"module": module,
"module_name": module_name,
"qual_name": qual_name,
"method_str": module_name + '.' + qual_name,
"method_impl": method_impl
}
return converter

def pass_converter(converter):
Expand All @@ -584,4 +608,4 @@ def pass_converter(converter):
else:
return pass_converter

return register_converter
return register_converter