Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ extern "C" float TVMTestAddOne(float y) { return y + 1; }
// This way can be helpful when we want to use a header only
// minimum version of TVM Runtime.
extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) {
const PackedFunc& fregister = *static_cast<PackedFunc*>(pregister);
const PackedFunc& fregister = GetRef<PackedFunc>(static_cast<PackedFuncObj*>(pregister));
auto mul = [](TVMArgs args, TVMRetValue* rv) {
int x = args[0];
int y = args[1];
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ struct TypeIndex {
kRuntimeMap = 5,
/*! \brief runtime::ShapeTuple. */
kRuntimeShapeTuple = 6,
/*! \brief runtime::PackedFunc. */
kRuntimePackedFunc = 7,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
Expand Down
203 changes: 138 additions & 65 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,72 @@ class TVMMovableArgValueWithContext_;
class TVMRetValue;
class TVMArgsSetter;

/*!
* \brief Object container class that backs PackedFunc.
* \note Do not use this function directly, use PackedFunc.
*/
class PackedFuncObj : public Object {
public:
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param rv The return value.
*/
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;

static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc;
static constexpr const char* _type_key = "runtime.PackedFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object);

protected:
/*!
* \brief Internal struct for extracting the callable method from callable type.
*/
template <class TPackedFuncSubObj>
struct Extractor {
/*!
* \brief Extracting the callable method from callable type.
* \param obj The base packed function object class.
* \param args The arguments
* \param rv The return value.
*/
static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv);
};

/*! \brief The internal callable function type. */
using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*);

/*!
* \brief Constructing a packed function object from a function pointer.
* \param f_call_pack The function pointer used to call the packed function.
*/
explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {}

/*! \brief Delete the default constructor explicitly. */
PackedFuncObj() = delete;

/*! \brief Internal callable function pointer used to call the packed function. */
FCallPacked* f_call_packed_;
};

/*! \brief Derived object class for constructing PackedFuncObj. */
template <class TCallable>
class PackedFuncSubObj : public PackedFuncObj {
using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;

public:
/*! \brief The type of derived object class */
using TSelf = PackedFuncSubObj<TCallable>;
/*!
* \brief Derived object class for constructing PackedFuncObj.
* \param callable The type-erased callable object.
*/
explicit PackedFuncSubObj(TCallable callable)
: PackedFuncObj(Extractor<TSelf>::Call), callable_(callable) {}
/*! \brief Type-erased filed for storing callable object*/
mutable TStorage callable_;
};

/*!
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
Expand All @@ -65,36 +131,23 @@ class TVMArgsSetter;
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
class PackedFunc : public ObjectRef {
public:
/*! \brief Constructor from null */
PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*)
/*!
* \brief The internal std::function
* \param args The arguments to the function.
* \param rv The return value.
*
* \code
* // Example code on how to implemented FType
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
* // automatically convert arguments to desired type.
* int a0 = args[0];
* float a1 = args[1];
* ...
* // automatically assign values to rv
* std::string my_return_value = "x";
* *rv = my_return_value;
* }
* \endcode
*/
using FType = std::function<void(TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*! \brief constructor from null */
PackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `PackedFunc`
* \param data the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
template <typename TCallable,
typename = std::enable_if_t<
std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
!std::is_base_of<TCallable, PackedFunc>::value>>
explicit PackedFunc(TCallable data) {
using ObjType = PackedFuncSubObj<TCallable>;
data_ = make_object<ObjType>(std::forward<TCallable>(data));
}
/*!
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
Expand All @@ -116,17 +169,13 @@ class PackedFunc {
* \param args The arguments
* \param rv The return value.
*/
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return the internal body function */
inline FType body() const;
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }

private:
/*! \brief internal container of packed function */
FType body_;
TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj);
};

