@@ -407,3 +407,95 @@ module attributes {transform.with_named_sequence} {
407407 transform.yield
408408 }
409409}
410+
411+ // -----
412+ // Checks we use nan as the neutral element for maxnumf op.
413+ func.func @generic_split_maxnumf (%in: tensor <32 xf32 >, %out: tensor <f32 >) -> tensor <f32 > {
414+ %r = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>,
415+ affine_map <(d0 ) -> ()>],
416+ iterator_types = [" reduction" ]}
417+ ins (%in : tensor <32 xf32 >)
418+ outs (%out : tensor <f32 >) {
419+ ^bb0 (%arg1: f32 , %arg2: f32 ):
420+ %y = arith.maxnumf %arg1 , %arg2 : f32
421+ linalg.yield %y : f32
422+ } -> tensor <f32 >
423+ return %r : tensor <f32 >
424+ }
425+
426+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
427+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
428+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
429+ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
430+ // CHECK-LABEL: func @generic_split_maxnumf
431+ // The float value 0xFFC00000 that is filled into the init tensor represents negative NaN.
432+ // CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
433+ // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
434+ // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
435+ // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
436+ // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
437+ // CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
438+ // CHECK: arith.maxnumf
439+ // CHECK: linalg.yield
440+ // CHECK: } -> tensor<4xf32>
441+ // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
442+ // CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
443+ // CHECK: arith.maxnumf {{.*}}
444+ // CHECK: linalg.yield
445+ // CHECK: } -> tensor<f32>
446+ // CHECK: return %[[R]] : tensor<f32>
447+
448+ module attributes {transform.with_named_sequence } {
449+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
450+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
451+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4 , insert_split_dimension = 0 , inner_parallel }
452+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
453+ transform.yield
454+ }
455+ }
456+
457+ // -----
458+ // Checks we use nan as the neutral element for minnumf op.
459+ func.func @generic_split_minnumf (%in: tensor <32 xf32 >, %out: tensor <f32 >) -> tensor <f32 > {
460+ %r = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>,
461+ affine_map <(d0 ) -> ()>],
462+ iterator_types = [" reduction" ]}
463+ ins (%in : tensor <32 xf32 >)
464+ outs (%out : tensor <f32 >) {
465+ ^bb0 (%arg1: f32 , %arg2: f32 ):
466+ %y = arith.minnumf %arg1 , %arg2 : f32
467+ linalg.yield %y : f32
468+ } -> tensor <f32 >
469+ return %r : tensor <f32 >
470+ }
471+
472+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
473+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
474+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
475+ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
476+ // CHECK-LABEL: func @generic_split_minnumf
477+ // The float value 0x7FC00000 that is filled into the init tensor represents positive NaN.
478+ // CHECK-DAG: %[[ID:.*]] = arith.constant 0x7FC00000 : f32
479+ // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
480+ // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
481+ // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
482+ // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
483+ // CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
484+ // CHECK: arith.minnumf
485+ // CHECK: linalg.yield
486+ // CHECK: } -> tensor<4xf32>
487+ // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
488+ // CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
489+ // CHECK: arith.minnumf {{.*}}
490+ // CHECK: linalg.yield
491+ // CHECK: } -> tensor<f32>
492+ // CHECK: return %[[R]] : tensor<f32>
493+
494+ module attributes {transform.with_named_sequence } {
495+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
496+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
497+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4 , insert_split_dimension = 0 , inner_parallel }
498+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
499+ transform.yield
500+ }
501+ }
0 commit comments