Skip to content

Commit 6e55de0

Browse files
tqchenShiboXing
authored andcommitted
[FFI][REFACTOR] Move Downcast out of ffi for now (apache#18198)
Downcast was added for backward compact reasons and it have duplicated features as Any.cast. This PR moves it out of ffi to node for now so the ffi part contains minimal set of implementations.
1 parent 17f9eed commit 6e55de0

File tree

16 files changed

+171
-131
lines changed

16 files changed

+171
-131
lines changed

ffi/include/tvm/ffi/any.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,6 @@ struct AnyEqual {
635635
}
636636
}
637637
};
638-
639638
} // namespace ffi
640639

641640
// Expose to the tvm namespace for usability

ffi/include/tvm/ffi/cast.h

Lines changed: 4 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,18 @@
1818
*/
1919
/*!
2020
* \file tvm/ffi/cast.h
21-
* \brief Value casting support
21+
* \brief Extra value casting helpers
2222
*/
2323
#ifndef TVM_FFI_CAST_H_
2424
#define TVM_FFI_CAST_H_
2525

2626
#include <tvm/ffi/any.h>
27-
#include <tvm/ffi/dtype.h>
28-
#include <tvm/ffi/error.h>
2927
#include <tvm/ffi/object.h>
3028
#include <tvm/ffi/optional.h>
3129

32-
#include <utility>
33-
3430
namespace tvm {
3531
namespace ffi {
32+
3633
/*!
3734
* \brief Get a reference type from a raw object ptr type
3835
*
@@ -46,7 +43,7 @@ namespace ffi {
4643
* \return The corresponding RefType
4744
*/
4845
template <typename RefType, typename ObjectType>
49-
TVM_FFI_INLINE RefType GetRef(const ObjectType* ptr) {
46+
inline RefType GetRef(const ObjectType* ptr) {
5047
static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>,
5148
"Can only cast to the ref of same container type");
5249

@@ -75,92 +72,9 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr) {
7572
"Can only cast to the ref of same container type");
7673
return details::ObjectUnsafe::ObjectPtrFromUnowned<BaseType>(ptr);
7774
}
78-
79-
/*!
80-
* \brief Downcast a base reference type to a more specific type.
81-
*
82-
* \param ref The input reference
83-
* \return The corresponding SubRef.
84-
* \tparam SubRef The target specific reference type.
85-
* \tparam BaseRef the current reference type.
86-
*/
87-
template <typename SubRef, typename BaseRef,
88-
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, BaseRef>>>
89-
inline SubRef Downcast(BaseRef ref) {
90-
if (ref.defined()) {
91-
if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
92-
TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to "
93-
<< SubRef::ContainerType::_type_key << " failed.";
94-
}
95-
return SubRef(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(ref)));
96-
} else {
97-
if constexpr (is_optional_type_v<SubRef> || SubRef::_type_is_nullable) {
98-
return SubRef(ObjectPtr<Object>(nullptr));
99-
}
100-
TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `"
101-
<< SubRef::ContainerType::_type_key
102-
<< "` is not allowed. Use Downcast<Optional<T>> instead.";
103-
TVM_FFI_UNREACHABLE();
104-
}
105-
}
106-
107-
/*!
108-
* \brief Downcast any to a specific type
109-
*
110-
* \param ref The input reference
111-
* \return The corresponding SubRef.
112-
* \tparam T The target specific reference type.
113-
*/
114-
template <typename T>
115-
inline T Downcast(const Any& ref) {
116-
if constexpr (std::is_same_v<T, Any>) {
117-
return ref;
118-
} else {
119-
return ref.cast<T>();
120-
}
121-
}
122-
123-
/*!
124-
* \brief Downcast any to a specific type
125-
*
126-
* \param ref The input reference
127-
* \return The corresponding SubRef.
128-
* \tparam T The target specific reference type.
129-
*/
130-
template <typename T>
131-
inline T Downcast(Any&& ref) {
132-
if constexpr (std::is_same_v<T, Any>) {
133-
return std::move(ref);
134-
} else {
135-
return std::move(ref).cast<T>();
136-
}
137-
}
138-
139-
/*!
140-
* \brief Downcast std::optional<Any> to std::optional<T>
141-
*
142-
* \param ref The input reference
143-
* \return The corresponding SubRef.
144-
* \tparam OptionalType The target optional type
145-
*/
146-
template <typename OptionalType, typename = std::enable_if_t<is_optional_type_v<OptionalType>>>
147-
inline OptionalType Downcast(const std::optional<Any>& ref) {
148-
if (ref.has_value()) {
149-
if constexpr (std::is_same_v<OptionalType, Any>) {
150-
return *ref;
151-
} else {
152-
return (*ref).cast<OptionalType>();
153-
}
154-
} else {
155-
return OptionalType(std::nullopt);
156-
}
157-
}
158-
15975
} // namespace ffi
16076

