@@ -2137,7 +2137,9 @@ def transform_layout(
21372137 Examples
21382138 --------
21392139 Before transform_layout, in TensorIR, the IR is:
2140+
21402141 .. code-block:: python
2142+
21412143 @T.prim_func
21422144 def before_transform_layout(a: T.handle, c: T.handle) -> None:
21432145 A = T.match_buffer(a, (128, 128), "float32")
@@ -2151,14 +2153,20 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None:
21512153 with T.block("C"):
21522154 vi, vj = T.axis.remap("SS", [i, j])
21532155 C[vi, vj] = B[vi, vj] + 1.0
2156+
21542157 Create the schedule and do transform_layout:
2158+
21552159 .. code-block:: python
2160+
21562161 sch = tir.Schedule(before_storage_align)
21572162 sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True,
21582163 index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
21592164 print(sch.mod["main"].script())
2165+
21602166 After applying transform_layout, the IR becomes:
2167+
21612168 .. code-block:: python
2169+
21622170 @T.prim_func
21632171 def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
21642172 A = T.match_buffer(a, (128, 128), "float32")
@@ -2172,6 +2180,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
21722180 with T.block("C"):
21732181 vi, vj = T.axis.remap("SS", [i, j])
21742182 C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0
2183+
21752184 """
21762185 if callable (index_map ):
21772186 index_map = IndexMap .from_func (index_map )
0 commit comments