Skip to content

Commit 9204e39

Browse files
authored
Fix implementation of mod (#758)
* Support `xor` * Support `signbit` * Support `copysign` * Fix implementation of `mod`
1 parent 7b79953 commit 9204e39

File tree

3 files changed

+75
-8
lines changed

3 files changed

+75
-8
lines changed

src/Ops.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,18 @@ end
548548
# )
549549
# return TracedRArray{T,N}((), res, size(x))
550550
# end
551+
@noinline function bitcast_convert(
552+
::Type{U},
553+
x::TracedRNumber{T};
554+
location=mlir_stacktrace("bitcast_convert", @__FILE__, @__LINE__),
555+
) where {T,U}
556+
res = MLIR.IR.result(
557+
stablehlo.bitcast_convert(
558+
x.mlir_data; result_0=mlir_type(TracedRArray{U,0}, ()), location
559+
),
560+
)
561+
return TracedRNumber{U}((), res)
562+
end
551563

552564
@noinline function fft(
553565
x::TracedRArray{T,N};

src/TracedRNumber.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ for (jlop, hloop) in (
9494
(:(Base.:*), :multiply),
9595
(:(Base.:/), :divide),
9696
(:(Base.:^), :power),
97-
(:(Base.mod), :remainder),
9897
(:(Base.rem), :remainder),
9998
)
10099
@eval function $(jlop)(
@@ -109,13 +108,30 @@ function Base.rem(
109108
) where {T}
110109
return Ops.remainder(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
111110
end
112-
113111
function Base.rem(
114112
@nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T})
115113
) where {T}
116114
return Ops.remainder(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs)
117115
end
118116

117+
# Based on https://github.com/JuliaLang/julia/blob/39255d47db7657950ff1c82137ecec5a70bae622/base/float.jl#L608-L617
118+
function Base.mod(
119+
@nospecialize(x::Reactant.TracedRNumber{T}), @nospecialize(y::Reactant.TracedRNumber{T})
120+
) where {T}
121+
r = rem(x, y)
122+
return ifelse(r == 0, copysign(r, y), ifelse((r > 0) (y > 0), r + y, r))
123+
end
124+
function Base.mod(
125+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number)
126+
) where {T}
127+
return mod(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
128+
end
129+
function Base.mod(
130+
@nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T})
131+
) where {T}
132+
return mod(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs)
133+
end
134+
119135
function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer}
120136
return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
121137
end
@@ -224,6 +240,12 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
224240
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
225241
)
226242
end
243+
function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
244+
return Ops.xor(
245+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
246+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
247+
)
248+
end
227249
Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x)
228250
end
229251
end
@@ -391,4 +413,20 @@ function Base.typed_hvncat(
391413
return Base.typed_hvncat(T, dims, row_first, xs...)
392414
end
393415

416+
for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64))
417+
@eval begin
418+
Base.signbit(x::TracedRNumber{$(Ti)}) = x < 0
419+
Base.signbit(x::TracedRNumber{$(Tf)}) = signbit(Ops.bitcast_convert($(Ti), x))
420+
end
421+
end
422+
Base.signbit(::TracedRNumber{<:Unsigned}) = ConcreteRNumber(false)
423+
424+
Base.copysign(x::TracedRNumber, y::TracedRNumber) = ifelse(signbit(y), -1, 1) * abs(x)
425+
function Base.copysign(x::TracedRNumber{T}, y::S) where {T,S<:Number}
426+
return copysign(x, TracedUtils.promote_to(TracedRNumber{S}, y))
394427
end
428+
function Base.copysign(x::S, y::TracedRNumber{T}) where {S<:Number,T}
429+
return copysign(TracedUtils.promote_to(TracedRNumber{S}, x), y)
430+
end
431+
432+
end # module TracedRNumberOverrides

test/basic.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -829,20 +829,37 @@ end
829829
a = [-1.1, 7.7, -3.3, 9.9, -5.5]
830830
b = [6.6, -2.2, -8.8, 4.4, -10.1]
831831

832-
# Currently broken because `mod` is JIT-ed to an HLO operator with same semantic as
833-
# Julia's `rem`, rather than `mod`.
834832
expected_mod = mod.(a, b)
835-
@test_broken Reactant.@jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b)))
836-
expected_mod
837-
@test_broken Reactant.@jit(mod.(a, Reactant.to_rarray(b))) expected_mod
838-
@test_broken Reactant.@jit(mod.(Reactant.to_rarray(a), b)) expected_mod
833+
@test Reactant.@jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
834+
@test Reactant.@jit(mod.(a, Reactant.to_rarray(b))) expected_mod
835+
@test Reactant.@jit(mod.(Reactant.to_rarray(a), b)) expected_mod
839836

840837
expected_rem = rem.(a, b)
841838
@test Reactant.@jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
842839
@test Reactant.@jit(rem.(a, Reactant.to_rarray(b))) expected_rem
843840
@test Reactant.@jit(rem.(Reactant.to_rarray(a), b)) expected_rem
844841
end
845842

843+
@testset "xor" begin
844+
for a in (true, false), b in (true, false)
845+
@test @jit(xor(ConcreteRNumber(a), ConcreteRNumber(b))) == xor(a, b)
846+
end
847+
end
848+
849+
@testset "signbit" begin
850+
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
851+
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x)
852+
end
853+
end
854+
855+
@testset "copysign" begin
856+
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
857+
# Make sure also the return type is correct
858+
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) ===
859+
copysign(a, b)
860+
end
861+
end
862+
846863
@testset "reduce integers" begin
847864
x = rand(Bool, 100)
848865
x_ra = Reactant.to_rarray(x)

0 commit comments

Comments
 (0)