Skip to content

Commit 0c694d1

Browse files
committed
[relay] Relay annotation and partitioning for codegen
1 parent e6ff3f7 commit 0c694d1

File tree

19 files changed

+1279
-5
lines changed

19 files changed

+1279
-5
lines changed

include/tvm/relay/attrs/annotation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
5757
}
5858
};
5959

60+
/*!
61+
* \brief Options for the operators used to annotate a compiler.
62+
*/
63+
struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
64+
/*! \brief The 3rd party compiler for code generation. */
65+
std::string compiler;
66+
67+
TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") {
68+
TVM_ATTR_FIELD(compiler)
69+
.describe("The 3rd compiler used for code generation.");
70+
}
71+
};
72+
6073
} // namespace relay
6174
} // namespace tvm
6275
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_

include/tvm/relay/op_attr_types.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/build_module.h>
3030
#include <tvm/relay/type.h>
3131
#include <tvm/relay/expr.h>
32+
#include <string>
3233

3334
namespace tvm {
3435
namespace relay {
@@ -122,7 +123,7 @@ using FTVMSchedule = runtime::TypedPackedFunc<
122123
* operator with other expressions. This function will be invoked
123124
* in AlterOpLayout pass.
124125
* \param attrs The attribute of the original node.
125-
* \param inputs The input symbols of the original node.
126+
* \param args The input symbols of the original node.
126127
* \param tinfos An array of placeholders, use for getting the inferred shape
127128
* and dtype of the inputs.
128129
* \return new_expr The modified expression.
@@ -136,8 +137,8 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
136137
* \brief Legalizes an expression with another expression. This function will be
137138
* invoked in Legalize pass. It is a target-dependent pass.
138139
* \param attrs The attribute of the original node.
139-
* \param inputs The input symbols of the original node.
140-
* \param tinfos An array of placeholders, use for getting the inferred shape
140+
* \param args The input symbols of the original node.
141+
* \param arg_types An array of placeholders, use for getting the inferred shape
141142
* and dtype of the inputs.
142143
* \return new_expr The modified expression.
143144
*/
@@ -146,6 +147,22 @@ using FTVMLegalize = runtime::TypedPackedFunc<
146147
const Array<Expr>& args,
147148
const Array<tvm::relay::Type>& arg_types)>;
148149

150+
/*!
151+
* \brief Annotates an expression to indicate which compiler an op
152+
* should be used for codegen.
153+
*
154+
* \param attrs The attribute of the original expr.
155+
* \param args The arguments of the original expr.
156+
* \param compiler The compiler that is used to compile the op.
157+
*
158+
* \return true if this op should be registered to invoke a specific compiler
159+
* for codegen, otherwise, false.
160+
*/
161+
using FTVMAnnotateCompiler = runtime::TypedPackedFunc<
162+
bool(const Attrs& attrs, // NOLINT(*)
163+
const Array<Expr>& args,
164+
const std::string& compiler)>;
165+
149166
/*!
150167
* \brief Forward rewriting rule for a specific op.
151168
*

include/tvm/relay/transform.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,14 @@ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
576576
*/
577577
TVM_DLL Pass PrintIR(bool show_meta_data = true);
578578

579+
/*!
580+
* \brief Partition a Relay program into regions that can be executed on
581+
* different backends.
582+
*
583+
* \return The pass.
584+
*/
585+
TVM_DLL Pass PartitionGraph();
586+
579587
} // namespace transform
580588

