Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call.");
TVM_ATTR_FIELD(metadata)
.describe("Metadata attached to the lowered function call.")
.set_default(Map<String, ObjectRef>());
}
};

Expand Down
62 changes: 33 additions & 29 deletions include/tvm/relay/attrs/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file tvm/relay/attrs/on_device.h
* \brief Attribute for the on device annotation.
* \brief Attribute for the "on_device" annotation (ie operator).
*/
#ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_
#define TVM_RELAY_ATTRS_ON_DEVICE_H_
Expand All @@ -33,9 +33,9 @@ namespace tvm {
namespace relay {

/*!
* \brief Attributes for the "on_device" special operator.
* \brief Attributes for the "on_device" annotation (ie operator).
*
* The Relay call (aka 'annotation'):
* The Relay call:
* \code
* on_device(sub_expr, se_scope=S)
* \endcode
Expand All @@ -54,44 +54,48 @@ namespace relay {
* multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z)
* \endcode
*
* The Relay call
* \code
* on_device(sub_expr, se_scope=S, is_fixed=True)
* \endcode
* is similar to the above, however the annotation itself must appear in an expression on the
* same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not
* insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs.
* However it is needed by the \p PlanDevices pass to fully specify the results of device planning
* so that the pass is idempotent.
*
* E.g.: The following program is equivalent to the above:
* \code
* let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True)
* multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z)
* \endcode
* The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored
* on the GPU.
* The \p constraint_body (default true) and \p constraint_result (default false) fields can be
* used by passes for finer-grained control over how the \p SEScope constraint should be applied.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
/*!
* \brief (Virtual) \p SEScope on which the result of the argument expression should be stored.
* \brief The \p SEScope to constraint to apply to the body, result, or both body and result
* of the "on_device" call.
*/
SEScope se_scope = SEScope::FullyUnconstrained();

/*!
* \brief If fales (the default), the result of the "on_device" call is not constrained to be
* \p se_scope.
*/
bool constrain_result = false;

/*!
* \brief If true (the default), the body of the "on_device" call is constrained to be \p
* se_scope.
*/
bool constrain_body = true;

/*!
* \brief Returns true if both the body and result are constrained.
*/
bool is_fixed() const { return constrain_result && constrain_body; }

/*!
* \brief If true, the result \p SEScope must also be \p se_scope, and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
* \brief Returns true only the body is constrained (the 'normal' case).
*/
bool is_fixed = false;
bool is_normal() const { return !constrain_result && constrain_body; }

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(se_scope)
.describe("The (virtual) device and scope holding the expression result.")
.describe("The (virtual) device to constrain to.")
.set_default(SEScope::FullyUnconstrained());
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
TVM_ATTR_FIELD(constrain_result)
.describe("Whether the constraint applies to the overall expression")
.set_default(false);
TVM_ATTR_FIELD(constrain_body)
.describe("Whether the constraint applies to the body sub-expression.")
.set_default(true);
}
};

Expand Down
7 changes: 6 additions & 1 deletion include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
*
* kInvalidDeviceType denotes unconstrained.
*/
int device_type_int;
int /* actually DLDeviceType */ device_type_int;

DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }

Expand Down Expand Up @@ -303,6 +303,11 @@ class SEScope : public ObjectRef {
return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
}

/*! \brief Returns the \p SEScope for \p memory_scope alone. */
static SEScope ForMemoryScope(MemoryScope memory_scope) {
return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
24 changes: 15 additions & 9 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,35 @@ def _make_se_scope(device):
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))


def on_device(data, device, is_fixed=False):
"""Annotates an expression with the device type on which its result should be stored.
def on_device(body, device, constrain_result=False, constrain_body=True):
"""Annotates a body expression with device constraints. The constraint influences
how the body is compiled, where the body is evaluated, and where the result of
evaluation is stored.

Note that the defaults for the constrain_body and constrain_result parameters should
almost never need to be overridden by the user. These parameters are exposed here
to help unit tests exercise the PlanDevices pass machinery.

Parameters
----------
data : tvm.relay.Expr
body : tvm.relay.Expr
The expression to be annotated.

device : Union[:py:class:`Device`, str]
The device to annotate with.

is_fixed : bool
If false (the default), a device_copy
If true, the annotation does not imply a device_copy may be inserted to
reconcile the device of the data argument with the device for the context of the
annotated expression.
constrain_result : bool
If false (the default), the result of the on_device is not constrained to be on device.

constrain_body : bool
If true (the default), the body of the on_device is constrained to be on device.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.OnDevice(data, _make_se_scope(device), is_fixed)
return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body)


def function_on_device(function, param_devices, result_device):
Expand Down
11 changes: 5 additions & 6 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Expand All @@ -521,6 +515,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
#if TVM_LOG_DEBUG
for (const Type& type_arg : op->type_args) {
args.push_back(Print(type_arg));
}
#endif
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
Expand Down
57 changes: 40 additions & 17 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/tir/function.h>

#include <algorithm>
#include <string>

namespace tvm {
Expand All @@ -36,49 +37,71 @@ static const char* kSemVer = "0.0.5";
Doc TextPrinter::PrintMod(const IRModule& mod) {
Doc doc;
int counter = 0;

// We'll print in alphabetical order to make a/b diffs easier to work with.

// type definitions
std::vector<GlobalTypeVar> tyvars;
for (const auto& kv : mod->type_definitions) {
tyvars.emplace_back(kv.first);
}
std::sort(tyvars.begin(), tyvars.end(),
[](const GlobalTypeVar& left, const GlobalTypeVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& tyvar : tyvars) {
if (counter++ != 0) {
doc << Doc::NewLine();
}
doc << relay_text_printer_.Print(kv.second);
doc << relay_text_printer_.Print(mod->type_definitions[tyvar]);
doc << Doc::NewLine();
}

// functions
std::vector<GlobalVar> vars;
for (const auto& kv : mod->functions) {
if (kv.second.as<relay::FunctionNode>()) {
vars.emplace_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& var : vars) {
const BaseFunc& base_func = mod->functions[var];
if (base_func.as<relay::FunctionNode>()) {
relay_text_printer_.dg_ =
relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second);
relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func);
}
if (counter++ != 0) {
doc << Doc::NewLine();
}
if (kv.second.as<relay::FunctionNode>()) {
if (base_func.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
os << "def @" << var->name_hint;
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func);
} else if (base_func.as<tir::PrimFuncNode>()) {
doc << "@" << var->name_hint;
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func));
}
doc << Doc::NewLine();
}

#if TVM_LOG_DEBUG
// attributes
// TODO(mbs): Make this official, including support from parser.
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
doc << "attributes {" << Doc::NewLine();
std::vector<String> keys;
for (const auto& kv : mod->attrs->dict) {
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
keys.emplace_back(kv.first);
}
std::sort(keys.begin(), keys.end());
doc << "attributes {" << Doc::NewLine();
for (const auto& key : keys) {
doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine();
}
doc << "}" << Doc::NewLine();
}
#endif

return doc;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ class ConvertAddToSubtract : public MixedModeMutator {

// Since we are replacing the Relay function with a call to a TIR function, we must use the
// call_lowered op.
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
return CallLowered(std::move(new_global_var), call->args,
std::move(Attrs(call_lowered_attrs)), call->type_args, call->span);
CallLoweredAttrs attrs;
attrs.metadata.Set("relay_attrs", call->attrs);
ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span);
}
}

Expand All @@ -144,5 +144,4 @@ transform::Pass RelayToTIR() {
} // namespace example_target_hooks
} // namespace contrib
} // namespace relay

} // namespace tvm
Loading