@@ -171,15 +171,27 @@ function Base.mod(
171
171
r = rem (x, y)
172
172
return ifelse (r == 0 , copysign (r, y), ifelse ((r > 0 ) ⊻ (y > 0 ), r + y, r))
173
173
end
174
- function Base. mod (
175
- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: Number )
174
+
175
+ function Base. mod1 (
176
+ @nospecialize (x:: Reactant.TracedRNumber{T} ), @nospecialize (y:: Reactant.TracedRNumber{T} )
176
177
) where {T}
177
- return mod (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
178
+ m = mod (x, y)
179
+ return ifelse (m == 0 , y, m)
178
180
end
179
- function Base. mod (
180
- @nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRNumber{T} )
181
- ) where {T}
182
- return mod (TracedUtils. promote_to (TracedRNumber{T}, lhs), rhs)
181
+
182
+ for op in (:mod , :mod1 )
183
+ @eval begin
184
+ function Base. $op (
185
+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: Number )
186
+ ) where {T}
187
+ return mod (lhs, TracedUtils. promote_to (TracedRNumber{T}, rhs))
188
+ end
189
+ function Base. $op (
190
+ @nospecialize (lhs:: Number ), @nospecialize (rhs:: TracedRNumber{T} )
191
+ ) where {T}
192
+ return mod (TracedUtils. promote_to (TracedRNumber{T}, lhs), rhs)
193
+ end
194
+ end
183
195
end
184
196
185
197
function Base. div (@nospecialize (lhs:: TracedRNumber{T} ), rhs) where {T<: Integer }
@@ -938,7 +950,9 @@ for (Ti, Tf) in ((Int16, Float16), (Int32, Float32), (Int64, Float64))
938
950
end
939
951
Base. signbit (:: TracedRNumber{<:Unsigned} ) = ConcretePJRTNumber (false )
940
952
941
- Base. copysign (x:: TracedRNumber , y:: TracedRNumber ) = ifelse (signbit (y), - 1 , 1 ) * abs (x)
953
+ function Base. copysign (x:: TracedRNumber , y:: TracedRNumber )
954
+ return ifelse (signbit (y), - one (x), one (x)) * abs (x)
955
+ end
942
956
function Base. copysign (x:: TracedRNumber{T} , y:: S ) where {T,S<: Number }
943
957
return copysign (x, TracedUtils. promote_to (TracedRNumber{S}, y))
944
958
end
0 commit comments