Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
*
* kInvalidDeviceType denotes unconstrained.
*/
int device_type_int;
int /* actually DLDeviceType */ device_type_int;

DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }

Expand Down Expand Up @@ -303,6 +303,11 @@ class SEScope : public ObjectRef {
return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
}

/*! \brief Returns the \p SEScope for \p memory_scope alone. */
static SEScope ForMemoryScope(MemoryScope memory_scope) {
return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class Buffer : public ObjectRef {
public:
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL Buffer(Var ptr, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Span span = Span());

Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
*/
Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate = nullptr);

// internal helper.
class Internal;
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class VarNode : public PrimExprNode {
*/
String name_hint;
/*!
* \brief type annotaion of the variable.
* \brief type annotation of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
Expand Down
68 changes: 68 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=invalid-name
from typing import Dict, List

from tvm import Object
from tvm.tir.stmt import Block, BufferRegion
from tvm.tir.stmt import PrimExpr
from tvm.tir.expr import Var
Expand Down Expand Up @@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
Map from buffer to the LCA of all access to it.
"""
return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member


# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would
# introduce a cycling dependency. We make do with Object.


def get_prim_func_arg_and_result_memory_constraints(
func: PrimFunc, relay_func_type: Object
) -> List[str]:
"""Returns the memory (aka storage) scope constraints for all the arguments and result
of func. However the result will be w.r.t. the func's representation as a Relay Function
of relay_func_type before lowering and conversion to DPS.

Visible for testing.

Parameters
----------
func: tvm.tir.PrimFunc
The function to retrieve constraints from.

relay_func_type: tvm.relay.FuncType
The type of the Relay Function from which the func was derived.

Returns
-------
result: List[AnyStr]
Memory scope constraints for funcs args and result in Relay form. The empty string
denotes 'no constraint'.
"""
return _ffi_api.GetPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member
func, relay_func_type
)


def apply_prim_func_arg_and_result_memory_constraints(
func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[str]
) -> PrimFunc:
"""Returns func written to capture the memory (aka storage) scope constraints
for each of the func's parameters given by arg_and_result_memory_scopes. However,
arg_and_result_memory_scopes should be w.r.t. the func's representation as a Relay
Function of relay_func_type before lowering and conversion to DPS.

Visible for testing.

CAUTION: This is experimental. The resulting PrimFunc may not have fully accounted
for all new memory scopes.

Parameters
----------
func: tvm.tir.PrimFunc
The function to retrieve constraints from.

relay_func_type: tvm.relay.FuncType
The type of the Relay Function from which the func was derived.

arg_and_result_memory_scopes: Array[AnyStr]
Memory constraints for funcs args and result in Relay form. The empty string denotes
'no constraint'.

Returns
-------
result: tvm.tir.PrimFunc
The rewritten func.
"""
return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member
func, relay_func_type, arg_and_result_memory_scopes
)
Loading