161-
// Expose to the tvm namespace
162-
// Rationale: convinience and no ambiguity
163-
using ffi::Downcast;
77+
using ffi::GetObjectPtr;
16478
using ffi::GetRef;
16579
} // namespace tvm
16680
#endif // TVM_FFI_CAST_H_

ffi/tests/cpp/test_string.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
*/
1919
#include <gtest/gtest.h>
2020
#include <tvm/ffi/any.h>
21-
#include <tvm/ffi/cast.h>
2221
#include <tvm/ffi/string.h>
2322

2423
namespace {
@@ -266,7 +265,7 @@ TEST(String, Cast) {
266265
string source = "this is a string";
267266
String s{source};
268267
Any r = s;
269-
String s2 = Downcast<String>(r);
268+
String s2 = r.cast<String>();
270269
}
271270

272271
TEST(String, Concat) {

include/tvm/node/cast.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file tvm/node/cast.h
21+
* \brief Value casting helpers
22+
*/
23+
#ifndef TVM_NODE_CAST_H_
24+
#define TVM_NODE_CAST_H_
25+
26+
#include <tvm/ffi/any.h>
27+
#include <tvm/ffi/cast.h>
28+
#include <tvm/ffi/dtype.h>
29+
#include <tvm/ffi/error.h>
30+
#include <tvm/ffi/object.h>
31+
#include <tvm/ffi/optional.h>
32+
33+
#include <utility>
34+
35+
namespace tvm {
36+
37+
/*!
38+
* \brief Downcast a base reference type to a more specific type.
39+
*
40+
* \param ref The input reference
41+
* \return The corresponding SubRef.
42+
* \tparam SubRef The target specific reference type.
43+
* \tparam BaseRef the current reference type.
44+
*/
45+
template <typename SubRef, typename BaseRef,
46+
typename = std::enable_if_t<std::is_base_of_v<ffi::ObjectRef, BaseRef>>>
47+
inline SubRef Downcast(BaseRef ref) {
48+
if (ref.defined()) {
49+
if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
50+
TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << " to "
51+
<< SubRef::ContainerType::_type_key << " failed.";
52+
}
53+
return SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(std::move(ref)));
54+
} else {
55+
if constexpr (ffi::is_optional_type_v<SubRef> || SubRef::_type_is_nullable) {
56+
return SubRef(ffi::ObjectPtr<ffi::Object>(nullptr));
57+
}
58+
TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `"
59+
<< SubRef::ContainerType::_type_key
60+
<< "` is not allowed. Use Downcast<Optional<T>> instead.";
61+
TVM_FFI_UNREACHABLE();
62+
}
63+
}
64+
65+
/*!
66+
* \brief Downcast any to a specific type
67+
*
68+
* \param ref The input reference
69+
* \return The corresponding SubRef.
70+
* \tparam T The target specific reference type.
71+
*/
72+
template <typename T>
73+
inline T Downcast(const ffi::Any& ref) {
74+
if constexpr (std::is_same_v<T, Any>) {
75+
return ref;
76+
} else {
77+
return ref.cast<T>();
78+
}
79+
}
80+
81+
/*!
82+
* \brief Downcast any to a specific type
83+
*
84+
* \param ref The input reference
85+
* \return The corresponding SubRef.
86+
* \tparam T The target specific reference type.
87+
*/
88+
template <typename T>
89+
inline T Downcast(ffi::Any&& ref) {
90+
if constexpr (std::is_same_v<T, Any>) {
91+
return std::move(ref);
92+
} else {
93+
return std::move(ref).cast<T>();
94+
}
95+
}
96+
97+
/*!
98+
* \brief Downcast std::optional<Any> to std::optional<T>
99+
*
100+
* \param ref The input reference
101+
* \return The corresponding SubRef.
102+
* \tparam OptionalType The target optional type
103+
*/
104+
template <typename OptionalType, typename = std::enable_if_t<ffi::is_optional_type_v<OptionalType>>>
105+
inline OptionalType Downcast(const std::optional<ffi::Any>& ref) {
106+
if (ref.has_value()) {
107+
if constexpr (std::is_same_v<OptionalType, ffi::Any>) {
108+
return *ref;
109+
} else {
110+
return (*ref).cast<OptionalType>();
111+
}
112+
} else {
113+
return OptionalType(std::nullopt);
114+
}
115+
}
116+
} // namespace tvm
117+
#endif // TVM_NODE_CAST_H_

include/tvm/node/node.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#define TVM_NODE_NODE_H_
3636

