Skip to content

Commit 03e5458

Browse files
anijain2305wweic
authored andcommitted
Relaxing convolution infer checks. (apache#3511)
- Weight dtype can be different than idtype. So, using the weight tensor to set the dtype of weight. - For conv2d NCHWc operator, the weight can be of any dimension. For int8 computation on Intel, it can be 7D. Relaxing the weight type checking.
1 parent 4aaf241 commit 03e5458

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/relay/op/nn/convolution.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ bool Conv2DRel(const Array<Type>& types,
8383
channels = param->channels;
8484
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
8585
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
86+
DataType weight_dtype = data->dtype;
87+
if (weight != nullptr) {
88+
weight_dtype = weight->dtype;
89+
}
8690
// assign result to reporter
87-
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
91+
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
8892
} else {
8993
// use weight to infer the conv shape.
9094
if (weight == nullptr) return false;
@@ -701,7 +705,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
701705
.add_argument("data", "Tensor", "The input tensor.")
702706
.add_argument("weight", "Tensor", "The weight tensor.")
703707
.set_support_level(10)
704-
.add_type_rel("Conv2D", Conv2DRel)
708+
.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
705709
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
706710
Conv2DInferCorrectLayout<Conv2DAttrs>);
707711

tests/python/relay/test_op_level2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def test_conv2d_infer_type():
4545
(2, 10, 3, 3), "float32")
4646

4747
# infer by shape of w, mixed precision
48-
4948
n, c, h, w = tvm.var("n"), 10, 224, 224
5049
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
5150
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
@@ -55,6 +54,16 @@ def test_conv2d_infer_type():
5554
assert yy.checked_type == relay.TensorType(
5655
(n, 2, 222, 222), "int32")
5756

57+
# infer shape in case of different dtypes for input and weight.
58+
n, c, h, w = tvm.var("n"), 10, 224, 224
59+
x = relay.var("x", relay.TensorType((n, c, h, w), "uint8"))
60+
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
61+
y = relay.nn.conv2d(x, w, out_dtype="int32")
62+
assert "out_dtype=\"int32\"" in y.astext()
63+
yy = run_infer_type(y)
64+
assert yy.checked_type == relay.TensorType(
65+
(n, 2, 222, 222), "int32")
66+
5867
# Infer with a different layout
5968
n, c, h, w = 4, 32, 224, 224
6069
x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))

0 commit comments

Comments
 (0)