Skip to content

Commit d35aadf

Browse files
use upstream utils functions and add lit tests
1 parent cf49ea0 commit d35aadf

File tree

2 files changed

+72
-13
lines changed

2 files changed

+72
-13
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,6 @@ static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
7676
return transpose;
7777
}
7878

79-
static int64_t getDimFromValue(Value dimValue) {
80-
if (auto constOp = dimValue.getDefiningOp<arith::ConstantOp>()) {
81-
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
82-
return intAttr.getInt();
83-
}
84-
}
85-
return ShapedType::kDynamic;
86-
}
87-
8879
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
8980
public:
9081
using OpConversionPattern::OpConversionPattern;
@@ -514,8 +505,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
514505

515506
// Broadcast the batch dimensions of both the matrices.
516507
Value broadcastedLhs, broadcastedRhs;
517-
SmallVector<int64_t> lhsTargetShape = llvm::to_vector(
518-
llvm::map_range(lhsBroadcastToShape, getDimFromValue));
508+
SmallVector<int64_t> lhsTargetShape =
509+
llvm::to_vector(llvm::map_range(lhsBroadcastToShape, [](Value v) {
510+
return getConstantIntValue(v).value_or(ShapedType::kDynamic);
511+
}));
519512

520513
auto lhsBroadcastType = RankedTensorType::get(
521514
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
@@ -525,8 +518,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
525518
return rewriter.notifyMatchFailure(
526519
op, "unable to perform broadcast operation");
527520
}
528-
SmallVector<int64_t> rhsTargetShape = llvm::to_vector(
529-
llvm::map_range(rhsBroadcastToShape, getDimFromValue));
521+
SmallVector<int64_t> rhsTargetShape =
522+
llvm::to_vector(llvm::map_range(rhsBroadcastToShape, [](Value v) {
523+
return getConstantIntValue(v).value_or(ShapedType::kDynamic);
524+
}));
530525
auto rhsBroadcastType = RankedTensorType::get(
531526
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
532527
if (failed(torch_to_linalg::broadcastToGivenShape(

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,70 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch
4343

4444
// -----
4545

46+
// CHECK-LABEL: func.func @torch.aten.matmul.4d
47+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>,
48+
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
49+
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32>
50+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32>
51+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
52+
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
53+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
54+
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
55+
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
56+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
57+
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
58+
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
59+
// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
60+
// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index
61+
// CHECK: %[[VAL_12:.*]] = arith.constant 2 : index
62+
// CHECK: %[[VAL_13:.*]] = arith.constant 400 : index
63+
// CHECK: %[[VAL_14:.*]] = arith.constant 3 : index
64+
// CHECK: %[[VAL_15:.*]] = arith.constant 32 : index
65+
// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index
66+
// CHECK: %[[VAL_17:.*]] = arith.constant 32 : index
67+
// CHECK: %[[VAL_18:.*]] = arith.constant 3 : index
68+
// CHECK: %[[VAL_19:.*]] = arith.constant 400 : index
69+
// CHECK: %[[VAL_20:.*]] = arith.constant 32 : i64
70+
// CHECK: %[[VAL_21:.*]] = arith.constant 32 : i64
71+
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64
72+
// CHECK: cf.assert %[[VAL_22]], "mismatching contracting dimension"
73+
// CHECK: %[[VAL_23:.*]] = arith.constant 1 : i64
74+
// CHECK: %[[VAL_24:.*]] = arith.constant 1 : i64
75+
// CHECK: %[[VAL_25:.*]] = arith.constant 2 : i64
76+
// CHECK: %[[VAL_26:.*]] = arith.constant 2 : i64
77+
// CHECK: %[[VAL_27:.*]] = arith.constant 400 : i64
78+
// CHECK: %[[VAL_28:.*]] = arith.constant 32 : i64
79+
// CHECK: %[[VAL_29:.*]] = arith.constant 32 : i64
80+
// CHECK: %[[VAL_30:.*]] = arith.constant 400 : i64
81+
// CHECK: %[[VAL_31:.*]] = arith.constant 0 : i64
82+
// CHECK: %[[VAL_32:.*]] = arith.constant 0 : index
83+
// CHECK: %[[VAL_33:.*]] = arith.constant 1 : index
84+
// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32>
85+
// CHECK: %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32>
86+
// CHECK: %[[VAL_36:.*]] = arith.constant 0 : i64
87+
// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index
88+
// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index
89+
// CHECK: %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32>
90+
// CHECK: %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32>
91+
// CHECK: %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32>
92+
// CHECK: %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32>
93+
// CHECK: %[[VAL_43:.*]] = arith.constant 2 : index
94+
// CHECK: %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32>
95+
// CHECK: %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32
96+
// CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
97+
// CHECK: %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32>
98+
// CHECK: %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32>
99+
// CHECK: %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32>
100+
// CHECK: %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32>
101+
// CHECK: return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32>
102+
// CHECK: }
103+
func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> {
104+
%0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32>
105+
return %0 : !torch.vtensor<[1,2,400,400],f32>
106+
}
107+
108+
// -----
109+
46110
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
47111
// CHECK-NOT: assert
48112
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>

0 commit comments

Comments
 (0)