From 0c1b1727c66c1bc6d55f1ad011f4fc7dea80d87b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:52:31 -0500 Subject: [PATCH 1/4] extra reducedim_init methods --- base/reducedim.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/base/reducedim.jl b/base/reducedim.jl index dc34b4feb1f6a..d4beccc5f9daf 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -124,7 +124,9 @@ function _reducedim_init(f, op, fv, fop, A, region) end # initialization when computing minima and maxima requires a little care -for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin)) +for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin), + (:(FastMath.min_fast), :(FastMath.max_fast), :Inf, :typemax), + (:(FastMath.max_fast), :(FastMath.min_fast), :(-Inf), :typemin)) @eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region) # First compute the reduce indices. This will throw an ArgumentError # if any region is invalid From cd56650b5f1f0557c0446c99be79f7d9e84b3446 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:53:05 -0500 Subject: [PATCH 2/4] FastMath.maximum_fast + tests --- base/fastmath.jl | 12 +++++++++++- test/fastmath.jl | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/base/fastmath.jl b/base/fastmath.jl index 5f905b86554f4..0cf22457c158a 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -84,7 +84,10 @@ const fast_op = :sinh => :sinh_fast, :sqrt => :sqrt_fast, :tan => :tan_fast, - :tanh => :tanh_fast) + :tanh => :tanh_fast, + # reductions + :maximum => :maximum_fast, + :minimum => :minimum_fast) const rewrite_op = Dict(:+= => :+, @@ -366,4 +369,11 @@ for f in (:^, :atan, :hypot, :log) end end +# Reductions +maximum_fast(a; kw...) = Base.reduce(max_fast, a; kw...) +minimum_fast(a; kw...) = Base.reduce(min_fast, a; kw...) + +maximum_fast(f, a; kw...) = Base.mapreduce(f, max_fast, a; kw...) +minimum_fast(f, a; kw...) = Base.mapreduce(f, min_fast, a; kw...) + end diff --git a/test/fastmath.jl b/test/fastmath.jl index e93fb93330b4f..14b70ecdd40c0 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -207,6 +207,21 @@ end @test @fastmath(cis(third)) ≈ cis(third) end end + +@testset "reductions" begin + @test @fastmath(maximum([1,2,3])) == 3 + @test @fastmath(minimum([1,2,3])) == 1 + @test @fastmath(maximum(abs2, [1,2,3+0im])) == 9 + @test @fastmath(minimum(sqrt, [1,2,3])) == 1 + @test @fastmath(maximum(Float32[4 5 6; 7 8 9])) == 9.0f0 + @test @fastmath(minimum(Float32[4 5 6; 7 8 9])) == 4.0f0 + + @test @fastmath(maximum(Float32[4 5 6; 7 8 9]; dims=1)) == Float32[7.0 8.0 9.0] + @test @fastmath(minimum(Float32[4 5 6; 7 8 9]; dims=2)) == Float32[4.0; 7.0;;] + @test @fastmath(maximum(abs, [4+im -5 6-im; -7 8 -9]; dims=1)) == [7.0 8.0 9.0] + @test @fastmath(minimum(cbrt, [4 -5 6; -7 8 -9]; dims=2)) == cbrt.([-5; -9;;]) +end + @testset "issue #10544" begin a = fill(1.,2,2) b = fill(1.,2,2) From 162bb6ca42d845c0f165607433628a38cc3483d3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 6 Jan 2023 06:57:52 -0500 Subject: [PATCH 3/4] simpler reducedim_init overload --- base/fastmath.jl | 5 +++++ base/reducedim.jl | 4 +--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/base/fastmath.jl b/base/fastmath.jl index 0cf22457c158a..fdcea21c35396 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -370,10 +370,15 @@ for f in (:^, :atan, :hypot, :log) end # Reductions + maximum_fast(a; kw...) = Base.reduce(max_fast, a; kw...) minimum_fast(a; kw...) = Base.reduce(min_fast, a; kw...) maximum_fast(f, a; kw...) = Base.mapreduce(f, max_fast, a; kw...) minimum_fast(f, a; kw...) = Base.mapreduce(f, min_fast, a; kw...) +Base.reducedim_init(f, ::typeof(max_fast), A::AbstractArray, region) = + Base.reducedim_init(f, max, A::AbstractArray, region) +Base.reducedim_init(f, ::typeof(min_fast), A::AbstractArray, region) = + Base.reducedim_init(f, min, A::AbstractArray, region) end diff --git a/base/reducedim.jl b/base/reducedim.jl index d4beccc5f9daf..dc34b4feb1f6a 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -124,9 +124,7 @@ function _reducedim_init(f, op, fv, fop, A, region) end # initialization when computing minima and maxima requires a little care -for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin), - (:(FastMath.min_fast), :(FastMath.max_fast), :Inf, :typemax), - (:(FastMath.max_fast), :(FastMath.min_fast), :(-Inf), :typemin)) +for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin)) @eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region) # First compute the reduce indices. This will throw an ArgumentError # if any region is invalid From def00d02cdadb1f02c985ecf52dcf94da3588943 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 6 Jan 2023 06:58:17 -0500 Subject: [PATCH 4/4] in-place versions --- base/fastmath.jl | 15 ++++++++++++++- test/fastmath.jl | 10 ++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/base/fastmath.jl b/base/fastmath.jl index fdcea21c35396..a969bcaaa6ae0 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -87,7 +87,9 @@ const fast_op = :tanh => :tanh_fast, # reductions :maximum => :maximum_fast, - :minimum => :minimum_fast) + :minimum => :minimum_fast, + :maximum! => :maximum!_fast, + :minimum! => :minimum!_fast) const rewrite_op = Dict(:+= => :+, @@ -381,4 +383,15 @@ Base.reducedim_init(f, ::typeof(max_fast), A::AbstractArray, region) = Base.reducedim_init(f, max, A::AbstractArray, region) Base.reducedim_init(f, ::typeof(min_fast), A::AbstractArray, region) = Base.reducedim_init(f, min, A::AbstractArray, region) + +maximum!_fast(r::AbstractArray, A::AbstractArray; kw...) = + maximum!_fast(identity, r, A; kw...) +minimum!_fast(r::AbstractArray, A::AbstractArray; kw...) = + minimum!_fast(identity, r, A; kw...) + +maximum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = + Base.mapreducedim!(f, max_fast, Base.initarray!(r, f, max, init, A), A) +minimum!_fast(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = + Base.mapreducedim!(f, min_fast, Base.initarray!(r, f, min, init, A), A) + end diff --git a/test/fastmath.jl b/test/fastmath.jl index 14b70ecdd40c0..8755e727db092 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -220,6 +220,16 @@ end @test @fastmath(minimum(Float32[4 5 6; 7 8 9]; dims=2)) == Float32[4.0; 7.0;;] @test @fastmath(maximum(abs, [4+im -5 6-im; -7 8 -9]; dims=1)) == [7.0 8.0 9.0] @test @fastmath(minimum(cbrt, [4 -5 6; -7 8 -9]; dims=2)) == cbrt.([-5; -9;;]) + + x = randn(3,4,5) + x1 = sum(x; dims=1) + x23 = sum(x; dims=(2,3)) + @test @fastmath(maximum!(x1, x)) ≈ maximum(x; dims=1) + @test x1 ≈ maximum(x; dims=1) + @test @fastmath(minimum!(x23, x)) ≈ minimum(x; dims=(2,3)) + @test x23 ≈ minimum(x; dims=(2,3)) + @test @fastmath(maximum!(abs, x23, x .+ im)) ≈ maximum(abs, x .+ im; dims=(2,3)) + @test @fastmath(minimum!(abs2, x1, x .+ im)) ≈ minimum(abs2, x .+ im; dims=1) end @testset "issue #10544" begin