Skip to content

Commit e92d012

Browse files
author
sqing
committed
[TIR] Add parameter extent for access_ptr.
1 parent 1e579f8 commit e92d012

File tree

6 files changed

+59
-38
lines changed

6 files changed

+59
-38
lines changed

include/tvm/tir/buffer.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,11 @@ class Buffer : public ObjectRef {
186186
* \param ptr_type The type of the pointer.
187187
* \param content_lanes The number of lanes for the (data) type.
188188
* \param offset The offset of ptr.
189+
* \param input_extent The extent of ptr.
189190
*/
190191
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
191-
int content_lanes = 1,
192-
PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
192+
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
193+
Optional<PrimExpr> input_extent = NullOpt) const;
193194
/*!
194195
* \brief Create an Expr that does a vector load at begin index.
195196
* \param begin The beginning index

python/tvm/tir/buffer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Buffer(Object):
4242
READ = 1
4343
WRITE = 2
4444

45-
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
45+
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, extent=None):
4646
"""Get an access pointer to the head of buffer.
4747
4848
This is the recommended method to get buffer data
@@ -66,6 +66,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
6666
The offset of pointer. We can use it to offset by
6767
the number of elements from the address of ptr.
6868
69+
extent: Expr, optional
70+
The extent of pointer.
71+
6972
Examples
7073
--------
7174
.. code-block:: python
@@ -78,6 +81,8 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
7881
buffer.access_ptr("rw")
7982
# Get access ptr for read with offset
8083
buffer.access_ptr("r", offset = 100)
84+
# Get access ptr for read with extent
85+
buffer.access_ptr("r", extent = 100)
8186
"""
8287
if isinstance(access_mask, string_types):
8388
mask = 0
@@ -90,8 +95,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
9095
raise ValueError("Unknown access_mask %s" % access_mask)
9196
access_mask = mask
9297
offset = convert(offset)
98+
extent = convert(extent)
9399
return _ffi_api.BufferAccessPtr(
94-
self, access_mask, ptr_type, content_lanes, offset # type: ignore
100+
self, access_mask, ptr_type, content_lanes, offset, extent # type: ignore
95101
)
96102

97103
def vload(self, begin, dtype=None):

src/tir/ir/buffer.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
495495
return slice;
496496
}
497497

498-
PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
499-
PrimExpr offset) const {
498+
PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
499+
Optional<PrimExpr> input_extent) const {
500500
const BufferNode* self = operator->();
501501
ICHECK(self != nullptr);
502502
PrimExpr e_dtype;
@@ -519,6 +519,10 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
519519
} else {
520520
e_dtype = tir::TypeAnnotation(self->dtype);
521521
}
522+
523+
if (input_extent.defined()) {
524+
extent = input_extent.value();
525+
}
522526
Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
523527
make_const(DataType::Int(32), access_mask)};
524528
return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);

tests/python/unittest/test_tir_analysis_get_block_access_region.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ def opaque_access_func() -> None:
105105
)
106106

107107

108+
@T.prim_func
109+
def opaque_access_with_tvm_access_ptr_func() -> None:
110+
A = T.alloc_buffer([1024])
111+
B = T.alloc_buffer([1024])
112+
C = T.alloc_buffer([1024])
113+
with T.block("opaque"):
114+
T.reads(A[0:1024], C[0:1024])
115+
T.writes(B[0:1024], C[0:1024])
116+
T.evaluate(A.access_ptr("r"))
117+
T.evaluate(B.access_ptr("w"))
118+
T.evaluate(C.access_ptr("rw"))
119+
120+
108121
@T.prim_func
109122
def access_in_if_then_else_func() -> None:
110123
A = T.alloc_buffer([8])
@@ -235,6 +248,21 @@ def test_opaque_access():
235248
tvm.ir.assert_structural_equal(ret0[1], ret1[1])
236249

237250

251+
def test_opaque_access_with_tvm_access_ptr():
252+
block = opaque_access_with_tvm_access_ptr_func.body.block.body.block
253+
alloc_buffers = opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers
254+
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
255+
256+
ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map)
257+
ret1 = tir.analysis.get_block_access_region(block, buffer_var_map)
258+
tvm.ir.assert_structural_equal(block.reads, ret0[0])
259+
tvm.ir.assert_structural_equal(block.writes, ret0[1])
260+
with pytest.raises(ValueError):
261+
tvm.ir.assert_structural_equal(ret0[0], ret1[0])
262+
with pytest.raises(ValueError):
263+
tvm.ir.assert_structural_equal(ret0[1], ret1[1])
264+
265+
238266
def test_match_buffer():
239267
root_block = match_buffer_func.body.block
240268
block = root_block.body.body.body.block
@@ -333,6 +361,7 @@ def test_access_of_decompose_reduction():
333361
test_block_access_region_detector()
334362
test_opaque_block()
335363
test_opaque_access()
364+
test_opaque_access_with_tvm_access_ptr()
336365
test_match_buffer()
337366
test_access_in_if_then_else_func()
338367
test_access_in_branch_func()

tests/python/unittest/test_tir_buffer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def test_buffer_access_ptr_extent():
7676
aptr = Ab.access_ptr("rw", offset=100)
7777
assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100)
7878

79+
# Test extent from input params
80+
aptr = Ab.access_ptr("rw", extent=200)
81+
assert tvm.ir.structural_equal(aptr.args[3], 200)
82+
aptr = Ab.access_ptr("rw", offset=100, extent=100)
83+
assert tvm.ir.structural_equal(aptr.args[3], 100)
84+
7985

8086
def test_buffer_vload():
8187
m = te.size_var("m")

tests/python/unittest/test_tir_schedule_compute_inline.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,7 @@ def opaque_access_load(a: T.handle, c: T.handle) -> None:
183183
vi, vj = T.axis.remap("SS", [i, j])
184184
T.reads(B[0:128, 0:128])
185185
T.writes(C[0:128, 0:128])
186-
T.evaluate(
187-
T.tvm_access_ptr(
188-
T.type_annotation(dtype="float32"), B.data, 0, 128, 1, dtype="handle"
189-
)
190-
)
186+
T.evaluate(B.access_ptr("r", extent=128))
191187
C[vi, vj] = B[vi, vj] + 1.0
192188

193189

@@ -205,16 +201,8 @@ def opaque_access_store(a: T.handle, c: T.handle) -> None:
205201
vi, vj = T.axis.remap("SS", [i, j])
206202
T.reads(B[0:128, 0:128])
207203
T.writes(C[0:128, 0:128])
208-
T.evaluate(
209-
T.tvm_access_ptr(
210-
T.type_annotation(dtype="float32"), B.data, 0, 128, 1, dtype="handle"
211-
)
212-
)
213-
T.evaluate(
214-
T.tvm_access_ptr(
215-
T.type_annotation(dtype="float32"), C.data, 0, 128, 2, dtype="handle"
216-
)
217-
)
204+
T.evaluate(B.access_ptr("r", extent=128))
205+
T.evaluate(C.access_ptr("w", extent=128))
218206
C[vi, vj] = B[vi, vj] + 1.0
219207

220208

@@ -296,14 +284,8 @@ def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
296284
# annotated opaque partial access
297285
T.reads(A[0:512])
298286
T.writes(A_cache[0:512])
299-
T.evaluate(
300-
T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 512, 1, dtype="handle")
301-
)
302-
T.evaluate(
303-
T.tvm_access_ptr(
304-
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, 2, dtype="handle"
305-
)
306-
)
287+
T.evaluate(A.access_ptr("r", extent=512))
288+
T.evaluate(A_cache.access_ptr("w", extent=512))
307289
for i in range(512):
308290
with T.block("BB"):
309291
vi = T.axis.remap("S", [i])
@@ -323,14 +305,8 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
323305
# annotated opaque partial access should be kept
324306
T.reads(A[0:512])
325307
T.writes([A_cache[0:512]])
326-
T.evaluate(
327-
T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 512, 1, dtype="handle")
328-
)
329-
T.evaluate(
330-
T.tvm_access_ptr(
331-
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, 2, dtype="handle"
332-
)
333-
)
308+
T.evaluate(A.access_ptr("r", extent=512))
309+
T.evaluate(A_cache.access_ptr("w", extent=512))
334310
for i in T.serial(0, 512):
335311
with T.block("B"):
336312
vi = T.axis.spatial(512, i)
@@ -614,7 +590,6 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr():
614590
sch = tir.Schedule(exp_exp_opaque_access_with_tvm_access_ptr, debug_mask="all")
615591
compute = sch.get_block("compute")
616592
sch.compute_inline(compute)
617-
print(sch.mod.script())
618593
tvm.ir.assert_structural_equal(
619594
exp_exp_opaque_access_with_tvm_access_ptr_inlined, sch.mod["main"]
620595
)

0 commit comments

Comments
 (0)