@@ -43,6 +43,70 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch
4343
4444// ----- 
4545
46+ // CHECK-LABEL: func.func @torch.aten.matmul.4d 
47+ // CHECK-SAME:      %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>, 
48+ // CHECK-SAME:      %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { 
49+ // CHECK:           %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32> 
50+ // CHECK:           %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32> 
51+ // CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index 
52+ // CHECK:           %[[VAL_3:.*]] = arith.constant 2 : index 
53+ // CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index 
54+ // CHECK:           %[[VAL_5:.*]] = arith.constant 2 : index 
55+ // CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index 
56+ // CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index 
57+ // CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index 
58+ // CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index 
59+ // CHECK:           %[[VAL_10:.*]] = arith.constant 1 : index 
60+ // CHECK:           %[[VAL_11:.*]] = arith.constant 1 : index 
61+ // CHECK:           %[[VAL_12:.*]] = arith.constant 2 : index 
62+ // CHECK:           %[[VAL_13:.*]] = arith.constant 400 : index 
63+ // CHECK:           %[[VAL_14:.*]] = arith.constant 3 : index 
64+ // CHECK:           %[[VAL_15:.*]] = arith.constant 32 : index 
65+ // CHECK:           %[[VAL_16:.*]] = arith.constant 2 : index 
66+ // CHECK:           %[[VAL_17:.*]] = arith.constant 32 : index 
67+ // CHECK:           %[[VAL_18:.*]] = arith.constant 3 : index 
68+ // CHECK:           %[[VAL_19:.*]] = arith.constant 400 : index 
69+ // CHECK:           %[[VAL_20:.*]] = arith.constant 32 : i64 
70+ // CHECK:           %[[VAL_21:.*]] = arith.constant 32 : i64 
71+ // CHECK:           %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64 
72+ // CHECK:           cf.assert %[[VAL_22]], "mismatching contracting dimension" 
73+ // CHECK:           %[[VAL_23:.*]] = arith.constant 1 : i64 
74+ // CHECK:           %[[VAL_24:.*]] = arith.constant 1 : i64 
75+ // CHECK:           %[[VAL_25:.*]] = arith.constant 2 : i64 
76+ // CHECK:           %[[VAL_26:.*]] = arith.constant 2 : i64 
77+ // CHECK:           %[[VAL_27:.*]] = arith.constant 400 : i64 
78+ // CHECK:           %[[VAL_28:.*]] = arith.constant 32 : i64 
79+ // CHECK:           %[[VAL_29:.*]] = arith.constant 32 : i64 
80+ // CHECK:           %[[VAL_30:.*]] = arith.constant 400 : i64 
81+ // CHECK:           %[[VAL_31:.*]] = arith.constant 0 : i64 
82+ // CHECK:           %[[VAL_32:.*]] = arith.constant 0 : index 
83+ // CHECK:           %[[VAL_33:.*]] = arith.constant 1 : index 
84+ // CHECK:           %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32> 
85+ // CHECK:           %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32> 
86+ // CHECK:           %[[VAL_36:.*]] = arith.constant 0 : i64 
87+ // CHECK:           %[[VAL_37:.*]] = arith.constant 0 : index 
88+ // CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index 
89+ // CHECK:           %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32> 
90+ // CHECK:           %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32> 
91+ // CHECK:           %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32> 
92+ // CHECK:           %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32> 
93+ // CHECK:           %[[VAL_43:.*]] = arith.constant 2 : index 
94+ // CHECK:           %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32> 
95+ // CHECK:           %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32 
96+ // CHECK:           %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> 
97+ // CHECK:           %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> 
98+ // CHECK:           %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32> 
99+ // CHECK:           %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32> 
100+ // CHECK:           %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32> 
101+ // CHECK:           return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32> 
102+ // CHECK:         } 
103+ func.func  @torch.aten.matmul.4d (%arg0:  !torch.vtensor <[1 ,2 ,32 ,400 ],f32 >, %arg1:  !torch.vtensor <[1 ,2 ,400 ,32 ],f32 >) -> !torch.vtensor <[1 ,2 ,400 ,400 ],f32 > {
104+     %0  = torch.aten.matmul  %arg1 , %arg0  : !torch.vtensor <[1 ,2 ,400 ,32 ],f32 >, !torch.vtensor <[1 ,2 ,32 ,400 ],f32 > -> !torch.vtensor <[1 ,2 ,400 ,400 ],f32 >
105+     return  %0  : !torch.vtensor <[1 ,2 ,400 ,400 ],f32 >
106+ }
107+ 
108+ // ----- 
109+ 
46110// CHECK-LABEL: func.func @torch.aten.mm$basic_strict( 
47111// CHECK-NOT: assert 
48112func.func  @torch.aten.mm$basic_strict (%arg0:  !torch.vtensor <[?,?],f32 >, %arg1:  !torch.vtensor <[?,?],f32 >) -> !torch.vtensor <[?,2 ],f32 >
0 commit comments