Skip to content

Commit fe2e302

Browse files
authored
fix: unwanted promotion in copysign (#1590)
1 parent e201029 commit fe2e302

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.158"
4+
version = "0.2.159"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/TracedRNumber.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,27 @@ function Base.mod(
171171
r = rem(x, y)
172172
return ifelse(r == 0, copysign(r, y), ifelse((r > 0) (y > 0), r + y, r))
173173
end
174-
function Base.mod(
175-
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number)
174+
175+
function Base.mod1(
176+
@nospecialize(x::Reactant.TracedRNumber{T}), @nospecialize(y::Reactant.TracedRNumber{T})
176177
) where {T}
177-
return mod(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
178+
m = mod(x, y)
179+
return ifelse(m == 0, y, m)
178180
end
179-
function Base.mod(
180-
@nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T})
181-
) where {T}
182-
return mod(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs)
181+
182+
for op in (:mod, :mod1)
183+
@eval begin
184+
function Base.$op(
185+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number)
186+
) where {T}
187+
return mod(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
188+
end
189+
function Base.$op(
190+
@nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T})
191+
) where {T}
192+
return mod(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs)
193+
end
194+
end
183195
end
184196

185197
function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer}
@@ -938,7 +950,9 @@ for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64))
938950
end
939951
Base.signbit(::TracedRNumber{<:Unsigned}) = ConcretePJRTNumber(false)
940952

941-
Base.copysign(x::TracedRNumber, y::TracedRNumber) = ifelse(signbit(y), -1, 1) * abs(x)
953+
function Base.copysign(x::TracedRNumber, y::TracedRNumber)
954+
return ifelse(signbit(y), -one(x), one(x)) * abs(x)
955+
end
942956
function Base.copysign(x::TracedRNumber{T}, y::S) where {T,S<:Number}
943957
return copysign(x, TracedUtils.promote_to(TracedRNumber{S}, y))
944958
end

test/basic.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,3 +1514,21 @@ stack_numbers(x) = stack([sum(x[:, i]) for i in axes(x, 2)])
15141514

15151515
@test @jit(stack_numbers(x_ra)) stack_numbers(x)
15161516
end
1517+
1518+
@testset "copysign/mod type check" begin
1519+
x = ConcreteRNumber(Int32(5))
1520+
y = ConcreteRNumber(Int32(3))
1521+
@test @jit(copysign(x, y)) isa ConcreteRNumber{Int32}
1522+
@test @jit(mod(x, y)) isa ConcreteRNumber{Int32}
1523+
end
1524+
1525+
@testset "mod1" begin
1526+
x = collect(Int32, 1:12)
1527+
y = Int32(10)
1528+
1529+
@testset for xᵢ in x
1530+
res = @jit mod1(ConcreteRNumber(xᵢ), ConcreteRNumber(y))
1531+
@test res isa ConcreteRNumber{Int32}
1532+
@test res == mod1(xᵢ, y)
1533+
end
1534+
end

0 commit comments

Comments
 (0)