Skip to content

Commit af01281

Browse files
authored
[TIR][Schedule] Allow named block and buffer arguments in Schedule (#11624)
* [Schedule] Allowed string argument as block arg This has previously been implemented for `Schedule.transform_layout` in #11296, extending to allow for block arguments in all `Schedule` methods. This change was only made for arguments that must be a `BlockRV`. For arguments that may be either a `BlockRV` or another type (e.g. `Schedule.get_child_blocks` accepts either `BlockRV` or `LoopRV`), this sugar is not implemented, to avoid ambiguity. * [Schedule] Allowed string argument to Schedule.reindex Similar to #11269, which added this functionality to `Schedule.transform_layout`. * CI test update
1 parent 81b42e6 commit af01281

12 files changed

+291
-227
lines changed

python/tvm/tir/schedule/schedule.py

Lines changed: 76 additions & 36 deletions
Large diffs are not rendered by default.

src/tir/schedule/primitive/cache_read_write.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,11 +1241,10 @@ struct ReIndexTraits : public UnpackedInstTraits<ReIndexTraits> {
12411241
Integer buffer_index_type) {
12421242
PythonAPICall py("reindex");
12431243
py.Input("block", block);
1244-
py.Input("buffer_index", buffer_index);
1245-
py.Input("buffer_index_type", '"' +
1246-
std::string(BufferIndexType2Str(
1247-
static_cast<BufferIndexType>(buffer_index_type->value))) +
1248-
'"');
1244+
std::ostringstream os;
1245+
os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
1246+
<< "\", " << buffer_index << ")";
1247+
py.Input("buffer", os.str());
12491248
py.SingleOutput(outputs);
12501249
return py.Str();
12511250
}

tests/python/unittest/test_tir_schedule_cache_read_write.py

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -741,13 +741,15 @@ def block_predicate_cache_write_output_buf() -> None:
741741

742742
########## Testcases for cache_read ##########
743743

744+
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
744745

745-
def test_cache_read_elementwise():
746+
747+
def test_cache_read_elementwise(use_block_name):
746748
sch = tir.Schedule(elementwise, debug_mask="all")
747749
block_b = sch.get_block("B")
748750
block_c = sch.get_block("C")
749-
cached_a = sch.cache_read(block_b, 0, "global")
750-
cached_b = sch.cache_read(block_c, 0, "local")
751+
cached_a = sch.cache_read("B" if use_block_name else block_b, 0, "global")
752+
cached_b = sch.cache_read("C" if use_block_name else block_c, 0, "local")
751753
assert sch.get(cached_a) == sch.get(sch.get_block("A_global"))
752754
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
753755
assert sch.get(block_b) == sch.get(sch.get_block("B"))
@@ -756,87 +758,87 @@ def test_cache_read_elementwise():
756758
verify_trace_roundtrip(sch=sch, mod=elementwise)
757759

758760

759-
def test_cache_read_under_scope():
761+
def test_cache_read_under_scope(use_block_name):
760762
sch = tir.Schedule(access_under_scope, debug_mask="all")
761-
block_b = sch.get_block("B")
762-
block_c = sch.get_block("C")
763+
block_b = "B" if use_block_name else sch.get_block("B")
764+
block_c = "C" if use_block_name else sch.get_block("C")
763765
sch.cache_read(block_b, 0, "local")
764766
sch.cache_read(block_c, 0, "global")
765767
tvm.ir.assert_structural_equal(cache_read_under_scope, sch.mod["main"])
766768
verify_trace_roundtrip(sch=sch, mod=access_under_scope)
767769

768770

769-
def test_cache_read_opaque_access():
771+
def test_cache_read_opaque_access(use_block_name):
770772
sch = tir.Schedule(opaque_access, debug_mask="all")
771-
block = sch.get_block("load_store")
773+
block = "load_store" if use_block_name else sch.get_block("load_store")
772774
sch.cache_read(block, 0, "global")
773775
tvm.ir.assert_structural_equal(cache_read_opaque_access, sch.mod["main"])
774776
verify_trace_roundtrip(sch=sch, mod=opaque_access)
775777

776778

777-
def test_cache_read_location():
779+
def test_cache_read_location(use_block_name):
778780
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
779-
block_b = sch.get_block("B")
781+
block_b = "B" if use_block_name else sch.get_block("B")
780782
sch.cache_read(block_b, 0, "global")
781783
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
782784
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
783785

784786

785-
def test_continuous_cache_read():
787+
def test_continuous_cache_read(use_block_name):
786788
sch = tir.Schedule(elementwise, debug_mask="all")
787-
block_c = sch.get_block("C")
789+
block_c = "C" if use_block_name else sch.get_block("C")
788790
sch.cache_read(block_c, 0, "shared")
789791
sch.cache_read(block_c, 0, "local")
790792
tvm.ir.assert_structural_equal(continuous_cache_read, sch.mod["main"])
791793
verify_trace_roundtrip(sch=sch, mod=elementwise)
792794

793795

794-
def test_cache_read_with_block_predicate():
796+
def test_cache_read_with_block_predicate(use_block_name):
795797
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
796-
block = sch.get_block("consumer")
798+
block = "consumer" if use_block_name else sch.get_block("consumer")
797799
sch.cache_read(block, 0, "shared")
798800
tvm.ir.assert_structural_equal(block_predicate_cache_read, sch.mod["main"])
799801
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
800802

801803

802-
def test_cache_read_non_int32_shape():
804+
def test_cache_read_non_int32_shape(use_block_name):
803805
sch = tir.Schedule(elementwise_shape_int64, debug_mask="all")
804-
block_b = sch.get_block("B")
806+
block_b = "B" if use_block_name else sch.get_block("B")
805807
sch.cache_read(block_b, 0, "global")
806808
tvm.ir.assert_structural_equal(cache_read_shape_int64, sch.mod["main"])
807809
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)
808810

