Skip to content

Commit 0343a39

Browse files
glou-neswsmoses
andauthored
SpecialFunctions simple functions (#384)
* `SpecialFunctions` simple functions * review * missing Ext * reviews * format * feedbacks * remove usage of `ReactantFloat`, simplify signatures * fix * real bound * feedback * remove assert * update signature * add missing def for 1.10, increase ~ tolerence for MacOS * missing def int, julia 1.10 * format * format 2 * error * simplify rounding * Revert "simplify rounding" This reverts commit bd84cb2. * disable tests * revert * revert * test CI * good order * remove fancy call * test * new test * round &co need float * format --------- Co-authored-by: William Moses <[email protected]>
1 parent 85d8ba4 commit 0343a39

File tree

7 files changed

+246
-35
lines changed

7 files changed

+246
-35
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
2828
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2929
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3030
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
31+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3132

3233
[sources.ReactantCore]
3334
path = "lib/ReactantCore"
@@ -37,6 +38,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
3738
ReactantArrayInterfaceExt = "ArrayInterface"
3839
ReactantCUDAExt = "CUDA"
3940
ReactantNNlibExt = "NNlib"
41+
ReactantSpecialFunctionsExt = "SpecialFunctions"
4042
ReactantPythonCallExt = "PythonCall"
4143
ReactantRandom123Ext = "Random123"
4244
ReactantStatisticsExt = "Statistics"

ext/ReactantSpecialFunctionsExt.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
module ReactantSpecialFunctionsExt
2+
using SpecialFunctions
3+
using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt
4+
using Reactant.TracedRNumberOverrides: float
5+
6+
for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)]
7+
(fns, fno) = fn isa Tuple ? fn : (fn, fn)
8+
@eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt})
9+
return Ops.$fno(float(x))
10+
end)
11+
end
12+
13+
function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat})
14+
return exp(Ops.lgamma(float(x)))
15+
end
16+
17+
function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt})
18+
return round(gamma(float(n)))
19+
end
20+
21+
function SpecialFunctions.loggamma1p(x::TracedRNumber{<:ReactantFloat})
22+
return loggamma(1 + x)
23+
end
24+
25+
function SpecialFunctions.logfactorial(x::TracedRNumber{<:ReactantInt})
26+
return loggamma(1 + x)
27+
end
28+
29+
# SpecialFunctions.invdigamma
30+
31+
function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt})
32+
return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition
33+
end
34+
35+
function SpecialFunctions.polygamma(
36+
n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt}
37+
)
38+
return Ops.polygamma(float(n), float(x))
39+
end
40+
41+
# SpecialFunctions.gamma_inc
42+
43+
# SpecialFunctions.gamma_inc_inv
44+
45+
function SpecialFunctions.loggammadiv(
46+
a::TracedRNumber{T}, b::TracedRNumber{T}
47+
) where {T<:ReactantFloat}
48+
return log(gamma(b) / gamma(a + b))
49+
end
50+
51+
#SpecialFunctions.gamma ...
52+
53+
function SpecialFunctions.beta(
54+
x::TracedRNumber{T}, y::TracedRNumber{T}
55+
) where {T<:ReactantFloatInt}
56+
return gamma(x) * gamma(y) / gamma(x + y)
57+
end
58+
59+
function SpecialFunctions.logbeta(
60+
x::TracedRNumber{T}, y::TracedRNumber{T}
61+
) where {T<:ReactantFloatInt}
62+
return log(abs(beta(x, y)))
63+
end
64+
65+
#TODO: sign function
66+
#SpecialFunctions.logabsbeta
67+
#SpecialFunctions.logabsbinomial
68+
69+
#SpecialFunctions.beta...
70+
71+
#utilities...
72+
73+
function SpecialFunctions.erf(
74+
x::TracedRNumber{T}, y::TracedRNumber{T}
75+
) where {T<:ReactantFloatInt}
76+
return erf(y) - erf(x)
77+
end
78+
79+
#SpecialFunctions.erfcinv
80+
81+
function SpecialFunctions.logerf(
82+
x::TracedRNumber{T}, y::TracedRNumber{T}
83+
) where {T<:ReactantFloatInt}
84+
return log(erf(x, y))
85+
end
86+
87+
function SpecialFunctions.erfcx(x::TracedRNumber{<:ReactantFloatInt})
88+
return exp(float(x^2)) * erfc(x)
89+
end
90+
91+
function SpecialFunctions.logerfc(x::TracedRNumber{<:ReactantFloatInt})
92+
return log(erfc(x))
93+
end
94+
95+
function SpecialFunctions.logerfcx(x::TracedRNumber{<:ReactantFloatInt})
96+
return log(erfcx(x))
97+
end
98+
99+
#Unsupported complex
100+
#SpecialFunctions.erfi
101+
102+
#SpecialFunctions.erfinv
103+
#SpecialFunctions.dawson
104+
#SpecialFunctions.faddeeva
105+
106+
#Airy and Related Functions
107+
108+
#Bessel ...
109+
110+
#Elliptic Integrals
111+
112+
function SpecialFunctions.zeta(
113+
z::TracedRNumber{T}, s::TracedRNumber{T}
114+
) where {T<:ReactantFloatInt}
115+
return Ops.zeta(z, s)
116+
end
117+
118+
end # module ReactantSpecialFunctionsExt

