|  | 
|  | 1 | +//===- TransposeConv2D.cpp - Convolution transposition  -------------------===// | 
|  | 2 | +// | 
|  | 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | 4 | +// See https://llvm.org/LICENSE.txt for license information. | 
|  | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | 6 | +// | 
|  | 7 | +//===----------------------------------------------------------------------===// | 
|  | 8 | + | 
|  | 9 | +#include "mlir/Dialect/Func/IR/FuncOps.h" | 
|  | 10 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" | 
|  | 11 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" | 
|  | 12 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" | 
|  | 13 | +#include "mlir/IR/BuiltinTypes.h" | 
|  | 14 | +#include "mlir/IR/PatternMatch.h" | 
|  | 15 | +#include "mlir/IR/ValueRange.h" | 
|  | 16 | +#include "mlir/Support/LogicalResult.h" | 
|  | 17 | +#include "mlir/Transforms/DialectConversion.h" | 
|  | 18 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | 
|  | 19 | +#include "llvm/ADT/SmallVector.h" | 
|  | 20 | +#include "llvm/Support/ErrorHandling.h" | 
|  | 21 | +#include "llvm/Support/RWMutex.h" | 
|  | 22 | +#include <memory> | 
|  | 23 | +#include <numeric> | 
|  | 24 | + | 
|  | 25 | +namespace mlir { | 
|  | 26 | +namespace linalg { | 
|  | 27 | +namespace { | 
|  | 28 | +// clang-format off | 
|  | 29 | +/// Convolution converter that applies the following rewrite: | 
|  | 30 | +/// | 
|  | 31 | +/// Before: | 
|  | 32 | +/// | 
|  | 33 | +///   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, | 
|  | 34 | +///                                               strides = dense<2> : tensor<2xi64>} | 
|  | 35 | +///      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) | 
|  | 36 | +///     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> | 
|  | 37 | +/// | 
|  | 38 | +/// After: | 
|  | 39 | +/// | 
|  | 40 | +///    %cst = arith.constant 0.000000e+00 : f32 | 
|  | 41 | +///    %0 = tensor.empty() : tensor<2x2x6x8xf32> | 
|  | 42 | +///    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> | 
|  | 43 | +///    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) | 
|  | 44 | +///                  permutation = [1, 2, 3, 0] | 
|  | 45 | +///    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} | 
|  | 46 | +///         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) | 
|  | 47 | +///         -> tensor<1x2x2x8xf32> | 
|  | 48 | +/// | 
|  | 49 | +/// with an analogous example for the quantized case. | 
|  | 50 | +// clang-format on | 
|  | 51 | +template <typename FHWCConvOp, typename HWCFConvOp> | 
|  | 52 | +FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, | 
|  | 53 | +                                             FHWCConvOp op) { | 
|  | 54 | +  // Construct a permutation of the filter tensor dimensions. For a 2D | 
|  | 55 | +  // convolution this will be known statically as [1, 2, 3, 0]. | 
|  | 56 | +  SmallVector<int64_t> filterPerm({1, 2, 3, 0}); | 
|  | 57 | + | 
|  | 58 | +  // Create the type for the transposed filter tensor. | 
|  | 59 | +  auto filter = op->getOperand(1); | 
|  | 60 | +  auto filterTy = cast<ShapedType>(filter.getType()); | 
|  | 61 | +  SmallVector<int64_t> newFilterShape(filterPerm.size()); | 
|  | 62 | +  std::generate(std::begin(newFilterShape), std::end(newFilterShape), | 
|  | 63 | +                [dim = 0, &filterTy, &filterPerm]() mutable { | 
|  | 64 | +                  return filterTy.getShape()[filterPerm[dim++]]; | 
|  | 65 | +                }); | 
|  | 66 | + | 
|  | 67 | +  // Because linalg.transpose expects an "out" parameter we need to pass it a | 
|  | 68 | +  // tensor of zeros of the result type so here we construct that tensor. | 
|  | 69 | +  auto inputType = op->getOperand(0).getType(); | 
|  | 70 | +  auto elementTy = cast<ShapedType>(inputType).getElementType(); | 
|  | 71 | +  auto loc = op->getLoc(); | 
|  | 72 | + | 
|  | 73 | +  const auto isTensorOp = isa<TensorType>(inputType); | 
|  | 74 | +  Value input; | 
|  | 75 | +  if (isTensorOp) { | 
|  | 76 | + | 
|  | 77 | +    input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) | 
|  | 78 | +                .getResult(); | 
|  | 79 | +  } else { | 
|  | 80 | +    input = rewriter | 
|  | 81 | +                .create<memref::AllocOp>( | 
|  | 82 | +                    loc, MemRefType::get(newFilterShape, elementTy)) | 
|  | 83 | +                .getResult(); | 
|  | 84 | +  } | 
|  | 85 | + | 
|  | 86 | +  // We can then construct the transposition on our filter. | 
|  | 87 | +  auto transpose = | 
|  | 88 | +      rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); | 
|  | 89 | + | 
|  | 90 | +  Value newFilter; | 
|  | 91 | +  if (isTensorOp) { | 
|  | 92 | +    newFilter = transpose.getResult()[0]; | 
|  | 93 | +  } else { | 
|  | 94 | +    newFilter = input; | 
|  | 95 | +  } | 
|  | 96 | + | 
|  | 97 | +  SmallVector<Value> newInputs{op.getInputs()}; | 
|  | 98 | +  // The filter is always the second input argument, the other inputs can be | 
|  | 99 | +  // left as they are. | 
|  | 100 | +  newInputs[1] = newFilter; | 
|  | 101 | +  // It is possible the convolution doesn't define any results and its | 
|  | 102 | +  // out argument is just used instead. | 
|  | 103 | +  SmallVector<Type> resultTy; | 
|  | 104 | +  if (op.getNumResults()) { | 
|  | 105 | +    resultTy.push_back(op->getResult(0).getType()); | 
|  | 106 | +  } | 
|  | 107 | +  auto newConv = | 
|  | 108 | +      rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), | 
|  | 109 | +                                  op.getStrides(), op.getDilations()); | 
|  | 110 | +  rewriter.replaceOp(op, newConv); | 
|  | 111 | +  return newConv.getOperation(); | 
|  | 112 | +} | 
|  | 113 | + | 
|  | 114 | +template <typename FHWCConvOp, typename HWCFConvOp> | 
|  | 115 | +class ConvConverter : public OpRewritePattern<FHWCConvOp> { | 
|  | 116 | +public: | 
|  | 117 | +  using OpRewritePattern<FHWCConvOp>::OpRewritePattern; | 
|  | 118 | +  LogicalResult matchAndRewrite(FHWCConvOp op, | 
|  | 119 | +                                PatternRewriter &rewriter) const final { | 
|  | 120 | +    if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) { | 
|  | 121 | +      return failure(); | 
|  | 122 | +    } | 
|  | 123 | +    return success(); | 
|  | 124 | +  } | 
|  | 125 | +}; | 
|  | 126 | +} // namespace | 
|  | 127 | + | 
|  | 128 | +FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, | 
|  | 129 | +                                       linalg::Conv2DNhwcFhwcOp op) { | 
|  | 130 | + | 
|  | 131 | +  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp, | 
|  | 132 | +                               linalg::Conv2DNhwcHwcfOp>(rewriter, op); | 
|  | 133 | +} | 
|  | 134 | + | 
|  | 135 | +FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, | 
|  | 136 | +                                       linalg::Conv2DNhwcFhwcQOp op) { | 
|  | 137 | + | 
|  | 138 | +  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp, | 
|  | 139 | +                               linalg::Conv2DNhwcHwcfQOp>(rewriter, op); | 
|  | 140 | +} | 
|  | 141 | + | 
|  | 142 | +void populateTranposeConv2DPatterns(RewritePatternSet &patterns) { | 
|  | 143 | +  MLIRContext *context = patterns.getContext(); | 
|  | 144 | +  patterns.insert< | 
|  | 145 | +      ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>, | 
|  | 146 | +      ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>( | 
|  | 147 | +      context); | 
|  | 148 | +} | 
|  | 149 | +} // namespace linalg | 
|  | 150 | +} // namespace mlir | 
0 commit comments