Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions include/tvm/script/printer/var_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_
#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_

#include <tvm/node/node.h>
#include <tvm/node/object_path.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/frame.h>
#include <tvm/script/printer/traced_object.h>

#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Variable Table manages mapping from variable object to ExprDoc during
* the process of printing TVMScript.
*
* The value type of this map is ExprDoc rather than IdDoc or String. It's
* because variables can be implicitly defined. For example in TIR buffer (tir::Buffer),
* `buf->data` is a variable, while its representation in TVMScript should be an
* expression `x.data`, where `x` is the variable for the buffer itself.
*/
class VarTableNode : public Object {
public:
void VisitAttrs(AttrVisitor*) {}

/*!
* \brief Define variable by name.
* \param obj The variable object.
* \param name_hint The hint for variable name.
* \param object_path The object_path for the returned ExprDoc.
* \param frame The frame that this variable is defined in.
*
* \return The id doc for this variable.
*
* This function will rename the variable to avoid name conflict with other variables
* in the table.
*/
IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path,
const Frame& frame);

/*!
* \brief Define variable by name.
* \param obj The variable object.
* \param name_hint The hint for variable name.
* \param frame The frame that this variable is defined in.
*
* \return The id doc for this variable.
*
* This is a shortcut version of `Define` which accepts a traced string.
*/
IdDoc Define(const ObjectRef& obj, const TracedObject<String>& name_hint, const Frame& frame) {
return Define(obj, name_hint.Get(), name_hint.GetPath(), frame);
}

using DocFactory = std::function<ExprDoc()>;

/*!
* \brief Define variable by doc factory.
* \param obj The variable object.
* \param doc_factory The function to return an ExprDoc object for this variable.
* \param frame The frame that this variable is defined in.
*
* This function is a special form of `Define`. Variable is mapped to ExprDoc rather
* than IdDoc. It's useful when a variable is implicitly defined without a name, like
* the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`.
*
* This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure I fully comprehend, would you mind elaborating a little bit with a short example? Thanks a lot!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Assume DefineByDoc accepts a Doc instead of std::function<Doc()>. It will be more difficult to handle case like

c[k] = a[i] + a[j]

when we call the GetVarDoc for the first a and second a, we expect the returned Docs to be different objects and have different source_path. Therefore the VarTable needs to hold a function that returns Doc, rather than Doc itself.

I also update the doc, trying to better clarify it. Let me know if it still looks confusing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me! Thanks for your detailed explanation!

* return a new Doc object every time it's called, as the returned doc will have
* different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject
* so VarTable needs to call a factory function to get a freshly-constructed Doc object
* every time GetVarDoc is called.
*/
void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame);

/*!
* \brief Get the doc for variable.
* \param obj The variable object.
* \param object_path The object path for the variable.
*
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;

/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
*
* \return a boolean for whether variable exists.
*/
bool IsVarDefined(const ObjectRef& obj) const;

static constexpr const char* _type_key = "script.printer.VarTable";
TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object);

private:
void RemoveVar(const ObjectRef& obj);

struct VariableInfo {
DocFactory doc_factory;
Optional<String> name;
};
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
std::unordered_set<String> defined_names;
};

/*!
* \brief Reference type of VarTableNode.
*/
class VarTable : public ObjectRef {
public:
/*!
* \brief Create an empty VarTable.
*/
VarTable();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode);
};

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_
118 changes: 118 additions & 0 deletions python/tvm/script/printer/var_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Functions to print doc into text format"""

from typing import Callable, Optional

from tvm._ffi import register_object
from tvm.runtime import Object, ObjectPath

from . import _ffi_api
from .doc import ExprDoc, IdDoc
from .frame import Frame


@register_object("script.printer.VarTable")
class VarTable(Object):
"""
Variable Table manages mapping from variable object to ExprDoc during
the process of printing TVMScript.
"""

def __init__(self):
"""
Create an empty VarTable.
"""
self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member

def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc:
"""
Define a variable by name.

Parameters
----------
obj : Object
The variable object.
name_hint : str
The hint for variable name.
object_path : ObjectPath
The object path to be associated with the returned ExprDoc.
frame : Frame
Then frame that this variable is defined in.

Returns
-------
doc : IdDoc
The doc for this variable.
"""
return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member

def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None:
"""
Define a variable by ExprDoc.

Parameters
----------
obj : Object
The variable object.
doc_factory : Callable[[], ExprDoc]
The hint for variable name.
frame : Frame
Then frame that this variable is defined in.

Returns
-------
None
"""
_ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member

def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]:
"""
Get the doc for a variable.

Parameters
----------
obj : Object
The variable object.
object_path : ObjectPath
The object path to be associated with the returned ExprDoc.

Returns
-------
doc : ExprDoc
The doc for this variable.
"""
return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member

def is_var_defined(self, obj: Object) -> bool:
"""
Check whether a variable is defined.

Parameters
----------
obj : Object
The variable object.

Returns
-------
is_defined : bool
Whether the variable is defined.
"""
return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member

def __contains__(self, obj: Object) -> bool:
return self.is_var_defined(obj)
108 changes: 108 additions & 0 deletions src/script/printer/var_table.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/node/object_path.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/var_table.h>

namespace tvm {
namespace script {
namespace printer {

String GenerateUniqueName(const String& name_hint, std::unordered_set<String>* defined_names) {
String name = name_hint;
for (int i = 1; !defined_names->insert(name).second; ++i) {
name = name_hint + "_" + std::to_string(i);
}
return name;
}

IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint,
const ObjectPath& object_path, const Frame& frame) {
String name = GenerateUniqueName(name_hint, &this->defined_names);
DocFactory doc_factory = [name]() { return IdDoc(name); };

auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
ICHECK(result.second) << "Duplicated object: " << obj;

IdDoc def_doc(name);
def_doc->source_paths.push_back(object_path);

frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });

return def_doc;
}

void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) {
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;

ICHECK(!doc_factory()->IsInstance<IdDocNode>())
<< "VarTableNode::Define cannot be used for variable that's mapped to IdDoc.";

obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});

frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
}

Optional<ExprDoc> VarTableNode::GetVarDoc(const ObjectRef& obj,
const ObjectPath& object_path) const {
auto it = obj2info.find(obj);
if (it == obj2info.end()) {
return NullOpt;
}
ExprDoc doc = it->second.doc_factory();
doc->source_paths.push_back(object_path);
return doc;
}

bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }

void VarTableNode::RemoveVar(const ObjectRef& obj) {
auto it = obj2info.find(obj);
ICHECK(it != obj2info.end()) << "No such object: " << obj;

if (it->second.name.defined()) {
defined_names.erase(it->second.name.value());
}
obj2info.erase(it);
}

VarTable::VarTable() { data_ = make_object<VarTableNode>(); }

TVM_REGISTER_NODE_TYPE(VarTableNode);
TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); });
TVM_REGISTER_GLOBAL("script.printer.VarTableDefine")
.set_body_method<VarTable, VarTableNode, IdDoc, const ObjectRef&, const String&,
const ObjectPath&, const Frame&>(&VarTableNode::Define);
TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
.set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory,
Frame frame) {
var_table->DefineByDoc(
obj, [f = std::move(factory)]() { return f(); }, frame);
});
TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
.set_body_method<VarTable>(&VarTableNode::GetVarDoc);
TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
.set_body_method<VarTable>(&VarTableNode::IsVarDefined);

} // namespace printer
} // namespace script
} // namespace tvm
Loading