src/Reactant.jl

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,21 @@ using Enzyme
1818
struct ReactantABI <: Enzyme.EnzymeCore.ABI end
1919

2020
@static if isdefined(Core, :BFloat16)
21-
const ReactantPrimitive = Union{
22-
Bool,
23-
Int8,
24-
UInt8,
25-
Int16,
26-
UInt16,
27-
Int32,
28-
UInt32,
29-
Int64,
30-
UInt64,
31-
Float16,
32-
Core.BFloat16,
33-
Float32,
34-
Float64,
35-
Complex{Float32},
36-
Complex{Float64},
37-
}
21+
const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64}
3822
else
39-
const ReactantPrimitive = Union{
40-
Bool,
41-
Int8,
42-
UInt8,
43-
Int16,
44-
UInt16,
45-
Int32,
46-
UInt32,
47-
Int64,
48-
UInt64,
49-
Float16,
50-
Float32,
51-
Float64,
52-
Complex{Float32},
53-
Complex{Float64},
54-
}
23+
const ReactantFloat = Union{Float16,Float32,Float64}
5524
end
5625

26+
const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}
27+
28+
const ReactantFloatInt = Union{
29+
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
30+
}
31+
32+
const ReactantPrimitive = Union{
33+
Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64}
34+
}
35+
5736
abstract type RNumber{T<:ReactantPrimitive} <: Number end
5837

5938
abstract type RArray{T,N} <: AbstractArray{T,N} end

src/TracedRNumber.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,6 @@ for (jlop, hloop) in (
210210
(:(Base.log), :log),
211211
(:(Base.log1p), :log_plus_one),
212212
(:(Base.sqrt), :sqrt),
213-
(:(Base.ceil), :ceil),
214-
(:(Base.floor), :floor),
215213
)
216214
@eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs)
217215
end
@@ -237,6 +235,12 @@ function Base.float(x::TracedRNumber{T}) where {T}
237235
return TracedUtils.promote_to(TracedRNumber{float(T)}, x)
238236
end
239237

238+
using Reactant: ReactantFloat
239+
240+
Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A)
241+
Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A)
242+
Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A)
243+
240244
# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays
241245
Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...)
242246
function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T}

test/basic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,12 @@ end
608608
end
609609
end
610610

