Skip to content

Commit 85ccf30

Browse files
Fix irrational promotion (#924)
* Fix irrational promotion * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixup * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update basic.jl * fix * fix test --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 1ca7386 commit 85ccf30

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

src/TracedRNumber.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T}
1616
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
1717
end
1818

19+
function Base.rtoldefault(T::Type{<:TracedRNumber})
20+
return T(Base.rtoldefault(unwrapped_eltype(T)))
21+
end
22+
1923
function Base.isfinite(x::TracedRNumber{<:Complex})
2024
return isfinite(real(x)) & isfinite(imag(x))
2125
end
@@ -59,6 +63,18 @@ function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{S}) where {T,S}
5963
return TracedRNumber{Base.promote_type(T, S)}
6064
end
6165

66+
function Base.promote_rule(
67+
T::Type{<:AbstractIrrational}, ::Type{Reactant.TracedRNumber{S}}
68+
) where {S}
69+
return TracedRNumber{Base.promote_type(T, S)}
70+
end
71+
72+
function Base.promote_rule(
73+
::Type{Reactant.TracedRNumber{S}}, T::Type{<:AbstractIrrational}
74+
) where {S}
75+
return TracedRNumber{Base.promote_type(T, S)}
76+
end
77+
6278
# NOTE: This is inconsistent with the behavior of `convert` but we do it since it is a very
6379
# common usecase
6480
TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x

src/TracedUtils.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,13 @@ function broadcast_to_size(arg::Base.RefValue, rsize)
651651
return arg
652652
end
653653

654-
broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize)))
654+
function broadcast_to_size(arg::AbstractIrrational, rsize)
655+
return broadcast_to_size(Base.convert(Float64, arg), rsize)
656+
end
657+
658+
function broadcast_to_size(arg::ReactantPrimitive, rsize)
659+
return Ops.constant(Base.fill(arg, Tuple(rsize)))
660+
end
655661

656662
function broadcast_to_size(arg::TracedRNumber{T}, rsize) where {T}
657663
length(rsize) == 0 && return arg

test/basic.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,3 +998,11 @@ end
998998
@test res[2] == 215
999999
@test res[3] == 216
10001000
end
1001+
1002+
mulpi(x) = π * x
1003+
1004+
@testset "Irrational promotion" begin
1005+
x = Reactant.to_rarray(ones(2))
1006+
y = @jit mulpi(x)
1007+
@test all(Array(y) .≈ π)
1008+
end

0 commit comments

Comments
 (0)