From f95ad297aacbf84c0566fc3b51563d86e5ae0b72 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Tue, 2 Jan 2018 23:30:06 +0800 Subject: [PATCH 01/20] modified schedule_dataflow_rewrite.cc to fix losing tensor problem --- src/schedule/schedule_dataflow_rewrite.cc | 7 ++++++- .../unittest/test_schedule_schedule_ops.py | 20 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index d1a69ecf0203..1d24f280062f 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,7 +86,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - vsub[tensor] = cache; + //vsub[tensor] = cache; + Tensor sugar_tensor = tensor; + Stage s = operator[](tensor->op); + if (! s->op.same_as(tensor->op)) // can we just always use s->op.ouput(0) to map cache ? + sugar_tensor = s->op.output(0); + vsub[sugar_tensor] = cache; std::unordered_map vmap; for (Operation op : readers) { diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index a85db2a23e86..03b8dbf48c8c 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -182,6 +182,25 @@ def test_schedule_cache(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_middle_cache(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + + C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C') + D = tvm.compute((m, n), lambda i, j: C(i , j) , name='D') + + s = tvm.create_schedule(D.op) + AA = s.cache_read(A, "local", readers=[C]) + BB = s.cache_read(B, "local", readers=[C]) + CC = s.cache_read(C, "local", readers=[D]) + DD = s.cache_write(D, "local") + #s[AA].compute_at(s[CC], CC.op.axis[0]) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + def test_schedule_cache_relayout1(): m = tvm.var('m') @@ -231,6 +250,7 @@ def test_schedule_cache_relayout3(): if __name__ == "__main__": + test_schedule_middle_cache() test_inline_multi_reduce() test_schedule_cache_relayout3() test_schedule_cache_relayout2() From dd15d6556779397cb02d959d3268b5e9843fa0b6 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Tue, 2 Jan 2018 23:45:22 +0800 Subject: [PATCH 02/20] modified schedule_dataflow_rewrite.cc for lint scan --- src/schedule/schedule_dataflow_rewrite.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 1d24f280062f..fe44b60f5173 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,10 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - //vsub[tensor] = cache; Tensor sugar_tensor = tensor; Stage s = operator[](tensor->op); - if (! s->op.same_as(tensor->op)) // can we just always use s->op.ouput(0) to map cache ? + if (!(s->op.same_as(tensor->op))) sugar_tensor = s->op.output(0); vsub[sugar_tensor] = cache; From c2747cb254657cb044e4c4b32625214b9ca5edfe Mon Sep 17 00:00:00 2001 From: libing4752 Date: Tue, 2 Jan 2018 23:48:57 +0800 Subject: [PATCH 03/20] modified schedule_dataflow_rewrite.cc for lint scan --- src/schedule/schedule_dataflow_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index fe44b60f5173..938c61332b88 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -88,7 +88,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, std::unordered_map vsub; Tensor sugar_tensor = tensor; Stage s = operator[](tensor->op); - if (!(s->op.same_as(tensor->op))) + if (!(s->op.same_as(tensor->op))) sugar_tensor = s->op.output(0); vsub[sugar_tensor] = cache; From 70bcd580d4cfc62aaceef20da7a5c6e87fc0c8e8 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Wed, 3 Jan 2018 20:36:39 +0800 Subject: [PATCH 04/20] using tensor's value_index to index output of stage op --- src/schedule/schedule_dataflow_rewrite.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 938c61332b88..b58df9d0481f 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,10 +86,8 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - Tensor sugar_tensor = tensor; Stage s = operator[](tensor->op); - if (!(s->op.same_as(tensor->op))) - sugar_tensor = s->op.output(0); + Tensor sugar_tensor = s->op.output(tensor->value_index); vsub[sugar_tensor] = cache; std::unordered_map vmap; From 6d94540aa3b340ebe372292958e8c7e2bb9e602e Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 22:02:12 +0800 Subject: [PATCH 05/20] repare address offset for different kinds of dtype --- src/pass/storage_rewrite.cc | 24 +++++----- tests/python/unittest/test_address_offset.py | 48 ++++++++++++++++++++ 2 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 tests/python/unittest/test_address_offset.py diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 7215c3f97a43..8f354ed02d9e 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -580,29 +580,29 @@ class StoragePlanRewriter : public IRMutator { if (info.defined()) { align = (info->max_simd_bits + e->elem_type.bits() - 1) / e->elem_type.bits(); } - uint64_t total_elem = e->const_nbits / e->elem_type.bits(); - if (total_elem % align != 0) { - total_elem += align - (total_elem % align); + uint64_t total_bits = e->const_nbits; + if (total_bits % info->max_simd_bits != 0) { + total_bits += info->max_simd_bits - (total_bits % info->max_simd_bits); } e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { CHECK_NE(e->const_nbits, 0U); - CHECK_NE(total_elem, 0U); - size_t num_elem = child->const_nbits / child->elem_type.bits(); - child->elem_offset = total_elem; + CHECK_NE(total_bits, 0U); + child->elem_offset = total_bits / child->elem_type.bits(); child->alloc_var = e->alloc_var; - total_elem += num_elem; - if (total_elem % align != 0) { - total_elem += align - (total_elem % align); + total_bits += child->const_nbits; + if (total_bits % info->max_simd_bits != 0) { + total_bits += info->max_simd_bits - (total_bits % info->max_simd_bits); } } - Expr alloc_size = make_const(e->allocs[0]->extents[0].type(), - total_elem); + auto alloc_type = e->allocs[0]->extents[0].type(); + Expr alloc_size = make_const(alloc_type, + (total_bits + alloc_type.bits() - 1) / alloc_type.bits()); e->new_alloc = Allocate::make( e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate::make(0)); if (info.defined()) { - CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) + CHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); } } diff --git a/tests/python/unittest/test_address_offset.py b/tests/python/unittest/test_address_offset.py new file mode 100644 index 000000000000..fbd7150fd1aa --- /dev/null +++ b/tests/python/unittest/test_address_offset.py @@ -0,0 +1,48 @@ +import tvm + +local_buf = "local.a" + +@tvm.register_func("tvm.info.mem.%s" % local_buf) +def meminfo_cache(): + return tvm.make.node( + "MemoryInfo", + unit_bits=8, + max_simd_bits=32, + max_num_bits=128 * 128 * 128, + head_address = None + ) +def cast_intrin(src_dtype, dst_dtype): + shape = (256,) + a = tvm.placeholder(shape, name = "a", dtype = src_dtype) + + func = tvm.compute(shape, lambda i: a[i].astype(dst_dtype), name = "a_cast") + in_buff = tvm.decl_buffer(shape, dtype=src_dtype, name='buffer_src', data=None, scope='local', data_alignment=-1, offset_factor=0) + out_buff = tvm.decl_buffer(shape, dtype=dst_dtype, name='buffer_dst', data=None, scope='local', data_alignment=-1, offset_factor=0) + + def replace_intrin(ins, outs): + i = ins[0] + o = outs[0] + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(dst_dtype, "cast", + i.access_ptr("r", "int32"), + o.access_ptr('rw',"int32"))) + return ib.get() + + return tvm.decl_tensor_intrin(func.op, replace_intrin, name='a', binds={i:in_buff, o:out_buff}) + +src_dtype = "float16" +dst_dtype = "float16" + +shape = (256,) +a = tvm.placeholder(shape, name = "a", dtype = src_dtype) +nocast = tvm.compute(shape, lambda i: a[i].astype(dst_dtype), name = "a_no_cast") +dst_dtype = "int32" +casts32 = tvm.compute(shape, lambda i: nocast[i].astype(dst_dtype), name = "a_cast_s32") +s = tvm.create_schedule(casts32.op) +print "shit" +s.cache_read(a, local_buf, [nocast]) +s.cache_write(casts32, local_buf) +s.cache_write(nocast, local_buf) + +s[nocast].compute_inline() +print tvm.lower(s, [a,casts32], simple_mode = True) From 370b588babc0eb9aa3a0bcabfc1c5d9c67c738fa Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 22:14:44 +0800 Subject: [PATCH 06/20] bc --- src/pass/storage_rewrite.cc | 4 ---- tests/python/unittest/test_pass_storage_rewrite.py | 3 ++- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 8f354ed02d9e..bffe92ae4c58 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -576,10 +576,6 @@ class StoragePlanRewriter : public IRMutator { // allocate with element type. CHECK_NE(e->const_nbits, 0U); MemoryInfo info = GetMemoryInfo(e->scope.to_string()); - size_t align = 1; - if (info.defined()) { - align = (info->max_simd_bits + e->elem_type.bits() - 1) / e->elem_type.bits(); - } uint64_t total_bits = e->const_nbits; if (total_bits % info->max_simd_bits != 0) { total_bits += info->max_simd_bits - (total_bits % info->max_simd_bits); diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 1e4dda684eb3..5df7f751763d 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -91,7 +91,6 @@ def test_storage_combine(): s = tvm.create_schedule(B.op) for S in stages[:-1]: s[S].set_scope("global:tag") - bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -100,7 +99,9 @@ def test_storage_combine(): stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.Simplify(stmt) + print stmt stmt = tvm.ir_pass.StorageRewrite(stmt) + print stmt num_alloc = [0] def verify(n): if isinstance(n, tvm.stmt.Allocate): From 4357ca89cbfea80fca380a6ecc3b2cac79e51a44 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 22:30:13 +0800 Subject: [PATCH 07/20] aaa --- src/pass/storage_rewrite.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index bffe92ae4c58..edf1978aef31 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -572,17 +572,23 @@ class StoragePlanRewriter : public IRMutator { } // New allocation for merged data void NewAllocTagMerged(StorageEntry* e) { + #include CHECK_NE(e->scope.tag.length(), 0U); // allocate with element type. CHECK_NE(e->const_nbits, 0U); MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_bits = e->const_nbits; + std::cout << total_bits << std::endl; + std::cout << "shit" << std::endl; + std::cout << 1 / info->max_simd_bits << std::endl; if (total_bits % info->max_simd_bits != 0) { total_bits += info->max_simd_bits - (total_bits % info->max_simd_bits); } + std::cout << total_bits << std::endl; e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { - CHECK_NE(e->const_nbits, 0U); + std::cout << total_bits << std::endl; + CHECK_NE(child->const_nbits, 0U); CHECK_NE(total_bits, 0U); child->elem_offset = total_bits / child->elem_type.bits(); child->alloc_var = e->alloc_var; From f0c79d677fddf6b03ec66976bd3730b425cff4b8 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 22:33:48 +0800 Subject: [PATCH 08/20] aaaaa --- src/pass/storage_rewrite.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index edf1978aef31..7fce1429546c 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -578,11 +578,12 @@ class StoragePlanRewriter : public IRMutator { CHECK_NE(e->const_nbits, 0U); MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_bits = e->const_nbits; - std::cout << total_bits << std::endl; - std::cout << "shit" << std::endl; - std::cout << 1 / info->max_simd_bits << std::endl; - if (total_bits % info->max_simd_bits != 0) { - total_bits += info->max_simd_bits - (total_bits % info->max_simd_bits); + size_t align = 1; + if (info.define()){ + align = info->max_simd_bits; + } + if (total_bits % align != 0) { + total_bits += align - (total_bits % align); } std::cout << total_bits << std::endl; e->alloc_var = e->allocs[0]->buffer_var; @@ -593,8 +594,8 @@ class StoragePlanRewriter : public IRMutator { child->elem_offset = total_bits / child->elem_type.bits(); child->alloc_var = e->alloc_var; total_bits += child->const_nbits; - if (total_bits % info->max_simd_bits != 0) { - total_bits += info->max_simd_bits - (total_bits % info->max_simd_bits); + if (total_bits % align != 0) { + total_bits += align - (total_bits % align); } } auto alloc_type = e->allocs[0]->extents[0].type(); From 9135f4fa9d7eaf0dbb28f9aee58c452314947bb9 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 23:23:59 +0800 Subject: [PATCH 09/20] repare address for different dtypes --- src/pass/storage_rewrite.cc | 5 +--- .../unittest/test_pass_storage_rewrite.py | 27 +++++++++++++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 7fce1429546c..959b7967bf1b 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -572,23 +572,20 @@ class StoragePlanRewriter : public IRMutator { } // New allocation for merged data void NewAllocTagMerged(StorageEntry* e) { - #include CHECK_NE(e->scope.tag.length(), 0U); // allocate with element type. CHECK_NE(e->const_nbits, 0U); MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_bits = e->const_nbits; size_t align = 1; - if (info.define()){ + if (info.defined()){ align = info->max_simd_bits; } if (total_bits % align != 0) { total_bits += align - (total_bits % align); } - std::cout << total_bits << std::endl; e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { - std::cout << total_bits << std::endl; CHECK_NE(child->const_nbits, 0U); CHECK_NE(total_bits, 0U); child->elem_offset = total_bits / child->elem_type.bits(); diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 5df7f751763d..59ca1ed66ba4 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -49,6 +49,30 @@ def verify(n): tvm.ir_pass.PostOrderVisit(body, verify) assert num_alloc[0] == 1 +def test_alloc_different_dtypes(): + ib = tvm.ir_builder.create() + n = tvm.var("n") + global_a = tvm.placeholder((256,), name = "global_a", dtype = "float32") + with ib.for_range(0, 1, name="i") as i: + with ib.for_range(0, 256, name="j") as j: + A = ib.allocate("float32", 256, name="A", scope="local.L0A") + A[j] = 2.5 + with ib.for_range(0, 256, name="j") as j: + B = ib.allocate("int16", 256, name="B", scope="local.L0A") + B[j] = tvm.const(1, dtype = "int16") + with ib.for_range(0, 256, name="j") as j: + C = ib.allocate("float16", 256, name="C", scope="local.L0A") + C[j] = tvm.const(1, dtype = "float16") + with ib.for_range(0, 256, name="j") as j: + D = ib.allocate("uint16", 256, name="D", scope="local.L0A") + D[j] = tvm.const(1, dtype = "uint16") + + body = ib.get() + body = tvm.ir_pass.StorageRewrite(body) + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + assert n.extents[0].value == 640 + tvm.ir_pass.PostOrderVisit(body, verify) def test_inplace_rule(): @@ -99,9 +123,7 @@ def test_storage_combine(): stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.Simplify(stmt) - print stmt stmt = tvm.ir_pass.StorageRewrite(stmt) - print stmt num_alloc = [0] def verify(n): if isinstance(n, tvm.stmt.Allocate): @@ -176,6 +198,7 @@ def test_parallel_alloc(): if __name__ == "__main__": test_alloc_seq() + test_alloc_different_dtypes() test_inplace_rule() test_storage_share() test_parallel_alloc() From a26cee5b5a20d4e45b836f4879e8b1fd745731d4 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 23:37:35 +0800 Subject: [PATCH 10/20] remove nonsense files --- tests/python/unittest/test_address_offset.py | 48 -------------------- 1 file changed, 48 deletions(-) delete mode 100644 tests/python/unittest/test_address_offset.py diff --git a/tests/python/unittest/test_address_offset.py b/tests/python/unittest/test_address_offset.py deleted file mode 100644 index fbd7150fd1aa..000000000000 --- a/tests/python/unittest/test_address_offset.py +++ /dev/null @@ -1,48 +0,0 @@ -import tvm - -local_buf = "local.a" - -@tvm.register_func("tvm.info.mem.%s" % local_buf) -def meminfo_cache(): - return tvm.make.node( - "MemoryInfo", - unit_bits=8, - max_simd_bits=32, - max_num_bits=128 * 128 * 128, - head_address = None - ) -def cast_intrin(src_dtype, dst_dtype): - shape = (256,) - a = tvm.placeholder(shape, name = "a", dtype = src_dtype) - - func = tvm.compute(shape, lambda i: a[i].astype(dst_dtype), name = "a_cast") - in_buff = tvm.decl_buffer(shape, dtype=src_dtype, name='buffer_src', data=None, scope='local', data_alignment=-1, offset_factor=0) - out_buff = tvm.decl_buffer(shape, dtype=dst_dtype, name='buffer_dst', data=None, scope='local', data_alignment=-1, offset_factor=0) - - def replace_intrin(ins, outs): - i = ins[0] - o = outs[0] - ib = tvm.ir_builder.create() - ib.emit(tvm.call_extern(dst_dtype, "cast", - i.access_ptr("r", "int32"), - o.access_ptr('rw',"int32"))) - return ib.get() - - return tvm.decl_tensor_intrin(func.op, replace_intrin, name='a', binds={i:in_buff, o:out_buff}) - -src_dtype = "float16" -dst_dtype = "float16" - -shape = (256,) -a = tvm.placeholder(shape, name = "a", dtype = src_dtype) -nocast = tvm.compute(shape, lambda i: a[i].astype(dst_dtype), name = "a_no_cast") -dst_dtype = "int32" -casts32 = tvm.compute(shape, lambda i: nocast[i].astype(dst_dtype), name = "a_cast_s32") -s = tvm.create_schedule(casts32.op) -print "shit" -s.cache_read(a, local_buf, [nocast]) -s.cache_write(casts32, local_buf) -s.cache_write(nocast, local_buf) - -s[nocast].compute_inline() -print tvm.lower(s, [a,casts32], simple_mode = True) From 42701893a910369ae576a34775f01d526ed532eb Mon Sep 17 00:00:00 2001 From: libing4752 Date: Mon, 22 Jan 2018 23:44:02 +0800 Subject: [PATCH 11/20] add whitespace of line 581 --- src/pass/storage_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 959b7967bf1b..b37136173e17 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_bits = e->const_nbits; size_t align = 1; - if (info.defined()){ + if (info.defined()) { align = info->max_simd_bits; } if (total_bits % align != 0) { From d65953a8ad8f6f013ab4fd8ee1cafa1d1cb9cf90 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Tue, 23 Jan 2018 21:05:53 +0800 Subject: [PATCH 12/20] use base alloc elem_type --- src/pass/storage_rewrite.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index b37136173e17..4ce3047db575 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -595,9 +595,10 @@ class StoragePlanRewriter : public IRMutator { total_bits += align - (total_bits % align); } } - auto alloc_type = e->allocs[0]->extents[0].type(); + auto alloc_type = e->elem_type; + uint64_t type_bits = alloc_type.bits() * alloc_type.lanes(); Expr alloc_size = make_const(alloc_type, - (total_bits + alloc_type.bits() - 1) / alloc_type.bits()); + (total_bits + type_bits - 1) / type_bits); e->new_alloc = Allocate::make( e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate::make(0)); From 111f16b6562d43ece3b455a10843fa46df0f916b Mon Sep 17 00:00:00 2001 From: libing4752 Date: Tue, 23 Jan 2018 22:25:39 +0800 Subject: [PATCH 13/20] enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits --- src/pass/storage_rewrite.cc | 2 +- .../unittest/test_pass_storage_rewrite.py | 68 ++++++++++++------- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 4ce3047db575..9b38e252276d 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -597,7 +597,7 @@ class StoragePlanRewriter : public IRMutator { } auto alloc_type = e->elem_type; uint64_t type_bits = alloc_type.bits() * alloc_type.lanes(); - Expr alloc_size = make_const(alloc_type, + Expr alloc_size = make_const(Int(64), (total_bits + type_bits - 1) / type_bits); e->new_alloc = Allocate::make( e->alloc_var, e->elem_type, {alloc_size}, const_true(), diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 59ca1ed66ba4..2df2d922e6f5 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -50,29 +50,51 @@ def verify(n): assert num_alloc[0] == 1 def test_alloc_different_dtypes(): - ib = tvm.ir_builder.create() - n = tvm.var("n") - global_a = tvm.placeholder((256,), name = "global_a", dtype = "float32") - with ib.for_range(0, 1, name="i") as i: - with ib.for_range(0, 256, name="j") as j: - A = ib.allocate("float32", 256, name="A", scope="local.L0A") - A[j] = 2.5 - with ib.for_range(0, 256, name="j") as j: - B = ib.allocate("int16", 256, name="B", scope="local.L0A") - B[j] = tvm.const(1, dtype = "int16") - with ib.for_range(0, 256, name="j") as j: - C = ib.allocate("float16", 256, name="C", scope="local.L0A") - C[j] = tvm.const(1, dtype = "float16") - with ib.for_range(0, 256, name="j") as j: - D = ib.allocate("uint16", 256, name="D", scope="local.L0A") - D[j] = tvm.const(1, dtype = "uint16") - - body = ib.get() - body = tvm.ir_pass.StorageRewrite(body) - def verify(n): - if isinstance(n, tvm.stmt.Allocate): - assert n.extents[0].value == 640 - tvm.ir_pass.PostOrderVisit(body, verify) + def stmt_generater(dtype_list, length): + ib = tvm.ir_builder.create() + base_dtype = dtype_list[0] + global_a = tvm.placeholder((length,), name = "global_a", dtype = base_dtype) + for index, dtype in enumerate(dtype_list): + with ib.for_range(0, length, name="j") as j: + A = ib.allocate(dtype, length, name="A_" + str(index), scope="local.L0A") + A[j] = tvm.const(1, dtype = dtype) + return ib.get() + + def dtype_bit_len(dtype): + index = 0 + for i in dtype: + if i.isdigit(): + break + index += 1 + return int(dtype[index:]) + + def offset_generater(dtype_list, length): + dtype_len_list = [dtype_bit_len(i) for i in dtype_list] + base_len = dtype_len_list[0] + return sum([i * length / base_len for i in dtype_len_list]) + + def dtype_test(dtype_list, length): + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + assert n.extents[0].value == offset + + body = stmt_generater(dtype_list, length) + offset = offset_generater(dtype_list, length) + body = tvm.ir_pass.StorageRewrite(body) + tvm.ir_pass.PostOrderVisit(body, verify) + + length = 1024 + dtype_list = ["float16", "int32", "uint16", "int8"] + dtype_test(dtype_list, length) + + dtype_list = ["float32", "int32", "uint16", "int8"] + dtype_test(dtype_list, length) + + dtype_list = ["float64", "int32", "uint16", "int8"] + dtype_test(dtype_list, length) + + dtype_list = ["int8", "int32", "uint16", "uint8"] + dtype_test(dtype_list, length) def test_inplace_rule(): From 8436366a60e38aebb0630de190ce60642a6cef9e Mon Sep 17 00:00:00 2001 From: libing4752 Date: Wed, 24 Jan 2018 07:55:55 +0800 Subject: [PATCH 14/20] use extends[0]->type() as dtype of offset --- src/pass/storage_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 9b38e252276d..d9ab7ec71984 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -597,7 +597,7 @@ class StoragePlanRewriter : public IRMutator { } auto alloc_type = e->elem_type; uint64_t type_bits = alloc_type.bits() * alloc_type.lanes(); - Expr alloc_size = make_const(Int(64), + Expr alloc_size = make_const(e->allocs[0]->extents[0].type(), (total_bits + type_bits - 1) / type_bits); e->new_alloc = Allocate::make( e->alloc_var, e->elem_type, {alloc_size}, const_true(), From 942fdc591ea27bddca327c44fc7b7f1b0a94247e Mon Sep 17 00:00:00 2001 From: libing4752 Date: Wed, 24 Jan 2018 07:57:55 +0800 Subject: [PATCH 15/20] clear program writes --- src/pass/storage_rewrite.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index d9ab7ec71984..f052a9b05b90 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -595,8 +595,7 @@ class StoragePlanRewriter : public IRMutator { total_bits += align - (total_bits % align); } } - auto alloc_type = e->elem_type; - uint64_t type_bits = alloc_type.bits() * alloc_type.lanes(); + uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); Expr alloc_size = make_const(e->allocs[0]->extents[0].type(), (total_bits + type_bits - 1) / type_bits); e->new_alloc = Allocate::make( From 91c4769a9f16fd6eed747ef1544d585cdb468ee0 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Fri, 2 Feb 2018 20:40:32 +0800 Subject: [PATCH 16/20] enhance inject_copy_intin to support of pragma stmt with no loops --- src/pass/inject_copy_intrin.cc | 103 +++++++++++------- .../unittest/test_pass_inject_copy_intrin.py | 20 ++++ 2 files changed, 84 insertions(+), 39 deletions(-) diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index cafcddcb9dde..b9ccdd74b091 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator { private: bool MatchCopyPattern(Stmt stmt, Stmt *out) { Stmt body = stmt; + bool is_single_point_copy = false; // strip the loops std::vector loops; @@ -48,6 +49,9 @@ class CopyIntrinInjector : public IRMutator { loops.push_back(op); body = op->body; } + if (0 == loops.size()){ + is_single_point_copy = true; + } const Store* store = body.as(); if (store == nullptr) return false; const Select* select = store->value.as(); const Cast* cast = store->value.as(); const Load* load = store->value.as(); - + if (0 == loops.size()) { + is_single_point_copy = true; + CHECK(select == nullptr); + } // for now only support true condition matching if (select != nullptr) { load = select->true_value.as(); @@ -72,19 +73,17 @@ class CopyIntrinInjector : public IRMutator { for (const For* op : loops) { loop_vars.push_back(Var(op->loop_var.node_)); } - Array store_strides; - Array load_strides; + Array + store_strides = arith::DetectLinearEquation(store->index, loop_vars); + Array + load_strides = arith::DetectLinearEquation(load->index, loop_vars); + if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; - + auto loop_var_size = loop_vars.size(); if (is_single_point_copy) { - store_strides.push_back(make_const(Int(32), 1)); - load_strides.push_back(make_const(Int(32), 1)); + loop_var_size = 1; dst_shape.push_back(make_const(Int(32), 1)); } else { - store_strides = arith::DetectLinearEquation(store->index, loop_vars); - load_strides = arith::DetectLinearEquation(load->index, loop_vars); - if (load_strides.size() == 0 || store_strides.size() == 0) return false; - for (const For* op : loops) { dst_shape.push_back(op->extent); } @@ -92,52 +91,41 @@ class CopyIntrinInjector : public IRMutator { Array src_shape = dst_shape; Array pad_before, pad_after; Expr pad_value; - Expr src_elem_offset = load_strides[loop_vars.size()]; - if (!is_single_point_copy) { - if (select != nullptr) { - Array clip_bound = - arith::DetectClipBound(select->condition, loop_vars); - pad_value = select->false_value; - if (clip_bound.size() == 0) return false; - CHECK_EQ(src_shape.size(), loop_vars.size()); - CHECK_EQ(clip_bound.size(), loop_vars.size() * 2); - for (size_t i = 0; i < src_shape.size(); ++i) { - Expr min_value = clip_bound[2 * i]; - Expr max_value = clip_bound[2 * i + 1]; - Type t = loop_vars[i].type(); - Expr svalue = src_shape[i]; - if (min_value.defined()) { - Expr pbefore = Simplify(Max::make(min_value, make_zero(t))); - src_elem_offset = src_elem_offset + pbefore * load_strides[i]; - svalue = svalue - pbefore; - pad_before.push_back(pbefore); - } else { - pad_before.push_back(make_zero(t)); - } - if (max_value.defined()) { - Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1), - make_zero(t))); - svalue = svalue - pafter; - pad_after.push_back(pafter); - } else { - pad_after.push_back(make_zero(t)); - } - src_shape.Set(i, Simplify(svalue)); + Expr src_elem_offset = load_strides[loop_var_size]; + if (select != nullptr) { + Array clip_bound = + arith::DetectClipBound(select->condition, loop_vars); + pad_value = select->false_value; + if (clip_bound.size() == 0) return false; + CHECK_EQ(src_shape.size(), loop_vars.size()); + CHECK_EQ(clip_bound.size(), loop_vars.size() * 2); + for (size_t i = 0; i < src_shape.size(); ++i) { + Expr min_value = clip_bound[2 * i]; + Expr max_value = clip_bound[2 * i + 1]; + Type t = loop_vars[i].type(); + Expr svalue = src_shape[i]; + if (min_value.defined()) { + Expr pbefore = Simplify(Max::make(min_value, make_zero(t))); + src_elem_offset = src_elem_offset + pbefore * load_strides[i]; + svalue = svalue - pbefore; + pad_before.push_back(pbefore); + } else { + pad_before.push_back(make_zero(t)); + } + if (max_value.defined()) { + Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1), + make_zero(t))); + svalue = svalue - pafter; + pad_after.push_back(pafter); + } else { + pad_after.push_back(make_zero(t)); } - src_elem_offset = Simplify(src_elem_offset); + src_shape.Set(i, Simplify(svalue)); } + src_elem_offset = Simplify(src_elem_offset); } CHECK_EQ(load_strides.size(), store_strides.size()); - CHECK_EQ(load_strides.size(), loop_vars.size() + 1); - auto loop_var_size = loop_vars.size(); - Expr dst_elem_offset; - if (is_single_point_copy) { - loop_var_size = 1; - src_elem_offset = load->index; - dst_elem_offset = store->index; - } else { - dst_elem_offset = store_strides[loop_vars.size()]; - } + CHECK_EQ(load_strides.size(), loop_var_size + 1); Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); Buffer dst = BufferNode::make( @@ -145,7 +133,7 @@ class CopyIntrinInjector : public IRMutator { store->value.type(), dst_shape, dst_strides, - dst_elem_offset, + store_strides[loop_var_size], store->buffer_var->name_hint, GetStorageScope(store->buffer_var.get()), 0, 0); From 4a52de0508480469378279108c3702a53e2b92ca Mon Sep 17 00:00:00 2001 From: libing4752 Date: Sun, 4 Feb 2018 10:37:06 +0800 Subject: [PATCH 20/20] fix cpplint errors --- src/arithmetic/detect_linear_equation.cc | 2 +- src/pass/inject_copy_intrin.cc | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index dad37a9281fb..642a866866d2 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -127,7 +127,7 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { Array coeff; if (0 == vars.size()) { - coeff.push_back(make_const(Int(32),1)); + coeff.push_back(make_const(Int(32), 1)); } else { for (Var v : vars) { LinearEqEntry ret; diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 78b6c74b486c..ba44253a0cd5 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -49,7 +49,6 @@ class CopyIntrinInjector : public IRMutator { loops.push_back(op); body = op->body; } - const Store* store = body.as(); if (store == nullptr) return false; const Select* select = store->value.as