Skip to content

Conversation

@qsqqsqqsq-intellif
Copy link
Contributor

@qsqqsqqsq-intellif qsqqsqqsq-intellif commented Sep 27, 2024

Overview

This PR introduces a new TIR schedule primitive annotate_buffer_access that allows explicit annotation of buffer access regions for both reads and writes.

Motivation

TVM currently does not support inferring the numerical range of floating-point calculations. As a result, buffer access regions involving floating-point calculations cannot be accurately inferred and default to the full extent of the buffer. This new primitive addresses this limitation by allowing manual specification of access regions.

Usage scenarios

This primitive is particularly useful for operations where the default buffer region inference may not capture the precise access patterns, such as in resize operations. It overrides the automatically inferred region for the specified buffer.

Example

Trivial Example

before:

   @T.prim_func
    def before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")):
        for i0, i1, i2, i3 in T.grid(1, 1, 16, 16):
            with T.block("resize"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, 0:32, 0:32])
                T.writes(resize[v_i0, v_i1, v_i2, v_i3])
                resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)]))

Perform annotate_buffer_access:

    sch.annotate_buffer_access(block, 0, "read",
        gen_new_ranges=lambda v_i0, v_i1, v_i2, v_i3: [
            v_i0,
            v_i1,
            (v_i2 * 2 - 3, v_i2 * 2 + 3),
            (v_i3 * 2 - 3, v_i3 * 2 + 3),
        ],
    )

after:

     @T.prim_func
    def after(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")):
        for i0, i1, i2, i3 in T.grid(1, 1, 16, 16):
            with T.block("resize"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 - 3:v_i3 * 2 + 3])
                T.writes(resize[v_i0, v_i1, v_i2, v_i3])
                T.block_attr({"explicit_read_region": [0]})
                resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)]))

The primitive adds an annotation(T.block_attr({"explicit_read_region": [0]})) to the block, indicating that an explicit region has been provided for the buffer at the given index. This annotation is used in the CompactBufferAllocation pass to respect the manually specified region instead of relying on automatic inference.

Resize Op Tile Example

We can optimize the tiling of the "cache" block for the "resize" operation using the annotate_buffer_access primitive.
before:

    @T.prim_func
    def before(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")):
        x_global = T.alloc_buffer([1, 3, 200, 200], dtype="float32")
        for ax0, ax1, ax2, ax3 in T.grid(1, 3, 200, 200):
            with T.block("cache"):
                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
        for i0, i1, i2, i3 in T.grid(1, 3, 100, 100):
            with T.block("resize"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(v_i2 * 2 + 0.5)), T.Cast("int32", T.floor(v_i3 * 2 + 0.5))]

Let's split the i2 loop and i3 loop of the "resize" block, and then compute-at "cache" block to outer loop of resize. This is a typical schedule of tile process.

    h, w = s.get_loops(resize_block)[-2:]
    ho, hi = s.split(h, factors=[10, 10])
    wo, wi = s.split(w, factors=[10, 10])
    s.reorder(ho, wo, hi, wi)
    s.compute_at(cache_block, wo)

After tiling without annotate_buffer_access:

    @T.prim_func
    def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")):
        x_global = T.alloc_buffer((1, 3, 200, 200))
        for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10):
            for ax0, ax1 in T.grid(200, 200):
                with T.block("cache"):
                    v0 = T.axis.spatial(1, 0)
                    v1, v2, v3 = T.axis.remap("SSS", [i1, ax0, ax1])
                    T.reads(x[v0, v1, v2, v3])
                    T.writes(x_global[v0, v1, v2, v3])
                    x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
            for i2_1, i3_1 in T.grid(10, 10):
                with T.block("resize"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1)
                    v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1)
                    T.reads(x_global[v_i0, v_i1, 0:200, 0:200])
                    T.writes(y[v_i0, v_i1, v_i2, v_i3])
                    y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))]

Notice that the "cache" block still reads the entire 200x200 region after compute-at. To optimize this, we can use annotate_buffer_access to explicitly annotate the buffer region of the "resize" block:

    s.annotate_buffer_access(
        resize_block,
        0,
        "read",
        lambda vn, vc, vh, vw: (vn, vc, (vh * 2 - 3, vh * 2 + 3), (vw * 2 - 3, vw * 2 + 3)),
    )
    s.compute_at(cache_block, wo)

After tiling with annotate_buffer_access:

    @T.prim_func
    def after_with_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")):
        x_global = T.alloc_buffer((1, 3, 200, 200))
        for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10):
            for ax0, ax1 in T.grid(24, 24):
                with T.block("cache"):
                    v0 = T.axis.spatial(1, 0)
                    v1 = T.axis.spatial(3, i1)
                    v2 = T.axis.spatial(200, i2_0 * 20 - 3 + ax0)
                    v3 = T.axis.spatial(200, i3_0 * 20 - 3 + ax1)
                    T.where(3 <= i2_0 * 20 + ax0 and i2_0 * 20 + ax0 < 203 and 3 <= i3_0 * 20 + ax1 and i3_0 * 20 + ax1 < 203)
                    T.reads(x[v0, v1, v2, v3])
                    T.writes(x_global[v0, v1, v2, v3])
                    x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3]
            for i2_1, i3_1 in T.grid(10, 10):
                with T.block("resize"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1)
                    v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1)
                    T.reads(x_global[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 - 3 + 6, v_i3 * 2 - 3:v_i3 * 2 - 3 + 6])
                    T.writes(y[v_i0, v_i1, v_i2, v_i3])
                    T.block_attr({"explicit_read_region": [0]})
                    y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))]

The "cache" block now only reads the necessary 24x24 region instead of the entire 200x200 input. These optimizations significantly reduce memory bandwidth requirements and improve cache efficiency, especially for larger input sizes.

Note

Caution should be exercised when using this function, as incorrect annotations may lead to incorrect code generation or runtime errors. It's crucial to ensure that the specified region covers all actual reads or writes performed by the block for the given buffer.
cc @Hzfengsy @junrushao

@qsqqsqqsq-intellif qsqqsqqsq-intellif force-pushed the annotate_buffer_access branch 4 times, most recently from 9c81e7d to 4db004c Compare September 29, 2024 06:21
@qsqqsqqsq-intellif qsqqsqqsq-intellif force-pushed the annotate_buffer_access branch 3 times, most recently from bdb9d4a to 449eeed Compare October 9, 2024 02:52
@qsqqsqqsq-intellif qsqqsqqsq-intellif force-pushed the annotate_buffer_access branch 2 times, most recently from be963f6 to 5844d73 Compare October 13, 2024 05:57
@Hzfengsy Hzfengsy merged commit 35d6a1b into apache:main Oct 16, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants