Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 99 additions & 23 deletions include/tvm/relax/nested_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>

#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -115,20 +116,14 @@ namespace relax {
* use this class or logic of a similar kind.
*/
template <typename T>
class NestedMsg : public ObjectRef {
class NestedMsg {
public:
// default constructors.
NestedMsg() = default;
NestedMsg(const NestedMsg<T>&) = default;
NestedMsg(NestedMsg<T>&&) = default;
NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
NestedMsg<T>& operator=(NestedMsg<T>&&) = default;
/*!
* \brief Construct from an ObjectPtr
* whose type already satisfies the constraint
* \param ptr
*/
explicit NestedMsg(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief Nullopt handling */
NestedMsg(std::nullopt_t) {} // NOLINT(*)
// nullptr handling.
Expand All @@ -140,16 +135,17 @@ class NestedMsg : public ObjectRef {
}
// normal value handling.
NestedMsg(T other) // NOLINT(*)
: ObjectRef(std::move(other)) {}
: data_(std::move(other)) {}
NestedMsg<T>& operator=(T other) {
ObjectRef::operator=(std::move(other));
data_ = std::move(other);
return *this;
}
// Array<NestedMsg<T>> handling
NestedMsg(Array<NestedMsg<T>, void> other) // NOLINT(*)
: ObjectRef(std::move(other)) {}
: data_(other) {}

NestedMsg<T>& operator=(Array<NestedMsg<T>, void> other) {
ObjectRef::operator=(std::move(other));
data_ = std::move(other);
return *this;
}

Expand All @@ -170,38 +166,40 @@ class NestedMsg : public ObjectRef {
bool operator!=(std::nullptr_t) const { return data_ != nullptr; }

/*! \return Whether the nested message is not-null leaf value */
bool IsLeaf() const { return data_ != nullptr && data_->IsInstance<LeafContainerType>(); }
bool IsLeaf() const {
return data_.type_index() != ffi::TypeIndex::kTVMFFINone &&
data_.type_index() != ffi::TypeIndex::kTVMFFIArray;
}

/*! \return Whether the nested message is null */
bool IsNull() const { return data_ == nullptr; }
bool IsNull() const { return data_.type_index() == ffi::TypeIndex::kTVMFFINone; }

/*! \return Whether the nested message is nested */
bool IsNested() const { return data_ != nullptr && data_->IsInstance<ffi::ArrayObj>(); }
bool IsNested() const { return data_.type_index() == ffi::TypeIndex::kTVMFFIArray; }

/*!
* \return The underlying leaf value.
* \note This function checks if the msg is leaf.
*/
T LeafValue() const {
ICHECK(IsLeaf());
return T(data_);
return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
}

/*!
* \return a corresponding nested array.
* \note This checks if the underlying data type is array.
*/
Array<NestedMsg<T>, void> NestedArray() const {
ICHECK(IsNested());
return Array<NestedMsg<T>, void>(data_);
return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<NestedMsg<T>, void>>(data_);
}

using ContainerType = Object;
using LeafContainerType = typename T::ContainerType;

static_assert(std::is_base_of<ObjectRef, T>::value, "NestedMsg is only defined for ObjectRef.");

static constexpr bool _type_is_nullable = true;
private:
ffi::Any data_;
// private constructor
explicit NestedMsg(ffi::Any data) : data_(data) {}
template <typename, typename>
friend struct ffi::TypeTraits;
};

/*!
Expand Down Expand Up @@ -598,5 +596,83 @@ StructInfo TransformTupleLeaf(StructInfo sinfo, std::array<NestedMsg<T>, N> msgs
}

} // namespace relax

namespace ffi {

template <typename T>
inline constexpr bool use_default_type_traits_v<relax::NestedMsg<T>> = false;

template <typename T>
struct TypeTraits<relax::NestedMsg<T>> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const relax::NestedMsg<T>& src, TVMFFIAny* result) {
*result = ffi::AnyView(src.data_).CopyToTVMFFIAny();
}

TVM_FFI_INLINE static void MoveToAny(relax::NestedMsg<T> src, TVMFFIAny* result) {
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_));
}

TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
return TypeTraitsBase::GetMismatchTypeInfo(src);
}

static bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFINone) {
return true;
}
if (TypeTraits<T>::CheckAnyStrict(src)) {
return true;
}
if (src->type_index == TypeIndex::kTVMFFIArray) {
const ffi::ArrayObj* array = reinterpret_cast<const ffi::ArrayObj*>(src->v_obj);
for (size_t i = 0; i < array->size(); ++i) {
const Any& any_v = (*array)[i];
if (!details::AnyUnsafe::CheckAnyStrict<relax::NestedMsg<T>>(any_v)) return false;
}
}
return true;
}

TVM_FFI_INLINE static relax::NestedMsg<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return relax::NestedMsg<T>(Any(AnyView::CopyFromTVMFFIAny(*src)));
}

TVM_FFI_INLINE static relax::NestedMsg<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
return relax::NestedMsg<T>(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src)));
}

static std::optional<relax::NestedMsg<T>> TryCastFromAnyView(const TVMFFIAny* src) {
if (CheckAnyStrict(src)) {
return CopyFromAnyViewAfterCheck(src);
}
// slow path run conversion
if (src->type_index == TypeIndex::kTVMFFINone) {
return relax::NestedMsg<T>(std::nullopt);
}
if (auto opt_value = TypeTraits<T>::TryCastFromAnyView(src)) {
return relax::NestedMsg<T>(*std::move(opt_value));
}
if (src->type_index == TypeIndex::kTVMFFIArray) {
const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
Array<relax::NestedMsg<T>> result;
result.reserve(n->size());
for (size_t i = 0; i < n->size(); i++) {
const Any& any_v = (*n)[i];
if (auto opt_v = any_v.try_cast<relax::NestedMsg<T>>()) {
result.push_back(*std::move(opt_v));
} else {
return std::nullopt;
}
}
return relax::NestedMsg<T>(result);
}
return std::nullopt;
}

TVM_FFI_INLINE static std::string TypeStr() {
return "NestedMsg<" + details::Type2Str<T>::v() + ">";
}
};
} // namespace ffi
} // namespace tvm
#endif // TVM_RELAX_NESTED_MSG_H_
23 changes: 17 additions & 6 deletions tests/cpp/nested_msg_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST(NestedMsg, Basic) {
EXPECT_ANY_THROW(msg.LeafValue());

auto arr = msg.NestedArray();
EXPECT_TRUE(arr[0].same_as(x));
EXPECT_TRUE(arr[0].LeafValue().same_as(x));
EXPECT_TRUE(arr[1] == nullptr);
EXPECT_TRUE(arr[1].IsNull());

Expand All @@ -72,13 +72,24 @@ TEST(NestedMsg, Basic) {
EXPECT_TRUE(a0.IsNested());
auto t0 = a0.NestedArray()[1];
EXPECT_TRUE(t0.IsNested());
EXPECT_TRUE(t0.NestedArray()[2].same_as(y));
EXPECT_TRUE(t0.NestedArray()[2].LeafValue().same_as(y));

// assign leaf
a0 = x;

EXPECT_TRUE(a0.IsLeaf());
EXPECT_TRUE(a0.same_as(x));
EXPECT_TRUE(a0.LeafValue().same_as(x));
}

TEST(NestedMsg, IntAndAny) {
NestedMsg<int64_t> msg({1, std::nullopt, 2});
Any any_msg = msg;
NestedMsg<int64_t> msg2 = any_msg.cast<NestedMsg<int64_t>>();

EXPECT_TRUE(msg2.IsNested());
EXPECT_EQ(msg2.NestedArray()[0].LeafValue(), 1);
EXPECT_TRUE(msg2.NestedArray()[1].IsNull());
EXPECT_EQ(msg2.NestedArray()[2].LeafValue(), 2);
}

TEST(NestedMsg, ForEachLeaf) {
Expand Down Expand Up @@ -174,13 +185,13 @@ TEST(NestedMsg, MapAndDecompose) {

DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg<Integer> msg) {
if (value.same_as(x)) {
EXPECT_TRUE(msg.same_as(c0));
EXPECT_TRUE(msg.LeafValue().same_as(c0));
++x_count;
} else if (value.same_as(y)) {
EXPECT_TRUE(msg.same_as(c1));
EXPECT_TRUE(msg.LeafValue().same_as(c1));
++y_count;
} else {
EXPECT_TRUE(msg.same_as(c2));
EXPECT_TRUE(msg.LeafValue().same_as(c2));
++z_count;
}
});
Expand Down
Loading