-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TVMScript] Printer VarTable #12336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[TVMScript] Printer VarTable #12336
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| #ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_ | ||
| #define TVM_SCRIPT_PRINTER_VAR_TABLE_H_ | ||
|
|
||
| #include <tvm/node/node.h> | ||
| #include <tvm/node/object_path.h> | ||
| #include <tvm/script/printer/doc.h> | ||
| #include <tvm/script/printer/frame.h> | ||
| #include <tvm/script/printer/traced_object.h> | ||
|
|
||
| #include <unordered_map> | ||
| #include <unordered_set> | ||
|
|
||
| namespace tvm { | ||
| namespace script { | ||
| namespace printer { | ||
|
|
||
| /*! | ||
| * \brief Variable Table manages mapping from variable object to ExprDoc during | ||
| * the process of printing TVMScript. | ||
| * | ||
| * The value type of this map is ExprDoc rather than IdDoc or String. It's | ||
| * because variables can be implicitly defined. For example in TIR buffer (tir::Buffer), | ||
| * `buf->data` is a variable, while its representation in TVMScript should be an | ||
| * expression `x.data`, where `x` is the variable for the buffer itself. | ||
| */ | ||
| class VarTableNode : public Object { | ||
| public: | ||
| void VisitAttrs(AttrVisitor*) {} | ||
|
|
||
| /*! | ||
| * \brief Define variable by name. | ||
| * \param obj The variable object. | ||
| * \param name_hint The hint for variable name. | ||
| * \param object_path The object_path for the returned ExprDoc. | ||
| * \param frame The frame that this variable is defined in. | ||
| * | ||
| * \return The id doc for this variable. | ||
| * | ||
| * This function will rename the variable to avoid name conflict with other variables | ||
| * in the table. | ||
| */ | ||
| IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path, | ||
| const Frame& frame); | ||
|
|
||
| /*! | ||
| * \brief Define variable by name. | ||
| * \param obj The variable object. | ||
| * \param name_hint The hint for variable name. | ||
| * \param frame The frame that this variable is defined in. | ||
| * | ||
| * \return The id doc for this variable. | ||
| * | ||
| * This is a shortcut version of `Define` which accepts a traced string. | ||
| */ | ||
| IdDoc Define(const ObjectRef& obj, const TracedObject<String>& name_hint, const Frame& frame) { | ||
| return Define(obj, name_hint.Get(), name_hint.GetPath(), frame); | ||
| } | ||
|
|
||
| using DocFactory = std::function<ExprDoc()>; | ||
|
|
||
| /*! | ||
| * \brief Define variable by doc factory. | ||
| * \param obj The variable object. | ||
| * \param doc_factory The function to return an ExprDoc object for this variable. | ||
| * \param frame The frame that this variable is defined in. | ||
| * | ||
| * This function is a special form of `Define`. Variable is mapped to ExprDoc rather | ||
| * than IdDoc. It's useful when a variable is implicitly defined without a name, like | ||
| * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`. | ||
| * | ||
| * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to | ||
| * return a new Doc object every time it's called, as the returned doc will have | ||
| * different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject | ||
| * so VarTable needs to call a factory function to get a freshly-constructed Doc object | ||
| * every time GetVarDoc is called. | ||
| */ | ||
| void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame); | ||
|
|
||
| /*! | ||
| * \brief Get the doc for variable. | ||
| * \param obj The variable object. | ||
| * \param object_path The object path for the variable. | ||
| * | ||
| * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. | ||
| */ | ||
| Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; | ||
|
|
||
| /*! | ||
| * \brief Check if a variable exists in the table. | ||
| * \param obj The variable object. | ||
| * | ||
| * \return a boolean for whether variable exists. | ||
| */ | ||
| bool IsVarDefined(const ObjectRef& obj) const; | ||
|
|
||
| static constexpr const char* _type_key = "script.printer.VarTable"; | ||
| TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object); | ||
|
|
||
| private: | ||
| void RemoveVar(const ObjectRef& obj); | ||
|
|
||
| struct VariableInfo { | ||
| DocFactory doc_factory; | ||
| Optional<String> name; | ||
| }; | ||
| std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info; | ||
| std::unordered_set<String> defined_names; | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Reference type of VarTableNode. | ||
| */ | ||
| class VarTable : public ObjectRef { | ||
| public: | ||
| /*! | ||
| * \brief Create an empty VarTable. | ||
| */ | ||
| VarTable(); | ||
| TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode); | ||
| }; | ||
|
|
||
| } // namespace printer | ||
| } // namespace script | ||
| } // namespace tvm | ||
|
|
||
| #endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Functions to print doc into text format""" | ||
|
|
||
| from typing import Callable, Optional | ||
|
|
||
| from tvm._ffi import register_object | ||
| from tvm.runtime import Object, ObjectPath | ||
|
|
||
| from . import _ffi_api | ||
| from .doc import ExprDoc, IdDoc | ||
| from .frame import Frame | ||
|
|
||
|
|
||
| @register_object("script.printer.VarTable") | ||
| class VarTable(Object): | ||
| """ | ||
| Variable Table manages mapping from variable object to ExprDoc during | ||
| the process of printing TVMScript. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| """ | ||
| Create an empty VarTable. | ||
| """ | ||
| self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member | ||
|
|
||
| def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc: | ||
| """ | ||
| Define a variable by name. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| obj : Object | ||
| The variable object. | ||
| name_hint : str | ||
| The hint for variable name. | ||
| object_path : ObjectPath | ||
| The object path to be associated with the returned ExprDoc. | ||
| frame : Frame | ||
| Then frame that this variable is defined in. | ||
|
|
||
| Returns | ||
| ------- | ||
| doc : IdDoc | ||
| The doc for this variable. | ||
| """ | ||
| return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member | ||
|
|
||
| def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None: | ||
| """ | ||
| Define a variable by ExprDoc. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| obj : Object | ||
| The variable object. | ||
| doc_factory : Callable[[], ExprDoc] | ||
| The hint for variable name. | ||
| frame : Frame | ||
| Then frame that this variable is defined in. | ||
|
|
||
| Returns | ||
| ------- | ||
| None | ||
| """ | ||
| _ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member | ||
|
|
||
| def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]: | ||
| """ | ||
| Get the doc for a variable. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| obj : Object | ||
| The variable object. | ||
| object_path : ObjectPath | ||
| The object path to be associated with the returned ExprDoc. | ||
|
|
||
| Returns | ||
| ------- | ||
| doc : ExprDoc | ||
| The doc for this variable. | ||
| """ | ||
| return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member | ||
|
|
||
| def is_var_defined(self, obj: Object) -> bool: | ||
| """ | ||
| Check whether a variable is defined. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| obj : Object | ||
| The variable object. | ||
|
|
||
| Returns | ||
| ------- | ||
| is_defined : bool | ||
| Whether the variable is defined. | ||
| """ | ||
| return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member | ||
|
|
||
| def __contains__(self, obj: Object) -> bool: | ||
| return self.is_var_defined(obj) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| #include <tvm/node/object_path.h> | ||
| #include <tvm/runtime/container/optional.h> | ||
| #include <tvm/runtime/logging.h> | ||
| #include <tvm/runtime/registry.h> | ||
| #include <tvm/script/printer/var_table.h> | ||
|
|
||
| namespace tvm { | ||
| namespace script { | ||
| namespace printer { | ||
|
|
||
| String GenerateUniqueName(const String& name_hint, std::unordered_set<String>* defined_names) { | ||
| String name = name_hint; | ||
| for (int i = 1; !defined_names->insert(name).second; ++i) { | ||
| name = name_hint + "_" + std::to_string(i); | ||
| } | ||
| return name; | ||
| } | ||
|
|
||
| IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint, | ||
| const ObjectPath& object_path, const Frame& frame) { | ||
| String name = GenerateUniqueName(name_hint, &this->defined_names); | ||
| DocFactory doc_factory = [name]() { return IdDoc(name); }; | ||
|
|
||
| auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); | ||
| ICHECK(result.second) << "Duplicated object: " << obj; | ||
|
|
||
| IdDoc def_doc(name); | ||
| def_doc->source_paths.push_back(object_path); | ||
|
|
||
| frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); | ||
|
|
||
| return def_doc; | ||
| } | ||
|
|
||
| void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) { | ||
| ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; | ||
|
|
||
| ICHECK(!doc_factory()->IsInstance<IdDocNode>()) | ||
| << "VarTableNode::Define cannot be used for variable that's mapped to IdDoc."; | ||
|
|
||
| obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); | ||
|
|
||
| frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); | ||
| } | ||
|
|
||
| Optional<ExprDoc> VarTableNode::GetVarDoc(const ObjectRef& obj, | ||
| const ObjectPath& object_path) const { | ||
| auto it = obj2info.find(obj); | ||
| if (it == obj2info.end()) { | ||
| return NullOpt; | ||
| } | ||
| ExprDoc doc = it->second.doc_factory(); | ||
| doc->source_paths.push_back(object_path); | ||
| return doc; | ||
| } | ||
|
|
||
| bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } | ||
|
|
||
| void VarTableNode::RemoveVar(const ObjectRef& obj) { | ||
| auto it = obj2info.find(obj); | ||
| ICHECK(it != obj2info.end()) << "No such object: " << obj; | ||
|
|
||
| if (it->second.name.defined()) { | ||
| defined_names.erase(it->second.name.value()); | ||
| } | ||
| obj2info.erase(it); | ||
| } | ||
|
|
||
| VarTable::VarTable() { data_ = make_object<VarTableNode>(); } | ||
|
|
||
| TVM_REGISTER_NODE_TYPE(VarTableNode); | ||
| TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); }); | ||
| TVM_REGISTER_GLOBAL("script.printer.VarTableDefine") | ||
| .set_body_method<VarTable, VarTableNode, IdDoc, const ObjectRef&, const String&, | ||
| const ObjectPath&, const Frame&>(&VarTableNode::Define); | ||
| TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") | ||
| .set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory, | ||
| Frame frame) { | ||
| var_table->DefineByDoc( | ||
| obj, [f = std::move(factory)]() { return f(); }, frame); | ||
| }); | ||
| TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") | ||
| .set_body_method<VarTable>(&VarTableNode::GetVarDoc); | ||
| TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") | ||
| .set_body_method<VarTable>(&VarTableNode::IsVarDefined); | ||
|
|
||
| } // namespace printer | ||
| } // namespace script | ||
| } // namespace tvm |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make sure I fully comprehend, would you mind elaborating a little bit with a short example? Thanks a lot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Assume
DefineByDocaccepts aDocinstead ofstd::function<Doc()>. It will be more difficult to handle case likewhen we call the
GetVarDocfor the firstaand seconda, we expect the returned Docs to be different objects and have differentsource_path. Therefore theVarTableneeds to hold a function that returns Doc, rather than Doc itself.I also update the doc, trying to better clarify it. Let me know if it still looks confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense to me! Thanks for your detailed explanation!