Skip to content

Commit 3398250

Browse files
committed
[TIR] Allow memory (aka storage) scopes to be retrieved/applied to PrimFuncs
This is in support of #9613 which allows memory scopes to flow out of already-lowered PrimFuncs into the rest of the Relay program. This means scope choices made during lowering can be accounted for in the rest of the program, with device_copies inserted as required. Somewhat more speculatively we also allow memory scopes to flow in to PrimFuncs. This is in preparation for when we can split lowering into two phases: i) lower "primitive" fused Relay functions to TensorIR in a schedulable form roughly isomorphic to TE, and ii) actual scheduling down to traditional TIR. Once that split is made it will be possible to flow memory scopes out of one PrimFunc and into another so as to avoid unnecessary device_copies being necessary due to independently chosen memory scopes. I also suspect we'll want to put our focus on layouts rather than memory scopes, but this at least sets up some of the machinery.
1 parent 0e0adf5 commit 3398250

File tree

9 files changed

+759
-4
lines changed

9 files changed

+759
-4
lines changed

include/tvm/target/se_scope.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
170170
*
171171
* kInvalidDeviceType denotes unconstrained.
172172
*/
173-
int device_type_int;
173+
int /* actually DLDeviceType */ device_type_int;
174174

175175
DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }
176176

@@ -303,6 +303,11 @@ class SEScope : public ObjectRef {
303303
return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
304304
}
305305

306+
/*! \brief Returns the \p SEScope for \p memory_scope alone. */
307+
static SEScope ForMemoryScope(MemoryScope memory_scope) {
308+
return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope));
309+
}
310+
306311
/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
307312
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
308313
MemoryScope memory_scope) {

include/tvm/tir/analysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626

2727
#include <tvm/ir/module.h>
2828
#include <tvm/ir/transform.h>
29+
#include <tvm/target/se_scope.h>
2930
#include <tvm/tir/expr.h>
3031
#include <tvm/tir/function.h>
3132
#include <tvm/tir/op_attr_types.h>
3233
#include <tvm/tir/stmt.h>
3334

3435
#include <string>
36+
#include <unordered_map>
3537

3638
namespace tvm {
3739
namespace tir {
@@ -242,4 +244,5 @@ TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
242244
} // namespace transform
243245
} // namespace tir
244246
} // namespace tvm
247+
245248
#endif // TVM_TIR_ANALYSIS_H_

include/tvm/tir/buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class Buffer : public ObjectRef {
144144
public:
145145
// User can specify data_alignment and offset_factor to be 0
146146
// A default value will be picked.
147-
TVM_DLL Buffer(Var ptr, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
147+
TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
148148
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
149149
BufferType buffer_type, Span span = Span());
150150

include/tvm/tir/stmt_functor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
280280
*/
281281
Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
282282
std::function<Stmt(const Stmt&)> fmutate = nullptr);
283+
283284
// internal helper.
284285
class Internal;
285286
};

include/tvm/tir/var.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class VarNode : public PrimExprNode {
5252
*/
5353
String name_hint;
5454
/*!
55-
* \brief type annotaion of the variable.
55+
* \brief type annotation of the variable.
5656
*
5757
* It is an optional field that provides a refined type of the variable than dtype.
5858
*

python/tvm/tir/analysis/analysis.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
# under the License.
1717
"""Wrapping existing analysis utils."""
1818
# pylint: disable=invalid-name
19-
from typing import Dict, List
19+
from typing import Dict, List, AnyStr
2020

21+
from tvm import Object
2122
from tvm.tir.stmt import Block, BufferRegion
2223
from tvm.tir.stmt import PrimExpr
2324
from tvm.tir.expr import Var
@@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
196197
Map from buffer to the LCA of all access to it.
197198
"""
198199
return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member
200+
201+
202+
# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would
203+
# introduce a cycling dependency. We make do with Object.
204+
205+
206+
def get_prim_func_arg_and_result_memory_constraints(
207+
func: PrimFunc, relay_func_type: Object
208+
) -> List[AnyStr]:
209+
"""Returns the memory (aka storage) scope constraints for all the arguments and result
210+
of func. However the result will be w.r.t. the func's representation as a Relay Function
211+
of relay_func_type before lowering and conversion to DPS.
212+
213+
Visible for testing.
214+
215+
Parameters
216+
----------
217+
func: tvm.tir.PrimFunc
218+
The function to retrieve constraints from.
219+
220+
relay_func_type: tvm.relay.FuncType
221+
The type of the Relay Function from which the func was derived.
222+
223+
Returns
224+
-------
225+
result: List[AnyStr]
226+
Memory scope constraints for funcs args and result in Relay form. The empty string
227+
denotes 'no constraint'.
228+
"""
229+
return _ffi_api.GetPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member
230+
func, relay_func_type
231+
)
232+
233+
234+
def apply_prim_func_arg_and_result_memory_constraints(
235+
func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[AnyStr]
236+
) -> PrimFunc:
237+
"""Returns func written to capture the memory (aka storage) scope constraints
238+
for each of the func's parameters given by arg_and_result_memory_scopes. However,
239+
arg_and_result_memory_scopes should be w.r.t. the func's representation as a Relay
240+
Function of relay_func_type before lowering and conversion to DPS.
241+
242+
Visible for testing.
243+
244+
CAUTION: This is experimental. The resulting PrimFunc may not have fully accounted
245+
for all new memory scopes.
246+
247+
Parameters
248+
----------
249+
func: tvm.tir.PrimFunc
250+
The function to retrieve constraints from.
251+
252+
relay_func_type: tvm.relay.FuncType
253+
The type of the Relay Function from which the func was derived.
254+
255+
arg_and_result_memory_scopes: Array[AnyStr]
256+
Memory constraints for funcs args and result in Relay form. The empty string denotes
257+
'no constraint'.
258+
259+
Returns
260+
-------
261+
result: tvm.tir.PrimFunc
262+
The rewritten func.
263+
"""
264+
return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member
265+
func, relay_func_type, arg_and_result_memory_scopes
266+
)

0 commit comments

Comments
 (0)