Skip to content

Commit 8b725eb

Browse files
committed
add dilations field to onnx importer
blacking files
1 parent 7f85111 commit 8b725eb

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""ONNX: Open Neural Network Exchange frontend for Relay."""
2020
import copy
2121
import warnings
22+
2223
import numpy as np
2324
import tvm
2425
from tvm.ir import IRModule
@@ -28,16 +29,14 @@
2829
from .. import analysis
2930
from .. import expr as _expr
3031
from .. import function as _function
32+
from .. import loops as _loops
3133
from .. import op as _op
3234
from .. import qnn as _qnn
33-
from .. import vision as _vision
34-
from .. import loops as _loops
3535
from .. import ty as _ty
36-
37-
from .common import AttrCvt, Renamer
38-
from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value, fold_constant
39-
from .common import infer_type, get_name
40-
36+
from .. import vision as _vision
37+
from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op,
38+
infer_channels, infer_shape, infer_type, infer_value,
39+
new_var)
4140

4241
__all__ = ["from_onnx"]
4342

@@ -312,8 +311,8 @@ def _impl_v1(cls, inputs, attr, params):
312311

313312
return AttrCvt(
314313
op_name=dimension_picker(cls.name),
315-
transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)},
316-
ignores=["dilations", "storage_order"],
314+
transforms={"kernel_shape": "pool_size", "pads": ("padding", 0), "dilations": ("dilation", 1)},
315+
ignores=["storage_order"],
317316
custom_check=dimension_constraint(),
318317
)([data], attr, params)
319318

tests/python/frontend/onnx/test_forward.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import numpy as np
18-
import onnx
19-
from onnx import helper, TensorProto, mapping, numpy_helper
18+
import pytest
19+
import scipy
2020
import torch
2121
import torchvision
22-
import pytest
23-
import tvm.topi.testing
2422
import tvm
23+
import tvm.testing
24+
import tvm.topi.testing
2525
from tvm import relay
2626
from tvm.contrib import graph_executor
27-
import scipy
28-
import tvm.testing
27+
28+
import onnx
29+
from onnx import TensorProto, helper, mapping, numpy_helper
2930

3031

3132
def get_input_data_shape_dict(graph_def, input_data):
@@ -2696,7 +2697,7 @@ def repeat(N, D):
26962697

26972698
@tvm.testing.uses_gpu
26982699
def test_unsqueeze_constant():
2699-
from torch.nn import Linear, Sequential, Module
2700+
from torch.nn import Linear, Module, Sequential
27002701

27012702
class Flatten(Module):
27022703
def forward(self, input):
@@ -4212,7 +4213,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
42124213
"test_isinf_negative/",
42134214
"test_isinf_positive/",
42144215
"test_matmulinteger/",
4215-
"test_maxpool_2d_dilations/",
42164216
"test_maxpool_2d_same_lower/",
42174217
"test_maxpool_2d_same_upper/",
42184218
"test_maxpool_with_argmax_2d_precomputed_pads/",

0 commit comments

Comments
 (0)