3737
#include <tvm/ffi/memory.h>
38+
#include <tvm/node/cast.h>
3839
#include <tvm/node/repr_printer.h>
3940
#include <tvm/node/structural_equal.h>
4041
#include <tvm/node/structural_hash.h>
@@ -57,8 +58,6 @@ using ffi::ObjectPtrHash;
5758
using ffi::ObjectRef;
5859
using ffi::PackedArgs;
5960
using ffi::TypeIndex;
60-
using runtime::Downcast;
61-
using runtime::GetRef;
6261

6362
} // namespace tvm
6463
#endif // TVM_NODE_NODE_H_

include/tvm/runtime/disco/session.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ inline std::string DiscoAction2String(DiscoAction action) {
124124
LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast<int>(action);
125125
}
126126

127+
class SessionObj;
128+
127129
/*!
128130
* \brief An object that exists on all workers.
129131
*
@@ -156,6 +158,9 @@ class DRefObj : public Object {
156158
int64_t reg_id;
157159
/*! \brief Back-pointer to the host controler session */
158160
ObjectRef session{nullptr};
161+
162+
private:
163+
inline SessionObj* GetSession();
159164
};
160165

161166
/*!
@@ -321,18 +326,22 @@ class WorkerZeroData {
321326

322327
// Implementation details
323328

329+
inline SessionObj* DRefObj::GetSession() {
330+
return const_cast<SessionObj*>(static_cast<const SessionObj*>(session.get()));
331+
}
332+
324333
DRefObj::~DRefObj() {
325334
if (this->session.defined()) {
326-
Downcast<Session>(this->session)->DeallocReg(reg_id);
335+
GetSession()->DeallocReg(reg_id);
327336
}
328337
}
329338

330339
ffi::Any DRefObj::DebugGetFromRemote(int worker_id) {
331-
return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id, worker_id);
340+
return GetSession()->DebugGetFromRemote(this->reg_id, worker_id);
332341
}
333342

334343
void DRefObj::DebugCopyFrom(int worker_id, ffi::AnyView value) {
335-
return Downcast<Session>(this->session)->DebugSetRegister(this->reg_id, value, worker_id);
344+
return GetSession()->DebugSetRegister(this->reg_id, value, worker_id);
336345
}
337346

338347
template <typename... Args>

include/tvm/runtime/object.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ using tvm::ffi::ObjectPtrEqual;
3939
using tvm::ffi::ObjectPtrHash;
4040
using tvm::ffi::ObjectRef;
4141

42-
using tvm::ffi::Downcast;
4342
using tvm::ffi::GetObjectPtr;
4443
using tvm::ffi::GetRef;
4544

include/tvm/runtime/vm/vm.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,12 @@ class VirtualMachine : public runtime::ModuleNode {
189189
using ContainerType = typename T::ContainerType;
190190
uint32_t key = ContainerType::RuntimeTypeIndex();
191191
if (auto it = extensions.find(key); it != extensions.end()) {
192-
return Downcast<T>((*it).second);
192+
ffi::Any value = (*it).second;
193+
return value.cast<T>();
193194
}
194195
auto [it, _] = extensions.emplace(key, T::Create());
195-
return Downcast<T>((*it).second);
196+
ffi::Any value = (*it).second;
197+
return value.cast<T>();
196198
}
197199

198200
/*!
@@ -224,7 +226,7 @@ class VirtualMachine : public runtime::ModuleNode {
224226
std::vector<Device> devices;
225227
/*! \brief The VM extensions. Mapping from the type index of the extension to the extension
226228
* instance. */
227-
std::unordered_map<uint32_t, VMExtension> extensions;
229+
std::unordered_map<uint32_t, Any> extensions;
228230
};
229231

230232
} // namespace vm

src/node/container_printing.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
* \file node/container_printint.cc
2323
*/
2424
#include <tvm/ffi/function.h>
25+
#include <tvm/node/cast.h>
2526
#include <tvm/node/functor.h>
2627
#include <tvm/node/repr_printer.h>
2728

@@ -62,6 +63,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
6263

6364
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
6465
.set_dispatch<ffi::ShapeObj>([](const ObjectRef& node, ReprPrinter* p) {
65-
p->stream << ffi::Downcast<ffi::Shape>(node);
66+
p->stream << Downcast<ffi::Shape>(node);
6667
});
6768
} // namespace tvm

src/node/repr_printer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
*/
2424
#include <tvm/ffi/function.h>
2525
#include <tvm/ffi/reflection/registry.h>
26+
#include <tvm/node/cast.h>
2627
#include <tvm/node/repr_printer.h>
2728
#include <tvm/runtime/device_api.h>
2829

0 commit comments

Comments
 (0)