diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 620b7e76e836..332f78a2fe78 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -269,9 +269,5 @@ inline constexpr bool type_contains_v, Tuple> = (type_contains } // namespace details } // namespace ffi - -// Expose to the tvm namespace -// Rationale: convinience and no ambiguity -using ffi::Tuple; } // namespace tvm #endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h deleted file mode 100644 index 0445c3d3baa2..000000000000 --- a/include/tvm/node/object_path.h +++ /dev/null @@ -1,287 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/node/object_path.h - * ObjectPath class that represents a path from a root object to one of its descendants - * via attribute access, array indexing etc. - */ - -#ifndef TVM_NODE_OBJECT_PATH_H_ -#define TVM_NODE_OBJECT_PATH_H_ - -#include -#include -#include - -#include - -namespace tvm { - -using runtime::Object; -using runtime::ObjectPtr; -using runtime::ObjectRef; - -class ObjectPath; - -/*! - * \brief Path to an object from some root object. - * - * Motivation: - * - * Same IR node object can be referenced in several different contexts inside a larger IR object. - * For example, a variable could be referenced in several statements within a block. - * - * This makes it impossible to use an object pointer to uniquely identify a "location" within - * the larger IR object for error reporting purposes. The ObjectPath class addresses this problem - * by serving as a unique "locator". - */ -class ObjectPathNode : public Object { - public: - /*! \brief Get the parent path */ - Optional GetParent() const; - /*! - * \brief Get the length of the path. - * - * For example, the path returned by `ObjectPath::Root()` has length 1. - */ - int32_t Length() const; - - /*! - * \brief Get a path prefix of the given length. - * - * Provided `length` must not exceed the `Length()` of this path. - */ - ObjectPath GetPrefix(int32_t length) const; - - /*! - * \brief Check if this path is a prefix of another path. - * - * The prefix is not strict, i.e. a path is considered a prefix of itself. - */ - bool IsPrefixOf(const ObjectPath& other) const; - - /*! \brief Check if two paths are equal. */ - bool PathsEqual(const ObjectPath& other) const; - - /*! \brief Extend this path with access to an object attribute. */ - ObjectPath Attr(const char* attr_key) const; - - /*! \brief Extend this path with access to an object attribute. */ - ObjectPath Attr(Optional attr_key) const; - - /*! \brief Extend this path with access to an array element. */ - ObjectPath ArrayIndex(int32_t index) const; - - /*! \brief Extend this path with access to a missing array element. */ - ObjectPath MissingArrayElement(int32_t index) const; - - /*! \brief Extend this path with access to a map value. */ - ObjectPath MapValue(ffi::Any key) const; - - /*! \brief Extend this path with access to a missing map entry. */ - ObjectPath MissingMapEntry() const; - - static constexpr const char* _type_key = "node.ObjectPath"; - TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); - - protected: - explicit ObjectPathNode(const ObjectPathNode* parent); - - friend class ObjectPath; - friend std::string GetObjectPathRepr(const ObjectPathNode* node); - - const ObjectPathNode* ParentNode() const; - - /*! Compares just the last node of the path, without comparing the whole path. */ - virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0; - - virtual std::string LastNodeString() const = 0; - - private: - Optional parent_; - int32_t length_; -}; - -class ObjectPath : public ObjectRef { - public: - /*! \brief Create a path that represents the root object itself. */ - static ObjectPath Root(Optional name = std::nullopt); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); -}; - -//------------------------------------------------------------------------- -//----- Concrete object path nodes ------------------------------------ -//------------------------------------------------------------------------- - -// ----- Root ----- - -class RootPathNode final : public ObjectPathNode { - public: - Optional name; - - explicit RootPathNode(Optional name = std::nullopt); - - static constexpr const char* _type_key = "node.RootPath"; - TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class RootPath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode); -}; - -// ----- Attribute access ----- - -class AttributeAccessPathNode final : public ObjectPathNode { - public: - /*! \brief Name of the attribute being accessed. Must be a static string. */ - String attr_key; - - explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key); - - static constexpr const char* _type_key = "node.AttributeAccessPath"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class AttributeAccessPath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath, - AttributeAccessPathNode); -}; - -// ----- Unknown attribute access ----- - -class UnknownAttributeAccessPathNode final : public ObjectPathNode { - public: - explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent); - - static constexpr const char* _type_key = "node.UnknownAttributeAccessPath"; - TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class UnknownAttributeAccessPath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath, - UnknownAttributeAccessPathNode); -}; - -// ----- Array element access by index ----- - -class ArrayIndexPathNode : public ObjectPathNode { - public: - /*! \brief Index of the array element that is being accessed. */ - int32_t index; - - explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index); - - static constexpr const char* _type_key = "node.ArrayIndexPath"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class ArrayIndexPath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode); -}; - -// ----- Missing array element ----- - -class MissingArrayElementPathNode : public ObjectPathNode { - public: - /*! \brief Index of the array element that is missing. */ - int32_t index; - - explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index); - - static constexpr const char* _type_key = "node.MissingArrayElementPath"; - TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class MissingArrayElementPath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath, - MissingArrayElementPathNode); -}; - -// ----- Map value ----- - -class MapValuePathNode : public ObjectPathNode { - public: - /*! \brief Key of the map entry that is being accessed */ - ffi::Any key; - - explicit MapValuePathNode(const ObjectPathNode* parent, ffi::Any key); - - static constexpr const char* _type_key = "node.MapValuePath"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class MapValuePath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode); -}; - -// ----- Missing map entry ----- - -class MissingMapEntryPathNode : public ObjectPathNode { - public: - explicit MissingMapEntryPathNode(const ObjectPathNode* parent); - - static constexpr const char* _type_key = "node.MissingMapEntryPath"; - TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); - - protected: - bool LastNodeEqual(const ObjectPathNode* other) const final; - std::string LastNodeString() const final; -}; - -class MissingMapEntryPath : public ObjectPath { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath, - MissingMapEntryPathNode); -}; - -} // namespace tvm - -#endif // TVM_NODE_OBJECT_PATH_H_ diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index e3baf397f25f..05687d70d742 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -23,6 +23,7 @@ #ifndef TVM_NODE_REPR_PRINTER_H_ #define TVM_NODE_REPR_PRINTER_H_ +#include #include #include @@ -87,6 +88,51 @@ inline std::ostream& operator<<(std::ostream& os, const Variant& n) { // return os; } +namespace reflection { + +inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { + namespace refl = ffi::reflection; + switch (step->kind) { + case refl::AccessKind::kAttr: { + os << '.' << step->key.cast(); + return os; + } + case refl::AccessKind::kArrayItem: { + os << "[" << step->key.cast() << "]"; + return os; + } + case refl::AccessKind::kMapItem: { + os << "[" << step->key << "]"; + return os; + } + case refl::AccessKind::kAttrMissing: { + os << ".key.cast() << "`>"; + return os; + } + case refl::AccessKind::kArrayItemMissing: { + os << "[key.cast() << ">]"; + return os; + } + case refl::AccessKind::kMapItemMissing: { + os << "[key << ">]"; + return os; + } + default: { + LOG(FATAL) << "Unknown access step kind: " << static_cast(step->kind); + } + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const AccessPath& path) { + Array steps = path->ToSteps(); + os << ""; + for (const auto& step : steps) { + os << step; + } + return os; +} +} // namespace reflection } // namespace ffi } // namespace tvm #endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index c55335380fe8..d046dbfae732 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -26,10 +26,10 @@ #include #include #include +#include #include #include #include -#include #include #include @@ -37,7 +37,7 @@ namespace tvm { -class PrinterConfigNode : public Object { +class PrinterConfigNode : public ffi::Object { public: /*! \brief A stack that tracks the names of the binding hierarchy */ Array binding_names = {}; @@ -113,9 +113,9 @@ class PrinterConfigNode : public Object { bool show_all_struct_info = true; /* \brief Object path to be underlined */ - Array path_to_underline = Array(); + Array path_to_underline; /*! \brief Object path to be annotated. */ - Map path_to_annotate = Map(); + Map path_to_annotate; /*! \brief Object to be underlined. */ Array obj_to_underline = Array(); /*! \brief Object to be annotated. */ diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 0e7dc246a3e7..12ba59118b72 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -24,8 +24,8 @@ #define TVM_NODE_STRUCTURAL_EQUAL_H_ #include +#include #include -#include #include #include @@ -74,27 +74,6 @@ class BaseValueEqual { } }; -/*! - * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality. - */ -class ObjectPathPairNode : public Object { - public: - ObjectPath lhs_path; - ObjectPath rhs_path; - - ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path); - - static constexpr const char* _type_key = "node.ObjectPathPair"; - TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object); -}; - -class ObjectPathPair : public ObjectRef { - public: - ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode); -}; - /*! * \brief Content-aware structural equality comparator for objects. * @@ -129,5 +108,6 @@ class StructuralEqual : public BaseValueEqual { TVM_DLL bool operator()(const ffi::Any& lhs, const ffi::Any& rhs, const bool map_free_params = false) const; }; + } // namespace tvm #endif // TVM_NODE_STRUCTURAL_EQUAL_H_ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index de3fb0bbad2c..b045ee00315b 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_DOC_H_ #define TVM_SCRIPT_PRINTER_DOC_H_ +#include #include #include #include @@ -31,6 +32,8 @@ namespace tvm { namespace script { namespace printer { +using AccessPath = ffi::reflection::AccessPath; + // Forward declaration class Doc; @@ -61,7 +64,7 @@ class DocNode : public Object { * this Doc is generated, in order to position the diagnostic * message. */ - mutable Array source_paths; + mutable Array source_paths; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -266,20 +269,20 @@ class LiteralDocNode : public ExprDocNode { */ class LiteralDoc : public ExprDoc { protected: - explicit LiteralDoc(ffi::Any value, const Optional& object_path); + explicit LiteralDoc(ffi::Any value, const Optional& object_path); public: /*! * \brief Create a LiteralDoc to represent None/null/empty value. * \param p The object path */ - static LiteralDoc None(const Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } + static LiteralDoc None(const Optional& p) { return LiteralDoc(ffi::Any(nullptr), p); } /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. * \param p The object path */ - static LiteralDoc Int(int64_t v, const Optional& p) { + static LiteralDoc Int(int64_t v, const Optional& p) { return LiteralDoc(IntImm(DataType::Int(64), v), p); } /*! @@ -287,7 +290,7 @@ class LiteralDoc : public ExprDoc { * \param v The boolean value. * \param p The object path */ - static LiteralDoc Boolean(bool v, const Optional& p) { + static LiteralDoc Boolean(bool v, const Optional& p) { return LiteralDoc(IntImm(DataType::Bool(), v), p); } /*! @@ -295,7 +298,7 @@ class LiteralDoc : public ExprDoc { * \param v The float value. * \param p The object path */ - static LiteralDoc Float(double v, const Optional& p) { + static LiteralDoc Float(double v, const Optional& p) { return LiteralDoc(FloatImm(DataType::Float(64), v), p); } /*! @@ -303,13 +306,13 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc Str(const String& v, const Optional& p) { return LiteralDoc(v, p); } + static LiteralDoc Str(const String& v, const Optional& p) { return LiteralDoc(v, p); } /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, const Optional& p) { + static LiteralDoc DataType(const runtime::DataType& v, const Optional& p) { std::string dtype = v.is_void() ? "void" : runtime::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } @@ -318,7 +321,7 @@ class LiteralDoc : public ExprDoc { * \param v The device. * \param p The object path */ - static LiteralDoc Device(const DLDevice& v, const Optional& p) { + static LiteralDoc Device(const DLDevice& v, const Optional& p) { std::ostringstream os; runtime::operator<<(os, v); return LiteralDoc::Str(os.str(), p); diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 8a181cf853ab..dd7eaff7cc69 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ +#include #include #include #include @@ -35,6 +36,8 @@ namespace tvm { namespace script { namespace printer { +using AccessPath = ffi::reflection::AccessPath; + //////////////////////// Frame //////////////////////// class IRDocsifier; @@ -235,7 +238,7 @@ class IRDocsifierNode : public Object { * \return The Doc object. */ template - inline TDoc AsDoc(const Any& obj, const ObjectPath& path) const; + inline TDoc AsDoc(const Any& obj, const AccessPath& path) const; }; /*! @@ -243,7 +246,7 @@ class IRDocsifierNode : public Object { */ class IRDocsifier : public ObjectRef { public: - using FType = IRDocsifierFunctor; + using FType = IRDocsifierFunctor; /*! \brief Create a IRDocsifier. */ explicit IRDocsifier(const PrinterConfig& cfg); /*! \brief The registration table for IRDocsifier. */ @@ -271,7 +274,7 @@ inline void FrameNode::ExitWithScope() { } template -inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const ObjectPath& path, +inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const AccessPath& path, const PrinterConfig& cfg) { if (cfg->obj_to_annotate.count(obj)) { if (const auto* stmt = d.as()) { @@ -291,7 +294,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ob } } for (const auto& pair : cfg->path_to_annotate) { - ObjectPath p = pair.first; + AccessPath p = pair.first; String attn = pair.second; if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { if (const auto* stmt = d.as()) { @@ -309,7 +312,7 @@ inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const Ob } template -inline TDoc IRDocsifierNode::AsDoc(const Any& value, const ObjectPath& path) const { +inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) const { switch (value.type_index()) { case ffi::TypeIndex::kTVMFFINone: return Downcast(LiteralDoc::None(path)); diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 62133ef2c9da..e4be2d31aa57 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -35,7 +35,7 @@ namespace script { namespace printer { /*! - * \brief Dynamic dispatch functor based on ObjectPath. + * \brief Dynamic dispatch functor based on AccessPath. * * This functor dispatches based on the type of object and the input dispatch token. */ diff --git a/python/tvm/ffi/access_path.py b/python/tvm/ffi/access_path.py index c4822074ebb8..fb8ab1b2edea 100644 --- a/python/tvm/ffi/access_path.py +++ b/python/tvm/ffi/access_path.py @@ -177,3 +177,5 @@ def to_steps(self) -> List["AccessStep"]: The list of access steps """ return self._to_steps() + + __hash__ = core.Object.__hash__ diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index d34137101119..088ca6b96506 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -196,8 +196,8 @@ def structural_equal(lhs, rhs, map_free_vars=False): return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) # type: ignore # pylint: disable=no-member -def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): - """Like structural_equal(), but returns the ObjectPaths of the first detected mismatch. +def get_first_structural_mismatch(lhs, rhs, map_free_vars=False, skip_ndarray_content=False): + """Like structural_equal(), but returns the AccessPath pair of the first detected mismatch. Parameters ---------- @@ -211,19 +211,18 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): Whether free variables (i.e. variables without a definition site) should be mapped as equal to each other. + skip_ndarray_content : bool + Whether to skip the content of ndarray. + Returns ------- - mismatch: Optional[Tuple[ObjectPath, ObjectPath]] + mismatch: Optional[Tuple[AccessPath, AccessPath]] `None` if `lhs` and `rhs` are structurally equal. - Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach. + Otherwise, a tuple of two AccessPath objects that point to the first detected mismtach. """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - mismatch = _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars) # type: ignore # pylint: disable=no-member - if mismatch is None: - return None - else: - return mismatch.lhs_path, mismatch.rhs_path + return _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars, skip_ndarray_content) # type: ignore # pylint: disable=no-member def assert_structural_equal(lhs, rhs, map_free_vars=False): diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 774c8dd635dd..ca70cf0f45a7 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -19,7 +19,6 @@ # class exposures from .packed_func import PackedFunc from .object import Object -from .object_path import ObjectPath, ObjectPathPair from .script_printer import Scriptable from .object_generic import ObjectGeneric from .device import Device diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py deleted file mode 100644 index 957db558a45b..000000000000 --- a/python/tvm/runtime/object_path.py +++ /dev/null @@ -1,144 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -ObjectPath class that represents a path from a root object to one of its descendants -via attribute access, array indexing etc. -""" - -from typing import Optional - -import tvm.ffi -from tvm.runtime import Object -from . import _ffi_node_api - - -__all__ = ( - "ObjectPath", - "RootPath", - "AttributeAccessPath", - "UnknownAttributeAccessPath", - "ArrayIndexPath", - "MissingArrayElementPath", - "MapValuePath", - "MissingMapEntryPath", - "ObjectPathPair", -) - - -@tvm.ffi.register_object("node.ObjectPath") -class ObjectPath(Object): - """ - Path to an object from some root object. - """ - - def __init__(self) -> None: - super().__init__() - raise ValueError( - "ObjectPath can't be initialized directly. " - "Use ObjectPath.root() to create a path to the root object" - ) - - @staticmethod - def root(root_name: Optional[str] = None) -> "ObjectPath": - return _ffi_node_api.ObjectPathRoot(root_name) - - def __eq__(self, other): - return _ffi_node_api.ObjectPathEqual(self, other) - - def __ne__(self, other): - return not _ffi_node_api.ObjectPathEqual(self, other) - - @property - def parent(self) -> "ObjectPath": - return _ffi_node_api.ObjectPathGetParent(self) - - def __len__(self) -> int: - return _ffi_node_api.ObjectPathLength(self) - - def get_prefix(self, length) -> "ObjectPath": - return _ffi_node_api.ObjectPathGetPrefix(self, length) - - def is_prefix_of(self, other) -> "ObjectPath": - return _ffi_node_api.ObjectPathIsPrefixOf(self, other) - - def attr(self, attr_key) -> "ObjectPath": - return _ffi_node_api.ObjectPathAttr(self, attr_key) - - def array_index(self, index) -> "ObjectPath": - return _ffi_node_api.ObjectPathArrayIndex(self, index) - - def missing_array_element(self, index) -> "ObjectPath": - return _ffi_node_api.ObjectPathMissingArrayElement(self, index) - - def map_value(self, key) -> "ObjectPath": - return _ffi_node_api.ObjectPathMapValue(self, tvm.runtime.convert(key)) - - def missing_map_entry(self) -> "ObjectPath": - return _ffi_node_api.ObjectPathMissingMapEntry(self) - - __hash__ = Object.__hash__ - - -@tvm.ffi.register_object("node.RootPath") -class RootPath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.AttributeAccessPath") -class AttributeAccessPath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.UnknownAttributeAccessPath") -class UnknownAttributeAccessPath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.ArrayIndexPath") -class ArrayIndexPath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.MissingArrayElementPath") -class MissingArrayElementPath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.MapValuePath") -class MapValuePath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.MissingMapEntryPath") -class MissingMapEntryPath(ObjectPath): - pass - - -@tvm.ffi.register_object("node.ObjectPathPair") -class ObjectPathPair(Object): - """ - Pair of ObjectPaths, one for each object being tested for structural equality. - """ - - @property - def lhs_path(self) -> ObjectPath: - return _ffi_node_api.ObjectPathPairLhsPath(self) - - @property - def rhs_path(self) -> ObjectPath: - return _ffi_node_api.ObjectPathPairRhsPath(self) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 2820f7b97bc9..a00281b435ef 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -19,10 +19,10 @@ from typing import Dict, List, Optional, Sequence from tvm.ffi import get_global_func, register_object +from tvm.ffi.access_path import AccessPath from tvm.runtime import Object from . import _ffi_node_api -from .object_path import ObjectPath @register_object("script.PrinterConfig") @@ -45,10 +45,10 @@ class PrinterConfig(Object): syntax_sugar: bool show_object_address: bool show_all_struct_info: bool - path_to_underline: Optional[List[ObjectPath]] - path_to_annotate: Optional[Dict[ObjectPath, str]] - obj_to_underline: Optional[List[Object]] - obj_to_annotate: Optional[Dict[Object, str]] + path_to_underline: Optional[List[AccessPath]] + path_to_annotate: Optional[Dict[AccessPath, str]] + obj_to_underline: Optional[List[AccessPath]] + obj_to_annotate: Optional[Dict[AccessPath, str]] def __init__( self, @@ -69,8 +69,8 @@ def __init__( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, - path_to_underline: Optional[List[ObjectPath]] = None, - path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + path_to_underline: Optional[List[AccessPath]] = None, + path_to_annotate: Optional[Dict[AccessPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> None: @@ -136,8 +136,8 @@ def script( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, - path_to_underline: Optional[List[ObjectPath]] = None, - path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + path_to_underline: Optional[List[AccessPath]] = None, + path_to_annotate: Optional[Dict[AccessPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> str: @@ -180,9 +180,9 @@ def script( If True (default), annotate all variable bindings with the struct info of that variable. If False, only add annotations where required for unambiguous round-trip of Relax -> TVMScript -> Relax. - path_to_underline : Optional[List[ObjectPath]] = None + path_to_underline : Optional[List[AccessPath]] = None Object path to be underlined - path_to_annotate : Optional[Dict[ObjectPath, str]] = None + path_to_annotate : Optional[Dict[AccessPath, str]] = None Object path to be annotated obj_to_underline : Optional[List[Object]] = None Object to be underlined @@ -239,8 +239,8 @@ def _relax_script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, - path_to_underline: Optional[List[ObjectPath]] = None, - path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + path_to_underline: Optional[List[AccessPath]] = None, + path_to_annotate: Optional[Dict[AccessPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> str: @@ -290,8 +290,8 @@ def show( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, - path_to_underline: Optional[List[ObjectPath]] = None, - path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + path_to_underline: Optional[List[AccessPath]] = None, + path_to_annotate: Optional[Dict[AccessPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> None: @@ -357,9 +357,9 @@ def show( If True (default), annotate all variable bindings with the struct info of that variable. If False, only add annotations where required for unambiguous round-trip of Relax -> TVMScript -> Relax. - path_to_underline : Optional[List[ObjectPath]] = None + path_to_underline : Optional[List[AccessPath]] = None Object path to be underlined - path_to_annotate : Optional[Dict[ObjectPath, str]] = None + path_to_annotate : Optional[Dict[AccessPath, str]] = None Object path to be annotated obj_to_underline : Optional[List[Object]] = None Object to be underlined diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index bf468b17ec18..382128ef33d7 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -20,7 +20,8 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union from tvm.ffi import register_object -from tvm.runtime import Object, ObjectPath +from tvm.ffi.access_path import AccessPath +from tvm.runtime import Object from tvm.tir import FloatImm, IntImm from . import _ffi_api @@ -129,7 +130,7 @@ class LiteralDoc(ExprDoc): def __init__( self, value: Union[str, float, bool, int, None], - path: Optional[ObjectPath] = None, + path: Optional[AccessPath] = None, ): if value is None: self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone, path) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py index b43ca3b5333e..5f1f9800848b 100644 --- a/python/tvm/script/printer/doc_printer.py +++ b/python/tvm/script/printer/doc_printer.py @@ -18,7 +18,7 @@ from typing import List, Optional -from tvm.runtime import ObjectPath +from tvm.ffi.access_path import AccessPath from tvm.runtime.script_printer import PrinterConfig from . import _ffi_api @@ -30,7 +30,7 @@ def to_python_script( indent_spaces: int = 4, print_line_numbers: bool = False, num_context_lines: Optional[int] = None, - path_to_underline: Optional[List[ObjectPath]] = None, + path_to_underline: Optional[List[AccessPath]] = None, ) -> str: """Convert Doc into Python script. @@ -44,7 +44,7 @@ def to_python_script( Whether to print line numbers num_context_lines : Optional[int] Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] + path_to_underline : Optional[AccessPath] Object path to be underlined Returns diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 4670abe52ec1..7f84978105ea 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -168,7 +168,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { continue; } if (func_params_.count(p) && func_params_[p]->IsInstance()) { - const auto& tuple = Downcast(func_params_[p]); + const auto& tuple = Downcast(func_params_[p]); Array tuple_names; for (const auto& f : tuple->fields) { if (expr_tensor_map_.count(f)) { @@ -735,7 +735,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const CallNode* void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { ExprVisitor::VisitBinding_(binding, val); const String& name = config_.use_var_name ? binding->var->name_hint() : ""; - AddNode(GetRef(val), binding->var, name); + AddNode(GetRef(val), binding->var, name); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) { @@ -806,7 +806,7 @@ Array GraphBuilder::GetPluginInputs(const Expr& expr) { ICHECK(expr->IsInstance()) << "plugin expr should be call"; const auto& call = Downcast(expr); ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; - return Downcast(call->args[1])->fields; + return Downcast(call->args[1])->fields; } Map WeightsExtractor::GetWeights(const Function& func) { diff --git a/src/node/object_path.cc b/src/node/object_path.cc deleted file mode 100644 index 3e68e0d0efa0..000000000000 --- a/src/node/object_path.cc +++ /dev/null @@ -1,345 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include - -#include -#include - -using namespace tvm::runtime; - -namespace tvm { - -// ============== ObjectPathNode ============== - -ObjectPathNode::ObjectPathNode(const ObjectPathNode* parent) - : parent_(GetRef(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {} - -// --- GetParent --- - -Optional ObjectPathNode::GetParent() const { - return Downcast>(parent_); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathGetParent", &ObjectPathNode::GetParent); -}); - -// --- Length --- - -int32_t ObjectPathNode::Length() const { return length_; } - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathLength", &ObjectPathNode::Length); -}); - -// --- GetPrefix --- - -ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { - CHECK_GE(length, 1) << "IndexError: Prefix length must be at least 1"; - CHECK_LE(length, Length()) << "IndexError: Attempted to get a prefix longer than the path itself"; - - const ObjectPathNode* node = this; - int32_t suffix_len = Length() - length; - for (int32_t i = 0; i < suffix_len; ++i) { - node = node->ParentNode(); - } - - return GetRef(node); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathGetPrefix", &ObjectPathNode::GetPrefix); -}); - -// --- IsPrefixOf --- - -bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { - int32_t this_len = Length(); - if (this_len > other->Length()) { - return false; - } - return this->PathsEqual(other->GetPrefix(this_len)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathIsPrefixOf", &ObjectPathNode::IsPrefixOf); -}); - -// --- Attr --- - -ObjectPath ObjectPathNode::Attr(const char* attr_key) const { - if (attr_key != nullptr) { - return ObjectPath(make_object(this, attr_key)); - } else { - return ObjectPath(make_object(this)); - } -} - -ObjectPath ObjectPathNode::Attr(Optional attr_key) const { - if (attr_key.has_value()) { - return ObjectPath(make_object(this, attr_key.value())); - } else { - return ObjectPath(make_object(this)); - } -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("node.ObjectPathAttr", - [](const ObjectPath& object_path, Optional attr_key) { - return object_path->Attr(attr_key); - }); -}); - -// --- ArrayIndex --- - -ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { - return ObjectPath(make_object(this, index)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathArrayIndex", &ObjectPathNode::ArrayIndex); -}); - -// --- MissingArrayElement --- - -ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { - return ObjectPath(make_object(this, index)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathMissingArrayElement", - &ObjectPathNode::MissingArrayElement); -}); - -// --- MapValue --- - -ObjectPath ObjectPathNode::MapValue(Any key) const { - return ObjectPath(make_object(this, std::move(key))); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathMapValue", &ObjectPathNode::MapValue); -}); - -// --- MissingMapEntry --- - -ObjectPath ObjectPathNode::MissingMapEntry() const { - return ObjectPath(make_object(this)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathMissingMapEntry", &ObjectPathNode::MissingMapEntry); -}); - -// --- PathsEqual ---- - -bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { - if (!other.defined() || Length() != other->Length()) { - return false; - } - - const ObjectPathNode* lhs = this; - const ObjectPathNode* rhs = static_cast(other.get()); - - while (lhs != nullptr && rhs != nullptr) { - if (lhs->type_index() != rhs->type_index()) { - return false; - } - if (!lhs->LastNodeEqual(rhs)) { - return false; - } - lhs = lhs->ParentNode(); - rhs = rhs->ParentNode(); - } - - return lhs == nullptr && rhs == nullptr; -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_method("node.ObjectPathEqual", &ObjectPathNode::PathsEqual); -}); - -// --- Repr --- - -std::string GetObjectPathRepr(const ObjectPathNode* node) { - std::string ret; - while (node != nullptr) { - std::string node_str = node->LastNodeString(); - ret.append(node_str.rbegin(), node_str.rend()); - node = static_cast(node->GetParent().get()); - } - std::reverse(ret.begin(), ret.end()); - return ret; -} - -static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) { - p->stream << GetObjectPathRepr(static_cast(node.get())); -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); - -// --- Private/protected methods --- - -const ObjectPathNode* ObjectPathNode::ParentNode() const { - return static_cast(parent_.get()); -} - -// ============== ObjectPath ============== - -/* static */ ObjectPath ObjectPath::Root(Optional name) { - return ObjectPath(make_object(name)); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("node.ObjectPathRoot", ObjectPath::Root); -}); - -// ============== Individual path classes ============== - -// ----- Root ----- - -RootPathNode::RootPathNode(Optional name) : ObjectPathNode(nullptr), name(name) {} - -bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const { - const auto* other = static_cast(other_path); - - if (other->name.has_value() != name.has_value()) { - return false; - } else if (name && other->name) { - return name.value() == other->name.value(); - } else { - return true; - } -} - -std::string RootPathNode::LastNodeString() const { return name.value_or(""); } - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); - -// ----- AttributeAccess ----- - -AttributeAccessPathNode::AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key) - : ObjectPathNode(parent), attr_key(std::move(attr_key)) {} - -bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { - const auto* otherAttrAccess = static_cast(other); - return attr_key == otherAttrAccess->attr_key; -} - -std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; } - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch(PrintObjectPathRepr); - -// ----- UnknownAttributeAccess ----- - -UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(const ObjectPathNode* parent) - : ObjectPathNode(parent) {} - -bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { - // Consider any two unknown attribute accesses unequal - return false; -} - -std::string UnknownAttributeAccessPathNode::LastNodeString() const { - return "."; -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch(PrintObjectPathRepr); - -// ----- ArrayIndexPath ----- - -ArrayIndexPathNode::ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index) - : ObjectPathNode(parent), index(index) {} - -bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const { - const auto* otherArrayIndex = static_cast(other); - return index == otherArrayIndex->index; -} - -std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]"; } - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); - -// ----- MissingArrayElement ----- - -MissingArrayElementPathNode::MissingArrayElementPathNode(const ObjectPathNode* parent, - int32_t index) - : ObjectPathNode(parent), index(index) {} - -bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const { - const auto* otherMissingElement = static_cast(other); - return index == otherMissingElement->index; -} - -std::string MissingArrayElementPathNode::LastNodeString() const { - return "[]"; -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch(PrintObjectPathRepr); - -// ----- MapValue ----- - -MapValuePathNode::MapValuePathNode(const ObjectPathNode* parent, Any key) - : ObjectPathNode(parent), key(std::move(key)) {} - -bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const { - const auto* otherMapValue = static_cast(other); - return ffi::AnyEqual()(key, otherMapValue->key); -} - -std::string MapValuePathNode::LastNodeString() const { - std::ostringstream s; - s << "[" << key << "]"; - return s.str(); -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(PrintObjectPathRepr); - -// ----- MissingMapEntry ----- - -MissingMapEntryPathNode::MissingMapEntryPathNode(const ObjectPathNode* parent) - : ObjectPathNode(parent) {} - -bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } - -std::string MissingMapEntryPathNode::LastNodeString() const { return "[]"; } - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch(PrintObjectPathRepr); - -} // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index d3b62b5e8775..6a60b9723d3d 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -116,6 +116,16 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + p->stream << Downcast(node); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + p->stream << Downcast(node); + }); + TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) { diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 5b52ba635a37..9b1565d2ab3a 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -26,6 +26,8 @@ namespace tvm { +using AccessPath = ffi::reflection::AccessPath; + TVM_FFI_STATIC_INIT_BLOCK({ PrinterConfigNode::RegisterReflection(); }); TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { @@ -99,11 +101,11 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->num_context_lines = v.value().cast(); } if (auto v = config_dict.Get("path_to_underline")) { - n->path_to_underline = Downcast>>(v).value_or(Array()); + n->path_to_underline = Downcast>>(v).value_or(Array()); } if (auto v = config_dict.Get("path_to_annotate")) { n->path_to_annotate = - Downcast>>(v).value_or(Map()); + Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("obj_to_underline")) { n->obj_to_underline = Downcast>>(v).value_or(Array()); diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index c6875d3fca4b..be009a77c305 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include @@ -34,75 +33,19 @@ namespace tvm { -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("node.ObjectPathPairLhsPath", - [](const ObjectPathPair& object_path_pair) { return object_path_pair->lhs_path; }) - .def("node.ObjectPathPairRhsPath", - [](const ObjectPathPair& object_path_pair) { return object_path_pair->rhs_path; }); -}); - -ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path) - : lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {} - -ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) { - data_ = make_object(std::move(lhs_path), std::move(rhs_path)); -} - -Optional ObjectPathPairFromAccessPathPair( - Optional src) { - if (!src.has_value()) return std::nullopt; - auto translate_path = [](ffi::reflection::AccessPath path) { - ObjectPath result = ObjectPath::Root(); - for (const auto& step : path->ToSteps()) { - switch (step->kind) { - case ffi::reflection::AccessKind::kAttr: { - result = result->Attr(step->key.cast()); - break; - } - case ffi::reflection::AccessKind::kArrayItem: { - result = result->ArrayIndex(step->key.cast()); - break; - } - case ffi::reflection::AccessKind::kMapItem: { - result = result->MapValue(step->key); - break; - } - case ffi::reflection::AccessKind::kArrayItemMissing: { - result = result->MissingArrayElement(step->key.cast()); - break; - } - case ffi::reflection::AccessKind::kMapItemMissing: { - result = result->MissingMapEntry(); - break; - } - default: { - LOG(FATAL) << "Invalid access path kind: " << static_cast(step->kind); - break; - } - } - } - return result; - }; - - return ObjectPathPair(translate_path((*src).get<0>()), translate_path((*src).get<1>())); -} - bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool assert_mode, bool map_free_vars) { if (assert_mode) { - auto first_mismatch = ObjectPathPairFromAccessPathPair( - ffi::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); + auto first_mismatch = ffi::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars); if (first_mismatch.has_value()) { std::ostringstream oss; oss << "StructuralEqual check failed, caused by lhs"; - oss << " at " << (*first_mismatch)->lhs_path; + oss << " at " << (*first_mismatch).get<0>(); { // print lhs PrinterConfig cfg; cfg->syntax_sugar = false; - cfg->path_to_underline.push_back((*first_mismatch)->lhs_path); + cfg->path_to_underline.push_back((*first_mismatch).get<0>()); // The TVMScriptPrinter::Script will fallback to Repr printer, // if the root node to print is not supported yet, // e.g. Relax nodes, ArrayObj, MapObj, etc. @@ -111,11 +54,11 @@ bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool assert_mode oss << std::endl << "and rhs"; { // print rhs - oss << " at " << (*first_mismatch)->rhs_path; + oss << " at " << (*first_mismatch).get<1>(); { PrinterConfig cfg; cfg->syntax_sugar = false; - cfg->path_to_underline.push_back((*first_mismatch)->rhs_path); + cfg->path_to_underline.push_back((*first_mismatch).get<1>()); // The TVMScriptPrinter::Script will fallback to Repr printer, // if the root node to print is not supported yet, // e.g. Relax nodes, ArrayObj, MapObj, etc. @@ -134,18 +77,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("node.StructuralEqual", NodeStructuralEqualAdapter) - .def("node.GetFirstStructuralMismatch", - [](const Any& lhs, const Any& rhs, bool map_free_vars) { - /* - Optional first_mismatch; - bool equal = - SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); - ICHECK(equal == !first_mismatch.defined()); - return first_mismatch; - */ - return ObjectPathPairFromAccessPathPair( - ffi::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars)); - }); + .def("node.GetFirstStructuralMismatch", ffi::StructuralEqual::GetFirstMismatch); }); bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs, diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 6c6f50785221..7b0846051609 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -23,10 +23,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 7d1f6281df23..aa7e0473488b 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -79,7 +79,7 @@ StmtBlockDoc::StmtBlockDoc(Array stmts) { this->data_ = std::move(n); } -LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) { +LiteralDoc::LiteralDoc(ffi::Any value, const Optional& object_path) { ObjectPtr n = make_object(); n->value = value; if (object_path.defined()) { @@ -268,7 +268,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "script.printer.DocSetSourcePaths", - [](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); + [](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); }); TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index e5a47d7ca2a5..7e6d76c4bf9a 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -264,9 +264,9 @@ DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) { void DocPrinter::Append(const Doc& doc) { Append(doc, PrinterConfig()); } void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) { - for (const ObjectPath& p : cfg->path_to_underline) { + for (const AccessPath& p : cfg->path_to_underline) { path_to_underline_.push_back(p); - current_max_path_length_.push_back(0); + current_max_path_depth_.push_back(0); current_underline_candidates_.push_back(std::vector()); } PrintDoc(doc); @@ -348,18 +348,18 @@ void DocPrinter::PrintDoc(const Doc& doc) { } size_t end_pos = output_.tellp(); - for (const ObjectPath& path : doc->source_paths) { + for (const AccessPath& path : doc->source_paths) { MarkSpan({start_pos, end_pos}, path); } } -void DocPrinter::MarkSpan(const ByteSpan& span, const ObjectPath& path) { +void DocPrinter::MarkSpan(const ByteSpan& span, const AccessPath& path) { int n = path_to_underline_.size(); for (int i = 0; i < n; ++i) { - ObjectPath p = path_to_underline_[i]; - if (path->Length() >= current_max_path_length_[i] && path->IsPrefixOf(p)) { - if (path->Length() > current_max_path_length_[i]) { - current_max_path_length_[i] = path->Length(); + AccessPath p = path_to_underline_[i]; + if (path->depth >= current_max_path_depth_[i] && path->IsPrefixOf(p)) { + if (path->depth > current_max_path_depth_[i]) { + current_max_path_depth_[i] = path->depth; current_underline_candidates_[i].clear(); } current_underline_candidates_[i].push_back(span); diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index aff587062d07..b92c9dbe7aa2 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -255,7 +255,7 @@ class DocPrinter { std::vector underlines_exempted_; private: - void MarkSpan(const ByteSpan& span, const ObjectPath& path); + void MarkSpan(const ByteSpan& span, const AccessPath& path); /*! \brief Options to customize certain aspects of the output */ PrinterConfig options_; @@ -267,7 +267,7 @@ class DocPrinter { std::vector line_starts_; /*! \brief Path of the object that we would like to underline */ - Array path_to_underline_; + Array path_to_underline_; /*! * \brief Candidate spans to be underlined, until we find a better match. @@ -276,7 +276,7 @@ class DocPrinter { std::vector> current_underline_candidates_; /*! \brief Path length of the objects that are current candidates for underlining. */ - std::vector current_max_path_length_; + std::vector current_max_path_depth_; /*! \brief Spans that we have already committed to underline. */ std::vector underlines_; diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 194c8f52b1aa..fd478768bf32 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -26,12 +26,12 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](ffi::Shape n, ObjectPath n_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](ffi::Shape n, AccessPath n_p, IRDocsifier d) -> Doc { int s = n.size(); Array results; results.reserve(s); for (int i = 0; i < s; ++i) { - results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayIndex(i))); + results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayItem(i))); } return TupleDoc(results); }); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 8bfbcb69ce50..70be98f4c425 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -56,7 +56,7 @@ struct SortableFunction { }; TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IRModule mod, AccessPath p, IRDocsifier d) -> Doc { std::vector functions; for (const auto& kv : mod->functions) { functions.push_back(SortableFunction(kv)); @@ -89,7 +89,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) const GlobalVar& gv = entry.gv; const BaseFunc& base_func = entry.func; d->cfg->binding_names.push_back(gv->name_hint); - Doc doc = d->AsDoc(base_func, p->Attr("functions")->MapValue(gv)); + Doc doc = d->AsDoc(base_func, p->Attr("functions")->MapItem(gv)); d->cfg->binding_names.pop_back(); if (const auto* stmt_block = doc.as()) { (*f)->stmts.push_back(stmt_block->stmts.back()); @@ -113,22 +113,22 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](DictAttrs attrs, AccessPath p, IRDocsifier d) -> Doc { return d->AsDoc(attrs->dict, p->Attr("dict")); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](GlobalVar gv, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](GlobalInfo ginfo, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "dummy_global_info")->Call({}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](VDevice vdev, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](VDevice vdev, AccessPath p, IRDocsifier d) -> Doc { d->AddGlobalInfo("vdevice", vdev); Map config = vdev->target->Export(); return IR(d, "vdevice") @@ -138,12 +138,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](Op op, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FuncType func_type, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "FuncType") ->Call({ d->AsDoc(func_type->arg_types, p->Attr("arg_types")), @@ -152,7 +152,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("ir", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("ir", [](Range range, AccessPath p, IRDocsifier d) -> Doc { return IR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index 63e703be5565..5643ab4de43a 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -24,19 +24,19 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // - "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { + "", [](Array array, AccessPath p, IRDocsifier d) -> Doc { int n = array.size(); Array results; results.reserve(n); for (int i = 0; i < n; ++i) { - results.push_back(d->AsDoc(array[i], p->ArrayIndex(i))); + results.push_back(d->AsDoc(array[i], p->ArrayItem(i))); } return ListDoc(results); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // - "", [](Map dict, ObjectPath p, IRDocsifier d) -> Doc { + "", [](Map dict, AccessPath p, IRDocsifier d) -> Doc { using POO = std::pair; std::vector items{dict.begin(), dict.end()}; bool is_str_map = true; @@ -57,8 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ks.reserve(n); vs.reserve(n); for (int i = 0; i < n; ++i) { - ks.push_back(d->AsDoc(items[i].first, p->MissingMapEntry())); - vs.push_back(d->AsDoc(items[i].second, p->MapValue(items[i].first))); + ks.push_back(d->AsDoc(items[i].first, p->MapItemMissing(items[i].first))); + vs.push_back(d->AsDoc(items[i].second, p->MapItem(items[i].first))); } return DictDoc(ks, vs); }); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 0e22b5000cb8..efe7bc2f937a 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -207,7 +207,7 @@ IRDocsifier::FType& IRDocsifier::vtable() { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_fallback([](ObjectRef obj, ObjectPath p, IRDocsifier d) -> Doc { + .set_fallback([](ObjectRef obj, AccessPath p, IRDocsifier d) -> Doc { return d->AddMetadata(obj); }); diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 22baf1c21c74..d4580af96891 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -22,7 +22,7 @@ namespace tvm { namespace script { namespace printer { -IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& d, // +IfDoc PrintIfExpr(const relax::If& n, const AccessPath& n_p, const IRDocsifier& d, // const Optional& var, const Optional& ann) { using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); @@ -41,7 +41,7 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::MatchCast n, AccessPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; Optional ann = std::nullopt; @@ -57,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::VarBinding n, AccessPath n_p, IRDocsifier d) -> Doc { if (const auto if_ = n->value.as()) { Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](relax::If n, ObjectPath n_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](relax::If n, AccessPath n_p, IRDocsifier d) -> Doc { return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); }); diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 663f8712d9a9..e7e7e21380e4 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -29,7 +29,7 @@ namespace printer { class AttrPrinter { public: - explicit AttrPrinter(ObjectPath p, const IRDocsifier& d, Array* keys, + explicit AttrPrinter(AccessPath p, const IRDocsifier& d, Array* keys, Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} @@ -54,13 +54,13 @@ class AttrPrinter { } } - ObjectPath p; + AccessPath p; const IRDocsifier& d; Array* keys; Array* values; }; -ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifier& d) { +ExprDoc PrintCallee(const relax::Expr& n, const AccessPath& n_p, const IRDocsifier& d) { // TODO(@junrushao): handle callee better if (const auto* ext = n.as()) { return LiteralDoc::Str(ext->global_symbol, n_p); @@ -69,7 +69,7 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi } } -Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& n_p, +Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace"); @@ -87,22 +87,22 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& Array kwargs_keys; Array kwargs_values; // Step 1. Print n->args[0], the callee - args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); // Step 2. Print n->args[1], the input arguments - args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayIndex(1))); + args.push_back(d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1))); // Step 3. Print n->sinfo_args, the output struct info relax::StructInfo o_sinfo = n->sinfo_args[0]; - ObjectPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayIndex(0); + AccessPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayItem(0); bool is_dtensor = false; kwargs_keys.push_back("out_sinfo"); if (const auto* o = o_sinfo.as()) { Array fields; - ObjectPath fields_p = o_sinfo_p->Attr("fields"); + AccessPath fields_p = o_sinfo_p->Attr("fields"); for (int i = 0, l = o->fields.size(); i < l; ++i) { if (o->fields[i].as()) { is_dtensor = true; } - fields.push_back(d->AsDoc(o->fields[i], fields_p->ArrayIndex(i))); + fields.push_back(d->AsDoc(o->fields[i], fields_p->ArrayItem(i))); } kwargs_values.push_back(ListDoc(fields)); } else { @@ -147,7 +147,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& // Step 4. Print n->args[2], the tir variables if (n->args.size() == 3) { kwargs_keys.push_back("tir_vars"); - kwargs_values.push_back(d->AsDoc(n->args[2], n_p->Attr("args")->ArrayIndex(2))); + kwargs_values.push_back(d->AsDoc(n->args[2], n_p->Attr("args")->ArrayItem(2))); } if (n->op.same_as(call_tir_local_view)) { return Relax(d, "dist.call_tir_local_view")->Call(args, kwargs_keys, kwargs_values); @@ -160,7 +160,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& } } -Optional PrintAssertOp(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { +Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { return std::nullopt; @@ -171,15 +171,15 @@ Optional PrintAssertOp(const relax::Call& n, const ObjectPath& n_p, con // (the format string will be interpreted as an argument and there will be a new default format // string given) Array args; - args.push_back(d->AsDoc(n->args[0], n_p->Attr("args")->ArrayIndex(0))); - ExprDoc second_arg = d->AsDoc(n->args[1], n_p->Attr("args")->ArrayIndex(1)); + args.push_back(d->AsDoc(n->args[0], n_p->Attr("args")->ArrayItem(0))); + ExprDoc second_arg = d->AsDoc(n->args[1], n_p->Attr("args")->ArrayItem(1)); for (size_t i = 2; i < n->args.size(); i++) { - args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayItem(i))); } return Relax(d, "assert_op")->Call(args, {"format"}, {second_arg}); } -Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, +Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { @@ -187,7 +187,7 @@ Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, } Array args; - args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); Array kwargs_keys; Array kwargs_values; ICHECK(n->attrs.defined()); @@ -198,7 +198,7 @@ Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, return Relax(d, "hint_on_device")->Call(args); } -Optional PrintToVDevice(const relax::Call& n, const ObjectPath& n_p, +Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { @@ -206,7 +206,7 @@ Optional PrintToVDevice(const relax::Call& n, const ObjectPath& n_p, } Array args; - args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); Array kwargs_keys; Array kwargs_values; ICHECK(n->attrs.defined()); @@ -221,7 +221,7 @@ Optional PrintToVDevice(const relax::Call& n, const ObjectPath& n_p, return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } -Optional PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p, +Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n_p, const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { @@ -232,17 +232,17 @@ Optional PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p, // is the _format_ string, or else roundtripping will fail // (the format string will be interpreted as an argument and there will be a new default format // string given) - ExprDoc first_arg = d->AsDoc(n->args[0], n_p->Attr("args")->ArrayIndex(0)); + ExprDoc first_arg = d->AsDoc(n->args[0], n_p->Attr("args")->ArrayItem(0)); Array args; for (size_t i = 1; i < n->args.size(); i++) { - args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayItem(i))); } return Relax(d, "print")->Call(args, {"format"}, {first_arg}); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Call n, AccessPath n_p, IRDocsifier d) -> Doc { // Special case: call_tir, call_dps_packed, call_tir_with_grad if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); @@ -287,10 +287,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 2. Print args if (!n->args.empty()) { - args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d)); + args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); } for (int i = 1, l = n->args.size(); i < l; ++i) { - args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayIndex(i))); + args.push_back(d->AsDoc(n->args[i], n_p->Attr("args")->ArrayItem(i))); } // Step 3. Print attrs if (n->attrs.defined()) { @@ -316,11 +316,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 4. Print type_args if (n->sinfo_args.size() > 0) { - ObjectPath sinfo_args_p = n_p->Attr("sinfo_args"); + AccessPath sinfo_args_p = n_p->Attr("sinfo_args"); Array sinfo_args; for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) { - sinfo_args.push_back( - d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayIndex(i))); + sinfo_args.push_back(d->AsDoc(n->sinfo_args[i], sinfo_args_p->ArrayItem(i))); } kwargs_keys.push_back("sinfo_args"); kwargs_values.push_back(TupleDoc(sinfo_args)); diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index 9bf49a2830db..d8b3871b35bc 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -29,14 +29,14 @@ namespace printer { // distributed::Placement TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", - [](relax::distributed::Placement n, ObjectPath n_p, + [](relax::distributed::Placement n, AccessPath n_p, IRDocsifier d) -> Doc { return d->AsDoc(n->ToString(), n_p); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::distributed::DTensorStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::distributed::DTensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { Array args; Array kwargs_keys; Array kwargs_values; @@ -45,11 +45,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->tensor_sinfo->shape.value().as()) { auto shape_expr = GetRef(shape); - ObjectPath shape_p = n_p->Attr("shape")->Attr("values"); + AccessPath shape_p = n_p->Attr("shape")->Attr("values"); Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( - PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d)); + PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d)); } args.push_back(TupleDoc(shape_docs)); } else { @@ -90,7 +90,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::distributed::DeviceMesh n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::distributed::DeviceMesh n, AccessPath n_p, IRDocsifier d) -> Doc { bool has_relax_frame = false; const IRFrameNode* f = nullptr; for (const Frame& frame : d->frames) { diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index 808177b15020..c411622e6409 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -29,57 +29,57 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PrimValue n, AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): float numbers return Relax(d, "prim_value")->Call({d->AsDoc(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::StringImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::StringImm n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::DataTypeImm n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::DataTypeImm n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Tuple n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Tuple n, AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): revisit tuple printing if (n->fields.empty()) { return Relax(d, "tuple")->Call({}); } Array fields_doc; - ObjectPath fields_p = n_p->Attr("fields"); + AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { - fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } return TupleDoc(fields_doc); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TupleGetItem n, AccessPath n_p, IRDocsifier d) -> Doc { ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index")); return d->AsDoc(n->tuple, n_p->Attr("tuple"))[{idx}]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ShapeExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeExpr n, AccessPath n_p, IRDocsifier d) -> Doc { Array values_doc; - ObjectPath values_p = n_p->Attr("values"); + AccessPath values_p = n_p->Attr("values"); for (int i = 0, l = n->values.size(); i < l; ++i) { - values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); + values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayItem(i), d)); } return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); -Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { +Optional SpecialScalar(const runtime::NDArray& n, const AccessPath& p) { DataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { @@ -134,7 +134,7 @@ Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::Constant n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc { if (Optional s = SpecialScalar(n->data, n_p->Attr("data"))) { if (n->struct_info_.as()) { ExprDoc ann = d->AsDoc(n->struct_info_, n_p->Attr("struct_info_")); @@ -149,7 +149,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return d->AddMetadata(n); }); -Doc PrintRelaxVar(relax::Var n, ObjectPath p, IRDocsifier d) { +Doc PrintRelaxVar(relax::Var n, AccessPath p, IRDocsifier d) { if (!d->IsVarDefined(n)) { ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); Frame f = d->frames.back(); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 490254fb672d..aa6182f189fe 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -40,7 +40,7 @@ bool AtTopLevelFunction(const IRDocsifier& d) { TVM_FFI_STATIC_INIT_BLOCK({ RelaxFrameNode::RegisterReflection(); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](relax::Function n, AccessPath n_p, IRDocsifier d) -> Doc { std::unordered_set func_vars; With f(d); @@ -64,12 +64,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 2. Print params Array params; { - ObjectPath params_p = n_p->Attr("params"); + AccessPath params_p = n_p->Attr("params"); for (int i = 0, l = n->params.size(); i < l; ++i) { params.push_back(AssignDoc( /*lhs=*/DefineVar(n->params[i], *f, d), /*rhs=*/std::nullopt, - StructInfoAsAnn(n->params[i], params_p->ArrayIndex(i), d, std::nullopt))); + StructInfoAsAnn(n->params[i], params_p->ArrayItem(i), d, std::nullopt))); } } // Step 3. Clean up func variables @@ -106,14 +106,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array dec_values; if (!n->is_pure) { dec_keys.push_back("pure"); - dec_values.push_back(LiteralDoc::Boolean(false, Optional())); + dec_values.push_back(LiteralDoc::Boolean(false, Optional())); } // if the function is global or is not in a module and does not have a global symbol, // indicate that it's private if (AtTopLevelFunction(d) && (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { dec_keys.push_back("private"); - dec_values.push_back(LiteralDoc::Boolean(true, Optional())); + dec_values.push_back(LiteralDoc::Boolean(true, Optional())); } if (dec_keys.size()) { decorator = decorator->Call(pos_args, dec_keys, dec_values); @@ -127,7 +127,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ExternFunc n, AccessPath n_p, IRDocsifier d) -> Doc { // TODO(@junrushao): print more information out of extern function. return Relax(d, "ExternFunc")->Call({LiteralDoc::Str(n->global_symbol, n_p)}); }); diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc index c0010034e436..7cedc63c271c 100644 --- a/src/script/printer/relax/region.cc +++ b/src/script/printer/relax/region.cc @@ -22,14 +22,14 @@ namespace tvm { namespace script { namespace printer { -Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, +Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, bool use_ret) { With f(d); const Array& blocks = n->blocks; - ObjectPath blocks_p = n_p->Attr("blocks"); + AccessPath blocks_p = n_p->Attr("blocks"); Array* stmts = &(*f)->stmts; for (int i = 0, l = blocks.size(); i < l; ++i) { - Doc block = d->AsDoc(blocks[i], blocks_p->ArrayIndex(i)); + Doc block = d->AsDoc(blocks[i], blocks_p->ArrayItem(i)); if (const auto* stmt_block = block.as()) { stmts->insert(stmts->end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); } else if (const auto* stmt = block.as()) { @@ -48,18 +48,18 @@ Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, cons } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](relax::SeqExpr n, ObjectPath n_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](relax::SeqExpr n, AccessPath n_p, IRDocsifier d) -> Doc { return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false)); }); -Array PrintBindingBlock(const relax::BindingBlock& n, const ObjectPath& n_p, +Array PrintBindingBlock(const relax::BindingBlock& n, const AccessPath& n_p, const IRDocsifier& d, Array* non_dataflow_vars) { const Array& bindings = n->bindings; - ObjectPath bindings_p = n_p->Attr("bindings"); + AccessPath bindings_p = n_p->Attr("bindings"); Array stmts; for (int i = 0, l = bindings.size(); i < l; ++i) { const relax::Binding& binding = bindings[i]; - ObjectPath binding_p = bindings_p->ArrayIndex(i); + AccessPath binding_p = bindings_p->ArrayItem(i); ICHECK(binding->var.defined()); Doc binding_doc = d->AsDoc(binding, binding_p); if (const auto* stmt = binding_doc.as()) { @@ -78,13 +78,13 @@ Array PrintBindingBlock(const relax::BindingBlock& n, const ObjectPath& TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::BindingBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::BindingBlock n, AccessPath n_p, IRDocsifier d) -> Doc { return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::DataflowBlock n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::DataflowBlock n, AccessPath n_p, IRDocsifier d) -> Doc { Array non_dataflow_vars; Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 7043952c7c15..87de6a8335f5 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -26,11 +26,11 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ObjectStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ObjectStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); -ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) { +ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d) { ExprDoc expr_doc = d->AsDoc(e, e_p); // Step 1. Find if `func_vars` are being collected const RelaxFrameNode* f = nullptr; @@ -62,7 +62,7 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifie TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PrimStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { Array args; Array kwargs_keys; Array kwargs_values; @@ -79,13 +79,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->values.defined()) { Array shape = n->values.value(); - ObjectPath shape_p = n_p->Attr("values"); + AccessPath shape_p = n_p->Attr("values"); Array shape_docs; for (int i = 0, ndim = shape.size(); i < ndim; ++i) { - shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayIndex(i), d)); + shape_docs.push_back(PrintShapeVar(shape[i], shape_p->ArrayItem(i), d)); } return Relax(d, "Shape")->Call({ListDoc(shape_docs)}); } @@ -95,7 +95,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TensorStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TensorStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { Array args; Array kwargs_keys; Array kwargs_values; @@ -103,11 +103,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Need to dig into ShapeExpr to preserve the `R.shape` prefix if (const auto* shape = n->shape.value().as()) { auto shape_expr = GetRef(shape); - ObjectPath shape_p = n_p->Attr("shape")->Attr("values"); + AccessPath shape_p = n_p->Attr("shape")->Attr("values"); Array shape_docs; for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { shape_docs.push_back( - PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d)); + PrintShapeVar(shape_expr->values[i], shape_p->ArrayItem(i), d)); } args.push_back(TupleDoc(shape_docs)); } else { @@ -137,21 +137,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TupleStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TupleStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->fields.empty()) { return Relax(d, "Tuple"); } Array fields_doc; - ObjectPath fields_p = n_p->Attr("fields"); + AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { - fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } return Relax(d, "Tuple")->Call(fields_doc); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::FuncStructInfo n, AccessPath n_p, IRDocsifier d) -> Doc { auto ret_doc = d->AsDoc(n->ret, n_p->Attr("ret")); auto purity_doc = LiteralDoc::Boolean(n->purity, n_p->Attr("purity")); @@ -177,9 +177,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TODO(@junrushao): track symbolic shape relation Array params_doc; Array params = n->params.value(); - ObjectPath params_p = n_p->Attr("params"); + AccessPath params_p = n_p->Attr("params"); for (int i = 0, n_params = params.size(); i < n_params; ++i) { - params_doc.push_back(d->AsDoc(params[i], params_p->ArrayIndex(i))); + params_doc.push_back(d->AsDoc(params[i], params_p->ArrayItem(i))); } return Relax(d, "Callable")->Call({TupleDoc(params_doc), ret_doc, purity_doc}); }); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index eafd67365dad..67f39a6f6c45 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -41,7 +41,7 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { return f; } -Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { +Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { ICHECK(n->dtype.is_scalar()) << "TypeError: " << "Relax only uses scalar TIR variables," << "but received TIR variable " << n << " with dtype " << n->dtype; @@ -74,7 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", P TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // + "relax", [](tvm::IntImm n, AccessPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases if (n->dtype.is_bool()) { return LiteralDoc::Boolean(n->value, n_p); @@ -85,7 +85,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "relax", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { // + "relax", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // if (Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -97,7 +97,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "relax", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc { // + "relax", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // Optional doc = d->GetVarDoc(mod); ICHECK(doc) << "Unable to print IRModule before definition in Relax."; if (d->cfg->module_alias.empty()) { @@ -117,7 +117,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("relax", [](Range range, AccessPath p, IRDocsifier d) -> Doc { return Relax(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc index de9ef1ae914b..d4ad35a13ee5 100644 --- a/src/script/printer/relax/type.cc +++ b/src/script/printer/relax/type.cc @@ -26,20 +26,20 @@ namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ShapeType n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ShapeType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Shape") ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::ObjectType n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::ObjectType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Object"); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::TensorType n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::TensorType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "Tensor") ->Call({}, {"ndim", "dtype"}, {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")), @@ -48,32 +48,32 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](relax::PackedFuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + "", [](relax::PackedFuncType n, AccessPath n_p, IRDocsifier d) -> Doc { return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this is correct }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "relax", [](tvm::TupleType n, ObjectPath n_p, IRDocsifier d) -> Doc { + "relax", [](tvm::TupleType n, AccessPath n_p, IRDocsifier d) -> Doc { if (n->fields.empty()) { return Relax(d, "Tuple"); } Array fields_doc; - ObjectPath fields_p = n_p->Attr("fields"); + AccessPath fields_p = n_p->Attr("fields"); for (int i = 0, l = n->fields.size(); i < l; ++i) { - fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayIndex(i))); + fields_doc.push_back(d->AsDoc(n->fields[i], fields_p->ArrayItem(i))); } return Relax(d, "Tuple")->Call(fields_doc); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "relax", [](tvm::FuncType n, ObjectPath n_p, IRDocsifier d) -> Doc { + "relax", [](tvm::FuncType n, AccessPath n_p, IRDocsifier d) -> Doc { Array arg_types_doc; Array arg_types = n->arg_types; - ObjectPath arg_types_p = n_p->Attr("arg_types"); + AccessPath arg_types_p = n_p->Attr("arg_types"); for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) { - arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayIndex(i))); + arg_types_doc.push_back(d->AsDoc(arg_types[i], arg_types_p->ArrayItem(i))); } return Relax(d, "Callable") ->Call({TupleDoc(arg_types_doc), // diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 022c6e9c66f4..37ae86220051 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -81,7 +81,7 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif return d->Define(var, frame, var->name_hint().empty() ? "v" : var->name_hint()); } -inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& v_p, +inline Optional StructInfoAsAnn(const relax::Var& v, const AccessPath& v_p, const IRDocsifier& d, const Optional& rhs) { if (!v->struct_info_.defined()) { return std::nullopt; @@ -133,10 +133,10 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); } -Array PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p, const IRDocsifier& d, +Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, const IRDocsifier& d, bool use_ret); -ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d); +ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifier& d); inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifier& d) { Array vdevices = d->global_infos["vdevice"]; diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 519bc9d66ca6..fb4f8a9d772b 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -22,13 +22,13 @@ namespace tvm { namespace script { namespace printer { -Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // - Optional opt_realize, Optional opt_realize_p) { +Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath block_p, // + Optional opt_realize, Optional opt_realize_p) { With frame(d, block); ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); const tir::BlockRealizeNode* realize = opt_realize.defined() ? opt_realize.value().get() : nullptr; - const ObjectPathNode* realize_p = opt_realize_p.defined() ? opt_realize_p.get() : nullptr; + AccessPath realize_p = *opt_realize_p; // Step 1. Handle block var and block bindings // Step 1.1. Obtain all loop var defined along path std::unordered_map loop_vars; @@ -67,7 +67,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // auto print_single_iter_var = [&](int i) { tir::IterVar iter_var = block->iter_vars[i]; - ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); + AccessPath iter_var_p = block_p->Attr("iter_var")->ArrayItem(i); ExprDoc rhs = TIR(d, "axis"); if (iter_var->iter_type == tir::IterVarType::kDataPar) { rhs = rhs->Attr("spatial"); @@ -94,7 +94,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // } if (realize) { ExprDoc binding = d->AsDoc(realize->iter_values[i], // - realize_p->Attr("iter_values")->ArrayIndex(i)); + realize_p->Attr("iter_values")->ArrayItem(i)); rhs = rhs->Call({dom, binding}); } else { rhs = rhs->Call({dom}); @@ -118,13 +118,13 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // lhs.reserve(m); loop_var_doc.reserve(m); std::string binding_type = ""; - Array binding_paths; + Array binding_paths; for (int i : remap_vars_indices) { tir::IterVar iter_var = block->iter_vars[i]; - ObjectPath iter_var_p = block_p->Attr("iter_vars")->ArrayIndex(i); + AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); lhs.push_back(DefineVar(iter_var->var, *frame, d)); loop_var_doc.push_back(d->AsDoc(realize->iter_values[i], - realize_p->Attr("iter_values")->ArrayIndex(i))); + realize_p->Attr("iter_values")->ArrayItem(i))); binding_paths.push_back(iter_var_p->Attr("iter_type")); binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; } @@ -160,12 +160,12 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // { Array reads; for (int i = 0, n = block->reads.size(); i < n; ++i) { - reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayIndex(i))); + reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayItem(i))); } (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads))); Array writes; for (int i = 0, n = block->writes.size(); i < n; ++i) { - writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayIndex(i))); + writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayItem(i))); } (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "writes")->Call(writes))); } @@ -178,7 +178,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // // Step 5. Handle `alloc_buffer` for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) { tir::Buffer buffer = block->alloc_buffers[i]; - ObjectPath buffer_p = block_p->Attr("alloc_buffers")->ArrayIndex(i); + AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *frame, d); ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d, BufferVarDefinition::DataPointer); @@ -187,7 +187,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // // Step 6. Handle `match_buffer` for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { tir::MatchBufferRegion buffer_region = block->match_buffers[i]; - ObjectPath buffer_region_p = block_p->Attr("match_buffers")->ArrayIndex(i); + AccessPath buffer_region_p = block_p->Attr("match_buffers")->ArrayItem(i); StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); (*frame)->stmts.push_back(doc); } @@ -216,7 +216,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tir::BlockRealize realize, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::BlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p); // since we do not have d->AsDoc for realize->block, // we should add possible doc decoration manually. @@ -225,7 +225,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Block block, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Block block, AccessPath p, IRDocsifier d) -> Doc { return PrintBlock(d, block, p, std::nullopt, std::nullopt); }); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index f1b75066cdba..0e7ae3a843cf 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -24,7 +24,7 @@ namespace tvm { namespace script { namespace printer { -Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, const Frame& frame, +Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { using tvm::tir::Var; using tvm::tir::VarNode; @@ -52,14 +52,14 @@ Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, auto is_new_var = [&](const PrimExpr& e) { return e->IsInstance() && !d->IsVarDefined(e); }; - auto add_out_of_line_var_def = [&](const Var& var, const ObjectPath& var_p) { + auto add_out_of_line_var_def = [&](const Var& var, const AccessPath& var_p) { ICHECK(!d->IsVarDefined(var)); ExprDoc lhs = DefineVar(var, frame, d); lhs->source_paths.push_back(var_p); var_def_lhs.push_back(lhs); var_def_rhs.push_back(PrintVarCreation(var, var_p, d)); }; - auto try_inline_def = [&](const PrimExpr& e, const ObjectPath& e_p, + auto try_inline_def = [&](const PrimExpr& e, const AccessPath& e_p, std::function inline_f) { ICHECK(is_new_var(e)); Var var = Downcast(e); @@ -74,13 +74,13 @@ Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, // Step 1. Handle `buffer.shape` { const Array& shape = buffer->shape; - ObjectPath shape_p = buffer_p->Attr("shape"); + AccessPath shape_p = buffer_p->Attr("shape"); int n = shape.size(); Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = shape[i]; - ObjectPath e_p = shape_p->ArrayIndex(i); + AccessPath e_p = shape_p->ArrayItem(i); if (is_new_var(e)) { add_out_of_line_var_def(Downcast(e), e_p); } @@ -109,13 +109,13 @@ Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, // Step 4. Handle `buffer.strides` if (!buffer->strides.empty()) { const Array& strides = buffer->strides; - ObjectPath strides_p = buffer_p->Attr("strides"); + AccessPath strides_p = buffer_p->Attr("strides"); int n = strides.size(); Array results; results.reserve(n); for (int i = 0; i < n; ++i) { PrimExpr e = strides[i]; - ObjectPath e_p = strides_p->ArrayIndex(i); + AccessPath e_p = strides_p->ArrayItem(i); if (is_new_var(e)) { if (try_inline_def(e, e_p, [=]() { return d->AsDoc(buffer, buffer_p) @@ -201,14 +201,14 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Arr } ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, - const ObjectPath& p, const Frame& frame, const IRDocsifier& d, + const AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d, var_definitions), /*args=*/args); } -ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, +ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d) { Map attrs = BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); ExprDoc shape = attrs.Get("shape").value(); @@ -217,7 +217,7 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } -Array BufferIndices(const Array& indices, const ObjectPath& p, +Array BufferIndices(const Array& indices, const AccessPath& p, const IRDocsifier& d) { int n = indices.size(); Array indices_doc; @@ -225,8 +225,8 @@ Array BufferIndices(const Array& indices, const ObjectPath& p, for (int i = 0; i < n; ++i) { if (const auto* ramp = indices[i].as()) { if (const auto* stride = ramp->stride.as()) { - ObjectPath ramp_p = p->Attr("indices")->ArrayIndex(i); - ObjectPath stride_p = ramp_p->Attr("stride"); + AccessPath ramp_p = p->Attr("indices")->ArrayItem(i); + AccessPath stride_p = ramp_p->Attr("stride"); ExprDoc start = d->AsDoc(ramp->base, // ramp_p->Attr("base")); ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, // @@ -239,18 +239,18 @@ Array BufferIndices(const Array& indices, const ObjectPath& p, continue; } } - indices_doc.push_back(d->AsDoc(indices[i], p->Attr("indices")->ArrayIndex(i))); + indices_doc.push_back(d->AsDoc(indices[i], p->Attr("indices")->ArrayItem(i))); } return indices_doc; } -Array BufferSlices(const Array& region, const ObjectPath& p, const IRDocsifier& d) { +Array BufferSlices(const Array& region, const AccessPath& p, const IRDocsifier& d) { int n = region.size(); Array indices; indices.reserve(n); for (int i = 0; i < n; ++i) { Range range = region[i]; - ObjectPath range_p = p->ArrayIndex(i); + AccessPath range_p = p->ArrayItem(i); ExprDoc min = d->AsDoc(range->min, range_p->Attr("min")); if (tir::is_one(range->extent)) { indices.push_back(min); @@ -264,14 +264,14 @@ Array BufferSlices(const Array& region, const ObjectPath& p, const I TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tir::BufferRegion buffer_region, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::BufferRegion buffer_region, AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = d->AsDoc(buffer_region->buffer, p->Attr("buffer")); return prefix[BufferSlices(buffer_region->region, p->Attr("region"), d)]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::BufferStore store, AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); ExprDoc value = d->AsDoc(store->value, p->Attr("value")); @@ -290,7 +290,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::BufferLoad load, AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); // Use .vload(...) syntax when there is a predicate @@ -304,7 +304,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { if (!d->IsVarDefined(buffer)) { if (Optional opt_f = FindLowestVarDef(buffer, d)) { ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); @@ -322,7 +322,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::MatchBufferRegion stmt, AccessPath p, IRDocsifier d) -> Doc { Frame frame = d->frames.back(); ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d); ExprDoc src_buffer = d->AsDoc(stmt->source, p->Attr("source")); @@ -333,7 +333,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::ProducerLoad load, AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(load->producer->GetNameHint()); return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index d0b14753cc16..78b52edf859c 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -24,9 +24,9 @@ namespace tvm { namespace script { namespace printer { -ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) { +ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; - ObjectPath type_p = var_p->Attr("type_annotation"); + AccessPath type_p = var_p->Attr("type_annotation"); ExprDoc rhs{nullptr}; Array kwargs_keys; Array kwargs_values; @@ -64,7 +64,7 @@ ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRD return rhs; } -Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) { +Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { if (Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); @@ -82,17 +82,17 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tir::Var var, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Var var, AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tir::SizeVar var, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::SizeVar var, AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::IterVar var, ObjectPath var_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::IterVar var, AccessPath var_p, IRDocsifier d) -> Doc { return TIR(d, "iter_var") ->Call({ d->AsDoc(var->var, var_p->Attr("var")), @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Not node, AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); if (a->IsInstance()) { return TIR(d, "Not")->Call({a}); @@ -112,7 +112,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::StringImm s, AccessPath p, IRDocsifier d) -> Doc { if (HasMultipleLines(s->value)) { return d->AddMetadata(s); } else { @@ -121,14 +121,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Cast cast, AccessPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); return TIR(d, "Cast")->Call({dtype, value}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Select select, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Select") ->Call({ d->AsDoc(select->condition, p->Attr("condition")), @@ -138,7 +138,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Ramp ramp, ObjectPath ramp_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Ramp ramp, AccessPath ramp_p, IRDocsifier d) -> Doc { return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), @@ -147,7 +147,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Broadcast bc, ObjectPath bc_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Broadcast bc, AccessPath bc_p, IRDocsifier d) -> Doc { return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), @@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::Shuffle shuffle, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Shuffle") ->Call({ d->AsDoc(shuffle->vectors, p->Attr("vectors")), @@ -167,7 +167,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::CommReducer r, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { ICHECK_EQ(r->lhs.size(), r->rhs.size()); LambdaDoc lambda{nullptr}; { @@ -185,7 +185,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array results; results.reserve(n_results); for (int i = 0; i < n_results; ++i) { - results.push_back(d->AsDoc(r->result[i], p->Attr("result")->ArrayIndex(i))); + results.push_back(d->AsDoc(r->result[i], p->Attr("result")->ArrayItem(i))); } if (results.size() == 1) { lambda = LambdaDoc(vars, results[0]); @@ -197,8 +197,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "comm_reducer")->Call({lambda, id}); }); -LambdaDoc PrintIndexMap(const ObjectRef& map, const Array& vs, const ObjectPath& vs_p, - const Array& es, const ObjectPath& es_p, const IRDocsifier& d) { +LambdaDoc PrintIndexMap(const ObjectRef& map, const Array& vs, const AccessPath& vs_p, + const Array& es, const AccessPath& es_p, const IRDocsifier& d) { With f(d, map); Array vars; for (int i = 0, l = vs.size(); i < l; ++i) { @@ -206,14 +206,14 @@ LambdaDoc PrintIndexMap(const ObjectRef& map, const Array& vs, const O } Array exprs; for (int i = 0, l = es.size(); i < l; ++i) { - exprs.push_back(d->AsDoc(es[i], es_p->ArrayIndex(i))); + exprs.push_back(d->AsDoc(es[i], es_p->ArrayItem(i))); } return LambdaDoc(vars, TupleDoc(exprs)); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::IndexMap m, ObjectPath m_p, IRDocsifier d) -> Doc { + "", [](tir::IndexMap m, AccessPath m_p, IRDocsifier d) -> Doc { LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices"), m->final_indices, m_p->Attr("final_indices"), d); if (m->inverse_index_map.defined()) { @@ -229,7 +229,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Let let, AccessPath p, IRDocsifier d) -> Doc { DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, {d->AsDoc(let->value, p->Attr("value"))}); return TIR(d, "Let")->Call({d->AsDoc(let->body, p->Attr("body"))}, // @@ -237,7 +237,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Call call, AccessPath call_p, IRDocsifier d) -> Doc { static const OpAttrMap& op_names = Op::GetAttrMap("TScriptPrinterName"); static const OpAttrMap dtype_locations = @@ -270,9 +270,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n_args; ++i) { if ((i == 0) && (f_llvm_lookup_intrinsic_name)) { String name = (*f_llvm_lookup_intrinsic_name)(id).cast(); - args.push_back(LiteralDoc::Str(name.c_str(), call_p->Attr("args")->ArrayIndex(i))); + args.push_back(LiteralDoc::Str(name.c_str(), call_p->Attr("args")->ArrayItem(i))); } else { - args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); + args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayItem(i))); } } if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { @@ -293,7 +293,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } for (int i = 0; i < n_args; ++i) { - args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); + args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayItem(i))); } if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); @@ -302,7 +302,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Reduce r, AccessPath p, IRDocsifier d) -> Doc { ExprDoc combiner = d->AsDoc(r->combiner, p->Attr("combiner")); ExprDoc source = d->AsDoc(r->source, p->Attr("source")); ExprDoc init = d->AsDoc(r->init, p->Attr("init")); @@ -318,7 +318,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch("", \ - [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + [](tir::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ return TIR(d, OpString)->Call({a, b}); \ @@ -334,7 +334,7 @@ bool IsNumber(const ExprDoc& e) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Div node, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Div node, AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); ExprDoc b = d->AsDoc(node->b, p->Attr("b")); PrimExpr ret = tvm::div(node->a, node->b); @@ -351,7 +351,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ .set_dispatch( \ - "", [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + "", [](tir::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 0df53c481f0c..bfdae3b14221 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -23,7 +23,7 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::For loop, ObjectPath loop_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::For loop, AccessPath loop_p, IRDocsifier d) -> Doc { // Step 1. Check syntactic sugar: `T.grid` std::vector grid; std::unordered_set grid_loop_vars; diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 1d035609cc9d..688c58e6de09 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -65,7 +65,7 @@ int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::PrimFunc func, AccessPath p, IRDocsifier d) -> Doc { With f(d, func); (*f)->AddDispatchToken(d, "tir"); IdDoc func_name = IdDoc(FindFunctionName(d, func).value_or("main")); @@ -87,12 +87,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { tir::Var var = func->params[i]; - ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + AccessPath var_p = p->Attr("params")->ArrayItem(i); if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { tir::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { - ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); + AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d); args.push_back(AssignDoc(lhs, std::nullopt, annotation)); @@ -134,7 +134,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) continue; } ExprDoc param_doc = args[i]->lhs; - ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); + AccessPath buffer_p = p->Attr("buffer_map")->MapItem(param); ExprDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d, BufferVarDefinition::MatchBuffer); @@ -163,12 +163,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }(); if (d->cfg->syntax_sugar && implicit_root_block) { tir::Block root_block = implicit_root_block.value(); - ObjectPath root_block_p = p->Attr("body")->Attr("block"); + AccessPath root_block_p = p->Attr("body")->Attr("block"); (*f)->stmts.push_back(CommentDoc("with T.block(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { tir::Buffer buffer = root_block->alloc_buffers[i]; - ObjectPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayIndex(i); + AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *f, d, BufferVarDefinition::DataPointer); @@ -191,7 +191,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (!func->attrs.defined() || !func->attrs->dict.count(tvm::attr::kGlobalSymbol)) { Array pos_args; decorator = decorator->Call(pos_args, {"private"}, - {LiteralDoc::Boolean(true, Optional())}); + {LiteralDoc::Boolean(true, Optional())}); } return HeaderWrapper(d, FunctionDoc( @@ -206,7 +206,7 @@ TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "tir", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc { // + "tir", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // if (Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -218,7 +218,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "tir", [](tvm::IRModule mod, ObjectPath n_p, IRDocsifier d) -> Doc { // + "tir", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // Optional doc = d->GetVarDoc(mod); ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 9aaf3e3411aa..a99d4236158f 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -27,7 +27,7 @@ namespace printer { TVM_FFI_STATIC_INIT_BLOCK({ TIRFrameNode::RegisterReflection(); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](IntImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](IntImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == d->cfg->int_dtype) { return LiteralDoc::Int(imm->value, imm_p->Attr("value")); @@ -40,7 +40,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](FloatImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](FloatImm imm, AccessPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; if (dtype == d->cfg->float_dtype) { return LiteralDoc::Float(imm->value, imm_p->Attr("value")); @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("tir", [](Range range, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), @@ -60,12 +60,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PrimType ty, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, DType2Str(ty->dtype)); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](PointerType ty, ObjectPath ty_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](PointerType ty, AccessPath ty_p, IRDocsifier d) -> Doc { ExprDoc element_type{nullptr}; if (const auto* prim_type = ty->element_type.as()) { element_type = LiteralDoc::DataType(prim_type->dtype, // @@ -82,7 +82,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](TupleType ty, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](TupleType ty, AccessPath p, IRDocsifier d) -> Doc { if (ty->fields.empty()) { return LiteralDoc::None(p); } @@ -90,7 +90,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](Target target, AccessPath p, IRDocsifier d) -> Doc { Map config = target->Export(); return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 50756bceb706..5a52de1849f1 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -79,11 +79,11 @@ Optional FindReturnValue(const tir::Stmt& node) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc { if (d->cfg->syntax_sugar) { if (auto return_value = FindReturnValue(eval)) { - ExprDoc value = d->AsDoc(return_value.value(), - p->Attr("value")->Attr("args")->ArrayIndex(0)); + ExprDoc value = + d->AsDoc(return_value.value(), p->Attr("value")->Attr("args")->ArrayItem(0)); return ReturnDoc(value); } } @@ -96,7 +96,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::LetStmt stmt, AccessPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); // Step 1. Type annotation Optional type_doc = d->AsDoc(stmt->var->type_annotation, // @@ -132,7 +132,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); ExprDoc msg = d->AsDoc(stmt->message, p->Attr("message")); @@ -147,7 +147,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::While stmt, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::While stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); @@ -155,7 +155,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); namespace { -Doc DeclBufferDoc(tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d, +Doc DeclBufferDoc(tir::DeclBuffer stmt, AccessPath p, IRDocsifier d, BufferVarDefinition var_definitions) { bool concise = AllowConciseScoping(d, stmt); ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d, @@ -169,13 +169,13 @@ Doc DeclBufferDoc(tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d, TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { return DeclBufferDoc(stmt, p, d, BufferVarDefinition::None); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::IfThenElse stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); Array then_branch; Array else_branch; @@ -193,7 +193,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::SeqStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tir::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc { With f(d, stmt); AsDocBody(stmt, p, f->get(), d); return StmtBlockDoc((*f)->stmts); @@ -220,7 +220,7 @@ bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { + "", [](tir::Allocate stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt_p); if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) { return DeclBufferDoc(Downcast(stmt->body), stmt_p->Attr("body"), d, @@ -278,7 +278,7 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( - "", [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { + "", [](tir::AllocateConst stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); Array args; @@ -333,14 +333,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional value, // - ObjectPath p, IRDocsifier d) { + AccessPath p, IRDocsifier d) { ExprDoc buffer = d->AsDoc(stmt->buffer, p->Attr("buffer")); { Array bounds; bounds.reserve(stmt->bounds.size()); for (int i = 0, n = stmt->bounds.size(); i < n; ++i) { Range range = stmt->bounds[i]; - ObjectPath range_p = p->Attr("bounds")->ArrayIndex(i); + AccessPath range_p = p->Attr("bounds")->ArrayItem(i); bounds.push_back( SliceDoc(d->AsDoc(range->min, range_p->Attr("min")), d->AsDoc(range->min + range->extent, range_p->Attr("extent")), // @@ -361,7 +361,7 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, OptionalCall(args, kwargs_keys, kwargs_values); } -void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p, +void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath& iter_var_p, const IRDocsifier& d) { Frame f = FindLowestVarDef(iter_var->var, d).value(); DefineVar(iter_var->var, f, d); @@ -372,10 +372,10 @@ void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p, f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } -ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p, +ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, Optional* define_var, const IRDocsifier& d) { tir::IterVar iter_var = Downcast(attr_stmt->node); - ObjectPath iter_var_p = attr_stmt_p->Attr("node"); + AccessPath iter_var_p = attr_stmt_p->Attr("node"); ExprDoc var_doc{nullptr}; if (d->IsVarDefined(iter_var->var)) { @@ -396,7 +396,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& at TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { + "", [](tir::BufferRealize stmt, AccessPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); ExprDoc rhs = DocsifyBufferRealize(stmt.get(), std::nullopt, p, d); With f(d, stmt); @@ -406,13 +406,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // - "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { + "", [](tir::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); Optional lhs = std::nullopt; Optional rhs = std::nullopt; Optional define_var = std::nullopt; tir::Stmt body = stmt->body; - ObjectPath body_p = stmt_p->Attr("body"); + AccessPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { if (const auto* realize = stmt->body.as()) { // TODO(tqchen): add any.same_as(ObjectRef) diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 4114600f0c46..4474a83ca8ff 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -109,12 +109,12 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I * \param f The frame * \param d The IRDocsifier */ -inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, const IRDocsifier& d) { +inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { Array body = seq_stmt->seq; for (int i = 0, n = body.size(); i < n; ++i) { f->allow_concise_scoping = (i == n - 1); - Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayIndex(i)); + Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); doc->source_paths.push_back(p); if (const auto* block = doc.as()) { f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); @@ -215,7 +215,7 @@ enum class BufferVarDefinition { * \return The ExprDoc corresponding to the buffer declaration */ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, - const ObjectPath& p, const Frame& frame, const IRDocsifier& d, + const AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions); /*! @@ -226,7 +226,7 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< * \param d The IRDocsifier * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, +ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d); /*! @@ -236,7 +236,7 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& * \param d The IRDocsifier * \return The ExprDoc corresponding to the Var creation */ -ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d); +ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d); /*! \brief A Var occurrence counter visitor */ class OccurrenceCounter : public tir::StmtExprVisitor { diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 95d24c91c41e..1e3a258579a2 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -50,7 +50,7 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f, const PrinterConfig& cfg) { - Doc doc = d->AsDoc(obj, ObjectPath::Root()); + Doc doc = d->AsDoc(obj, AccessPath::Root()); bool move_source_paths = false; if (const auto* expr_doc = doc.as()) { if (!cfg->verbose_expr) { diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index fdc93f0e90d3..5d85ef31e88e 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index 0e94a0e88da0..9e85e4cc86c7 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -31,6 +31,8 @@ namespace tvm { namespace tir { +using AccessPath = ffi::reflection::AccessPath; + namespace { class PurityChecker : TIRVisitorWithPath { public: @@ -43,12 +45,12 @@ class PurityChecker : TIRVisitorWithPath { private: explicit PurityChecker(bool assert_on_error) : assert_on_error_(assert_on_error) {} - void VisitStmt_(const AllocateNode* op, ObjectPath path) override { + void VisitStmt_(const AllocateNode* op, AccessPath path) override { internal_allocations_.insert(op->buffer_var); TIRVisitorWithPath::VisitStmt_(op, path); } - void VisitStmt_(const BufferStoreNode* op, ObjectPath path) override { + void VisitStmt_(const BufferStoreNode* op, AccessPath path) override { TIRVisitorWithPath::VisitStmt_(op, path); if (!internal_allocations_.count(op->buffer->data)) { @@ -60,7 +62,7 @@ class PurityChecker : TIRVisitorWithPath { } } - void VisitExpr_(const CallNode* call, ObjectPath path) override { + void VisitExpr_(const CallNode* call, AccessPath path) override { TIRVisitorWithPath::VisitExpr_(call, path); static auto op_call_effect = Op::GetAttrMap("TCallEffectKind"); diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index f2dc526df31f..2efd3648a5bb 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -39,6 +39,8 @@ namespace tvm { namespace tir { +using AccessPath = ffi::reflection::AccessPath; + namespace { template @@ -230,19 +232,19 @@ class UndefinedVarVerifier : public Verifier { private: using Verifier::Visit; - void Visit(const PrimFunc& prim_func, ObjectPath path) override { + void Visit(const PrimFunc& prim_func, AccessPath path) override { Verifier::Visit(prim_func, path); redefine_allowed_within_function_.clear(); } - void EnterDef(const IterVar& iter_var, ObjectPath path) override { + void EnterDef(const IterVar& iter_var, AccessPath path) override { Verifier::EnterDef(iter_var, path); if (iter_var->iter_type == IterVarType::kThreadIndex) { redefine_allowed_within_function_.insert(iter_var->var); } } - void EnterDef(const Var& var, ObjectPath path) override { + void EnterDef(const Var& var, AccessPath path) override { bool redefine_is_allowed = redefine_allowed_within_function_.count(var); { auto it = currently_defined_.find(var); @@ -265,14 +267,14 @@ class UndefinedVarVerifier : public Verifier { currently_defined_.insert({var, path}); } - void ExitDef(const Var& var, ObjectPath path) override { + void ExitDef(const Var& var, AccessPath path) override { auto active_def = currently_defined_.find(var); currently_defined_.erase(active_def); previously_defined_.insert({var, path}); } - void VisitExpr_(const VarNode* op, ObjectPath path) override { + void VisitExpr_(const VarNode* op, AccessPath path) override { auto var = GetRef(op); auto active_def = currently_defined_.find(var); @@ -292,10 +294,10 @@ class UndefinedVarVerifier : public Verifier { } // Variables that are defined in the currently-visited scope. - std::unordered_map currently_defined_; + std::unordered_map currently_defined_; // Variables that were previously defined, and are now out of scope. - std::unordered_map previously_defined_; + std::unordered_map previously_defined_; // Special variables that are allowed to be re-defined, so long as // that re-definition occurs within the same PrimFunc. For example @@ -315,12 +317,12 @@ class SingleEnvThreadVerifier : public Verifier { using Verifier::Verifier; private: - void Visit(const PrimFunc& prim_func, ObjectPath path) override { + void Visit(const PrimFunc& prim_func, AccessPath path) override { Verifier::Visit(prim_func, path); env_thread_vars_.clear(); } - void EnterDef(const IterVar& iter_var, ObjectPath path) override { + void EnterDef(const IterVar& iter_var, AccessPath path) override { if (iter_var->iter_type == IterVarType::kThreadIndex) { if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) { const auto& [prev_var, prev_path] = it->second; @@ -340,7 +342,7 @@ class SingleEnvThreadVerifier : public Verifier { } } - std::unordered_map> env_thread_vars_; + std::unordered_map> env_thread_vars_; }; bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 1dbbe75528d7..aa3ca1959c5d 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -21,9 +21,10 @@ * \file tir/ir/tir_visitor_with_path.cc * \brief Provide a TIR visitor that tracks the current location */ - #include "tir_visitor_with_path.h" +#include + #include #include #include @@ -33,7 +34,9 @@ namespace tvm { namespace tir { -void TIRVisitorWithPath::Visit(const IRModule& mod, ObjectPath path) { +using AccessPath = ffi::reflection::AccessPath; + +void TIRVisitorWithPath::Visit(const IRModule& mod, AccessPath path) { // To ensure deterministic order of visits, sort the GlobalVar first // by visibility (public then private), then alphabetically by name. std::vector gvars; @@ -59,20 +62,20 @@ void TIRVisitorWithPath::Visit(const IRModule& mod, ObjectPath path) { std::vector> context; for (const auto& gvar : gvars) { - context.push_back(WithDef(gvar, path->Attr("global_var_map_")->MapValue(gvar->name_hint))); + context.push_back(WithDef(gvar, path->Attr("global_var_map_")->MapItem(gvar->name_hint))); } for (const auto& gvar : gvars) { auto base_func = mod->functions[gvar]; if (auto prim_func = base_func.as()) { - Visit(prim_func.value(), path->Attr("functions")->MapValue(gvar)); + Visit(prim_func.value(), path->Attr("functions")->MapItem(gvar)); } } while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { +void TIRVisitorWithPath::Visit(const PrimFunc& func, AccessPath path) { // The implicit definitions from a PrimFunc::buffer_map are pretty // weird. They only apply if no previous definition of that // variable has occurred. Therefore, to ensure that we only avoid @@ -82,14 +85,14 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { auto ppath = path->Attr("params"); for (size_t i = 0; i < func->params.size(); i++) { - context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i))); + context.push_back(WithDef(func->params[i], ppath->ArrayItem(i))); } auto buffer_map_path = path->Attr("buffer_map"); for (size_t i = 0; i < func->params.size(); i++) { if (auto opt = func->buffer_map.Get(func->params[i])) { auto buf = opt.value(); - auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); + auto buf_path = buffer_map_path->MapItem(ppath->ArrayItem(i)); for (auto& def : WithMatchBufferDefs(buf, buf_path)) { context.push_back(std::move(def)); @@ -101,7 +104,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { // visit the buffer definition itself. for (size_t i = 0; i < func->params.size(); i++) { if (auto opt = func->buffer_map.Get(func->params[i])) { - auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); + auto buf_path = buffer_map_path->MapItem(ppath->ArrayItem(i)); context.push_back(WithDef(opt.value(), buf_path)); } } @@ -111,18 +114,18 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::EnterDef(const IterVar& iter_var, ObjectPath path) { +void TIRVisitorWithPath::EnterDef(const IterVar& iter_var, AccessPath path) { if (iter_var->dom.defined()) { Visit(iter_var->dom, path->Attr("dom")); } EnterDef(iter_var->var, path->Attr("var")); } -void TIRVisitorWithPath::ExitDef(const IterVar& iter_var, ObjectPath path) { +void TIRVisitorWithPath::ExitDef(const IterVar& iter_var, AccessPath path) { ExitDef(iter_var->var, path->Attr("var")); } -void TIRVisitorWithPath::EnterDef(const Buffer& buffer, ObjectPath path) { +void TIRVisitorWithPath::EnterDef(const Buffer& buffer, AccessPath path) { // Defining a buffer counts as using all parameters in the buffer // (e.g. shape/strides). Visit(buffer->data, path->Attr("data")); @@ -130,9 +133,9 @@ void TIRVisitorWithPath::EnterDef(const Buffer& buffer, ObjectPath path) { Visit(buffer->strides, path->Attr("strides")); Visit(buffer->elem_offset, path->Attr("elem_offset")); } -void TIRVisitorWithPath::ExitDef(const Buffer& buffer, ObjectPath path) {} +void TIRVisitorWithPath::ExitDef(const Buffer& buffer, AccessPath path) {} -void TIRVisitorWithPath::Visit(const Buffer& buffer, ObjectPath path) { +void TIRVisitorWithPath::Visit(const Buffer& buffer, AccessPath path) { // Using a buffer *also* counts as using all parameters in the buffer. Visit(buffer->data, path->Attr("data")); Visit(buffer->shape, path->Attr("shape")); @@ -140,12 +143,12 @@ void TIRVisitorWithPath::Visit(const Buffer& buffer, ObjectPath path) { Visit(buffer->elem_offset, path->Attr("elem_offset")); } -void TIRVisitorWithPath::Visit(const BufferRegion& region, ObjectPath path) { +void TIRVisitorWithPath::Visit(const BufferRegion& region, AccessPath path) { Visit(region->buffer, path->Attr("buffer")); Visit(region->region, path->Attr("region")); } -void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, ObjectPath path) { +void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, AccessPath path) { Visit(match->source, path->Attr("source")); // MatchBufferRegion define the match->buffer, but do not own the @@ -153,25 +156,25 @@ void TIRVisitorWithPath::Visit(const MatchBufferRegion& match, ObjectPath path) // definitions are handled in the BlockNode visitor. } -void TIRVisitorWithPath::Visit(const IterVar& iter_var, ObjectPath path) { +void TIRVisitorWithPath::Visit(const IterVar& iter_var, AccessPath path) { if (iter_var->dom.defined()) { Visit(iter_var->dom, path->Attr("dom")); } Visit(iter_var->var, path->Attr("var")); } -void TIRVisitorWithPath::Visit(const Range& range, ObjectPath path) { +void TIRVisitorWithPath::Visit(const Range& range, AccessPath path) { Visit(range->min, path->Attr("min")); Visit(range->extent, path->Attr("extent")); } -void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); auto context = WithDef(op->var, path->Attr("var")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); std::vector, DefContext>> context; @@ -194,9 +197,9 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { ICHECK_EQ(arr.size(), 2U); Buffer buffer_view = Downcast(arr[0]); Buffer orig_buffer = Downcast(arr[1]); - Visit(orig_buffer, path->Attr("node")->ArrayIndex(1)); + Visit(orig_buffer, path->Attr("node")->ArrayItem(1)); - for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) { + for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayItem(0))) { context.push_back(std::move(var)); } @@ -210,43 +213,43 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { } } -void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const ForNode* op, AccessPath path) { Visit(op->min, path->Attr("min")); Visit(op->extent, path->Attr("extent")); auto context = WithDef(op->loop_var, path->Attr("loop_var")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const WhileNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const AllocateNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const AllocateNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->extents, path->Attr("extents")); auto context = WithDef(op->buffer_var, path->Attr("buffer_var")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const AllocateConstNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const AllocateConstNode* op, AccessPath path) { Visit(op->extents, path->Attr("extents")); auto context = WithDef(op->buffer_var, path->Attr("buffer_var")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, AccessPath path) { auto context = WithDef(op->buffer, path->Attr("buffer")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); Visit(op->buffer, path->Attr("buffer")); Visit(op->indices, path->Attr("indices")); } -void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->bounds, path->Attr("bounds")); auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data")); @@ -254,33 +257,33 @@ void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const IfThenElseNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->then_case, path->Attr("then_case")); Visit(op->else_case, path->Attr("else_case")); } -void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->message, path->Attr("message")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, AccessPath path) { Visit(op->seq, path->Attr("seq")); } -void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const EvaluateNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); } -void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, AccessPath path) { std::vector, DefContext, DefContext>> context; { auto iter_path = path->Attr("iter_vars"); for (size_t i = 0; i < op->iter_vars.size(); i++) { - context.push_back(WithDef(op->iter_vars[i], iter_path->ArrayIndex(i))); + context.push_back(WithDef(op->iter_vars[i], iter_path->ArrayItem(i))); } } Visit(op->reads, path->Attr("reads")); @@ -289,7 +292,7 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { { auto alloc_path = path->Attr("alloc_buffers"); for (size_t i = 0; i < op->alloc_buffers.size(); i++) { - auto buffer_path = alloc_path->ArrayIndex(i); + auto buffer_path = alloc_path->ArrayItem(i); auto buf = op->alloc_buffers[i]; context.push_back(WithDef(buf->data, buffer_path->Attr("data"))); context.push_back(WithDef(buf, buffer_path)); @@ -302,7 +305,7 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { for (size_t i = 0; i < op->match_buffers.size(); i++) { auto buf = op->match_buffers[i]->buffer; - auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer"); + auto buffer_path = match_path->ArrayItem(i)->Attr("buffer"); for (auto& def : WithMatchBufferDefs(buf, buffer_path)) { context.push_back(std::move(def)); @@ -316,34 +319,34 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { while (context.size()) context.pop_back(); } -void TIRVisitorWithPath::VisitStmt_(const BlockRealizeNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitStmt_(const BlockRealizeNode* op, AccessPath path) { Visit(op->iter_values, path->Attr("iter_values")); Visit(op->predicate, path->Attr("predicate")); Visit(op->block, path->Attr("block")); } -void TIRVisitorWithPath::VisitExpr_(const VarNode* op, ObjectPath path) {} +void TIRVisitorWithPath::VisitExpr_(const VarNode* op, AccessPath path) {} -void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, AccessPath path) { VisitExpr_(static_cast(op), path); } -void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, AccessPath path) { Visit(op->buffer, path->Attr("buffer")); Visit(op->indices, path->Attr("indices")); } -void TIRVisitorWithPath::VisitExpr_(const ProducerLoadNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const ProducerLoadNode* op, AccessPath path) { Visit(op->indices, path->Attr("indices")); } -void TIRVisitorWithPath::VisitExpr_(const LetNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const LetNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); auto context = WithDef(op->var, path->Attr("var")); Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitExpr_(const CallNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const CallNode* op, AccessPath path) { if (auto gvar = op->op.as()) { Visit(gvar.value(), path->Attr("op")); } @@ -351,7 +354,7 @@ void TIRVisitorWithPath::VisitExpr_(const CallNode* op, ObjectPath path) { } #define DEFINE_BINOP_VISIT_(OP) \ - void TIRVisitorWithPath::VisitExpr_(const OP* op, ObjectPath path) { \ + void TIRVisitorWithPath::VisitExpr_(const OP* op, AccessPath path) { \ Visit(op->a, path->Attr("a")); \ Visit(op->b, path->Attr("b")); \ } @@ -376,43 +379,43 @@ DEFINE_BINOP_VISIT_(OrNode); #undef DEFINE_BINOP_VISIT_ -void TIRVisitorWithPath::VisitExpr_(const IntImmNode* op, ObjectPath path) {} -void TIRVisitorWithPath::VisitExpr_(const FloatImmNode* op, ObjectPath path) {} -void TIRVisitorWithPath::VisitExpr_(const StringImmNode* op, ObjectPath path) {} +void TIRVisitorWithPath::VisitExpr_(const IntImmNode* op, AccessPath path) {} +void TIRVisitorWithPath::VisitExpr_(const FloatImmNode* op, AccessPath path) {} +void TIRVisitorWithPath::VisitExpr_(const StringImmNode* op, AccessPath path) {} -void TIRVisitorWithPath::VisitExpr_(const ReduceNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const ReduceNode* op, AccessPath path) { Visit(op->axis, path->Attr("axis")); Visit(op->source, path->Attr("source")); Visit(op->init, path->Attr("init")); Visit(op->condition, path->Attr("condition")); } -void TIRVisitorWithPath::VisitExpr_(const CastNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const CastNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); } -void TIRVisitorWithPath::VisitExpr_(const NotNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const NotNode* op, AccessPath path) { Visit(op->a, path->Attr("a")); } -void TIRVisitorWithPath::VisitExpr_(const SelectNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const SelectNode* op, AccessPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->true_value, path->Attr("true_value")); Visit(op->false_value, path->Attr("false_value")); } -void TIRVisitorWithPath::VisitExpr_(const RampNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const RampNode* op, AccessPath path) { Visit(op->base, path->Attr("base")); Visit(op->stride, path->Attr("stride")); Visit(op->lanes, path->Attr("lanes")); } -void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op, AccessPath path) { Visit(op->indices, path->Attr("indices")); Visit(op->vectors, path->Attr("vectors")); } -void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, ObjectPath path) { +void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, AccessPath path) { Visit(op->value, path->Attr("value")); Visit(op->lanes, path->Attr("lanes")); } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 6b1cd8ace487..0ff9da33eb6d 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -37,118 +37,119 @@ namespace tvm { namespace tir { -/*! \brief Visit TIR while tracking the ObjectPath */ -class TIRVisitorWithPath : protected ExprFunctor, - protected StmtFunctor { +/*! \brief Visit TIR while tracking the ffi::reflection::AccessPath */ +class TIRVisitorWithPath + : protected ExprFunctor, + protected StmtFunctor { public: template void operator()(TObjectRef&& obj) { - Visit(std::forward(obj), ObjectPath::Root()); + Visit(std::forward(obj), ffi::reflection::AccessPath::Root()); } protected: // Delegate to ExprFunctor::VisitExpr for PrimExpr, and any subclasses - inline void Visit(const PrimExpr& obj, ObjectPath path) { VisitExpr(obj, path); } + inline void Visit(const PrimExpr& obj, ffi::reflection::AccessPath path) { VisitExpr(obj, path); } // Delegate to ExprFunctor::VisitStmt for Stmt, and any subclasses - inline void Visit(const Stmt& obj, ObjectPath path) { VisitStmt(obj, path); } + inline void Visit(const Stmt& obj, ffi::reflection::AccessPath path) { VisitStmt(obj, path); } // Visitors for TIR constructs that are neither PrimExpr nor Stmt - virtual void Visit(const IRModule& obj, ObjectPath path); - virtual void Visit(const PrimFunc& obj, ObjectPath path); - virtual void Visit(const GlobalVar& obj, ObjectPath path) {} - virtual void Visit(const Range& obj, ObjectPath path); - virtual void Visit(const Buffer& obj, ObjectPath path); - virtual void Visit(const BufferRegion& obj, ObjectPath path); - virtual void Visit(const MatchBufferRegion& obj, ObjectPath path); - virtual void Visit(const IterVar& obj, ObjectPath path); + virtual void Visit(const IRModule& obj, ffi::reflection::AccessPath path); + virtual void Visit(const PrimFunc& obj, ffi::reflection::AccessPath path); + virtual void Visit(const GlobalVar& obj, ffi::reflection::AccessPath path) {} + virtual void Visit(const Range& obj, ffi::reflection::AccessPath path); + virtual void Visit(const Buffer& obj, ffi::reflection::AccessPath path); + virtual void Visit(const BufferRegion& obj, ffi::reflection::AccessPath path); + virtual void Visit(const MatchBufferRegion& obj, ffi::reflection::AccessPath path); + virtual void Visit(const IterVar& obj, ffi::reflection::AccessPath path); // Called when entering/exiting the scope of a GlobalVar definition. - virtual void EnterDef(const GlobalVar& var, ObjectPath path) {} - virtual void ExitDef(const GlobalVar& var, ObjectPath path) {} + virtual void EnterDef(const GlobalVar& var, ffi::reflection::AccessPath path) {} + virtual void ExitDef(const GlobalVar& var, ffi::reflection::AccessPath path) {} // Called when entering/exiting the scope of a tir::Var definition. - virtual void EnterDef(const Var& var, ObjectPath path) {} - virtual void ExitDef(const Var& var, ObjectPath path) {} + virtual void EnterDef(const Var& var, ffi::reflection::AccessPath path) {} + virtual void ExitDef(const Var& var, ffi::reflection::AccessPath path) {} // Called when entering/exiting the scope of an IterVar definition. // By default, visits the `Range IterVarNode::dom`, then enters the // scope of the internal `tir::Var`. - virtual void EnterDef(const IterVar& var, ObjectPath path); - virtual void ExitDef(const IterVar& var, ObjectPath path); + virtual void EnterDef(const IterVar& var, ffi::reflection::AccessPath path); + virtual void ExitDef(const IterVar& var, ffi::reflection::AccessPath path); // Called when entering/exiting the scope of a Buffer definition. // By default, visits the buffer's data pointer, shape, strides, and // elem_offset, which must be defined prior to defining the Buffer. - virtual void EnterDef(const Buffer& buffer, ObjectPath path); - virtual void ExitDef(const Buffer& buffer, ObjectPath path); + virtual void EnterDef(const Buffer& buffer, ffi::reflection::AccessPath path); + virtual void ExitDef(const Buffer& buffer, ffi::reflection::AccessPath path); // Utility to visit an array of nodes template - inline void Visit(const Array& arr, ObjectPath path) { + inline void Visit(const Array& arr, ffi::reflection::AccessPath path) { for (size_t i = 0; i < arr.size(); i++) { - Visit(arr[i], path->ArrayIndex(i)); + Visit(arr[i], path->ArrayItem(i)); } } // Utility to visit an optional node nodes template - inline void Visit(const Optional& opt, ObjectPath path) { + inline void Visit(const Optional& opt, ffi::reflection::AccessPath path) { if (opt) { Visit(opt.value(), path); } } using StmtFunctor::VisitStmt; - void VisitStmt_(const AttrStmtNode* op, ObjectPath path) override; - void VisitStmt_(const IfThenElseNode* op, ObjectPath path) override; - void VisitStmt_(const LetStmtNode* op, ObjectPath path) override; - void VisitStmt_(const ForNode* op, ObjectPath path) override; - void VisitStmt_(const WhileNode* op, ObjectPath path) override; - void VisitStmt_(const AllocateNode* op, ObjectPath path) override; - void VisitStmt_(const AllocateConstNode* op, ObjectPath path) override; - void VisitStmt_(const DeclBufferNode* op, ObjectPath path) override; - void VisitStmt_(const BufferStoreNode* op, ObjectPath path) override; - void VisitStmt_(const BufferRealizeNode* op, ObjectPath path) override; - void VisitStmt_(const AssertStmtNode* op, ObjectPath path) override; - void VisitStmt_(const SeqStmtNode* op, ObjectPath path) override; - void VisitStmt_(const EvaluateNode* op, ObjectPath path) override; - void VisitStmt_(const BlockNode* op, ObjectPath path) override; - void VisitStmt_(const BlockRealizeNode* op, ObjectPath path) override; + void VisitStmt_(const AttrStmtNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const IfThenElseNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const LetStmtNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const ForNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const WhileNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const AllocateNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const AllocateConstNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const DeclBufferNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const BufferStoreNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const BufferRealizeNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const AssertStmtNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const SeqStmtNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const EvaluateNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const BlockNode* op, ffi::reflection::AccessPath path) override; + void VisitStmt_(const BlockRealizeNode* op, ffi::reflection::AccessPath path) override; using ExprFunctor::VisitExpr; - void VisitExpr_(const VarNode* op, ObjectPath path) override; - void VisitExpr_(const SizeVarNode* op, ObjectPath path) override; - void VisitExpr_(const BufferLoadNode* op, ObjectPath path) override; - void VisitExpr_(const ProducerLoadNode* op, ObjectPath path) override; - void VisitExpr_(const LetNode* op, ObjectPath path) override; - void VisitExpr_(const CallNode* op, ObjectPath path) override; - void VisitExpr_(const AddNode* op, ObjectPath path) override; - void VisitExpr_(const SubNode* op, ObjectPath path) override; - void VisitExpr_(const MulNode* op, ObjectPath path) override; - void VisitExpr_(const DivNode* op, ObjectPath path) override; - void VisitExpr_(const ModNode* op, ObjectPath path) override; - void VisitExpr_(const FloorDivNode* op, ObjectPath path) override; - void VisitExpr_(const FloorModNode* op, ObjectPath path) override; - void VisitExpr_(const MinNode* op, ObjectPath path) override; - void VisitExpr_(const MaxNode* op, ObjectPath path) override; - void VisitExpr_(const EQNode* op, ObjectPath path) override; - void VisitExpr_(const NENode* op, ObjectPath path) override; - void VisitExpr_(const LTNode* op, ObjectPath path) override; - void VisitExpr_(const LENode* op, ObjectPath path) override; - void VisitExpr_(const GTNode* op, ObjectPath path) override; - void VisitExpr_(const GENode* op, ObjectPath path) override; - void VisitExpr_(const AndNode* op, ObjectPath path) override; - void VisitExpr_(const OrNode* op, ObjectPath path) override; - void VisitExpr_(const ReduceNode* op, ObjectPath path) override; - void VisitExpr_(const CastNode* op, ObjectPath path) override; - void VisitExpr_(const NotNode* op, ObjectPath path) override; - void VisitExpr_(const SelectNode* op, ObjectPath path) override; - void VisitExpr_(const RampNode* op, ObjectPath path) override; - void VisitExpr_(const BroadcastNode* op, ObjectPath path) override; - void VisitExpr_(const ShuffleNode* op, ObjectPath path) override; - void VisitExpr_(const IntImmNode* op, ObjectPath path) override; - void VisitExpr_(const FloatImmNode* op, ObjectPath path) override; - void VisitExpr_(const StringImmNode* op, ObjectPath path) override; + void VisitExpr_(const VarNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const SizeVarNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const BufferLoadNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const ProducerLoadNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const LetNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const CallNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const AddNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const SubNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const MulNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const DivNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const ModNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const FloorDivNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const FloorModNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const MinNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const MaxNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const EQNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const NENode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const LTNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const LENode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const GTNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const GENode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const AndNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const OrNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const ReduceNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const CastNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const NotNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const SelectNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const RampNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const BroadcastNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const ShuffleNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const IntImmNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const FloatImmNode* op, ffi::reflection::AccessPath path) override; + void VisitExpr_(const StringImmNode* op, ffi::reflection::AccessPath path) override; // Utility to call EnterDef/ExitDef. Used in the implementation of // WithDef. @@ -180,7 +181,7 @@ class TIRVisitorWithPath : protected ExprFunctorin_scope_definitions_.insert(obj_); self_->EnterDef(obj_, path_); @@ -195,19 +196,19 @@ class TIRVisitorWithPath : protected ExprFunctor - DefContext WithDef(T obj, ObjectPath path) { + DefContext WithDef(T obj, ffi::reflection::AccessPath path) { return DefContext(this, obj, path); } /* \brief Utility to track the scope of a node's definition. */ template - std::optional> WithDefIfUndefined(T obj, ObjectPath path) { + std::optional> WithDefIfUndefined(T obj, ffi::reflection::AccessPath path) { if (in_scope_definitions_.count(obj)) { return std::nullopt; } else { @@ -215,10 +216,11 @@ class TIRVisitorWithPath : protected ExprFunctor> WithMatchBufferDefs(Buffer buf, ObjectPath path) { + std::vector> WithMatchBufferDefs(Buffer buf, ffi::reflection::AccessPath path) { std::vector> context; - auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) { + auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, + ffi::reflection::AccessPath path) { if (auto opt = expr.as()) { auto var = opt.value(); if (auto var_def = WithDefIfUndefined(var, path)) { @@ -227,9 +229,10 @@ class TIRVisitorWithPath : protected ExprFunctor& arr, ObjectPath path) { + const Array& arr, + ffi::reflection::AccessPath path) { for (size_t i = 0; i < arr.size(); i++) { - try_visit_implicit_var_def(arr[i], path->ArrayIndex(i)); + try_visit_implicit_var_def(arr[i], path->ArrayItem(i)); } }; diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 84556aab6b27..251b33f910e7 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -18,8 +18,8 @@ import tvm import tvm.testing +from tvm.ffi.access_path import AccessPath from tvm.ir.base import get_first_structural_mismatch -from tvm.runtime import ObjectPath def get_first_mismatch_ensure_symmetry(a, b): @@ -56,32 +56,32 @@ def get_first_mismatch_ensure_symmetry(a, b): ( [1, 2, 3], [1, 4, 3], - ObjectPath.root().array_index(1), - ObjectPath.root().array_index(1), + AccessPath.root().array_item(1), + AccessPath.root().array_item(1), ), ( [1, 2, 3], [10, 2, 30], - ObjectPath.root().array_index(0), - ObjectPath.root().array_index(0), + AccessPath.root().array_item(0), + AccessPath.root().array_item(0), ), ( [1, 3, 4], [1, 2, 3, 4], - ObjectPath.root().array_index(1), - ObjectPath.root().array_index(1), + AccessPath.root().array_item(1), + AccessPath.root().array_item(1), ), ( [1, 2, 3], [1, 2, 3, 4], - ObjectPath.root().missing_array_element(3), - ObjectPath.root().array_index(3), + AccessPath.root().array_item_missing(3), + AccessPath.root().array_item(3), ), ( [], [1], - ObjectPath.root().missing_array_element(0), - ObjectPath.root().array_index(0), + AccessPath.root().array_item_missing(0), + AccessPath.root().array_item(0), ), ], ) @@ -141,14 +141,14 @@ def test_string_map_structural_equal_to_self(contents): ( dict(a=3, b=4), dict(a=3, b=5), - ObjectPath.root().map_value("b"), - ObjectPath.root().map_value("b"), + AccessPath.root().map_item("b"), + AccessPath.root().map_item("b"), ), ( dict(a=3, b=4), dict(a=3, b=4, c=5), - ObjectPath.root().missing_map_entry(), - ObjectPath.root().map_value("c"), + AccessPath.root().map_item_missing("c"), + AccessPath.root().map_item("c"), ), ], ) diff --git a/tests/python/ir/test_object_path.py b/tests/python/ir/test_object_path.py deleted file mode 100644 index 3fea5141c745..000000000000 --- a/tests/python/ir/test_object_path.py +++ /dev/null @@ -1,159 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pytest -import tvm -from tvm.runtime import object_path -from tvm.runtime.object_path import ObjectPath - - -def test_root_path(): - root = ObjectPath.root() - assert isinstance(root, object_path.RootPath) - assert str(root) == "" - assert len(root) == 1 - assert root == ObjectPath.root() - assert root.parent is None - - -def test_named_root_path(): - root = ObjectPath.root("base_name") - assert isinstance(root, object_path.RootPath) - assert str(root) == "base_name" - assert len(root) == 1 - assert root != ObjectPath.root() - assert root == ObjectPath.root("base_name") - assert root.parent is None - - -def test_path_attr(): - path = ObjectPath.root().attr("foo") - assert isinstance(path, object_path.AttributeAccessPath) - assert str(path) == ".foo" - assert len(path) == 2 - assert path.parent == ObjectPath.root() - - -def test_path_attr_unknown(): - path = ObjectPath.root().attr(None) - assert isinstance(path, object_path.UnknownAttributeAccessPath) - assert str(path) == "." - assert len(path) == 2 - assert path.parent == ObjectPath.root() - - -def test_path_array_index(): - path = ObjectPath.root().array_index(2) - assert isinstance(path, object_path.ArrayIndexPath) - assert str(path) == "[2]" - assert len(path) == 2 - assert path.parent == ObjectPath.root() - - -def test_path_missing_array_element(): - path = ObjectPath.root().missing_array_element(2) - assert isinstance(path, object_path.MissingArrayElementPath) - assert str(path) == "[]" - assert len(path) == 2 - assert path.parent == ObjectPath.root() - - -def test_path_map_value(): - path = ObjectPath.root().map_value("foo") - assert isinstance(path, object_path.MapValuePath) - assert str(path) == '["foo"]' - assert len(path) == 2 - assert path.parent == ObjectPath.root() - - -def test_path_missing_map_entry(): - path = ObjectPath.root().missing_map_entry() - assert isinstance(path, object_path.MissingMapEntryPath) - assert str(path) == "[]" - assert len(path) == 2 - assert path.parent == ObjectPath.root() - - -@pytest.mark.parametrize( - "a, b, expected", - [ - (ObjectPath.root(), ObjectPath.root(), True), - (ObjectPath.root(), ObjectPath.root().attr("foo"), True), - (ObjectPath.root().attr("foo"), ObjectPath.root(), False), - (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo"), True), - (ObjectPath.root().attr("bar"), ObjectPath.root().attr("foo"), False), - (ObjectPath.root().attr("foo"), ObjectPath.root().attr("foo").array_index(2), True), - (ObjectPath.root().attr("foo").array_index(2), ObjectPath.root().attr("foo"), False), - (ObjectPath.root().attr("foo"), ObjectPath.root().attr("bar").array_index(2), False), - ], -) -def test_path_is_prefix_of(a, b, expected): - assert a.is_prefix_of(b) == expected - - -paths_for_equality_test = [ - ObjectPath.root(), - ObjectPath.root().attr("foo"), - ObjectPath.root().attr("bar"), - ObjectPath.root().array_index(3), - ObjectPath.root().array_index(4), - ObjectPath.root().missing_array_element(3), - ObjectPath.root().missing_array_element(4), - ObjectPath.root().map_value("foo"), - ObjectPath.root().map_value("bar"), - ObjectPath.root().missing_map_entry(), - ObjectPath.root().attr("foo").missing_map_entry(), -] - - -def make_test_params_for_eq_test(): - return [ - pytest.param(idx, path, id="path{}".format(idx)) - for idx, path in enumerate(paths_for_equality_test) - ] - - -@pytest.mark.parametrize("a_idx, a_path", make_test_params_for_eq_test()) -@pytest.mark.parametrize("b_idx, b_path", make_test_params_for_eq_test()) -def test_path_equal(a_idx, a_path, b_idx, b_path): - expected = a_idx == b_idx - result = a_path == b_path - assert result == expected - - -def test_path_get_prefix(): - p1 = ObjectPath.root() - p2 = p1.attr("foo") - p3 = p2.array_index(5) - - assert p3.parent == p2 - assert p2.parent == p1 - assert p1.parent is None - - assert p2.get_prefix(1) == p1 - - assert p3.get_prefix(1) == p1 - assert p3.get_prefix(2) == p2 - assert p3.get_prefix(3) == p3 - - with pytest.raises(IndexError) as e: - p3.get_prefix(0) - assert "Prefix length must be at least 1" in str(e.value) - - with pytest.raises(IndexError) as e: - p3.get_prefix(4) - assert "Attempted to get a prefix longer than the path itself" in str(e.value) diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index 32099cecf4b2..601afd8f164f 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -18,7 +18,7 @@ import numpy as np import pytest from tvm import te -from tvm.runtime import ObjectPath +from tvm.ffi.access_path import AccessPath from tvm.script import tir as T, ir as I @@ -139,8 +139,8 @@ def test_prim_func_param_count_mismatch(): func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x)) func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x)) lhs_path, rhs_path = get_sequal_mismatch(func0, func1) - expected_lhs_path = ObjectPath.root().attr("params").missing_array_element(2) - expected_rhs_path = ObjectPath.root().attr("params").array_index(2) + expected_lhs_path = AccessPath.root().attr("params").array_item_missing(2) + expected_rhs_path = AccessPath.root().attr("params").array_item(2) assert lhs_path == expected_lhs_path assert rhs_path == expected_rhs_path @@ -153,7 +153,7 @@ def test_prim_func_param_dtype_mismatch(): func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x)) func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x)) lhs_path, rhs_path = get_sequal_mismatch(func0, func1) - expected_path = ObjectPath.root().attr("params").array_index(1).attr("dtype") + expected_path = AccessPath.root().attr("params").array_item(1).attr("dtype") assert lhs_path == expected_path assert rhs_path == expected_path @@ -167,7 +167,7 @@ def test_prim_func_body_mismatch(): func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0)) func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1)) lhs_path, rhs_path = get_sequal_mismatch(func0, func1) - expected_path = ObjectPath.root().attr("body").attr("value").attr("b") + expected_path = AccessPath.root().attr("body").attr("value").attr("b") assert lhs_path == expected_path assert rhs_path == expected_path @@ -260,7 +260,7 @@ def test_buffer_map_mismatch(): lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) expected_path = ( - ObjectPath.root().attr("buffer_map").map_value(x).attr("shape").array_index(1).attr("value") + AccessPath.root().attr("buffer_map").map_item(x).attr("shape").array_item(1).attr("value") ) assert lhs_path == expected_path assert rhs_path == expected_path @@ -280,9 +280,9 @@ def test_buffer_map_length_mismatch(): lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) - expected_lhs_path = ObjectPath.root().attr("buffer_map").missing_map_entry() + expected_lhs_path = AccessPath.root().attr("buffer_map").map_item_missing(y) assert lhs_path == expected_lhs_path - expected_rhs_path = ObjectPath.root().attr("buffer_map").map_value(y) + expected_rhs_path = AccessPath.root().attr("buffer_map").map_item(y) assert rhs_path == expected_rhs_path @@ -316,7 +316,7 @@ def test_while_condition_mismatch(): w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x)) lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) - expected_path = ObjectPath.root().attr("condition") + expected_path = AccessPath.root().attr("condition") assert lhs_path == expected_path assert rhs_path == expected_path @@ -326,7 +326,7 @@ def test_while_body_mismatch(): w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1)) lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) - expected_path = ObjectPath.root().attr("body").attr("value") + expected_path = AccessPath.root().attr("body").attr("value") assert lhs_path == expected_path assert rhs_path == expected_path @@ -351,7 +351,7 @@ def test_seq_mismatch(): ) lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) expected_path = ( - ObjectPath.root().attr("seq").array_index(2).attr("value").attr("b").attr("value") + AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value") ) assert lhs_path == expected_path assert rhs_path == expected_path @@ -371,7 +371,7 @@ def test_seq_mismatch_different_lengths(): seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)]) lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) expected_path = ( - ObjectPath.root().attr("seq").array_index(2).attr("value").attr("b").attr("value") + AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value") ) assert lhs_path == expected_path assert rhs_path == expected_path @@ -389,8 +389,8 @@ def test_seq_length_mismatch(): ) seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)]) lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) - expected_lhs_path = ObjectPath.root().attr("seq").array_index(3) - expected_rhs_path = ObjectPath.root().attr("seq").missing_array_element(3) + expected_lhs_path = AccessPath.root().attr("seq").array_item(3) + expected_rhs_path = AccessPath.root().attr("seq").array_item_missing(3) assert lhs_path == expected_lhs_path assert rhs_path == expected_rhs_path diff --git a/tests/python/tvmscript/test_tvmscript_printer_annotation.py b/tests/python/tvmscript/test_tvmscript_printer_annotation.py index fb57ae9ce635..c45c0a91c5c5 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_annotation.py +++ b/tests/python/tvmscript/test_tvmscript_printer_annotation.py @@ -18,7 +18,7 @@ from typing import Optional import pytest -from tvm.runtime import ObjectPath +from tvm.ffi.access_path import AccessPath from tvm.script import tir as T @@ -34,13 +34,13 @@ def _func(): T.evaluate(7) -def test_annotation_multi_object_paths(): +def test_annotation_multi_access_paths(): result = _func.with_attr("global_symbol", "main").script( path_to_annotate={ - ObjectPath.root().attr("body").attr("seq").array_index(1): "annotation 1", - ObjectPath.root().attr("body").attr("seq").array_index(3): "annotation 3", - ObjectPath.root().attr("body").attr("seq").array_index(5): "annotation 5", - ObjectPath.root().attr("body").attr("seq").array_index(7): "annotation 7", + AccessPath.root().attr("body").attr("seq").array_item(1): "annotation 1", + AccessPath.root().attr("body").attr("seq").array_item(3): "annotation 3", + AccessPath.root().attr("body").attr("seq").array_item(5): "annotation 5", + AccessPath.root().attr("body").attr("seq").array_item(7): "annotation 7", } ) assert ( diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py b/tests/python/tvmscript/test_tvmscript_printer_doc.py index e3d1280b32ae..20a705f9ff83 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_doc.py +++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py @@ -20,9 +20,8 @@ """ import pytest - import tvm -from tvm.runtime import ObjectPath +from tvm.ffi.access_path import AccessPath from tvm.script.printer.doc import ( AssertDoc, AssignDoc, @@ -547,7 +546,7 @@ def test_doc_source_paths(): doc = IdDoc("x") assert len(doc.source_paths) == 0 - source_paths = [ObjectPath.root(), ObjectPath.root().attr("x")] + source_paths = [AccessPath.root(), AccessPath.root().attr("x")] doc.source_paths = source_paths # This should triggers the __getattr__ and gets a tvm.ir.container.Array diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py index bbf95801ed0a..f3a385ca0911 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py +++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py @@ -19,7 +19,7 @@ import tvm from tvm.ir import assert_structural_equal -from tvm.runtime import ObjectPath +from tvm.ffi.access_path import AccessPath from tvm.script import ir as I, tir as T @@ -53,17 +53,17 @@ def func2(a: T.handle, b: T.handle): assert _error_message(ve.value) == _expected_result( func1, func2, - ObjectPath.root() + AccessPath.root() .attr("buffer_map") - .map_value(func1.params[1]) + .map_item(func1.params[1]) .attr("shape") - .array_index(1) + .array_item(1) .attr("value"), - ObjectPath.root() + AccessPath.root() .attr("buffer_map") - .map_value(func2.params[1]) + .map_item(func2.params[1]) .attr("shape") - .array_index(1) + .array_item(1) .attr("value"), ) @@ -86,15 +86,15 @@ def func(): assert _error_message(ve.value) == _expected_result( module1, module2, - ObjectPath.root() + AccessPath.root() .attr("functions") - .map_value(module1.get_global_var("func")) + .map_item(module1.get_global_var("func")) .attr("body") .attr("value") .attr("value"), - ObjectPath.root() + AccessPath.root() .attr("functions") - .map_value(module2.get_global_var("func")) + .map_item(module2.get_global_var("func")) .attr("body") .attr("value") .attr("value"), @@ -121,8 +121,8 @@ def func2(): assert _error_message(ve.value) == _expected_result( func1, func2, - ObjectPath.root().attr("body").attr("extents").array_index(0).attr("value"), - ObjectPath.root().attr("body").attr("extents").array_index(0).attr("value"), + AccessPath.root().attr("body").attr("extents").array_item(0).attr("value"), + AccessPath.root().attr("body").attr("extents").array_item(0).attr("value"), ) @@ -147,8 +147,8 @@ def func2(): assert _error_message(ve.value) == _expected_result( func1, func2, - ObjectPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"), - ObjectPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"), + AccessPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"), + AccessPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"), ) diff --git a/tests/python/tvmscript/test_tvmscript_printer_underlining.py b/tests/python/tvmscript/test_tvmscript_printer_underlining.py index a0fc139a2d29..e36e96c77d7f 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_underlining.py +++ b/tests/python/tvmscript/test_tvmscript_printer_underlining.py @@ -18,7 +18,7 @@ from typing import Optional import pytest -from tvm.runtime import ObjectPath +from tvm.ffi.access_path import AccessPath from tvm.script.printer.doc import ( ExprStmtDoc, IdDoc, @@ -30,8 +30,8 @@ from tvm.script import ir as I, tir as T -def make_path(name: str) -> ObjectPath: - return ObjectPath.root().attr(name) +def make_path(name: str) -> AccessPath: + return AccessPath.root().attr(name) def make_id_doc(name: str, path_name: Optional[str] = None) -> IdDoc: @@ -312,7 +312,7 @@ def test_underline_and_print_line_numbers(): ) -def test_underline_multi_object_paths(): +def test_underline_multi_access_paths(): doc = StmtBlockDoc([ExprStmtDoc(make_id_doc(f"line{i + 1}")) for i in range(10)]) result = to_python_script( doc, @@ -479,7 +479,7 @@ def func(): result = func.with_attr("global_symbol", "main").script( path_to_underline=[ - ObjectPath.root(), + AccessPath.root(), ] ) assert result == format_script( @@ -505,7 +505,7 @@ def func(): result = irmodule.script( path_to_underline=[ - ObjectPath.root().attr("functions").map_value(irmodule.get_global_var("func")), + AccessPath.root().attr("functions").map_item(irmodule.get_global_var("func")), ] ) assert result == format_script( @@ -534,7 +534,7 @@ def func(): result = irmodule.script( path_to_underline=[ - ObjectPath.root(), + AccessPath.root(), ] ) assert result == format_script(