Skip to content

Commit 028c204

Browse files
mbs-octomlSergey Shtin
authored andcommitted
[Relay] Support 'external codegen targets'. (apache#11173)
* [Relay] Support 'external codegen targets'. (Part of Collage, https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md) This change prepares the VM and Relay target handling machinery to support external codegen targets in addition to 'regular' targets. This allows us to configure the build with Collage as follows: ``` host_target = tvm.target.Target("llvm") targets = [tvm.target.Target("cuda", host_target), tvm.target.Target("cutlass", host_target), tvm.target.Target("cudnn", host_target)] with tvm.transform.PassContext(...): exe = tvm.relay.vm.compile(module, target=targets) ``` Four changes are required: 1. I introduce four new target kinds for the external codegens currently supported by Collage. Others can be added as they are vetted for use by Collage. These are given a device type matching the external codegen's assumption (ie just CUDA currently), and given a target kind attribute "is_external_codegen" of True. The latter is needed by Collage to signal the target kind name represents and external codegen 'compiler' name. See the RFC for specifics. 2. I introduce the binary relation Target::IsExternalCodegenFor so that external codegen targets can be related back to the 'underlying' targets they are implicitly using in their codegen. 3. I rework the VMCompiler and BuildModule interfaces to accept an Array<Target> of 'raw targets' instead of a Map<Integer, Target>. This more general representation is needed because we may now have multiple targets of the same device type active simultaneously. I add new static methods on the Python Target to convert to this form in a way that mimics check_and_update_host_consist. 4. I rework CompilationConfig to work from Array<Target> directly, to not depend on the host_target argument (since dealt with on the Python side), and to understand that if we have two targets for the same device type the non-external codegen target takes precedence. The change to CompilationConfig seems neutral with respect to the recent discussions on compilation configuration representation and tvmc. I made a few attempts to remove Target.check_and_update_host_const entirely in favor of using CompilationConfig as the definitive target handling choke point but backed out once they became too large. * - Working on unit tests * - Fix two Debug-only failures * - Use Array<Target> in GraphExecutorCodegen/AOTExecutorCodegen ifaces instead of CompilationConfig (don't want to bake it into any official APIs). - Started unit tests. * - Lints * - Moar Lints * - Fix some unit tests * - Fix last unit test failures * - whitespace * - Address Eric's comments. CI likely to fail due to stricter FindPrimitiveTargetOrFail but let's see. * - Comment adjustments. - Unit test for new Target members.
1 parent f4ad0c4 commit 028c204

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+979
-665
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,12 @@ if(GTEST_FOUND)
647647
target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} GTest::GTest GTest::Main GTest::gmock pthread dl)
648648
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1)
649649
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
650+
if(USE_RELAY_DEBUG)
651+
target_compile_definitions(cpptest PRIVATE "USE_RELAY_DEBUG")
652+
target_compile_definitions(cpptest PRIVATE "TVM_LOG_DEBUG")
653+
else()
654+
target_compile_definitions(cpptest PRIVATE "NDEBUG")
655+
endif()
650656
# For some reason, compile definitions are not propagated correctly, so we manually add them here
651657
target_compile_definitions(cpptest PUBLIC $<TARGET_PROPERTY:tvm,INTERFACE_COMPILE_DEFINITIONS>)
652658
gtest_discover_tests(cpptest)

