diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ffb7c1dc0f3..89ec9a599e4b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -505,9 +505,11 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; - // TODO: Improve usage of static shape information. - SmallVector lhsTargetShape(lhsBroadcastToShape.size(), - ShapedType::kDynamic); + SmallVector lhsTargetShape = + llvm::to_vector(llvm::map_range(lhsBroadcastToShape, [](Value v) { + return getConstantIntValue(v).value_or(ShapedType::kDynamic); + })); + auto lhsBroadcastType = RankedTensorType::get( lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( @@ -516,8 +518,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } - SmallVector rhsTargetShape(rhsBroadcastToShape.size(), - ShapedType::kDynamic); + SmallVector rhsTargetShape = + llvm::to_vector(llvm::map_range(rhsBroadcastToShape, [](Value v) { + return getConstantIntValue(v).value_or(ShapedType::kDynamic); + })); auto rhsBroadcastType = RankedTensorType::get( rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 262f6e646bdd..0caa6e0c6980 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -43,6 +43,22 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @torch.aten.matmul.4d +// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32> +// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32> +// CHECK-DAG: %[[LHS_CAST:.*]] = tensor.cast %[[LHS]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32> +// CHECK-DAG: %[[RHS_CAST:.*]] = tensor.cast %[[RHS]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32> +// CHECK-DAG: %[[COLLAPSED_LHS:.+]] = tensor.collapse_shape %[[LHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32> +// CHECK-DAG: %[[COLLAPSED_RHS:.+]] = tensor.collapse_shape %[[RHS_CAST]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32> +// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[COLLAPSED_RHS]], %[[COLLAPSED_LHS]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%{{.*}} : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32> +func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { + %0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32> + return %0 : !torch.vtensor<[1,2,400,400],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>