Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 46 additions & 40 deletions include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,54 +456,60 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
}

/*!
* \brief Fast_tanh_float implementation from Eigen
* \brief Fast_erf_float expression from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
* \param arg The input expression.
* \param bits The number of bits in the type.
*/
inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
auto plus_4 = make_const(DataType::Float(32), 4.f);
auto minus_4 = make_const(DataType::Float(32), -4.f);
inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
auto plus_4 = make_const(DataType::Float(bits), 4.f);
auto minus_4 = make_const(DataType::Float(bits), -4.f);

// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(DataType::Float(32), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(32), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(32), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(32), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(32), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(32), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(32), -2.72614225801306e-10f);
auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f);

// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(DataType::Float(32), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(32), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(32), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f);
auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f);

// clamp x
auto x = tvm::max(tvm::min(arg, plus_4), minus_4);
auto x2 = x * x;

// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;

// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;

return p / q;
}

/*!
* \brief Fast_erf_float expression from Eigen
*/
inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
return compute(
data->shape,
[&](const Array<Var>& i) {
// clamp x
auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
auto x2 = x * x;

// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;

// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;

return p / q;
},
name, tag);
data->shape, [&](const Array<Var>& i) { return fast_erf_float_expr(data(i), 32); }, name,
tag);
}

/*!
Expand Down
20 changes: 20 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
}
}

void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << "INFINITY";
} else if (std::isnan(op->value)) {
temp << "NAN";
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32)
temp << 'f';
else if (op->dtype.bits() == 16)
temp << 'h';
}
MarkConst(temp.str());
os << temp.str();
}

runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class CodeGenMetal final : public CodeGenC {
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
// reuse parent's function.
using CodeGenC::PrintType;

Expand Down
19 changes: 19 additions & 0 deletions src/target/source/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Metal intrinsic rules.
*/
#include <tvm/tir/op_attr_types.h>
#include <tvm/topi/elemwise.h>

#include "../intrin_rule.h"

Expand Down Expand Up @@ -90,6 +91,24 @@ TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

// There is no erf function in Metal. When erf is used, we use fast_erf instead
static PrimExpr DispatchFastErf(const PrimExpr& e) {
LOG(WARNING) << " Metal doesn't have built-in erf function. fast_erf will be used instead.";
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
int bits = arg.dtype().bits();
bool isFloat = arg.dtype().is_float();
PrimExpr res;
if (isFloat && (bits == 16 || bits == 32))
res = topi::fast_erf_float_expr(arg, bits);
else
LOG(FATAL) << "Unsupported type in Metal fast_erf";
return res;
}
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);

} // namespace intrin
} // namespace codegen
} // namespace tvm
81 changes: 81 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
from tvm import topi
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_inf_nan():
target = "metal"

def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)

dev = tvm.device(target, 0)

check_inf_nan(dev, 1, -float("inf"), "float32")
check_inf_nan(dev, 1, -float("inf"), "float16")
check_inf_nan(dev, 1, float("inf"), "float32")
check_inf_nan(dev, 1, float("inf"), "float16")
check_inf_nan(dev, 1, float("nan"), "float32")
check_inf_nan(dev, 1, float("nan"), "float16")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_erf():
target = "metal"

def check_erf(dev, n, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)

dev = tvm.device(target, 0)

check_erf(dev, 1, "float32")
check_erf(dev, 1, "float16")


if __name__ == "__main__":
test_metal_inf_nan()
test_metal_erf()