581589
/*!

python/tvm/relay/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from . import adt
3030
from . import analysis
3131
from . import transform
32-
from .build_module import build, create_executor, optimize
32+
from .build_module import build, create_executor, optimize, build_extern_compiler
3333
from .transform import build_config
3434
from . import prelude
3535
from . import parser

python/tvm/relay/build_module.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .module import Module as _Module
3131
from .backend import interpreter as _interpreter
3232
from .backend.vm import VMExecutor
33+
from . import transform as _transform
3334

3435
def _update_target(target):
3536
target = target if target else _target.current_target()
@@ -296,6 +297,34 @@ def optimize(mod, target=None, params=None):
296297
return mod, params
297298

298299

300+
def build_extern_compiler(mod, compiler):
301+
"""Helper function that annotates a Relay module and patitions the
302+
expression init into various regions. These regions will be handled
303+
by either default compilers in TVM stack or the provided external compiler.
304+
305+
Parameters
306+
----------
307+
mod : relay.Module
308+
The module to build. Using relay.Function is deprecated.
309+
310+
compiler : str
311+
The name of the external compiler.
312+
313+
Returns
314+
-------
315+
mod : relay.Module
316+
The relay module contains partitioned program regions (e.g. functions)
317+
that will be compiled using different compilers.
318+
"""
319+
if isinstance(mod, _expr.Function):
320+
mod = _Module.from_expr(mod)
321+
322+
seq = _transform.Sequential([_transform.AnnotateCompiler(compiler),
323+
_transform.PartitionGraph()])
324+
mod = seq(mod)
325+
return mod
326+
327+
299328
class GraphExecutor(_interpreter.Executor):
300329
"""Wrapper around Executor interface.
301330

python/tvm/relay/op/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# operator defs
2020
from .op import get, register, register_schedule, register_compute, register_gradient, \
2121
register_pattern, register_alter_op_layout, register_legalize, \
22-
schedule_injective, Op, OpPattern, debug
22+
register_annotate_compiler, schedule_injective, Op, OpPattern, debug
2323

2424
# Operators
2525
from .reduce import *

python/tvm/relay/op/annotation/annotation.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def stop_fusion(data):
6262
"""
6363
return _make.stop_fusion(data)
6464

65+
6566
def checkpoint(data):
6667
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
6768
@@ -78,3 +79,43 @@ def checkpoint(data):
7879
return _make.checkpoint(data)
7980

8081
register_schedule("annotation.checkpoint", schedule_injective)
82+
83+
84+
def compiler_begin(data, compiler):
85+
"""Annotate an expression to indicate that it is the beginning of
86+
a regeion that will be handled by the given compiler.
87+
88+
Parameters
89+
----------
90+
data : tvm.relay.Expr
91+
The expression to be annotated.
92+
93+
compiler : Str
94+
The compiler used to generate code of the annotated region.
95+
96+
Returns
97+
-------
98+
result : tvm.relay.Expr
99+
The annotated expression.
100+
"""
101+
return _make.compiler_begin(data, compiler)
102+
103+
104+
def compiler_end(data, compiler):
105+
"""Annotate an expression to indicate that it is the end of a region that
106+
is handled by the provided compiler.
107+
108+
Parameters
109+
----------
110+
data : tvm.relay.Expr
111+
The expression to be annotated.
112+
113+
compiler : Str
114+
The compiler used to generate code of the annotated region.
115+
116+
Returns
117+
-------
118+
result : tvm.relay.Expr
119+
The annotated expression.
120+
"""
121+
return _make.compiler_end(data, compiler)

python/tvm/relay/op/contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
"""Neural network related operators."""
1919
from __future__ import absolute_import as _abs
2020
from .contrib import *
21+
from .annotate_compiler import *
2122
from . import _contrib
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-argument
18+
"""
19+
External compiler related feature registration.
20+
21+
It implements dispatchers that check if an operator should use a given compiler
22+
to generate code.
23+
24+
Each compiler can customize the support of an operator. For example, they can
25+
check the attribute of the operator and/or the features of the input arguments
26+
to decide if we should use the compiler for codegen.
27+
"""
28+
from __future__ import absolute_import
29+
30+
import logging
31+
import pkgutil
32+
from pathlib import Path
33+
from importlib import import_module
34+
35+
from .. import op as reg
36+
37+
logger = logging.getLogger('AnnotateCompiler')
38+
39+
# Load available contrib compilers
40+
compilers = {}
41+
for _, name, _ in pkgutil.iter_modules([Path(__file__).parent]):
42+
compilers[name] = import_module(
43+
'.%s' % name, package='.'.join(__name__.split('.')[:-1]))
44+
45+
46+
def get_annotate_compiler(compiler, op_name):
47+
"""Get the annotate_compiler function from the registered compilers.
48+
49+
Parameters
50+
----------
51+
compiler : Str
52+
The name of a compiler that is used to generate code.
53+
54+
op_name : Str
55+
The name of an operator.
56+
57+
Returns
58+
-------
59+
ret : bool
60+
If the operator uses the provided compiler for codegen.
61+
"""
62+
if compiler in compilers:
63+
if hasattr(compilers[compiler], 'annotate_compiler'):
64+
annotate_compiler = getattr(compilers[compiler], 'annotate_compiler')
65+
if hasattr(annotate_compiler, op_name):
66+
return getattr(annotate_compiler, op_name)
67+
68+
logger.warning("%s in %s is not registered. Fallback to CPU", op_name,
69+
compiler)
70+
return lambda x, y: False
71+
72+
73+
@reg.register_annotate_compiler("nn.conv2d")
74+
def annotate_conv2d(attrs, args, compiler):
75+
"""Check if the provided compiler should be used for conv2d.
76+
"""
77+
return get_annotate_compiler(compiler, 'conv2d')(attrs, args)
78+
79+
80+
@reg.register_annotate_compiler("nn.dense")
81+
def annotate_dense(attrs, args, compiler):
82+
"""Check if the provided compiler should be used for dense.
83+
"""
84+
return get_annotate_compiler(compiler, 'dense')(attrs, args)
85+
86+
87+
@reg.register_annotate_compiler("nn.relu")
88+
def annotate_relu(attrs, args, compiler):
89+
"""Check if the provided compiler should be used for relu.
90+
"""
91+
return get_annotate_compiler(compiler, 'relu')(attrs, args)
92+
93+
94+
@reg.register_annotate_compiler("nn.batch_norm")
95+
def annotate_batch_norm(attrs, args, compiler):
96+
"""Check if the provided compiler should be used for batch_norm.
97+
"""
98+
return get_annotate_compiler(compiler, 'batch_norm')(attrs, args)
99+
100+
101+
@reg.register_annotate_compiler("subtract")
102+
def annotate_subtract(attrs, args, compiler):
103+
"""Check if the provided compiler should be used for subtract.
104+
"""
105+
return get_annotate_compiler(compiler, 'subtract')(attrs, args)
106+
107+
108+
@reg.register_annotate_compiler("add")
109+
def annotate_add(attrs, args, compiler):
110+
"""Check if the provided compiler should be used for add.
111+
"""
112+
return get_annotate_compiler(compiler, 'add')(attrs, args)
113+
114+
115+
@reg.register_annotate_compiler("multiply")
116+
def annotate_multiply(attrs, args, compiler):
117+
"""Check if the provided compiler should be used for multiply.
118+
"""
119+
return get_annotate_compiler(compiler, 'multiply')(attrs, args)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wildcard-import
18+
"""Neural network related operators."""
19+
from __future__ import absolute_import as _abs
20+
from .annotate_compiler import *

0 commit comments

Comments
 (0)