@@ -94,7 +94,6 @@ for (jlop, hloop) in (
94
94
(:(Base.:* ), :multiply ),
95
95
(:(Base.:/ ), :divide ),
96
96
(:(Base.:^ ), :power ),
97
- (:(Base. mod), :remainder ),
98
97
(:(Base. rem), :remainder ),
99
98
)
100
99
@eval function $ (jlop)(
@@ -109,13 +108,30 @@ function Base.rem(
109
108
) where {T}
110
109
return Ops. remainder (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
111
110
end
112
-
113
111
function Base. rem (
114
112
@nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRNumber{T} )
115
113
) where {T}
116
114
return Ops. remainder (TracedUtils. promote_to (TracedRNumber{T}, lhs), rhs)
117
115
end
118
116
117
+ # Based on https://github.com/JuliaLang/julia/blob/39255d47db7657950ff1c82137ecec5a70bae622/base/float.jl#L608-L617
118
+ function Base. mod (
119
+ @nospecialize (x:: Reactant.TracedRNumber{T} ), @nospecialize (y:: Reactant.TracedRNumber{T} )
120
+ ) where {T}
121
+ r = rem (x, y)
122
+ return ifelse (r == 0 , copysign (r, y), ifelse ((r > 0 ) ⊻ (y > 0 ), r + y, r))
123
+ end
124
+ function Base. mod (
125
+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: Number )
126
+ ) where {T}
127
+ return mod (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
128
+ end
129
+ function Base. mod (
130
+ @nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRNumber{T} )
131
+ ) where {T}
132
+ return mod (TracedUtils. promote_to (TracedRNumber{T}, lhs), rhs)
133
+ end
134
+
119
135
function Base. div (@nospecialize (lhs:: TracedRNumber{T} ), rhs) where {T<: Integer }
120
136
return Ops. divide (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
121
137
end
@@ -224,6 +240,12 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
224
240
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
225
241
)
226
242
end
243
+ function Base. xor (x:: TracedRNumber{<:$(T1)} , y:: TracedRNumber{<:$(T2)} )
244
+ return Ops. xor (
245
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
246
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
247
+ )
248
+ end
227
249
Base.:! (x:: TracedRNumber{<:$(T1)} ) = Ops. not (x)
228
250
end
229
251
end
@@ -391,4 +413,20 @@ function Base.typed_hvncat(
391
413
return Base. typed_hvncat (T, dims, row_first, xs... )
392
414
end
393
415
416
+ for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64))
417
+ @eval begin
418
+ Base. signbit (x:: TracedRNumber{$(Ti)} ) = x < 0
419
+ Base. signbit (x:: TracedRNumber{$(Tf)} ) = signbit (Ops. bitcast_convert ($ (Ti), x))
420
+ end
421
+ end
422
+ Base. signbit (:: TracedRNumber{<:Unsigned} ) = ConcreteRNumber (false )
423
+
424
+ Base. copysign (x:: TracedRNumber , y:: TracedRNumber ) = ifelse (signbit (y), - 1 , 1 ) * abs (x)
425
+ function Base. copysign (x:: TracedRNumber{T} , y:: S ) where {T,S<: Number }
426
+ return copysign (x, TracedUtils. promote_to (TracedRNumber{S}, y))
394
427
end
428
+ function Base. copysign (x:: S , y:: TracedRNumber{T} ) where {S<: Number ,T}
429
+ return copysign (TracedUtils. promote_to (TracedRNumber{S}, x), y)
430
+ end
431
+
432
+ end # module TracedRNumberOverrides
0 commit comments