Skip to content

Commit 0533b4c

Browse files
committed
- Address Eric's comments.
CI likely to fail due to stricter FindPrimitiveTargetOrFail but let's see.
1 parent 320caf4 commit 0533b4c

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

include/tvm/target/target.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ class Target : public ObjectRef {
177177
*/
178178
static Target WithHost(const Target& target, const Target& host);
179179

180+
/*!
181+
* \brief Returns true if \p this target represents an external codegen. If so,
182+
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,
183+
* and can be used to retrieve a partitioning pattern table using
184+
* \p get_pattern_table.
185+
*/
186+
bool IsExternalCodegen() const;
187+
180188
/*!
181189
* \brief Returns true if \p this target represents an external codegen which is compatible
182190
* with \p that target. In particular:

src/target/compilation_config.cc

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@ void CompilationConfigNode::VisitAttrs(AttrVisitor* v) {
3939
}
4040

4141
Target CompilationConfigNode::FindPrimitiveTargetOrFail(DLDeviceType device_type) const {
42-
if (device_type < 0 && primitive_targets.size() == 1) {
43-
// In the homogenous case don't be fussy with device types.
44-
return primitive_targets.front();
45-
}
46-
ICHECK_GT(device_type, 0);
42+
ICHECK_GT(device_type, 0) << "Invalid device type";
4743
auto itr = std::find_if(
4844
primitive_targets.begin(), primitive_targets.end(),
4945
[device_type](const Target& target) { return target->kind->device_type == device_type; });
@@ -144,6 +140,8 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
144140

145141
//
146142
// Check the primitive_targets are ordered correctly re Target::IsExternalCodegenFor.
143+
// Note we could just sort the list, but given all the implicit defaulting for backwards
144+
// compat it seems we should avoid making this any more magical than necessarny.
147145
//
148146
std::unordered_set<DLDeviceType> primitive_target_device_types;
149147
for (const auto& target : primitive_targets) {
@@ -157,13 +155,18 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
157155
}
158156
if (!first_primitive_target.defined()) {
159157
first_primitive_target = current_primitive_target;
160-
continue;
158+
CHECK(!first_primitive_target.IsExternalCodegen())
159+
<< "The first given target for device type " << device_type
160+
<< " must not be for an external codegen, however given "
161+
<< first_primitive_target->ToDebugString();
162+
} else {
163+
CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target))
164+
<< "When given multiple targets for the device type " << device_type
165+
<< " the first must be for non external codegen, and all subsequent must be for "
166+
"external codegen. However have been given first "
167+
<< first_primitive_target->ToDebugString() << " and subsequent "
168+
<< current_primitive_target->ToDebugString();
161169
}
162-
CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target))
163-
<< "The first given target for device type " << device_type << " is "
164-
<< first_primitive_target->ToDebugString() << ", however a later target "
165-
<< current_primitive_target->ToDebugString()
166-
<< " for the same device type is not an external codegen target.";
167170
}
168171
}
169172

src/target/target.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,14 @@ Target::Target(TargetKind kind, Optional<ObjectRef> host, String tag, Array<Stri
494494
data_ = std::move(data);
495495
}
496496

497-
bool Target::IsExternalCodegenFor(const Target& that) const {
497+
bool Target::IsExternalCodegen() const {
498498
TargetKindAttrMap<Bool> attr_map = TargetKind::GetAttrMap<Bool>(::tvm::attr::kIsExternalCodegen);
499-
return get()->kind->device_type == that->kind->device_type &&
500-
attr_map.get(get()->kind, Bool(false)) && !attr_map.get(that->kind, Bool(false));
499+
return attr_map.get(get()->kind, Bool(false));
500+
}
501+
502+
bool Target::IsExternalCodegenFor(const Target& that) const {
503+
return get()->kind->device_type == that->kind->device_type && IsExternalCodegen() &&
504+
!that.IsExternalCodegen();
501505
}
502506

503507
std::vector<std::string> TargetNode::GetKeys() const {

0 commit comments

Comments
 (0)