|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +/*! |
| 21 | + * \file tvm/relay/executor.h |
| 22 | + * \brief Object representation of Executor configuration and registry |
| 23 | + */ |
| 24 | +#ifndef TVM_RELAY_EXECUTOR_H_ |
| 25 | +#define TVM_RELAY_EXECUTOR_H_ |
| 26 | + |
| 27 | +#include <dmlc/registry.h> |
| 28 | +#include <tvm/ir/attrs.h> |
| 29 | +#include <tvm/ir/expr.h> |
| 30 | +#include <tvm/ir/type.h> |
| 31 | +#include <tvm/ir/type_relation.h> |
| 32 | +#include <tvm/node/attr_registry_map.h> |
| 33 | +#include <tvm/runtime/registry.h> |
| 34 | + |
| 35 | +#include <string> |
| 36 | +#include <unordered_map> |
| 37 | +#include <utility> |
| 38 | +#include <vector> |
| 39 | + |
| 40 | +namespace tvm { |
| 41 | + |
| 42 | +template <typename, typename> |
| 43 | +class AttrRegistry; |
| 44 | + |
| 45 | +namespace relay { |
| 46 | + |
| 47 | +/*! |
| 48 | + * \brief Executor information. |
| 49 | + * |
| 50 | + * This data structure stores the meta-data |
| 51 | + * about executors which can be used to pass around information. |
| 52 | + * |
| 53 | + * \sa Executor |
| 54 | + */ |
| 55 | +class ExecutorNode : public Object { |
| 56 | + public: |
| 57 | + /*! \brief name of the Executor */ |
| 58 | + String name; |
| 59 | + /* \brief Additional attributes storing meta-data about the Executor. */ |
| 60 | + DictAttrs attrs; |
| 61 | + |
| 62 | + /*! |
| 63 | + * \brief Get an attribute. |
| 64 | + * |
| 65 | + * \param attr_key The attribute key. |
| 66 | + * \param default_value The default value if the key does not exist, defaults to nullptr. |
| 67 | + * |
| 68 | + * \return The result |
| 69 | + * |
| 70 | + * \tparam TObjectRef the expected object type. |
| 71 | + * \throw Error if the key exists but the value does not match TObjectRef |
| 72 | + * |
| 73 | + * \code |
| 74 | + * |
| 75 | + * void GetAttrExample(const Executor& executor) { |
| 76 | + * auto value = executor->GetAttr<Integer>("AttrKey", 0); |
| 77 | + * } |
| 78 | + * |
| 79 | + * \endcode |
| 80 | + */ |
| 81 | + template <typename TObjectRef> |
| 82 | + Optional<TObjectRef> GetAttr( |
| 83 | + const std::string& attr_key, |
| 84 | + Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { |
| 85 | + return attrs.GetAttr(attr_key, default_value); |
| 86 | + } |
| 87 | + // variant that uses TObjectRef to enable implicit conversion to default value. |
| 88 | + template <typename TObjectRef> |
| 89 | + Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { |
| 90 | + return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); |
| 91 | + } |
| 92 | + |
| 93 | + void VisitAttrs(AttrVisitor* v) { |
| 94 | + v->Visit("name", &name); |
| 95 | + v->Visit("attrs", &attrs); |
| 96 | + } |
| 97 | + |
| 98 | + bool SEqualReduce(const ExecutorNode* other, SEqualReducer equal) const { |
| 99 | + return name == other->name && equal.DefEqual(attrs, other->attrs); |
| 100 | + } |
| 101 | + |
| 102 | + void SHashReduce(SHashReducer hash_reduce) const { |
| 103 | + hash_reduce(name); |
| 104 | + hash_reduce(attrs); |
| 105 | + } |
| 106 | + |
| 107 | + static constexpr const char* _type_key = "Executor"; |
| 108 | + TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); |
| 109 | +}; |
| 110 | + |
| 111 | +/*! |
| 112 | + * \brief Managed reference class to ExecutorNode. |
| 113 | + * \sa ExecutorNode |
| 114 | + */ |
| 115 | +class Executor : public ObjectRef { |
| 116 | + public: |
| 117 | + /*! |
| 118 | + * \brief Create a new Executor object using the registry |
| 119 | + * \throws Error if name is not registered |
| 120 | + * \param name The name of the executor. |
| 121 | + * \param attrs Attributes for the executor. |
| 122 | + * \return the new Executor object. |
| 123 | + */ |
| 124 | + TVM_DLL static Executor Create(String name, Map<String, ObjectRef> attrs); |
| 125 | + |
| 126 | + /*! |
| 127 | + * \brief List all registered Executors |
| 128 | + * \return the list of Executors |
| 129 | + */ |
| 130 | + TVM_DLL static Array<String> ListExecutors(); |
| 131 | + |
| 132 | + /*! |
| 133 | + * \brief List all options for a specific Executor |
| 134 | + * \param name The name of the Executor |
| 135 | + * \return Map of option name to type |
| 136 | + */ |
| 137 | + TVM_DLL static Map<String, String> ListExecutorOptions(const String& name); |
| 138 | + |
| 139 | + /*! \brief specify container node */ |
| 140 | + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); |
| 141 | + |
| 142 | + private: |
| 143 | + /*! |
| 144 | + * \brief Private Constructor |
| 145 | + * \param name The executor name |
| 146 | + * \param attrs Attributes to apply to this Executor node |
| 147 | + */ |
| 148 | + TVM_DLL Executor(String name, DictAttrs attrs) { |
| 149 | + auto n = make_object<ExecutorNode>(); |
| 150 | + n->name = std::move(name); |
| 151 | + n->attrs = std::move(attrs); |
| 152 | + data_ = std::move(n); |
| 153 | + } |
| 154 | +}; |
| 155 | + |
| 156 | +/*! |
| 157 | + * \brief Helper structure to register Executors |
| 158 | + * \sa TVM_REGISTER_EXECUTOR |
| 159 | + */ |
| 160 | +class ExecutorRegEntry { |
| 161 | + public: |
| 162 | + /*! \brief Set name of the Executor to be the same as registry if it is empty */ |
| 163 | + inline ExecutorRegEntry& set_name(); |
| 164 | + |
| 165 | + /*! |
| 166 | + * \brief Register a valid configuration option and its ValueType for validation |
| 167 | + * \param key The configuration key |
| 168 | + * \tparam ValueType The value type to be registered |
| 169 | + */ |
| 170 | + template <typename ValueType> |
| 171 | + inline ExecutorRegEntry& add_attr_option(const String& key); |
| 172 | + |
| 173 | + /*! |
| 174 | + * \brief Register a valid configuration option and its ValueType for validation |
| 175 | + * \param key The configuration key |
| 176 | + * \param default_value The default value of the key |
| 177 | + * \tparam ValueType The value type to be registered |
| 178 | + */ |
| 179 | + template <typename ValueType> |
| 180 | + inline ExecutorRegEntry& add_attr_option(const String& key, ObjectRef default_value); |
| 181 | + |
| 182 | + /*! |
| 183 | + * \brief Register or get a new entry. |
| 184 | + * \param name The name of the operator. |
| 185 | + * \return the corresponding entry. |
| 186 | + */ |
| 187 | + TVM_DLL static ExecutorRegEntry& RegisterOrGet(const String& name); |
| 188 | + |
| 189 | + private: |
| 190 | + /*! \brief Internal storage of value types */ |
| 191 | + struct ValueTypeInfo { |
| 192 | + std::string type_key; |
| 193 | + uint32_t type_index; |
| 194 | + }; |
| 195 | + std::unordered_map<std::string, ValueTypeInfo> key2vtype_; |
| 196 | + /*! \brief A hash table that stores the default value of each attr */ |
| 197 | + std::unordered_map<String, ObjectRef> key2default_; |
| 198 | + |
| 199 | + /*! \brief Index used for internal lookup of attribute registry */ |
| 200 | + uint32_t index_; |
| 201 | + |
| 202 | + // the name |
| 203 | + std::string name; |
| 204 | + |
| 205 | + /*! \brief Return the index stored in attr registry */ |
| 206 | + uint32_t AttrRegistryIndex() const { return index_; } |
| 207 | + /*! \brief Return the name stored in attr registry */ |
| 208 | + String AttrRegistryName() const { return name; } |
| 209 | + |
| 210 | + /*! \brief private constructor */ |
| 211 | + explicit ExecutorRegEntry(uint32_t reg_index) : index_(reg_index) {} |
| 212 | + |
| 213 | + // friend class |
| 214 | + template <typename> |
| 215 | + friend class AttrRegistryMapContainerMap; |
| 216 | + template <typename, typename> |
| 217 | + friend class tvm::AttrRegistry; |
| 218 | + friend class Executor; |
| 219 | +}; |
| 220 | + |
| 221 | +inline ExecutorRegEntry& ExecutorRegEntry::set_name() { |
| 222 | + if (name.empty()) { |
| 223 | + name = name; |
| 224 | + } |
| 225 | + return *this; |
| 226 | +} |
| 227 | + |
| 228 | +template <typename ValueType> |
| 229 | +inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key) { |
| 230 | + ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key |
| 231 | + << "' has been set once"; |
| 232 | + |
| 233 | + using ValueNodeType = typename ValueType::ContainerType; |
| 234 | + // NOTE: we could further update the function later. |
| 235 | + uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); |
| 236 | + |
| 237 | + ValueTypeInfo info; |
| 238 | + info.type_index = value_type_index; |
| 239 | + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); |
| 240 | + key2vtype_[key] = info; |
| 241 | + return *this; |
| 242 | +} |
| 243 | + |
| 244 | +template <typename ValueType> |
| 245 | +inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key, |
| 246 | + ObjectRef default_value) { |
| 247 | + add_attr_option<ValueType>(key); |
| 248 | + key2default_[key] = default_value; |
| 249 | + return *this; |
| 250 | +} |
| 251 | + |
| 252 | +// internal macros to make executor entries |
| 253 | +#define TVM_EXECUTOR_REGISTER_VAR_DEF \ |
| 254 | + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::ExecutorRegEntry& __make_##Executor |
| 255 | + |
| 256 | +/*! |
| 257 | + * \def TVM_REGISTER_EXECUTOR |
| 258 | + * \brief Register a new executor, or set attribute of the corresponding executor. |
| 259 | + * |
| 260 | + * \param ExecutorName The name of registry |
| 261 | + * |
| 262 | + * \code |
| 263 | + * |
| 264 | + * TVM_REGISTER_EXECUTOR("aot") |
| 265 | + * .add_attr_option<String>("my_option"); |
| 266 | + * .add_attr_option<String>("my_option_default", String("default")); |
| 267 | + * |
| 268 | + * \endcode |
| 269 | + */ |
| 270 | +#define TVM_REGISTER_EXECUTOR(ExecutorName) \ |
| 271 | + TVM_STR_CONCAT(TVM_EXECUTOR_REGISTER_VAR_DEF, __COUNTER__) = \ |
| 272 | + ::tvm::relay::ExecutorRegEntry::RegisterOrGet(ExecutorName).set_name() |
| 273 | +} // namespace relay |
| 274 | +} // namespace tvm |
| 275 | + |
| 276 | +#endif // TVM_RELAY_EXECUTOR_H_ |
0 commit comments