1010#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
1111#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
1212
13- #include " mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
14- #include " mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
15- #include " mlir/IR/BuiltinAttributes.h" // from @llvm-project
16- #include " mlir/IR/BuiltinTypes.h" // from @llvm-project
17- #include " mlir/IR/PatternMatch.h" // from @llvm-project
18- #include " mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
19- #include " mlir/Support/LLVM.h" // from @llvm-project
13+ #include " mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
14+ #include " mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
15+ #include " mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
16+ #include " mlir/IR/BuiltinAttributes.h" // from @llvm-project
17+ #include " mlir/IR/BuiltinTypes.h" // from @llvm-project
18+ #include " mlir/IR/PatternMatch.h" // from @llvm-project
19+ #include " mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
20+ #include " mlir/Support/LLVM.h" // from @llvm-project
2021
2122namespace mlir {
2223namespace tosa {
@@ -45,6 +46,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4546Value getTosaConstTensorSingleF32 (PatternRewriter &rewriter, Operation *op,
4647 float val);
4748
49+ // Create an int8_t const tosa.mul shift tensor from an int
50+ Value getTosaMulShiftConstTensor (PatternRewriter &rewriter, Operation *op,
51+ int32_t shift);
52+
4853// Create a zero constant tensor of the desired type and shape.
4954std::optional<Value> getZerosLikeTensor (PatternRewriter &rewriter,
5055 Operation *op, Type type);
@@ -58,55 +63,24 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
5863 ArrayRef<T> vec, ArrayRef<int64_t > shape,
5964 std::optional<Type> dtype = {});
6065
61- LogicalResult tosaCastTensorToType (PatternRewriter &rewriter, Operation *op,
62- Value src, Type destType, Value &result);
63-
64- Value promoteType (PatternRewriter &rewriter, Value input, TensorType outType );
66+ // Default function to create tosa.cast op. This should be called instead of
67+ // directly calling rewriter.create<tosa::CastOp>.
68+ std::optional<Value> tosaCastTensorToType (PatternRewriter &rewriter, Value src,
69+ TensorType destType );
6570
6671// Creates a TOSA operation and performs shape inference on the individual
6772// op. This allows shape inference during the framework to TOSA lowering.
73+ template <typename TosaOp, typename ... Args>
74+ TosaOp CreateOpAndInfer (ImplicitLocOpBuilder &builder, Type result_ty,
75+ Args &&...args) {
76+ return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
77+ }
78+
6879template <typename TosaOp, typename ... Args>
6980TosaOp CreateOpAndInfer (PatternRewriter &rewriter, Location loc, Type result_ty,
7081 Args &&...args) {
71- auto op = rewriter.create <TosaOp>(loc, result_ty, args...);
72-
73- InferShapedTypeOpInterface shapeInterface =
74- dyn_cast<InferShapedTypeOpInterface>(op.getOperation ());
75- if (!shapeInterface)
76- return op;
77-
78- SmallVector<ShapedTypeComponents> returnedShapes;
79- if (shapeInterface
80- .inferReturnTypeComponents (op.getContext (), op.getLoc (),
81- op->getOperands (), op->getAttrDictionary (),
82- op->getPropertiesStorage (),
83- op->getRegions (), returnedShapes)
84- .failed ())
85- return op;
86-
87- // We need to use the element type of the existing result type to generate
88- // the new result shaped type. This is because rescale can include a cast to
89- // different bit-width types and does not have a TypeAttr to define the
90- // target type.
91- auto result = op->getResult (0 );
92- auto predictedShape = returnedShapes[0 ];
93- auto currentKnowledge = ValueKnowledge::getKnowledgeFromType (result_ty);
94-
95- // Compute the knowledge based on the inferred type.
96- auto inferredKnowledge = ValueKnowledge::getPessimisticValueState ();
97- inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType ();
98- inferredKnowledge.hasRank = predictedShape.hasRank ();
99- if (predictedShape.hasRank ()) {
100- for (auto dim : predictedShape.getDims ()) {
101- inferredKnowledge.sizes .push_back (dim);
102- }
103- }
104-
105- // Compute the new type based on the joined version.
106- auto newKnowledge = ValueKnowledge::join (currentKnowledge, inferredKnowledge);
107- auto new_ty = newKnowledge.getType ();
108- result.setType (new_ty);
109- return op;
82+ ImplicitLocOpBuilder builder (loc, rewriter);
83+ return CreateOpAndInfer<TosaOp>(builder, result_ty, args...);
11084}
11185
11286template <typename TosaOp, typename ... Args>
0 commit comments