Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions src/relax/transform/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
}

Array<Expr> RemapArgs(const Array<Expr>& args) {
Array<Expr> new_args;
for (const auto& arg : args) {
new_args.push_back(VarReplacer::Replace(arg, var_remap_));
}
return new_args;
return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); });
}

// Util function to rewrite the expr to the given dtype
Expand Down Expand Up @@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator {
ReEmitBinding(binding, call_node->args[0]);
return;
}
DataType to;
ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);

Call new_call = GetRef<Call>(call_node);

// We first to remap the args to the current vars according to the var_remap_
new_call->args = std::move(RemapArgs(call_node->args));
new_call.CopyOnWrite()->args = RemapArgs(new_call->args);

// Then we rewrite the args according to the policy
std::optional<DataType> opt_new_dtype = std::nullopt;

if (policy == kAlways) {
to = fp16_;
opt_new_dtype = fp16_;
auto attr_map = Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
ICHECK(attr_map.count(op));
auto f = attr_map[op];
new_call = make_object<CallNode>(*(f(Call(new_call), output_dtype_).get()));
new_call = attr_map[op](new_call, output_dtype_);
} else if (policy == kFollow) {
to = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
} else if (policy == kNever) {
to = fp32_;
// An upstream operation may have changed the datatype of the
// arguments. Because this operation must be provided with
// exactly the same dtype as it previously had, it may require a
// cast back to the original datatype.

if (!new_call->args.same_as(call_node->args)) {
Array<Expr> new_typed_args;
for (size_t i = 0; i < call_node->args.size(); i++) {
auto arg = new_call->args[i];
auto old_ntype = NTypeFrom(call_node->args[i]);
new_typed_args.push_back(RewriteExpr(arg, old_ntype));
}
new_call.CopyOnWrite()->args = new_typed_args;
}

} else {
LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
}
new_call->args = std::move(RewriteArgs(new_call->args, to));
new_call->struct_info_ = NullOpt;
Expr new_value = builder_->Normalize(Call(new_call));
if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
// kAlways: store the tensors to fp16
// But global vars will be stored to the original dtype anyway (see below)
new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
}
if (!binding->var->IsInstance<DataflowVarNode>()) {
// Global var: store the tensors to the original dtype
NType to = NTypeFrom(binding->var);
new_value = RewriteExpr(new_value, to);

Expr new_value = new_call;
if (opt_new_dtype) {
auto new_dtype = opt_new_dtype.value();
new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype);
new_call.CopyOnWrite()->struct_info_ = NullOpt;

new_value = builder_->Normalize(Call(new_call));

if (!binding->var->IsInstance<DataflowVarNode>()) {
// Non-Dataflow var: store the tensors to the original dtype
new_value = RewriteExpr(new_value, NTypeFrom(binding->var));
} else if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
// kAlways: store the tensors to fp16
// But non-dataflow vars will be stored to the original dtype anyway (see above)
new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype));
}
}

ReEmitBinding(binding, builder_->Normalize(new_value));
}

Expand Down
34 changes: 31 additions & 3 deletions tests/python/relax/test_transform_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import relax
import tvm.testing
from tvm.relax.transform import ToMixedPrecision
from tvm.script.parser import ir as I, relax as R
from tvm.script.parser import ir as I, relax as R, tir as T


def _assert_test(input, expected=None, expected2=None):
Expand Down Expand Up @@ -614,8 +614,8 @@ def main(
x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), "float32")
) -> R.Tensor(None, "float32", ndim=4):
with R.dataflow():
gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, padding=(1, 1))
gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, axis=1)
gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w, padding=(1, 1))
gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x, axis=1)
gv2 = R.add(gv, gv1)
R.output(gv2)
return gv2
Expand Down Expand Up @@ -1036,5 +1036,33 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


def test_call_tir_with_float16_args():
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([64], "float16")):
cls = Before
with R.dataflow():
B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], "float16"))
C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], "float16"))
R.output(C)
return C

@T.prim_func
def tir_identity(
Input: T.Buffer(64, "float16"),
Output: T.Buffer(64, "float16"),
):
for i in range(64):
with T.block("copy"):
vi = T.axis.remap("S", [i])
Output[vi] = Input[vi]

Expected = Before

After = ToMixedPrecision()(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()