Skip to content

Commit 38dfea7

Browse files
mbaretLucien0
authored andcommitted
[CUDNN] Add cuDNN as a Relay partitioning target (BYOC) (apache#10871)
* [CUDNN] Add cuDNN as a Relay partitioning target (BYOC) This adds infrastructure to support offloading of Relay patterns to cuDNN. In this initial commit, only softmax is supported. * Refactor common TE BYOC code into separate file * Add test guard
1 parent 998abe0 commit 38dfea7

File tree

5 files changed

+237
-48
lines changed

5 files changed

+237
-48
lines changed

python/tvm/relay/op/contrib/cublas.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@
2626
from tvm.contrib import cublas
2727

2828
from ...dataflow_pattern import is_op, wildcard
29+
from .te_target import lower_composite, relay_to_runtime
2930
from .register import register_pattern_table
3031

3132

33+
tvm._ffi.register_func("relay.ext.cublas", relay_to_runtime(tvm.target.cuda()))
34+
35+
3236
def partition_for_cublas(
3337
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
3438
) -> tvm.IRModule:
@@ -111,51 +115,7 @@ def check_matmul_like(matched: relay.Call) -> bool:
111115
]
112116

113117

114-
_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
115-
_LOWER_MAP: Dict[str, _LowerFunc] = {}
116-
117-
118-
def _lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
119-
"""Register a lowering function for a given composite function name."""
120-
121-
def _register(f: _LowerFunc) -> _LowerFunc:
122-
_LOWER_MAP[comp_name] = f
123-
return f
124-
125-
return _register
126-
127-
128-
@tvm._ffi.register_func("relay.ext.cublas")
129-
def relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
130-
"""Compile cuBLAS Relay functions to a runtime module."""
131-
assert isinstance(partition, relay.Function)
132-
assert isinstance(partition.body, relay.Call)
133-
assert isinstance(partition.body.op, relay.Function)
134-
135-
global_name = str(partition.attrs.global_symbol)
136-
target = tvm.target.cuda()
137-
comp_func = partition.body.op
138-
comp_name = comp_func.attrs["Composite"]
139-
assert comp_name in _LOWER_MAP
140-
assert isinstance(comp_func.body, relay.Call)
141-
142-
op = comp_func.body
143-
inputs = []
144-
for i, param in enumerate(comp_func.params):
145-
inputs.append(
146-
te.placeholder(
147-
param.checked_type.shape,
148-
name=f"input_{i}",
149-
dtype=param.checked_type.dtype,
150-
)
151-
)
152-
153-
output = _LOWER_MAP[comp_name](op, inputs)
154-
prim_func = te.create_prim_func(inputs + [output])
155-
return tvm.build(prim_func, target=target, name=global_name)
156-
157-
158-
@_lower_composite("cublas.matmul")
118+
@lower_composite("cublas.matmul")
159119
def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
160120
"""Lower a matmul using cuBLAS."""
161121
return cublas.matmul(
@@ -167,7 +127,7 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
167127
)
168128

169129

170-
@_lower_composite("cublas.batch_matmul")
130+
@lower_composite("cublas.batch_matmul")
171131
def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
172132
"""Lower a batch_matmul using cuBLAS."""
173133
return cublas.batch_matmul(
@@ -179,7 +139,7 @@ def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
179139
)
180140

181141

