Skip to content

Commit 9c7aaac

Browse files
authored
[TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)
* [TIR] Moved PrimExpr operator overload from op.h to expr.h If a compilation unit includes `<tvm/ir/expr.h>`, but does not include `<tvm/tir/op.h>`, the operator overloads for `ObjectRef` are declared, but the operator overloads for `PrimExpr` are not. In this case, any use of `expr_a == expr_b` would use `ObjectRef`'s implementation and compare reference equality of the two expressions, rather than returning a `PrimExpr` that represents the comparison. By having the operator overloads in the `<tvm/ir/expr.h>` header file, directly adjacent to the `PrimExpr` declaration, the correct overload must be available whenever the `PrimExpr` can be used. Even though this would only impact `operator==`, `operator!=`, and `operator<`, the three operators defined for `ObjectRef`, this PR moves all operator overloads to `expr.h` for consistency. The named version of the operators (e.g. `tvm::add`) do not have overloaded variants, and so they are intentionally kept in `<tvm/tir/op.h>`. * Explicitly convert TVMRetValue to bool in target.cc Needed to avoid ambiguity between `TVMRetValue -> bool` conversion and `TVMRetValue -> int -> PrimExpr` conversion. * Used vector/unordered_set to track BufferInfoExtractor::call_order_ Use of `std::set<Call>` had ambiguity between `operator<` by `PrimExpr` or by `ObjectRef`. The comment for `call_order_` implied that the previous usage of `std::set<Call>` was intended to have a de-duplicated list in the order of occurrence. However, the `std::set` was ordered by `ObjectRef::operator<`, not by insertion order. Switching to using a `vector` for ordering and `unordered_set` for de-duplication resolves this issue, and also removes the use of `operator<`. * Remove C-style cast to fix lint error
1 parent b84ed27 commit 9c7aaac

File tree

4 files changed

+228
-201
lines changed

4 files changed

+228
-201
lines changed

include/tvm/ir/expr.h

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,220 @@ class PrimExpr : public BaseExpr {
133133
TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
134134
};
135135

136+
/*!
137+
* \brief add operator
138+
*
139+
* \param a left operand
140+
* \param b right operand
141+
* \return The result expression.
142+
* \note this function does eager constant folding for
143+
* index types(int32, int64) when possible.
144+
*/
145+
TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
146+
147+
/*!
148+
* \brief subtraction operator
149+
*
150+
* \param a left operand
151+
* \param b right operand
152+
* \return The result expression.
153+
* \note this function does eager constant folding for
154+
* index types(int32, int64) when possible.
155+
*/
156+
TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
157+
158+
/*!
159+
* \brief negation.
160+
*
161+
* \param a input.
162+
* \return The result expression.
163+
* \note this function does eager constant folding for
164+
* index types(int32, int64) when possible.
165+
*/
166+
TVM_DLL PrimExpr operator-(PrimExpr a);
167+
168+
/*!
169+
* \brief multiplication operator
170+
*
171+
* \param a left operand
172+
* \param b right operand
173+
* \return The result expression.
174+
* \note this function does eager constant folding for
175+
* index types(int32, int64) when possible.
176+
*/
177+
TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
178+
179+
/*!
180+
* \brief division operator
181+
*
182+
* \param a left operand
183+
* \param b right operand
184+
* \return The result expression.
185+
* \note this function does eager constant folding for
186+
* index types(int32, int64) when possible.
187+
*/
188+
TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
189+
190+
/*!
191+
* \brief left shift operator
192+
*
193+
* \param a left operand
194+
* \param b right operand
195+
* \return The result expression.
196+
* \note this function does eager constant folding for
197+
* index types(int32, int64) when possible.
198+
*/
199+
TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
200+
201+
/*!
202+
* \brief right shift operator
203+
*
204+
* \param a left operand
205+
* \param b right operand
206+
* \return The result expression.
207+
* \note this function does eager constant folding for
208+
* index types(int32, int64) when possible.
209+
*/
210+
TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
211+
212+
/*!
213+
* \brief greater
214+
*
215+
* \param a left operand
216+
* \param b right operand
217+
* \return The result expression.
218+
* \note this function does eager constant folding for
219+
* index types(int32, int64) when possible.
220+
*/
221+
TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
222+
223+
/*!
224+
* \brief greater_equal
225+
*
226+
* \param a left operand
227+
* \param b right operand
228+
* \return The result expression.
229+
* \note this function does eager constant folding for
230+
* index types(int32, int64) when possible.
231+
*/
232+
TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
233+
234+
/*!
235+
* \brief less
236+
*
237+
* \param a left operand
238+
* \param b right operand
239+
* \return The result expression.
240+
* \note this function does eager constant folding for
241+
* index types(int32, int64) when possible.
242+
*/
243+
TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
244+
245+
/*!
246+
* \brief less_equal
247+
*
248+
* \param a left operand
249+
* \param b right operand
250+
* \return The result expression.
251+
* \note this function does eager constant folding for
252+
* index types(int32, int64) when possible.
253+
*/
254+
TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
255+
256+
/*!
257+
* \brief equal
258+
*
259+
* \param a left operand
260+
* \param b right operand
261+
* \return The result expression.
262+
* \note this function does eager constant folding for
263+
* index types(int32, int64) when possible.
264+
*/
265+
TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
266+
267+
/*!
268+
* \brief not_equal
269+
*
270+
* \param a left operand
271+
* \param b right operand
272+
* \return The result expression.
273+
* \note this function does eager constant folding for
274+
* index types(int32, int64) when possible.
275+
*/
276+
TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
277+
278+
/*!
279+
* \brief and
280+
*
281+
* \param a left operand
282+
* \param b right operand
283+
* \return The result expression.
284+
* \note This operator does eager constant folding.
285+
*/
286+
TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
287+
288+
/*!
289+
* \brief or
290+
*
291+
* \param a left operand
292+
* \param b right operand
293+
* \return The result expression.
294+
* \note This operator does eager constant folding.
295+
*/
296+
TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
297+
298+
/*!
299+
* \brief not
300+
*
301+
* \param a left operand
302+
* \return The result expression.
303+
* \note This operator does eager constant folding.
304+
*/
305+
TVM_DLL PrimExpr operator!(PrimExpr a);
306+
307+
/*!
308+
* \brief take bitwise and of two values
309+
*
310+
* \param a left operand
311+
* \param b right operand
312+
* \return The result expression.
313+
* \note this function does eager constant folding for
314+
* index types(int32, int64) when possible.
315+
*/
316+
TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
317+
318+
/*!
319+
* \brief take bitwise or of two values
320+
*
321+
* \param a left operand
322+
* \param b right operand
323+
* \return The result expression.
324+
* \note this function does eager constant folding for
325+
* index types(int32, int64) when possible.
326+
*/
327+
TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
328+
329+
/*!
330+
* \brief take bitwise xor of two values
331+
*
332+
* \param a left operand
333+
* \param b right operand
334+
* \return The result expression.
335+
* \note this function does eager constant folding for
336+
* index types(int32, int64) when possible.
337+
*/
338+
TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
339+
340+
/*!
341+
* \brief take bitwise negation of two values
342+
*
343+
* \param a the input expression.
344+
* \return The result expression.
345+
* \note this function does eager constant folding for
346+
* index types(int32, int64) when possible.
347+
*/
348+
TVM_DLL PrimExpr operator~(PrimExpr a);
349+
136350
/*!
137351
* \brief Base node of all non-primitive expressions.
138352
*

0 commit comments

Comments
 (0)