cmake/modules/CUDA.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,17 @@ if(USE_CUDA)
4141
if(USE_CUDNN)
4242
message(STATUS "Build with cuDNN support")
4343
include_directories(SYSTEM ${CUDA_CUDNN_INCLUDE_DIRS})
44+
tvm_file_glob(GLOB CUDNN_RELAY_CONTRIB_SRC src/relay/backend/contrib/cudnn/*.cc)
45+
list(APPEND COMPILER_SRCS ${CUDNN_RELAY_CONTRIB_SRC})
4446
tvm_file_glob(GLOB CONTRIB_CUDNN_SRCS src/runtime/contrib/cudnn/*.cc)
4547
list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_SRCS})
4648
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY})
4749
endif(USE_CUDNN)
4850

4951
if(USE_CUBLAS)
5052
message(STATUS "Build with cuBLAS support")
53+
tvm_file_glob(GLOB CUBLAS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cublas/*.cc)
54+
list(APPEND COMPILER_SRCS ${CUBLAS_RELAY_CONTRIB_SRC})
5155
tvm_file_glob(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc)
5256
list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS})
5357
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY})

include/tvm/target/compilation_config.h

Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
/*!
2121
* \file tvm/target/compilation_config.h
2222
* \brief A helper class to collect all the targets in canonical form necessary for compilation.
23-
* CAUTION: Preliminary, currently only used to support device planning, very likely to change.
2423
*/
2524

2625
#ifndef TVM_TARGET_COMPILATION_CONFIG_H_
@@ -32,40 +31,30 @@ namespace tvm {
3231

3332
/*!
3433
* \brief Gathers the \p Targets and distinguished \p VirtualDevices in canonical form needed to
35-
* compile a Relay module. Centralizes any setup and validation logic needed to transition
36-
* from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly
37-
* (eg a a list of \p Targets) to the configuration.
34+
* compile a Relay module for execution over possibly heterogeneous devices. Centralizes the
35+
* validation and canonicalization logic needed to transition from targets supplied by the Python
36+
* APIs to a single internal representation. Also holds a cache of canonical \p VirtualDevices
37+
* so that structural equal virtual devices have pointer equal canonical virtual devices.
3838
*
39-
* CAUTION: This is subject to change as we rework compilation options in general. See
40-
* https://github.com/apache/tvm-rfcs/blob/main/rfcs/0028-command-line-registry-composition.md.
41-
* So far this class is only focussed on carrying just the configuration needed by PlanDevices,
42-
* and removing target-munging code duplication and inconsistencies between the three major build
43-
* flows for the VM (relay/backend/vm/compile.cc), Graph/AOT (relay/backend/build_module.cc) and
44-
* Interpreter (relay/backend/interpreter.cc). Over time we expect more global compiler
45-
* configuration (eg for executor and runtime config, for system memory pool configuration, etc)
46-
* to migrate into this class, and instances thereof to be attached to \p IRModules using a
47-
* well-known attribute.
39+
* The construction of \p CompilationConfig is idempotent, in that given the same \p PassContext
40+
* \p ctx and an arbitrary \p Array<Target> \p raw_targets:
41+
*
42+
* \code
43+
* CompilationConfig(ctxt, raw_targets)
44+
* is structurally equal to
45+
* CompilationConfig(ctxt, CompilationConfig(ctxt, raw_targets)->primitive_targets)
46+
* \endcode
47+
*
48+
* TODO(mbs): This is subject to change as we rework compilation options in general. This class
49+
* is probably better called a 'CompositeTarget', and may be better made a sub-class of Target or
50+
* some other common-target-root class.
4851
*/
4952
class CompilationConfigNode : public Object {
5053
public:
51-
/*!
52-
* \brief The legacy targets map, mapping device type to the corresponding \p Target to use
53-
* when compiling primitive functions. Does not include an entry for the host target, however
54-
* each \p Target in this map will have it's \p host field set to the \p host_target.
55-
*
56-
* Currently we require at most one \p Target per \p DLDeviceType, though we want to get rid of
57-
* that limitation.
58-
*
59-
* CAUTION: Since keys are \p Integers they are compared by object equality not integer
60-
* value.
61-
*
62-
* TODO(mbs): Remove once codegen updated for new target conventions.
63-
*/
64-
TargetMap legacy_target_map;
65-
6654
/*!
6755
* \brief The host target. Used for 'scalar' data and code (such as shapes and shape
6856
* functions) and residual Relay expressions and data (such as conditionals and ADTs).
57+
* Each \p primitive_target below will have this exact target object as its 'host'.
6958
*
7059
* Note that it is possible for a \p Target used for primitive operations to be structurally
7160
* equal to the host \p Target (up to the \p host field.) However the \p Target objects will
@@ -74,16 +63,37 @@ class CompilationConfigNode : public Object {
7463
Target host_target;
7564

7665
/*!
77-
* \brief Vector of all available \p Targets for compiling primitive operators. May contain
78-
* a \p Target for the same device type as for the \p host_target, however the \p host_target
79-
* should be used for all host computations and data. Each \p Target will have \p host_target
80-
* as its host.
66+
* \brief Vector of all available \p Targets for partitioning or compiling primitive tensor
67+
* operators (kernels). May contain a \p Target for the same device type as for the
68+
* \p host_target, however the \p host_target should be used for all host computations and data.
69+
* Each \p Target will have \p host_target as its 'host'.
70+
*
71+
* It is possible to have multiple primitive targets for the same device type. However given
72+
* primitive targets left and right where:
73+
* - left appears before right in the array
74+
* - left->kind->device_type == right->kind->device_type
75+
* then:
76+
* - right.IsExternalCodegenFor(left) must be true
77+
* In this way the FindPrimitiveTargetOrFail method will find the 'most general' target for
78+
* the requested device type.
79+
*
80+
* In the homogeneous case primitive_targets will have just one entry, which will be pointer equal
81+
* to optional_homogeneous_target.
82+
*
83+
* In the homogenous case where the 'host' is the same device as used for compiling kernels it
84+
* is *not* the case that optional_homogenous_target == host_target. This is because all
85+
* primitive always have their host field set to the host_target. Ie, it is valid to have:
86+
* \code
87+
* host_target=Target("llvm")
88+
* optional_homogenous_target=Target("llvm", host=host_target)
89+
* \endcode
8190
*/
8291
Array<Target> primitive_targets;
8392

8493
/*!
8594
* \brief \p VirtualDevice for primitive operators which are not otherwise constrained to a
86-
* particular device.
95+
* particular device. Used by the PlanDevices pass to determine a virtual device for every
96+
* sub-expression.
8797
*/
8898
VirtualDevice default_primitive_virtual_device = VirtualDevice::FullyUnconstrained();
8999

@@ -94,25 +104,33 @@ class CompilationConfigNode : public Object {
94104
* \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all
95105
* primitives are compiled for this target only.
96106
*
97-
* This is to support legacy passes which have not been adapted to hetrogeneous execution and
107+
* This is to support legacy passes which have not been adapted to heterogeneous execution and
98108
* rely on an implicit global \p Target to be in scope.
99109
*
100-
* TODO(mbs): Remove once all passes are 'hetrogeneous aware'.
110+
* TODO(mbs): Remove once all passes are 'heterogeneous aware'.
101111
*/
102112
Target optional_homogeneous_target;
103113

104114
void VisitAttrs(AttrVisitor* v);
105115

116+
/*!
117+
* \brief Return the unique \p Target to use for \p device_type. Fail if no such target exists.
118+
*
119+
* This will be the first primitive target with matching device type.
120+
*/
121+
Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const;
122+
106123
/*!
107124
* \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained
108125
* fields, however:
109-
* - If the target is null then it is filled in from the known available primitive targets by
110-
* matching on device type. Fails if no such target is known.
126+
* - If the target is null then it is filled in using \p FindPrimitiveTargetOrFail to match
127+
* the device type.
111128
* - The returned object is unique for the field values w.r.t. all other \p VirtualDevices
112-
* returned by this method.
129+
* returned by this method.
113130
*
114131
* We call the result the 'canonical' \p VirtualDevice. Two canonical \p VirtualDevices are
115-
* structurally equal if and only if they are pointer equal.
132+
* structurally equal if and only if they are pointer equal. In this way we can build maps
133+
* from virtual devices using just pointer equality.
116134
*/
117135
VirtualDevice CanonicalVirtualDevice(const VirtualDevice& virtual_device) const;
118136

@@ -121,31 +139,20 @@ class CompilationConfigNode : public Object {
121139

122140
private:
123141
/*!
124-
* \brief Establishes the default \p VirtualDevice for primitives and the \p VirtualDevice for the
125-
* host given:
126-
* - the vector of available primitive \p Targets.
127-
* - any host \p Target.
142+
* \brief Sets the primitive targets, the host target, the default primitive virtual device, and
143+
* the host virtual device given:
144+
* - the vector of 'raw' targets (in any order) supplied by one of the TVM entry points.
128145
* - any "relay.fallback_device_type" attribute on \p pass_ctx.
129146
* - whether the LLVM backend is available.
130-
* If necessary, creates new default \p Targets to match the required devices.
131-
*
132-
* NOTE: The implementation is a bit convoluted since it tries to maintain backwards
133-
* compatibility with legacy methods for conveying \p Targets.
134-
*
135-
* CAUTION: Recreated the primitive_targets so that they all have the given/constructed
136-
* host_target as their host (cf CheckAndUpdateHostConsistency).
147+
* Will look for a suitable host target in the given primitive targets, but if none found may
148+
* reuse a raw target or create a default CPU target.
137149
*/
138-
void EstablishDefaultVirtualDevices(const transform::PassContext& pass_ctx);
150+
void Init(const transform::PassContext& pass_ctx, const Array<Target>& raw_targets);
139151

140152
/*!
141-
* \brief Returns a freshly constructed \p Target to represent \p device_type.
153+
* \brief Returns a freshly constructed CPU \p Target.
142154
*/
143-
static Target MakeDefaultTarget(DLDeviceType device_type);
144-
145-
/*!
146-
* \brief Return the \p Target to use for \p device_type. Fail if no such target exists.
147-
*/
148-
Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const;
155+
static Target MakeDefaultCPUTarget();
149156

150157
/*!
151158
* \brief A cache of constructed virtual devices.
@@ -163,13 +170,11 @@ class CompilationConfigNode : public Object {
163170
class CompilationConfig : public ObjectRef {
164171
public:
165172
/*!
166-
* \brief Constructs the compilation config given the available \p Targets in the
167-
* \p legacy_target_map_arg and an optional \p optional_host_target_arg. May use
168-
* 'relay.fallback_device_type' and the availability of the LLVM compilation module
169-
* to decide on appropriate default devices.
173+
* \brief Constructs the compilation config given the settings in \p pass_ctx and supplied
174+
* \p raw_targets. See \p CompilationConfigNode::Init for details.
170175
*/
171-
TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg,
172-
Target optional_host_target_arg);
176+
TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx,
177+
const Array<Target>& raw_targets);
173178

174179
TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode);
175180
};

include/tvm/target/target.h

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,34 @@ class Target : public ObjectRef {
177177
*/
178178
static Target WithHost(const Target& target, const Target& host);
179179

180+
/*!
181+
* \brief Returns true if \p this target represents an external codegen. If so,
182+
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,
183+
* and can be used to retrieve a partitioning pattern table using
184+
* \p get_pattern_table.
185+
*/
186+
bool IsExternalCodegen() const;
187+
188+
/*!
189+
* \brief Returns true if \p this target represents an external codegen which is compatible
190+
* with \p that target. In particular:
191+
* - \p this has a true ::tvm::attr::kIsExternalCodegen attribute
192+
* - \p that does not have a true ::tvm::attr::kIsExternalCodegen attribute
193+
* - \p this and \p that have the same kind->device_type
194+
*
195+
* After partitioning, the external codegen compilation path may use \p that to guide it's
196+
* compilation to a \p runtime::Module. Given \p this, an appropriate \p that can be
197+
* found using \p CompilationConfig::FindPrimitiveTargetOrFail(this->kind->device_type).
198+
*
199+
* The \p CollagePartition pass uses this method to guide it's search over candidate partitions
200+
* using external codegen.
201+
*/
202+
bool IsExternalCodegenFor(const Target& that) const;
203+
180204
private:
205+
Target(TargetKind kind, Optional<ObjectRef> host, String tag, Array<String> keys,
206+
Map<String, ObjectRef> attrs);
207+
181208
// enable with syntax.
182209
friend class TargetInternal;
183210
friend class With<Target>;
@@ -194,8 +221,6 @@ class Target : public ObjectRef {
194221
TVM_DLL void ExitWithScope();
195222
};
196223

197-
using TargetMap = Map<Integer, Target>;
198-
199224
/*!
200225
* \brief Check and update host field of the given legacy target and target host pair.
201226
* Note that this function is for legacy target api compatibility issue only, not
@@ -205,15 +230,6 @@ using TargetMap = Map<Integer, Target>;
205230
*/
206231
void CheckAndUpdateHostConsistency(Target* target, Target* host);
207232

208-
/*!
209-
* \brief Check and update host field of the given legacy heterogeneous targets and
210-
* target host.Note that this function is for legacy target api compatibility issue only,
211-
* not recommended for other use.
212-
* \param target_map The pointer to a Map objects with values being Target objects
213-
* \param host The Target typed object for target host to be updated
214-
*/
215-
void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host);
216-
217233
/*!
218234
* \brief Check and update host field of the given legacy heterogeneous targets and
219235
* target host.Note that this function is for legacy target api compatibility issue only,

include/tvm/target/target_kind.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,26 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_name() {
384384
#define TVM_TARGET_KIND_REGISTER_VAR_DEF \
385385
static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetKindRegEntry& __make_##TargetKind
386386

387+
namespace attr {
388+
//
389+
// Distinguished TargetKind attribute names.
390+
//
391+
392+
/*!
393+
* \brief A \p TargetKind attribute of type \p Bool. If true, then the target kind name also
394+
* corresponds to an external codegen 'compiler' name. That name may be used:
395+
* - To retrieve partitioning rules using \p get_partition_table.
396+
* - To attach to Relay Functions under the \p attr::kCompiler attribute to indicate
397+
* the function is to be compiled by the external codegen path.
398+
*
399+
* The \p CollagePartition pass uses this attribute to guide it's search over candidate partitions
400+
* using external codegen.
401+
*
402+
* See also \p Target::IsExternalCodegenFor
403+
*/
404+
constexpr const char* kIsExternalCodegen = "is_external_codegen";
405+
} // namespace attr
406+
387407
/*!
388408
* \def TVM_REGISTER_TARGET_KIND
389409
* \brief Register a new target kind, or set attribute of the corresponding target kind.

python/tvm/autotvm/task/relay_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _lower(mod, target, params, opt_level=3):
4444
import vta
4545

4646
with vta.build_config(opt_level=opt_level, disabled_pass={"AlterOpLayout"}):
47-
mod, _ = relay.optimize(mod, target, params)
47+
mod, _ = relay.optimize(mod, target=target, params=params)
4848
grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
4949
grc.codegen(mod, mod["main"])
5050
return

python/tvm/autotvm/tophub.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from os import getenv
2727
import sys
2828
from pathlib import Path
29+
from tvm.ir.container import Array
2930

3031
from .task import ApplyHistoryBest
3132
from ..target import Target
@@ -87,7 +88,7 @@ def context(target, extra_files=None):
8788
Parameters
8889
----------
8990
target: Target or List of Target
90-
The compilation target
91+
The compilation targets
9192
extra_files: list of str, optional
9293
Extra log files to load
9394
"""
@@ -97,7 +98,7 @@ def context(target, extra_files=None):
9798

9899
best_context = ApplyHistoryBest([])
99100

100-
targets = target if isinstance(target, (list, tuple)) else [target]
101+
targets = target if isinstance(target, (Array, list, tuple)) else [target]
101102

102103
for tgt in targets:
103104
if isinstance(tgt, str):

python/tvm/relay/backend/graph_executor_codegen.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from tvm.runtime.ndarray import empty
3737
from tvm.relay import _build_module
3838
from tvm.target import Target
39-
from tvm.tir import expr as _expr
4039
from .utils import mangle_module_name
4140

4241

@@ -54,15 +53,8 @@ def __init__(self, mod, target):
5453
self._setup(mod, target)
5554

5655
def _setup(self, mod, target):
57-
tgts = {}
58-
if isinstance(target, dict):
59-
for dev, tgt in target.items():
60-
if not isinstance(tgt, (str, Target)):
61-
raise Exception("Unknown target type")
62-
tgts[dev] = Target(tgt)
63-
elif isinstance(target, (str, Target)):
64-
tgts[_expr.IntImm("int32", 0)] = Target(target)
65-
self._init(mod, tgts)
56+
raw_targets = Target.canonicalize_target_and_host(target)
57+
self._init(mod, raw_targets)
6658

6759
def codegen(self, ir_module, func):
6860
"""Compile a single function into a graph.

0 commit comments

Comments
 (0)