809811

810-
def test_cache_read_fail_multi_producer():
812+
def test_cache_read_fail_multi_producer(use_block_name):
811813
sch = tir.Schedule(func_multi_producer, debug_mask="all")
812-
block_b = sch.get_block("B")
814+
block_b = "B" if use_block_name else sch.get_block("B")
813815
with pytest.raises(tvm.tir.ScheduleError):
814816
sch.cache_read(block_b, 0, "global")
815817

816818

817-
def test_cache_read_fail_index_out_of_bound():
819+
def test_cache_read_fail_index_out_of_bound(use_block_name):
818820
sch = tir.Schedule(elementwise, debug_mask="all")
819-
block_b = sch.get_block("B")
821+
block_b = "B" if use_block_name else sch.get_block("B")
820822
with pytest.raises(tvm.tir.ScheduleError):
821823
sch.cache_read(block_b, 1, "global")
822824

823825

824-
def test_cache_read_fail_invalid_storage_scope():
826+
def test_cache_read_fail_invalid_storage_scope(use_block_name):
825827
sch = tir.Schedule(elementwise, debug_mask="all")
826-
block_b = sch.get_block("B")
828+
block_b = "B" if use_block_name else sch.get_block("B")
827829
with pytest.raises(tvm.tir.ScheduleError):
828830
sch.cache_read(block_b, 0, "test_scope")
829831

830832

831833
########## Testcases for cache_write ##########
832834

833835

834-
def test_cache_write_elementwise():
836+
def test_cache_write_elementwise(use_block_name):
835837
sch = tir.Schedule(elementwise, debug_mask="all")
836838
block_b = sch.get_block("B")
837839
block_c = sch.get_block("C")
838-
cached_b = sch.cache_write(block_b, 0, "local")
839-
cached_c = sch.cache_write(block_c, 0, "global")
840+
cached_b = sch.cache_write("B" if use_block_name else block_b, 0, "local")
841+
cached_c = sch.cache_write("C" if use_block_name else block_c, 0, "global")
840842
assert sch.get(cached_b) == sch.get(sch.get_block("B_local"))
841843
assert sch.get(cached_c) == sch.get(sch.get_block("C_global"))
842844
assert sch.get(block_b) == sch.get(sch.get_block("B"))
@@ -845,10 +847,10 @@ def test_cache_write_elementwise():
845847
verify_trace_roundtrip(sch=sch, mod=elementwise)
846848

847849

848-
def test_cache_write_under_scope():
850+
def test_cache_write_under_scope(use_block_name):
849851
sch = tir.Schedule(access_under_scope, debug_mask="all")
850-
block_a = sch.get_block("A")
851-
block_b = sch.get_block("B")
852+
block_a = "A" if use_block_name else sch.get_block("A")
853+
block_b = "B" if use_block_name else sch.get_block("B")
852854
block_scope = sch.get_block("scope")
853855
sch.cache_write(block_a, 0, "local")
854856
sch.cache_write(block_b, 0, "global")
@@ -857,70 +859,70 @@ def test_cache_write_under_scope():
857859
verify_trace_roundtrip(sch=sch, mod=access_under_scope)
858860