/*!
Expand Down Expand Up @@ -540,6 +589,13 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) {
return PackedFunc(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return PackedFunc(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator Device() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLDevice);
return value_.v_device;
Expand Down Expand Up @@ -601,6 +657,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Device;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::IsObjectRef;

Expand All @@ -620,11 +677,6 @@ class TVMArgValue : public TVMPODValue_ {
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return *ptr<PackedFunc>();
}
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -661,9 +713,9 @@ class TVMMovableArgValue_ : public TVMPODValue_ {
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Device;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
// reuse conversion rule from ArgValue.
operator std::string() const { return AsArgValue().operator std::string(); }
operator PackedFunc() const { return AsArgValue().operator PackedFunc(); }
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -756,6 +808,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator Device;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::IsObjectRef;

Expand All @@ -778,11 +831,6 @@ class TVMRetValue : public TVMPODValue_ {
return value_.v_type;
}
operator DataType() const { return DataType(operator DLDataType()); }
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return *ptr<PackedFunc>();
}
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -852,6 +900,7 @@ class TVMRetValue : public TVMPODValue_ {
ObjectRef::FFIClearAfterMove(&other);
} else {
SwitchToPOD(kTVMNullptr);
value_.v_handle = nullptr;
}
return *this;
}
Expand All @@ -860,11 +909,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
if (f == nullptr) {
this->SwitchToPOD(kTVMNullptr);
} else {
this->SwitchToClass(kTVMPackedFuncHandle, f);
}
this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_));
return *this;
}
template <typename FType>
Expand Down Expand Up @@ -941,7 +986,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kTVMPackedFuncHandle: {
SwitchToClass<PackedFunc>(kTVMPackedFuncHandle, other);
*this = other.operator PackedFunc();
break;
}
case kTVMModuleHandle: {
Expand Down Expand Up @@ -995,6 +1040,7 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr;
} else {
SwitchToPOD(kTVMNullptr);
value_.v_handle = nullptr;
}
}
void Clear() {
Expand All @@ -1005,7 +1051,7 @@ class TVMRetValue : public TVMPODValue_ {
delete ptr<std::string>();
break;
case kTVMPackedFuncHandle:
delete ptr<PackedFunc>();
static_cast<Object*>(value_.v_handle)->DecRef();
break;
case kTVMNDArrayHandle: {
NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
Expand Down Expand Up @@ -1148,9 +1194,19 @@ inline TVMArgValue TVMArgs::operator[](int i) const {

inline int TVMArgs::size() const { return num_args; }

inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); }
template <class TPackedFuncSubObj>
void PackedFuncObj::Extractor<TPackedFuncSubObj>::Call(const PackedFuncObj* obj, TVMArgs args,
TVMRetValue* rv) {
(static_cast<const TPackedFuncSubObj*>(obj))->callable_(args, rv);
}

inline PackedFunc::FType PackedFunc::body() const { return body_; }
TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const {
(*f_call_packed_)(this, args, rv);
}

TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
(static_cast<PackedFuncObj*>(data_.get()))->CallPacked(args, rv);
}

// internal namespace
inline const char* ArgTypeCode2Str(int type_code) {
Expand Down Expand Up @@ -1312,15 +1368,6 @@ class TVMArgsSetter {
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kTVMBytes;
}
TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const {
if (value != nullptr) {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kTVMPackedFuncHandle;
} else {
values_[i].v_handle = nullptr;
type_codes_[i] = kTVMNullptr;
}
}
template <typename FType>
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
operator()(i, value.packed());
Expand Down Expand Up @@ -1366,7 +1413,8 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
(static_cast<PackedFuncObj*>(data_.get()))
->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}

Expand Down Expand Up @@ -1518,6 +1566,11 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
} else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
ptr->IsInstance<PackedFunc::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMPackedFuncHandle;
} else if (std::is_rvalue_reference<decltype(value)>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg;
Expand All @@ -1527,6 +1580,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
}
} else {
type_codes_[i] = kTVMNullptr;
values_[i].v_handle = nullptr;
}
}

Expand All @@ -1543,6 +1597,10 @@ inline bool TVMPODValue_::IsObjectRef() const {
return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
return type_code_ == kTVMPackedFuncHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
// NOTE: we don't pass NDArray and runtime::Module as RValue ref.
if (type_code_ == kTVMObjectRValueRefArg) {
return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle));
Expand All @@ -1551,6 +1609,8 @@ inline bool TVMPODValue_::IsObjectRef() const {
type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) ||
(std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
type_code_ == kTVMPackedFuncHandle) ||
(type_code_ == kTVMObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
Expand Down Expand Up @@ -1584,6 +1644,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
// Casting to a sub-class of PackedFunc
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
return TObjectRef(data);
}
if (type_code_ == kTVMObjectHandle) {
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
Expand All @@ -1607,6 +1675,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
type_code_ == kTVMModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
type_code_ == kTVMPackedFuncHandle) {
// Casting to a base class that PackedFunc can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else {
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
Expand All @@ -1631,6 +1703,7 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
SwitchToObject(kTVMObjectHandle, std::move(other.data_));
} else {
SwitchToPOD(kTVMNullptr);
value_.v_handle = nullptr;
}
return *this;
}
Expand Down
Loading