|
16 | 16 | # under the License. |
17 | 17 | """Wrapping existing analysis utils.""" |
18 | 18 | # pylint: disable=invalid-name |
19 | | -from typing import Dict, List |
| 19 | +from typing import Dict, List, AnyStr |
20 | 20 |
|
| 21 | +from tvm import Object |
21 | 22 | from tvm.tir.stmt import Block, BufferRegion |
22 | 23 | from tvm.tir.stmt import PrimExpr |
23 | 24 | from tvm.tir.expr import Var |
@@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: |
196 | 197 | Map from buffer to the LCA of all access to it. |
197 | 198 | """ |
198 | 199 | return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member |
| 200 | + |
| 201 | + |
| 202 | +# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would |
| 203 | +# introduce a cycling dependency. We make do with Object. |
| 204 | + |
| 205 | + |
| 206 | +def get_prim_func_arg_and_result_memory_constraints( |
| 207 | + func: PrimFunc, relay_func_type: Object |
| 208 | +) -> List[AnyStr]: |
| 209 | + """Returns the memory (aka storage) scope constraints for all the arguments and result |
| 210 | + of func. However the result will be w.r.t. the func's representation as a Relay Function |
| 211 | + of relay_func_type before lowering and conversion to DPS. |
| 212 | +
|
| 213 | + Visible for testing. |
| 214 | +
|
| 215 | + Parameters |
| 216 | + ---------- |
| 217 | + func: tvm.tir.PrimFunc |
| 218 | + The function to retrieve constraints from. |
| 219 | +
|
| 220 | + relay_func_type: tvm.relay.FuncType |
| 221 | + The type of the Relay Function from which the func was derived. |
| 222 | +
|
| 223 | + Returns |
| 224 | + ------- |
| 225 | + result: List[AnyStr] |
| 226 | + Memory scope constraints for funcs args and result in Relay form. The empty string |
| 227 | + denotes 'no constraint'. |
| 228 | + """ |
| 229 | + return _ffi_api.GetPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member |
| 230 | + func, relay_func_type |
| 231 | + ) |
| 232 | + |
| 233 | + |
| 234 | +def apply_prim_func_arg_and_result_memory_constraints( |
| 235 | + func: PrimFunc, relay_func_type: Object, arg_and_result_memory_scopes: List[AnyStr] |
| 236 | +) -> PrimFunc: |
| 237 | + """Returns func written to capture the memory (aka storage) scope constraints |
| 238 | + for each of the func's parameters given by arg_and_result_memory_scopes. However, |
| 239 | + arg_and_result_memory_scopes should be w.r.t. the func's representation as a Relay |
| 240 | + Function of relay_func_type before lowering and conversion to DPS. |
| 241 | +
|
| 242 | + Visible for testing. |
| 243 | +
|
| 244 | + CAUTION: This is experimental. The resulting PrimFunc may not have fully accounted |
| 245 | + for all new memory scopes. |
| 246 | +
|
| 247 | + Parameters |
| 248 | + ---------- |
| 249 | + func: tvm.tir.PrimFunc |
| 250 | + The function to retrieve constraints from. |
| 251 | +
|
| 252 | + relay_func_type: tvm.relay.FuncType |
| 253 | + The type of the Relay Function from which the func was derived. |
| 254 | +
|
| 255 | + arg_and_result_memory_scopes: Array[AnyStr] |
| 256 | + Memory constraints for funcs args and result in Relay form. The empty string denotes |
| 257 | + 'no constraint'. |
| 258 | +
|
| 259 | + Returns |
| 260 | + ------- |
| 261 | + result: tvm.tir.PrimFunc |
| 262 | + The rewritten func. |
| 263 | + """ |
| 264 | + return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member |
| 265 | + func, relay_func_type, arg_and_result_memory_scopes |
| 266 | + ) |
0 commit comments