859861

860-
def test_cache_write_opaque_access():
862+
def test_cache_write_opaque_access(use_block_name):
861863
sch = tir.Schedule(opaque_access, debug_mask="all")
862-
block_store = sch.get_block("load_store")
863-
block_opaque = sch.get_block("opaque")
864-
block_match_buffer = sch.get_block("match_buffer")
864+
block_store = "load_store" if use_block_name else sch.get_block("load_store")
865+
block_opaque = "opaque" if use_block_name else sch.get_block("opaque")
866+
block_match_buffer = "match_buffer" if use_block_name else sch.get_block("match_buffer")
865867
sch.cache_write(block_store, 0, "global")
866868
sch.cache_write(block_opaque, 0, "global")
867869
sch.cache_write(block_match_buffer, 0, "global")
868870
tvm.ir.assert_structural_equal(cache_write_opaque_access, sch.mod["main"])
869871
verify_trace_roundtrip(sch=sch, mod=opaque_access)
870872

871873

872-
def test_cache_write_location():
874+
def test_cache_write_location(use_block_name):
873875
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
874-
block_a = sch.get_block("A")
876+
block_a = "A" if use_block_name else sch.get_block("A")
875877
sch.cache_write(block_a, 0, "global")
876878
tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"])
877879
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)
878880

879881

880-
def test_continuous_cache_write():
882+
def test_continuous_cache_write(use_block_name):
881883
sch = tir.Schedule(elementwise, debug_mask="all")
882-
block_b = sch.get_block("B")
884+
block_b = "B" if use_block_name else sch.get_block("B")
883885
sch.cache_write(block_b, 0, "shared")
884886
sch.cache_write(block_b, 0, "local")
885887
tvm.ir.assert_structural_equal(continuous_cache_write, sch.mod["main"])
886888
verify_trace_roundtrip(sch=sch, mod=elementwise)
887889

888890

889-
def test_cache_write_with_block_predicate():
891+
def test_cache_write_with_block_predicate(use_block_name):
890892
# cache write for intermediate buffer
891893
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
892-
block = sch.get_block("producer")
894+
block = "producer" if use_block_name else sch.get_block("producer")
893895
sch.cache_write(block, 0, "shared")
894896
tvm.ir.assert_structural_equal(block_predicate_cache_write_intermediate_buf, sch.mod["main"])
895897
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
896898
# cache write for external buffer
897899
sch = tir.Schedule(func_with_block_predicate, debug_mask="all")
898-
block = sch.get_block("consumer")
900+
block = "consumer" if use_block_name else sch.get_block("consumer")
899901
sch.cache_write(block, 0, "shared")
900902
tvm.ir.assert_structural_equal(block_predicate_cache_write_output_buf, sch.mod["main"])
901903
verify_trace_roundtrip(sch=sch, mod=func_with_block_predicate)
902904

903905

904-
def test_cache_write_fail_multi_producer():
906+
def test_cache_write_fail_multi_producer(use_block_name):
905907
sch = tir.Schedule(func_multi_producer, debug_mask="all")
906-
block_a0 = sch.get_block("A0")
907-
block_a1 = sch.get_block("A1")
908+
block_a0 = "A0" if use_block_name else sch.get_block("A0")
909+
block_a1 = "A1" if use_block_name else sch.get_block("A1")
908910
with pytest.raises(tvm.tir.ScheduleError):
909911
sch.cache_write(block_a0, 0, "global")
910912
with pytest.raises(tvm.tir.ScheduleError):
911913
sch.cache_write(block_a1, 0, "global")
912914

913915

914-
def test_cache_write_fail_index_out_of_bound():
916+
def test_cache_write_fail_index_out_of_bound(use_block_name):
915917
sch = tir.Schedule(elementwise, debug_mask="all")
916-
block_b = sch.get_block("B")
918+
block_b = "B" if use_block_name else sch.get_block("B")
917919
with pytest.raises(tvm.tir.ScheduleError):
918920
sch.cache_write(block_b, 1, "global")
919921

920922

921-
def test_cache_write_fail_invalid_storage_scope():
923+
def test_cache_write_fail_invalid_storage_scope(use_block_name):
922924
sch = tir.Schedule(elementwise, debug_mask="all")
923-
block_b = sch.get_block("B")
925+
block_b = "B" if use_block_name else sch.get_block("B")
924926
with pytest.raises(tvm.tir.ScheduleError):
925927
sch.cache_write(block_b, 0, "test_scope")
926928

0 commit comments

Comments
 (0)