Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,17 @@ def _mx_arange(inputs, attrs):
if attrs.get_int("repeat", 1) != 1:
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
dtype = attrs.get_str("dtype", "float32")
stop = attrs.get_str("stop", "None")
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
if stop == "None":
stop = None
else:
stop = _expr.const(float(stop), dtype=dtype)
new_attrs = {}
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0), dtype=dtype)
new_attrs["stop"] = stop
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0), dtype=dtype)
new_attrs["dtype"] = dtype
return _op.arange(**new_attrs)


Expand Down Expand Up @@ -863,7 +868,8 @@ def _mx_contrib_div_sqrt_dim(inputs, _):
assert len(inputs) == 1
ndim = len(_infer_type(inputs[0]).checked_type.shape)
dim = _op.take(_op.shape_of(inputs[0]), _expr.const(ndim-1, dtype="int32"))
sqrt_dim = _op.sqrt(dim.astype('float32'))
dtype = _infer_type(inputs[0]).checked_type.dtype
sqrt_dim = _op.sqrt(dim.astype(dtype))
out = inputs[0] / sqrt_dim
return out

Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .. import expr as _expr
from .. import op as _op
from .common import get_relay_op
from .common import infer_type as _infer_type

def _warn_not_used(attr, op='nnvm'):
import warnings
Expand Down Expand Up @@ -123,20 +124,22 @@ def _elemwise_sum(inputs, _, _dtype='float32'):


def _binop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'):
def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype)
return new_op(inputs[0], scalar)
return _impl


def _rbinop_scalar(new_op):
def _impl(inputs, attrs, odtype='float32'):
def _impl(inputs, attrs, odtype=None):
assert len(inputs) == 1
scalar = attrs.get_float("scalar")
# Note: binary scalar only works for float op for now
if odtype is None:
odtype = _infer_type(inputs[0]).checked_type.dtype
scalar = _expr.const(scalar, dtype=odtype)
return new_op(scalar, inputs[0])
return _impl
Expand Down
39 changes: 35 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/expr_operator.h>
#include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <tvm/runtime/packed_func.h>
#include <topi/transform.h>
#include <topi/elemwise.h>
#include <topi/broadcast.h>
Expand Down Expand Up @@ -1127,11 +1128,41 @@ and type as the input array.
TVM_REGISTER_NODE_TYPE(ArangeAttrs);

double ToScalar(const runtime::NDArray& array) {
if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else {
return reinterpret_cast<float*>(array->data)[0];
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<int8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<int16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<int32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<int64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 8) {
return reinterpret_cast<uint8_t*>(array->data)[0];
} else if (array->dtype.bits == 16) {
return reinterpret_cast<uint16_t*>(array->data)[0];
} else if (array->dtype.bits == 32) {
return reinterpret_cast<uint32_t*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<uint64_t*>(array->data)[0];
}
} else if (array->dtype.code == kDLFloat) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (array->dtype.bits == 16) {
return reinterpret_cast<__fp16*>(array->data)[0];
}
#endif
if (array->dtype.bits == 32) {
return reinterpret_cast<float*>(array->data)[0];
} else if (array->dtype.bits == 64) {
return reinterpret_cast<double*>(array->data)[0];
}
}
LOG(FATAL) << "Unknown data type: " << tvm::runtime::TVMType2String(array->dtype);
// make compiler happy
return -std::numeric_limits<double>::infinity();
}

bool ArangeRel(const Array<Type>& types,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
const auto param = attrs.as<LayerNormAttrs>();
CHECK(param);

Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {param->axis}, true, false);
Expr var = Variance(data, mean, {param->axis}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expand Down