Skip to content

Commit 5a31897

Browse files
committed
lint
1 parent 16f230a commit 5a31897

File tree

7 files changed

+26
-21
lines changed

7 files changed

+26
-21
lines changed

python/tvm/tir/function.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,7 @@ def __init__(self, desc_func, intrin_func):
182182

183183
@staticmethod
184184
def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc):
185-
return _ffi_api.TensorIntrinRegister( # type: ignore
186-
name, desc_func, intrin_func
187-
)
185+
return _ffi_api.TensorIntrinRegister(name, desc_func, intrin_func) # type: ignore
188186

189187
@staticmethod
190188
def get(name: str):

python/tvm/tir/schedule/schedule.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,7 @@ def before_blockize(
17661766
T.reads(A[vi, vj])
17671767
T.writes(B[vi, vj])
17681768
B[vi, vj] = A[vi, vj] * T.float32(2)
1769-
1769+
17701770
Create the schedule and do set_scope:
17711771
17721772
.. code-block:: python
@@ -1803,9 +1803,9 @@ def after_blockize(
18031803
----
18041804
blockize requires there is exactly one block under the given loop and the bindings of the
18051805
block are divisible by the subspace represented by the loops starting at the given loop.
1806-
1806+
18071807
"""
1808-
1808+
18091809
return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member
18101810

18111811
@type_checked
@@ -1860,7 +1860,7 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
18601860
A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
18611861
B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
18621862
C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
1863-
1863+
18641864
with T.block("root"):
18651865
vi = T.axis.S(16, 0)
18661866
vj = T.axis.S(16, 0)
@@ -1871,14 +1871,14 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
18711871
vjj = T.axis.S(16, vj + j)
18721872
vkk = T.axis.R(16, vk + k)
18731873
C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
1874-
1875-
1874+
1875+
18761876
@T.prim_func
18771877
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
18781878
A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
18791879
B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
18801880
C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
1881-
1881+
18821882
with T.block("root"):
18831883
vi = T.axis.S(16, 0)
18841884
vj = T.axis.S(16, 0)
@@ -1977,12 +1977,13 @@ def after_tensoirze(
19771977
dtype="handle",
19781978
)
19791979
)
1980-
1980+
19811981
"""
19821982
if isinstance(tensor_intrin, str):
19831983
tensor_intrin = String(tensor_intrin)
19841984
_ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member
1985-
self, loop, tensor_intrin)
1985+
self, loop, tensor_intrin
1986+
)
19861987

19871988
########## Schedule: Annotation ##########
19881989

src/tir/ir/function.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
8080
CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
8181

8282
// check both functions' bodies are directly block
83-
const auto* desc_realize = Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
84-
const auto* intrin_realize = Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
83+
const auto* desc_realize =
84+
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
85+
const auto* intrin_realize =
86+
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
8587
CHECK(desc_realize != nullptr) << "description function's body expect a directly block";
8688
CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block";
8789

@@ -144,7 +146,6 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
144146
return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
145147
});
146148

147-
148149
TVM_REGISTER_GLOBAL("tir.TensorIntrin")
149150
.set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) {
150151
return TensorIntrin(desc_func, intrin_func);

src/tir/schedule/analysis.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,13 @@ bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
447447
using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>;
448448
using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
449449

450-
/* \brief Deep comparison to check if two IR ASTs are equivalent */
450+
/*! \brief Deep comparison to check if two IR ASTs are equivalent */
451451
class TensorizeComparator : public ExprComparator, public StmtComparator {
452452
public:
453+
/*!
454+
* \brief Constructor of TensorizeComparator
455+
* \param assert_mode Whether to raise an error if the two IR ASTs do not match.
456+
*/
453457
explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {}
454458

455459
// Map from rhs buffer to lhs buffer
@@ -507,10 +511,10 @@ class TensorizeComparator : public ExprComparator, public StmtComparator {
507511

508512
protected:
509513
bool assert_mode_;
510-
bool is_scope_block = true, is_inner_block = true;
514+
bool is_scope_block = true;
515+
bool is_inner_block = true;
511516
};
512517

513-
514518
} // namespace tir
515519
} // namespace tvm
516520

src/tir/schedule/analysis/analysis.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) {
13741374
} catch (...) {
13751375
throw InvalidStorageScopeError(self->mod, std::move(storage_scope));
13761376
}
1377-
};
1377+
}
13781378

13791379
/******** Tensorize Comparator ********/
13801380

src/tir/schedule/primitive/blockize_tensorize.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region,
364364
new_region.push_back(relaxed_int_set[i].CoverRange(max_range));
365365
}
366366
return BufferRegion(buffer_region->buffer, std::move(new_region));
367-
};
367+
}
368368

369369
/*!
370370
* \brief Generate the outer block after blockize.

src/tir/schedule/traced_schedule.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_n
364364
}
365365

366366
void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& tensor_intrin) {
367-
LOG(FATAL) << "TensorIntrin cannot be directly passed to meta schedule. Please register the tensor intrin and pass the intrin name instead.";
367+
LOG(FATAL) << "TensorIntrin cannot be directly passed to meta schedule. Please register the "
368+
"tensor intrin and pass the intrin name instead.";
368369
}
369370

370371
/******** Schedule: Annotation ********/

0 commit comments

Comments
 (0)