@@ -3,6 +3,9 @@ module TracedRNumberOverrides
3
3
using .. Reactant:
4
4
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
5
5
using ReactantCore
6
+ using Adapt
7
+
8
+ import Base. TwicePrecision
6
9
7
10
ReactantCore. is_traced (:: TracedRNumber , seen) = true
8
11
ReactantCore. is_traced (:: TracedRNumber ) = true
@@ -262,6 +265,42 @@ function Base.ifelse(
262
265
end
263
266
end
264
267
268
+ function Base.:* (
269
+ x:: Base.TwicePrecision{T} , y:: Base.TwicePrecision{T}
270
+ ) where {T<: TracedRNumber }
271
+ zh, zl = Base. mul12 (x. hi, y. hi)
272
+ hi, lo = Base. canonicalize2 (zh, (x. hi * y. lo + x. lo * y. hi) + zl)
273
+ hi = ifelse (iszero (zh) | ! isfinite (zh), zh, hi)
274
+ lo = ifelse (iszero (zl) | ! isfinite (zl), zl, lo)
275
+
276
+ return Base. TwicePrecision {T} (hi, lo)
277
+ end
278
+
279
+ function Base.:+ (
280
+ x:: Base.TwicePrecision{T} , y:: Base.TwicePrecision{T}
281
+ ) where {T<: TracedRNumber }
282
+ r = x. hi + y. hi
283
+ @trace s = if abs (x. hi) > abs (y. hi)
284
+ begin
285
+ (((x. hi - r) + y. hi) + y. lo) + x. lo
286
+ end
287
+ else
288
+ begin
289
+ (((y. hi - r) + x. hi) + x. lo) + y. lo
290
+ end
291
+ end
292
+ return Base. TwicePrecision (Base. canonicalize2 (r, s)... )
293
+ end
294
+
295
+ function Base.:* (x:: TwicePrecision , v:: TracedRNumber )
296
+ @trace result = if v == 0
297
+ TwicePrecision (x. hi * v, x. lo * v)
298
+ else
299
+ x * TwicePrecision (oftype (x. hi * v, v))
300
+ end
301
+ return result
302
+ end
303
+
265
304
for (T1, T2) in zip ((Bool, Integer), (Bool, Integer))
266
305
T = promote_type (T1, T2)
267
306
@eval begin
@@ -271,18 +310,54 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
271
310
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
272
311
)
273
312
end
313
+ function Base.:& (x:: TracedRNumber{<:$(T1)} , y:: $ (T2))
314
+ return Ops. and (
315
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
316
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
317
+ )
318
+ end
319
+ function Base.:& (x:: $ (T1), y:: TracedRNumber{<:$(T2)} )
320
+ return Ops. and (
321
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
322
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
323
+ )
324
+ end
274
325
function Base.:| (x:: TracedRNumber{<:$(T1)} , y:: TracedRNumber{<:$(T2)} )
275
326
return Ops. or (
276
327
TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
277
328
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
278
329
)
279
330
end
331
+ function Base.:| (x:: TracedRNumber{<:$(T1)} , y:: $ (T2))
332
+ return Ops. or (
333
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
334
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
335
+ )
336
+ end
337
+ function Base.:| (x:: $ (T1), y:: TracedRNumber{<:$(T2)} )
338
+ return Ops. or (
339
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
340
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
341
+ )
342
+ end
280
343
function Base. xor (x:: TracedRNumber{<:$(T1)} , y:: TracedRNumber{<:$(T2)} )
281
344
return Ops. xor (
282
345
TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
283
346
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
284
347
)
285
348
end
349
+ function Base. xor (x:: TracedRNumber{<:$(T1)} , y:: $ (T2))
350
+ return Ops. xor (
351
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
352
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
353
+ )
354
+ end
355
+ function Base. xor (x:: $ (T1), y:: TracedRNumber{<:$(T2)} )
356
+ return Ops. xor (
357
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
358
+ TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
359
+ )
360
+ end
286
361
Base.:! (x:: TracedRNumber{<:$(T1)} ) = Ops. not (x)
287
362
end
288
363
end
@@ -424,9 +499,188 @@ function Base.getindex(
424
499
return Base. unsafe_getindex (r, i)
425
500
end
426
501
502
+ struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
503
+ ref:: R
504
+ step:: S
505
+ len:: L
506
+ offset:: L
507
+ end
508
+
509
+ function Adapt. parent_type (:: Type{TracedStepRangeLen{T,R,S,L}} ) where {T,R,S,L}
510
+ return TracedStepRangeLen{T,R,S,L}
511
+ end
512
+
513
+ # constructors and interface implementation copied from range.jl
514
+ function TracedStepRangeLen {T,R,S} (ref:: R , step:: S , len, offset= 1 ) where {T,R,S}
515
+ return TracedStepRangeLen {T,R,S,typeof(len)} (ref, step, len, offset)
516
+ end
517
+ function TracedStepRangeLen (ref:: R , step:: S , len, offset= 1 ) where {R,S}
518
+ return TracedStepRangeLen {typeof(ref + zero(step)),R,S,typeof(len)} (
519
+ ref, step, len, offset
520
+ )
521
+ end
522
+ function TracedStepRangeLen {T} (
523
+ ref:: R , step:: S , len:: Integer , offset:: Integer = 1
524
+ ) where {T,R,S}
525
+ return TracedStepRangeLen {T,R,S,typeof(len)} (ref, step, len, offset)
526
+ end
527
+
528
+ Base. isempty (r:: TracedStepRangeLen ) = length (r) == 0
529
+ Base. step (r:: TracedStepRangeLen ) = r. step
530
+ Base. step_hp (r:: TracedStepRangeLen ) = r. step
531
+ Base. length (r:: TracedStepRangeLen ) = r. len
532
+ Base. first (r:: TracedStepRangeLen ) = Base. unsafe_getindex (r, 1 )
533
+ Base. last (r:: TracedStepRangeLen ) = Base. unsafe_getindex (r, r. len)
534
+ function Base. iterate (r:: TracedStepRangeLen , i:: Integer = 1 )
535
+ @inline
536
+ i += oneunit (i)
537
+ length (r) < i && return nothing
538
+ return Base. unsafe_getindex (r, i), i
539
+ end
540
+
541
+ function _tracedsteprangelen_unsafe_getindex (
542
+ r:: AbstractRange{T} , i:: Union{I,TracedRNumber{I}}
543
+ ) where {T,I}
544
+ finalT = T
545
+ offsetT = typeof (r. offset)
546
+ if i isa TracedRNumber
547
+ if ! (T <: TracedRNumber )
548
+ finalT = TracedRNumber{T}
549
+ end
550
+ if ! (r. offset isa TracedRNumber)
551
+ offsetT = TracedRNumber{offsetT}
552
+ end
553
+ end
554
+ u = convert (offsetT, i) - r. offset
555
+ return finalT (r. ref + u * r. step)
556
+ end
557
+ function Base. unsafe_getindex (r:: TracedStepRangeLen , i:: Integer )
558
+ return _tracedsteprangelen_unsafe_getindex (r, i)
559
+ end
560
+ function Base. unsafe_getindex (r:: TracedStepRangeLen , i:: TracedRNumber{<:Integer} )
561
+ return _tracedsteprangelen_unsafe_getindex (r, i)
562
+ end
563
+ Base. getindex (r:: TracedStepRangeLen , i:: TracedRNumber ) = Base. unsafe_getindex (r, i)
564
+ function getindex (r:: TracedStepRangeLen{T} , s:: OrdinalRange{S} ) where {T,S<: Integer }
565
+ @inline
566
+ @boundscheck checkbounds (r, s)
567
+
568
+ len = length (s)
569
+ sstep = Base. step_hp (s)
570
+ rstep = Base. step_hp (r)
571
+ L = typeof (len)
572
+ if S === Bool
573
+ rstep *= one (sstep)
574
+ if len == 0
575
+ return TracedStepRangeLen {T} (first (r), rstep, zero (L), oneunit (L))
576
+ elseif len == 1
577
+ if first (s)
578
+ return TracedStepRangeLen {T} (first (r), rstep, oneunit (L), oneunit (L))
579
+ else
580
+ return TracedStepRangeLen {T} (first (r), rstep, zero (L), oneunit (L))
581
+ end
582
+ else # len == 2
583
+ return TracedStepRangeLen {T} (last (r), rstep, oneunit (L), oneunit (L))
584
+ end
585
+ else
586
+ # Find closest approach to offset by s
587
+ ind = LinearIndices (s)
588
+ offset = L (
589
+ max (min (1 + round (L, (r. offset - first (s)) / sstep), last (ind)), first (ind))
590
+ )
591
+ ref = Base. _getindex_hiprec (r, first (s) + (offset - oneunit (offset)) * sstep)
592
+ return TracedStepRangeLen {T} (ref, rstep * sstep, len, offset)
593
+ end
594
+ end
595
+ function Base. _getindex_hiprec (r:: TracedStepRangeLen , i:: Integer ) # without rounding by T
596
+ u = oftype (r. offset, i) - r. offset
597
+ return r. ref + u * r. step
598
+ end
599
+ function Base.:(== )(r:: T , s:: T ) where {T<: TracedStepRangeLen }
600
+ return (isempty (r) & isempty (s)) |
601
+ ((first (r) == first (s)) & (length (r) == length (s)) & (last (r) == last (s)))
602
+ end
603
+
604
+ # TODO : if there ever comes a ReactantStepRange:
605
+ # ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T}
606
+
607
+ function Base.:- (r:: TracedStepRangeLen{T,R,S,L} ) where {T,R,S,L}
608
+ return TracedStepRangeLen {T,R,S,L} (- r. ref, - r. step, r. len, r. offset)
609
+ end
610
+
611
+ # TODO : promotion from StepRangeLen{T} to TracedStepRangeLen{T}?
612
+ function Base. promote_rule (
613
+ :: Type{TracedStepRangeLen{T1,R1,S1,L1}} , :: Type{TracedStepRangeLen{T2,R2,S2,L2}}
614
+ ) where {T1,T2,R1,R2,S1,S2,L1,L2}
615
+ R, S, L = promote_type (R1, R2), promote_type (S1, S2), promote_type (L1, L2)
616
+ return Base. el_same (
617
+ promote_type (T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L}
618
+ )
619
+ end
620
+ TracedStepRangeLen {T,R,S,L} (r:: TracedStepRangeLen{T,R,S,L} ) where {T,R,S,L} = r
621
+ function TracedStepRangeLen {T,R,S,L} (r:: TracedStepRangeLen ) where {T,R,S,L}
622
+ return TracedStepRangeLen {T,R,S,L} (
623
+ convert (R, r. ref), convert (S, r. step), convert (L, r. len), convert (L, r. offset)
624
+ )
625
+ end
626
+ function TracedStepRangeLen {T} (r:: TracedStepRangeLen ) where {T}
627
+ return TracedStepRangeLen (convert (T, r. ref), convert (T, r. step), r. len, r. offset)
628
+ end
629
+ function Base. promote_rule (
630
+ a:: Type{TracedStepRangeLen{T,R,S,L}} , :: Type{OR}
631
+ ) where {T,R,S,L,OR<: AbstractRange }
632
+ return promote_rule (a, TracedStepRangeLen{eltype (OR),eltype (OR),eltype (OR),Int})
633
+ end
634
+ function TracedStepRangeLen {T,R,S,L} (r:: AbstractRange ) where {T,R,S,L}
635
+ return TracedStepRangeLen {T,R,S,L} (R (first (r)), S (step (r)), length (r))
636
+ end
637
+ function TracedStepRangeLen {T} (r:: AbstractRange ) where {T}
638
+ return TracedStepRangeLen (T (first (r)), T (step (r)), length (r))
639
+ end
640
+ TracedStepRangeLen (r:: AbstractRange ) = TracedStepRangeLen {eltype(r)} (r)
641
+
642
+ function Base. promote_rule (
643
+ :: Type{LinRange{A,L}} , b:: Type{TracedStepRangeLen{T2,R2,S2,L2}}
644
+ ) where {A,L,T2,R2,S2,L2}
645
+ return promote_rule (TracedStepRangeLen{A,A,A,L}, b)
646
+ end
647
+
648
+ function Base. _reverse (r:: TracedStepRangeLen , :: Colon )
649
+ # If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
650
+ # invalid. As `reverse(r)` is also empty, any offset would work so we keep
651
+ # `r.offset`
652
+ offset = isempty (r) ? r. offset : length (r) - r. offset + 1
653
+ return typeof (r)(r. ref, negate (r. step), length (r), offset)
654
+ end
655
+
656
+ # TODO : +, - for TracedStepRangeLen (see Base._define_range_op)
657
+
658
+ function (:: Type{T} )(x:: TwicePrecision ) where {T<: Reactant.TracedRNumber }
659
+ return (T (x. hi) + T (x. lo)):: T
660
+ end
661
+
662
+ function (:: Type{T} )(x:: TwicePrecision ) where {T<: Reactant.ConcreteRNumber }
663
+ return Reactant. ConcreteRNumber (T (x. hi) - T (x. lo)):: T
664
+ end
665
+
666
+ Base. nbitslen (r:: TracedStepRangeLen ) = Base. nbitslen (eltype (r), length (r), r. offset)
667
+ function TracedStepRangeLen (
668
+ ref:: TwicePrecision{T} , step:: TwicePrecision{T} , len, offset= 1
669
+ ) where {T}
670
+ return TracedStepRangeLen {T,TwicePrecision{T},TwicePrecision{T}} (ref, step, len, offset)
671
+ end
672
+ function Base. step (r:: TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}} ) where {T}
673
+ return T (r. step)
674
+ end
675
+
427
676
# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
428
677
function Base. unsafe_getindex (
429
- r:: Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision} ,
678
+ r:: Union {
679
+ Base. StepRangeLen{T,<: Base.TwicePrecision ,<: Base.TwicePrecision },
680
+ TracedStepRangeLen{
681
+ T,<: Base.TwicePrecision ,<: Base.TwicePrecision ,<: Base.TwicePrecision
682
+ },
683
+ },
430
684
i:: TracedRNumber{<:Integer} ,
431
685
) where {T}
432
686
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@@ -449,7 +703,9 @@ function Base.unsafe_getindex(
449
703
end
450
704
451
705
function Base. searchsortedfirst (
452
- a:: AbstractRange{<:Real} , x:: TracedRNumber{<:Real} , o:: Base.DirectOrdering
706
+ a:: AbstractRange{<:Union{Real,TracedRNumber}} ,
707
+ x:: TracedRNumber{<:Real} ,
708
+ o:: Base.DirectOrdering ,
453
709
):: TracedRNumber{keytype(a)}
454
710
455
711
# require_one_based_indexing(a)
@@ -460,7 +716,7 @@ function Base.searchsortedfirst(
460
716
! Base. Order. lt (o, f, x),
461
717
1 ,
462
718
ifelse (
463
- h == 0 | | Base. Order. lt (o, l, x),
719
+ ( h == 0 ) | Base. Order. lt (o, l, x),
464
720
length (a) + 1 ,
465
721
ifelse (Base. Order. lt (o, a[n], x), n + 1 , n),
466
722
),
0 commit comments