@@ -2138,6 +2138,7 @@ def transform_layout(
21382138 --------
21392139 Before transform_layout, in TensorIR, the IR is:
21402140 .. code-block:: python
2141+
21412142 @T.prim_func
21422143 def before_transform_layout(a: T.handle, c: T.handle) -> None:
21432144 A = T.match_buffer(a, (128, 128), "float32")
@@ -2151,14 +2152,18 @@ def before_transform_layout(a: T.handle, c: T.handle) -> None:
21512152 with T.block("C"):
21522153 vi, vj = T.axis.remap("SS", [i, j])
21532154 C[vi, vj] = B[vi, vj] + 1.0
2155+
21542156 Create the schedule and do transform_layout:
21552157 .. code-block:: python
2158+
21562159 sch = tir.Schedule(before_storage_align)
21572160 sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True,
21582161 index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16))
21592162 print(sch.mod["main"].script())
2163+
21602164 After applying transform_layout, the IR becomes:
21612165 .. code-block:: python
2166+
21622167 @T.prim_func
21632168 def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None:
21642169 A = T.match_buffer(a, (128, 128), "float32")
@@ -2172,6 +2177,7 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
21722177 with T.block("C"):
21732178 vi, vj = T.axis.remap("SS", [i, j])
21742179 C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0
2180+
21752181 """
21762182 if callable (index_map ):
21772183 index_map = IndexMap .from_func (index_map )
0 commit comments