182-
@_lower_composite("cublas.dense")
142+
@lower_composite("cublas.dense")
183143
def _lower_dense(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
184144
"""Lower a dense using cuBLAS."""
185145
return cublas.matmul(
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-argument
18+
"""cuDNN Relay integration."""
19+
from typing import Callable, List, Tuple, Dict, Optional
20+
21+
import tvm
22+
import tvm.ir
23+
from tvm import relay
24+
from tvm import te
25+
from tvm.relay import transform
26+
from tvm.contrib import cudnn
27+
28+
from ...dataflow_pattern import is_op, wildcard
29+
from .te_target import lower_composite, relay_to_runtime
30+
from .register import register_pattern_table
31+
32+
33+
tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda()))
34+
35+
36+
def partition_for_cudnn(
37+
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
38+
) -> tvm.IRModule:
39+
"""Partition the graph to offload for cuDNN.
40+
41+
Parameters
42+
----------
43+
mod : tvm.IRModule
44+
The module to partition.
45+
params : Optional[Dict[str, tvm.runtime.NDArray]]
46+
Constant input parameters.
47+
48+
Returns
49+
-------
50+
tvm.IRModule
51+
The partitioned module.
52+
"""
53+
54+
seq = tvm.transform.Sequential(
55+
[
56+
transform.InferType(),
57+
transform.MergeComposite(pattern_table()),
58+
transform.AnnotateTarget("cudnn"),
59+
transform.PartitionGraph(),
60+
transform.InferType(),
61+
]
62+
)
63+
return seq(mod)
64+
65+
66+
@register_pattern_table("cudnn")
67+
def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], bool]]]:
68+
"""Get the cuDNN pattern table."""
69+
70+
def softmax_pattern() -> relay.Pattern:
71+
"""Create pattern for softmax."""
72+
return is_op("nn.softmax")(wildcard())
73+
74+
def check_softmax(matched: relay.Call) -> bool:
75+
"""Check if softmax is supported by cuDNN."""
76+
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
77+
return False
78+
79+
return True
80+
81+
return [
82+
("cudnn.softmax", softmax_pattern(), check_softmax),
83+
]
84+
85+
86+
@lower_composite("cudnn.softmax")
87+
def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
88+
"""Lower a softmax using cuDNN."""
89+
return cudnn.softmax(inputs[0], axis=op.attrs["axis"])
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Support a Relay partitioning target using Tensor Expressions."""
18+
from typing import Callable, List, Dict
19+
20+
import tvm
21+
import tvm.ir
22+
from tvm import relay
23+
from tvm import te
24+
25+
26+
_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
27+
_LOWER_MAP: Dict[str, _LowerFunc] = {}
28+
29+
30+
def lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
31+
"""Register a lowering function for a given composite function name."""
32+
33+
def _register(f: _LowerFunc) -> _LowerFunc:
34+
_LOWER_MAP[comp_name] = f
35+
return f
36+
37+
return _register
38+
39+
40+
def relay_to_runtime(target: tvm.target.Target) -> Callable[[relay.Function], tvm.runtime.Module]:
41+
"""Create a Relay to runtime module lowering function using Tensor Expressions for lowering."""
42+
43+
def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
44+
"""Compile Relay functions to a runtime module using Tensor Expressions."""
45+
assert isinstance(partition, relay.Function)
46+
assert isinstance(partition.body, relay.Call)
47+
assert isinstance(partition.body.op, relay.Function)
48+
49+
global_name = str(partition.attrs.global_symbol)
50+
comp_func = partition.body.op
51+
comp_name = comp_func.attrs["Composite"]
52+
assert comp_name in _LOWER_MAP
53+
assert isinstance(comp_func.body, relay.Call)
54+
55+
op = comp_func.body
56+
inputs = []
57+
for i, param in enumerate(comp_func.params):
58+
inputs.append(
59+
te.placeholder(
60+
param.checked_type.shape,
61+
name=f"input_{i}",
62+
dtype=param.checked_type.dtype,
63+
)
64+
)
65+
66+
output = _LOWER_MAP[comp_name](op, inputs)
67+
prim_func = te.create_prim_func(inputs + [output])
68+
return tvm.build(prim_func, target=target, name=global_name)
69+
70+
return _relay_to_runtime

tests/python/contrib/test_cudnn.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121

2222
import tvm
2323
from tvm import te
24+
from tvm import relay
2425
from tvm.contrib import cudnn
2526
from tvm.contrib.nvcc import have_fp16
27+
from tvm.contrib import graph_executor
2628
import numpy as np
2729
import tvm.topi.testing
2830
import tvm.testing
31+
from tvm.relay.op.contrib.cudnn import partition_for_cudnn
2932

3033

3134
requires_cudnn = pytest.mark.skipif(
@@ -445,5 +448,70 @@ def conv_output_shape_kwargs(request):
445448
return request.param
446449

447450

451+
def _verify_cudnn_relay(expr):
452+
np.random.seed(42)
453+
454+
mod = tvm.IRModule.from_expr(expr)
455+
mod = relay.transform.InferType()(mod)
456+
func = mod["main"]
457+
cudnn_mod = partition_for_cudnn(mod)
458+
assert len(cudnn_mod.get_global_vars()) == 2
459+
460+
input_data = []
461+
for param in func.params:
462+
shape = [int(x) for x in param.checked_type.shape]
463+
input_data.append(
464+
(param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype))
465+
)
466+
467+
# Test against CPU reference
468+
cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod)
469+
cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod)
470+
outputs = []
471+
for target, dev, test_mod in [cuda_config, cpu_config]:
472+
with tvm.transform.PassContext(opt_level=3):
473+
lib = relay.build(test_mod, target=target, target_host=cpu_config[0])
474+
module = graph_executor.GraphModule(lib["default"](dev))
475+
for name, data in input_data:
476+
module.set_input(name, tvm.nd.array(data, dev))
477+
478+
module.run()
479+
out_type = func.body.checked_type
480+
outputs.append(
481+
module.get_output(0, tvm.nd.empty(out_type.shape, dtype=out_type.dtype)).numpy()
482+
)
483+
484+
tvm.testing.assert_allclose(
485+
outputs[0],
486+
outputs[1],
487+
rtol=1e-3,
488+
)
489+
490+
491+
@tvm.testing.requires_cuda
492+
@pytest.mark.parametrize(
493+
"shape,axis",
494+
[
495+
((200,), 0),
496+
((13, 27), 0),
497+
((44, 12, 67), 1),
498+
((1, 16, 16, 8), 2),
499+
((2, 4, 6, 8, 10), 3),
500+
],
501+
)
502+
@pytest.mark.parametrize(
503+
"dtype",
504+
[
505+
"float32",
506+
"float16",
507+
"float64",
508+
],
509+
)
510+
def test_relay_cudnn_softmax(shape, axis, dtype):
511+
x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype))
512+
softmax = relay.op.nn.softmax(x, axis=axis)
513+
_verify_cudnn_relay(softmax)
514+
515+
448516
if __name__ == "__main__":
449517
sys.exit(pytest.main(sys.argv))

tests/scripts/task_mypy.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ mypy --check-untyped-defs python/tvm/tir/transform/
3636
echo "Checking MyPy Type defs in the TIR package with unittest"
3737
MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py
3838

39-
echo "Checking MyPy Type defs in tvm.relay.op.contrib.cublas"
39+
echo "Checking MyPy Type defs in tvm.relay.op.contrib"
4040
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py
41+
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cudnn.py
42+
mypy --disallow-untyped-defs python/tvm/relay/op/contrib/te_target.py
4143

4244
#TODO(@mikepapadim): This is failing atm
4345
# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package."

0 commit comments

Comments
 (0)