Skip to content

Commit 6ba1110

Browse files
committed
[Relax][OP] More high-level operators (apache#18)
* relax.cumsum * Legalizer for expand_dims * relax.trilu * relax.cast * Legalizer for batch_norm and flatten * relax.take * relax.full * relax.split * relax.broadcast_to * relax.strided_slice * relax.image.resize2d * relax.nn.max_pool2d * relax.nn.adaptive_avg_pool2d
1 parent c19878a commit 6ba1110

File tree

21 files changed

+2728
-112
lines changed

21 files changed

+2728
-112
lines changed

include/tvm/relax/op_attr_types.h

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,171 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
403403
}
404404
}; // struct ReduceAttrs
405405

406+
/*! \brief Attributes used in cumsum operator */
407+
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
408+
Optional<Integer> axis;
409+
410+
TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") {
411+
TVM_ATTR_FIELD(axis).set_default(Optional<Integer>{NullOpt});
412+
}
413+
}; // struct CumsumAttrs
414+
415+
/*! \brief Attributes used in trilu operator */
416+
struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
417+
int k;
418+
bool is_upper;
419+
420+
TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") {
421+
TVM_ATTR_FIELD(k).describe(
422+
"The number of diagonals above or below the main diagonal to exclude or include.");
423+
TVM_ATTR_FIELD(is_upper).set_default(true).describe(
424+
"Whether to keep the upper or lower half of the diagonal.");
425+
}
426+
}; // struct TriluAttrs
427+
428+
/*! \brief Attributes used in cast operator */
429+
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
430+
DataType dtype;
431+
432+
TVM_DECLARE_ATTRS(CastAttrs, "relax.attrs.CastAttrs") {
433+
TVM_ATTR_FIELD(dtype).describe("Target data type");
434+
}
435+
}; // struct CastAttrs.
436+
437+
/*! \brief Attributes used in take operator */
438+
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
439+
Optional<Integer> axis;
440+
int batch_dims;
441+
String mode;
442+
443+
TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") {
444+
TVM_ATTR_FIELD(axis)
445+
.set_default(Optional<Integer>{NullOpt})
446+
.describe("The axis over which to select values.");
447+
TVM_ATTR_FIELD(batch_dims)
448+
.set_default(0)
449+
.describe("The batch_dims over which to select values.");
450+
TVM_ATTR_FIELD(mode).set_default("clip").describe(
451+
"Specify how out-of-bound indices will behave."
452+
"clip - clip to the range (default)"
453+
"wrap - wrap around the indices"
454+
"fast - no clip or wrap around (user must make sure indices are in-bound)");
455+
}
456+
}; // struct TakeAttrs
457+
458+
/*! \brief Attributes used in full operator */
459+
struct FullAttrs : public tvm::AttrsNode<FullAttrs> {
460+
DataType dtype;
461+
462+
TVM_DECLARE_ATTRS(FullAttrs, "relax.attrs.FullAttrs") {
463+
TVM_ATTR_FIELD(dtype).describe("Target data type.");
464+
}
465+
}; // struct FullAttrs
466+
467+
/*! \brief Attributes used in split operator */
468+
struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
469+
ObjectRef indices_or_sections;
470+
int axis;
471+
472+
TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") {
473+
TVM_ATTR_FIELD(indices_or_sections)
474+
.describe("The input array of indices or the number of split sections.");
475+
TVM_ATTR_FIELD(axis).describe("The axis to be splitted");
476+
}
477+
}; // struct SplitAttrs
478+
479+
/*! \brief Attributes used in strided_slice operator */
480+
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
481+
Array<PrimExpr> begin;
482+
Array<PrimExpr> end;
483+
Optional<Array<PrimExpr>> strides;
484+
Optional<Array<Integer>> axes;
485+
String slice_mode;
486+
487+
TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
488+
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
489+
TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
490+
TVM_ATTR_FIELD(strides).describe(
491+
"Stride values of the slice, a stride can be negative, which causes a reverse slice.");
492+
TVM_ATTR_FIELD(axes).describe(
493+
"Axes along which slicing is applied. When it is specified, the length of begin, end, "
494+
"strides, and axes must be equal.");
495+
TVM_ATTR_FIELD(slice_mode)
496+
.set_default("end")
497+
.describe(
498+
"The slice mode [end, size]."
499+
"end - The default slice mode, ending indices for the slice."
500+
"size - The input strides will be ignored, input end in this mode indicates the size"
501+
"of a slice starting at the location specified by begin. If end[i] is -1,"
502+
"all remaining elements in that dimension are included in the slice");
503+
}
504+
}; // struct StridedSliceAttrs
505+
506+
/*! \brief Attributes used in image resize2d operator */
507+
struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
508+
Array<PrimExpr> size;
509+
Array<FloatImm> roi;
510+
String layout;
511+
String method;
512+
String coordinate_transformation_mode;
513+
String rounding_method;
514+
double cubic_alpha;
515+
int cubic_exclude;
516+
double extrapolation_value;
517+
518+
TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") {
519+
TVM_ATTR_FIELD(size).describe("Output image size.");
520+
TVM_ATTR_FIELD(roi).describe(
521+
"Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
522+
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
523+
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
524+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
525+
"dimensions respectively. Resize is applied on the 'H' and"
526+
"'W' dimensions.");
527+
TVM_ATTR_FIELD(method).set_default("linear").describe(
528+
"Specify the mode to use for scaling."
529+
"nearest_neighbor - Nearest Neighbor"
530+
"linear - Bilinear Interpolation"
531+
"cubic - Bicubic Interpolation");
532+
TVM_ATTR_FIELD(coordinate_transformation_mode)
533+
.set_default("half_pixel")
534+
.describe(
535+
"Describes how to transform the coordinate in the resized tensor"
536+
"to the coordinate in the original tensor."
537+
"Refer to the ONNX Resize operator specification for details"
538+
"Available options are half_pixel, align_corners and asymmetric");
539+
TVM_ATTR_FIELD(rounding_method)
540+
.set_default("round")
541+
.describe(
542+
"indicates how to find the \"nearest\" pixel in nearest_neighbor method"
543+
"Available options are round, floor, and ceil.");
544+
TVM_ATTR_FIELD(cubic_alpha)
545+
.set_default(-0.5)
546+
.describe("Spline Coefficient for Bicubic Interpolation");
547+
TVM_ATTR_FIELD(cubic_exclude)
548+
.set_default(0)
549+
.describe("Flag to exclude exterior of the image during bicubic interpolation");
550+
TVM_ATTR_FIELD(extrapolation_value)
551+
.set_default(0.0)
552+
.describe("Value to return when roi is outside of the image");
553+
}
554+
}; // struct Resize2dAttrs
555+
556+
/*! \brief Attributes for 2d adaptive pool operator */
557+
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
558+
Optional<Array<PrimExpr>> output_size;
559+
String layout;
560+
561+
TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") {
562+
TVM_ATTR_FIELD(output_size).describe("Output height and width.");
563+
TVM_ATTR_FIELD(layout).describe(
564+
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
565+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
566+
"dimensions respectively. Pooling is applied on the 'H' and"
567+
"'W' dimensions.");
568+
}
569+
}; // struct AdaptivePool2DAttrs
570+
406571
} // namespace relax
407572
} // namespace tvm
408573
#endif // TVM_RELAX_OP_ATTR_TYPES_H_

python/tvm/relax/block_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def _convert_te_arg_helper(arg):
213213
), "emit_te only supports dict with string as the key currently"
214214
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
215215
elif (
216-
isinstance(arg, (int, float, str, tir.IntImm, tvm.ir.Type, tvm.ir.Attrs))
216+
isinstance(
217+
arg, (int, float, str, tir.IntImm, tir.FloatImm, tvm.ir.Type, tvm.ir.Attrs)
218+
)
217219
or arg is None
218220
):
219221
return arg

0 commit comments

Comments
 (0)