611+
@testset "$op" for op in [:round, :ceil, :floor]
612+
for x in (rand(Float32, (3, 3)), rand(Float64))
613+
@eval @test @jit($op.(ConcreteRNumber.($x))) == $op.($x)
614+
end
615+
end
616+
611617
@testset "dynamic indexing" begin
612618
x = randn(5, 3)
613619
x_ra = Reactant.to_rarray(x)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
using SpecialFunctions, Reactant
2+
3+
macro (a, b)
4+
return quote
5+
isapprox($a, $b; atol=1e-14)
6+
end
7+
end
8+
9+
@testset "gamma" begin
10+
@test SpecialFunctions.gamma(0.5) @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5)))
11+
@test SpecialFunctions.gamma(2) @jit(SpecialFunctions.gamma(ConcreteRNumber(2)))
12+
end
13+
14+
@testset "loggamma" begin
15+
@test SpecialFunctions.loggamma(0.5)
16+
@jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5)))
17+
@test SpecialFunctions.loggamma(2) @jit(SpecialFunctions.loggamma(ConcreteRNumber(2)))
18+
end
19+
20+
@testset "digamma" begin
21+
@test SpecialFunctions.digamma(0.5)
22+
@jit(SpecialFunctions.digamma(ConcreteRNumber(0.5)))
23+
@test SpecialFunctions.digamma(2) @jit(SpecialFunctions.digamma(ConcreteRNumber(2)))
24+
end
25+
26+
@testset "trigamma" begin
27+
@test SpecialFunctions.trigamma(0.5)
28+
@jit(SpecialFunctions.trigamma(ConcreteRNumber(0.5)))
29+
@test SpecialFunctions.trigamma(2) @jit(SpecialFunctions.trigamma(ConcreteRNumber(2)))
30+
end
31+
32+
@testset "beta" begin
33+
@test SpecialFunctions.beta(0.5, 0.6)
34+
@jit(SpecialFunctions.beta(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
35+
@test SpecialFunctions.beta(2, 4)
36+
@jit(SpecialFunctions.beta(ConcreteRNumber(2), ConcreteRNumber(4)))
37+
end
38+
39+
@testset "logbeta" begin
40+
@test SpecialFunctions.logbeta(0.5, 0.6)
41+
@jit(SpecialFunctions.logbeta(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
42+
@test SpecialFunctions.logbeta(2, 4)
43+
@jit(SpecialFunctions.logbeta(ConcreteRNumber(2), ConcreteRNumber(4)))
44+
end
45+
46+
@testset "erf" begin
47+
@test SpecialFunctions.erf(0.5) @jit(SpecialFunctions.erf(ConcreteRNumber(0.5)))
48+
@test SpecialFunctions.erf(2) @jit(SpecialFunctions.erf(ConcreteRNumber(2)))
49+
end
50+
51+
@testset "erf with 2 arguments" begin
52+
@test SpecialFunctions.erf(0.5, 0.6)
53+
@jit(SpecialFunctions.erf(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
54+
@test SpecialFunctions.erf(2, 4)
55+
@jit(SpecialFunctions.erf(ConcreteRNumber(2), ConcreteRNumber(4)))
56+
end
57+
58+
@testset "erfc" begin
59+
@test SpecialFunctions.erfc(0.5) @jit(SpecialFunctions.erfc(ConcreteRNumber(0.5)))
60+
@test SpecialFunctions.erfc(2) @jit(SpecialFunctions.erfc(ConcreteRNumber(2)))
61+
end
62+
63+
@testset "logerf" begin
64+
@test SpecialFunctions.logerf(0.5, 0.6)
65+
@jit(SpecialFunctions.logerf(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
66+
@test SpecialFunctions.logerf(2, 4)
67+
@jit(SpecialFunctions.logerf(ConcreteRNumber(2), ConcreteRNumber(4)))
68+
end
69+
70+
@testset "erfcx" begin
71+
@test SpecialFunctions.erfcx(0.5) @jit(SpecialFunctions.erfcx(ConcreteRNumber(0.5)))
72+
@test SpecialFunctions.erfcx(2) @jit(SpecialFunctions.erfcx(ConcreteRNumber(2)))
73+
end
74+
75+
@testset "logerfc" begin
76+
@test SpecialFunctions.logerfc(0.5)
77+
@jit(SpecialFunctions.logerfc(ConcreteRNumber(0.5)))
78+
@test SpecialFunctions.logerfc(2) @jit(SpecialFunctions.logerfc(ConcreteRNumber(2)))
79+
end
80+
81+
@testset "logerfcx" begin
82+
@test SpecialFunctions.logerfcx(0.5)
83+
@jit(SpecialFunctions.logerfcx(ConcreteRNumber(0.5)))
84+
@test SpecialFunctions.logerfcx(2) @jit(SpecialFunctions.logerfcx(ConcreteRNumber(2)))
85+
end
86+
87+
@testset "loggamma1p" begin
88+
@test SpecialFunctions.loggamma1p(0.5)
89+
@jit SpecialFunctions.loggamma1p(ConcreteRNumber(0.5))
90+
end
91+
92+
@testset "loggammadiv" begin
93+
@test SpecialFunctions.loggammadiv(150, 20)
94+
@jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20))
95+
end
96+
97+
@testset "zeta" begin
98+
s = ConcreteRArray([1.0, 2.0, 50.0])
99+
z = ConcreteRArray([1e-8, 0.001, 2.0])
100+
@test SpecialFunctions.zeta.(Array(s), Array(z)) @jit SpecialFunctions.zeta.(s, z)
101+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6363
# @safetestset "CUDA" include("integration/cuda.jl")
6464
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
6565
@safetestset "AbstractFFTs" include("integration/fft.jl")
66+
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
6667
@safetestset "Random" include("integration/random.jl")
6768
@safetestset "Python" include("integration/python.jl")
6869
end

0 commit comments

Comments
 (0)