Skip to content

Commit 8294bcf

Browse files
committed
[TIR][USMP] Added buffer info extraction pass
Updating the USMP utility tests to include tests that test creation of PoolInfo and PoolAllocation Objects. Change-Id: I5d349d0ffcac6b0160072d832dd9d5418699228e
1 parent c6e166f commit 8294bcf

File tree

3 files changed

+186
-117
lines changed

3 files changed

+186
-117
lines changed

python/tvm/tir/usmp/utils.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""USMP Utilities and Data Structures"""
1818
# pylint: disable=invalid-name
1919

20-
from typing import Dict, Optional
20+
from typing import Dict, Optional, List
2121

2222
from tvm._ffi import register_object
2323
from tvm.runtime import Object
@@ -27,6 +27,43 @@
2727
CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools"
2828

2929

30+
@register_object("tir.usmp.PoolInfo")
31+
class PoolInfo(Object):
32+
"""PoolInfo object holds information related to memory pools
33+
where the statically sized allocate nodes will pooled into.
34+
35+
Parameters
36+
----------
37+
pool_name : str
38+
The name of the memory pool
39+
40+
target_access : Dict[Target, str]
41+
A dictionary where keys describe which targets could
42+
access the pool where value could take the values :
43+
a) "rw" : read-write access
44+
b) "ro" : write-only acesss
45+
46+
size_hint_bytes : Optional[int]
47+
The expected size hint to be used by the allocator.
48+
The default value would be -1 which means the pool
49+
is not size restricted.
50+
51+
"""
52+
53+
READ_WRITE_ACCESS = "rw"
54+
READ_ONLY_ACCESS = "ro"
55+
56+
def __init__(
57+
self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1
58+
):
59+
self.__init_handle_by_constructor__(
60+
_ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member
61+
pool_name,
62+
target_access,
63+
size_hint_bytes,
64+
)
65+
66+
3067
@register_object("tir.usmp.BufferInfo")
3168
class BufferInfo(Object):
3269
"""BufferInfo object holds information related to buffers
@@ -41,7 +78,10 @@ class BufferInfo(Object):
4178
size_bytes : int
4279
The size in bytes
4380
44-
alignment : int
81+
pool_candidates : List[PoolInfo]
82+
The list of candidates pools this buffer could be placed
83+
84+
alignment : Optional[int]
4585
The byte alignment required in the workspace memory
4686
4787
"""
@@ -50,12 +90,14 @@ def __init__(
5090
self,
5191
name_hint: str,
5292
size_bytes: int,
53-
alignment: int = None,
93+
pool_candidates: List[PoolInfo],
94+
alignment: Optional[int] = None,
5495
):
5596
self.__init_handle_by_constructor__(
5697
_ffi_api.BufferInfo, # type: ignore # pylint: disable=no-member
5798
name_hint,
5899
size_bytes,
100+
pool_candidates,
59101
alignment,
60102
)
61103

@@ -72,43 +114,6 @@ def set_conflicts(self, conflicts: list):
72114
_ffi_api.BufferInfoSetConflicts(self, conflicts)
73115

74116

75-
@register_object("tir.usmp.PoolInfo")
76-
class PoolInfo(Object):
77-
"""PoolInfo object holds information related to memory pools
78-
where the statically sized allocate nodes will pooled into.
79-
80-
Parameters
81-
----------
82-
pool_name : str
83-
The name of the memory pool
84-
85-
target_access : Dict[Target, str]
86-
A dictionary where keys describe which targets could
87-
access the pool where value could take the values :
88-
a) "rw" : read-write access
89-
b) "ro" : write-only acesss
90-
91-
size_hint_bytes : Optional[int]
92-
The expected size hint to be used by the allocator.
93-
The default value would be -1 which means the pool
94-
is not size restricted.
95-
96-
"""
97-
98-
READ_WRITE_ACCESS = "rw"
99-
READ_ONLY_ACCESS = "ro"
100-
101-
def __init__(
102-
self, pool_name: str, target_access: Dict[Target, str], size_hint_bytes: Optional[int] = -1
103-
):
104-
self.__init_handle_by_constructor__(
105-
_ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member
106-
pool_name,
107-
target_access,
108-
size_hint_bytes,
109-
)
110-
111-
112117
@register_object("tir.usmp.PoolAllocation")
113118
class PoolAllocation(Object):
114119
"""PoolAllocation object holds information related to an allocation

tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_allocate(stmt):
5353
return allocates
5454

5555

56-
def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
56+
def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
5757
"""helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""
5858

5959
def set_poolinfos(stmt):
@@ -70,12 +70,12 @@ def set_poolinfos(stmt):
7070
return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
7171

7272

73-
def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
73+
def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
7474
"""helper to assing poolinfos to allocate nodes in a IRModule"""
7575
ret = tvm.IRModule()
7676
for global_var, basefunc in mod.functions.items():
7777
if isinstance(basefunc, tvm.tir.PrimFunc):
78-
ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
78+
ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
7979
return ret
8080

8181

@@ -158,7 +158,7 @@ def test_linear():
158158
pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}
159159
)
160160
tir_mod = LinearStructure
161-
tir_mod = assign_poolinfos_to_allocates_in_irmodule(
161+
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
162162
tir_mod, [fast_memory_pool, slow_memory_pool]
163163
)
164164
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(
@@ -274,7 +274,7 @@ def test_parallel_serial_mixed_for_loops():
274274
target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
275275
)
276276
all_serial_tir_mod = AllSerialForLoops
277-
all_serial_tir_mod = assign_poolinfos_to_allocates_in_irmodule(
277+
all_serial_tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
278278
all_serial_tir_mod, [global_ws_pool]
279279
)
280280
main_func = all_serial_tir_mod["tvmgen_default_run_model"]
@@ -287,7 +287,7 @@ def test_parallel_serial_mixed_for_loops():
287287
assert name in ["dummy_allocate", "Conv2dOutput_8", "PaddedInput_8"]
288288

289289
parallel_serial_mixed_tir_mod = ParallelSerialMixedForLoops
290-
parallel_serial_mixed_tir_mod = assign_poolinfos_to_allocates_in_irmodule(
290+
parallel_serial_mixed_tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
291291
parallel_serial_mixed_tir_mod, [global_ws_pool]
292292
)
293293
main_func = parallel_serial_mixed_tir_mod["tvmgen_default_run_model"]
@@ -634,7 +634,7 @@ def test_inception_structure():
634634
target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
635635
)
636636
tir_mod = InceptionStructure
637-
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
637+
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
638638
main_func = tir_mod["tvmgen_default_run_model"]
639639
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
640640
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)

0 commit comments

Comments
 (0)