Skip to content

Commit 8ea467f

Browse files
jumerckxwsmosesgithub-actions[bot]
authored
TracedStepRangeLen (#960)
* initial implementations * steprangelen tracing * find TracedStepRangeLen * fixes * formatting * cleanup * make constructors less restrictive, also trace offset and len. * more changes * searchsortedfirst fix * add boolean operators between Traced and non-traced args. fix getindex for traced index. * formatting * finer bool operators with tracedrnumber Co-authored-by: William Moses <[email protected]> * fix getindex * add test * cleanup * Adapt.parent_type * Update src/TracedRNumber.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: William Moses <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 66e4a54 commit 8ea467f

File tree

5 files changed

+364
-5
lines changed

5 files changed

+364
-5
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,17 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
419419
return res
420420
end
421421

422+
import Reactant.TracedRNumberOverrides.TracedStepRangeLen
423+
424+
function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen)
425+
return TracedStepRangeLen(
426+
Adapt.adapt(ReactantKernelAdaptor(), r.ref),
427+
Adapt.adapt(ReactantKernelAdaptor(), r.step),
428+
Adapt.adapt(ReactantKernelAdaptor(), r.len),
429+
Adapt.adapt(ReactantKernelAdaptor(), r.offset),
430+
)
431+
end
432+
422433
# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
423434
# and not the operation itself).
424435
struct LLVMFunc{F,tt}

src/Compiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ end
3333
end
3434

3535
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
36-
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
36+
(isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) &&
37+
return Base.getfield(obj, field)
3738
return Base.getindex(obj, field)
3839
end
3940

@@ -1472,7 +1473,6 @@ function codegen_flatten!(
14721473
is_sharded &&
14731474
runtime isa Val{:PJRT} &&
14741475
(flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...))
1475-
14761476
return flatten_names, flatten_code
14771477
end
14781478

src/TracedRNumber.jl

Lines changed: 259 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ module TracedRNumberOverrides
33
using ..Reactant:
44
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
55
using ReactantCore
6+
using Adapt
7+
8+
import Base.TwicePrecision
69

710
ReactantCore.is_traced(::TracedRNumber, seen) = true
811
ReactantCore.is_traced(::TracedRNumber) = true
@@ -262,6 +265,42 @@ function Base.ifelse(
262265
end
263266
end
264267

268+
function Base.:*(
269+
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
270+
) where {T<:TracedRNumber}
271+
zh, zl = Base.mul12(x.hi, y.hi)
272+
hi, lo = Base.canonicalize2(zh, (x.hi * y.lo + x.lo * y.hi) + zl)
273+
hi = ifelse(iszero(zh) | !isfinite(zh), zh, hi)
274+
lo = ifelse(iszero(zl) | !isfinite(zl), zl, lo)
275+
276+
return Base.TwicePrecision{T}(hi, lo)
277+
end
278+
279+
function Base.:+(
280+
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
281+
) where {T<:TracedRNumber}
282+
r = x.hi + y.hi
283+
@trace s = if abs(x.hi) > abs(y.hi)
284+
begin
285+
(((x.hi - r) + y.hi) + y.lo) + x.lo
286+
end
287+
else
288+
begin
289+
(((y.hi - r) + x.hi) + x.lo) + y.lo
290+
end
291+
end
292+
return Base.TwicePrecision(Base.canonicalize2(r, s)...)
293+
end
294+
295+
function Base.:*(x::TwicePrecision, v::TracedRNumber)
296+
@trace result = if v == 0
297+
TwicePrecision(x.hi * v, x.lo * v)
298+
else
299+
x * TwicePrecision(oftype(x.hi * v, v))
300+
end
301+
return result
302+
end
303+
265304
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
266305
T = promote_type(T1, T2)
267306
@eval begin
@@ -271,18 +310,54 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
271310
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
272311
)
273312
end
313+
function Base.:&(x::TracedRNumber{<:$(T1)}, y::$(T2))
314+
return Ops.and(
315+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
316+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
317+
)
318+
end
319+
function Base.:&(x::$(T1), y::TracedRNumber{<:$(T2)})
320+
return Ops.and(
321+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
322+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
323+
)
324+
end
274325
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
275326
return Ops.or(
276327
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
277328
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
278329
)
279330
end
331+
function Base.:|(x::TracedRNumber{<:$(T1)}, y::$(T2))
332+
return Ops.or(
333+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
334+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
335+
)
336+
end
337+
function Base.:|(x::$(T1), y::TracedRNumber{<:$(T2)})
338+
return Ops.or(
339+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
340+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
341+
)
342+
end
280343
function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
281344
return Ops.xor(
282345
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
283346
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
284347
)
285348
end
349+
function Base.xor(x::TracedRNumber{<:$(T1)}, y::$(T2))
350+
return Ops.xor(
351+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
352+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
353+
)
354+
end
355+
function Base.xor(x::$(T1), y::TracedRNumber{<:$(T2)})
356+
return Ops.xor(
357+
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
358+
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
359+
)
360+
end
286361
Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x)
287362
end
288363
end
@@ -424,9 +499,188 @@ function Base.getindex(
424499
return Base.unsafe_getindex(r, i)
425500
end
426501

502+
struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
503+
ref::R
504+
step::S
505+
len::L
506+
offset::L
507+
end
508+
509+
function Adapt.parent_type(::Type{TracedStepRangeLen{T,R,S,L}}) where {T,R,S,L}
510+
return TracedStepRangeLen{T,R,S,L}
511+
end
512+
513+
# constructors and interface implementation copied from range.jl
514+
function TracedStepRangeLen{T,R,S}(ref::R, step::S, len, offset=1) where {T,R,S}
515+
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
516+
end
517+
function TracedStepRangeLen(ref::R, step::S, len, offset=1) where {R,S}
518+
return TracedStepRangeLen{typeof(ref + zero(step)),R,S,typeof(len)}(
519+
ref, step, len, offset
520+
)
521+
end
522+
function TracedStepRangeLen{T}(
523+
ref::R, step::S, len::Integer, offset::Integer=1
524+
) where {T,R,S}
525+
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
526+
end
527+
528+
Base.isempty(r::TracedStepRangeLen) = length(r) == 0
529+
Base.step(r::TracedStepRangeLen) = r.step
530+
Base.step_hp(r::TracedStepRangeLen) = r.step
531+
Base.length(r::TracedStepRangeLen) = r.len
532+
Base.first(r::TracedStepRangeLen) = Base.unsafe_getindex(r, 1)
533+
Base.last(r::TracedStepRangeLen) = Base.unsafe_getindex(r, r.len)
534+
function Base.iterate(r::TracedStepRangeLen, i::Integer=1)
535+
@inline
536+
i += oneunit(i)
537+
length(r) < i && return nothing
538+
return Base.unsafe_getindex(r, i), i
539+
end
540+
541+
function _tracedsteprangelen_unsafe_getindex(
542+
r::AbstractRange{T}, i::Union{I,TracedRNumber{I}}
543+
) where {T,I}
544+
finalT = T
545+
offsetT = typeof(r.offset)
546+
if i isa TracedRNumber
547+
if !(T <: TracedRNumber)
548+
finalT = TracedRNumber{T}
549+
end
550+
if !(r.offset isa TracedRNumber)
551+
offsetT = TracedRNumber{offsetT}
552+
end
553+
end
554+
u = convert(offsetT, i) - r.offset
555+
return finalT(r.ref + u * r.step)
556+
end
557+
function Base.unsafe_getindex(r::TracedStepRangeLen, i::Integer)
558+
return _tracedsteprangelen_unsafe_getindex(r, i)
559+
end
560+
function Base.unsafe_getindex(r::TracedStepRangeLen, i::TracedRNumber{<:Integer})
561+
return _tracedsteprangelen_unsafe_getindex(r, i)
562+
end
563+
Base.getindex(r::TracedStepRangeLen, i::TracedRNumber) = Base.unsafe_getindex(r, i)
564+
function getindex(r::TracedStepRangeLen{T}, s::OrdinalRange{S}) where {T,S<:Integer}
565+
@inline
566+
@boundscheck checkbounds(r, s)
567+
568+
len = length(s)
569+
sstep = Base.step_hp(s)
570+
rstep = Base.step_hp(r)
571+
L = typeof(len)
572+
if S === Bool
573+
rstep *= one(sstep)
574+
if len == 0
575+
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
576+
elseif len == 1
577+
if first(s)
578+
return TracedStepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L))
579+
else
580+
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
581+
end
582+
else # len == 2
583+
return TracedStepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L))
584+
end
585+
else
586+
# Find closest approach to offset by s
587+
ind = LinearIndices(s)
588+
offset = L(
589+
max(min(1 + round(L, (r.offset - first(s)) / sstep), last(ind)), first(ind))
590+
)
591+
ref = Base._getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep)
592+
return TracedStepRangeLen{T}(ref, rstep * sstep, len, offset)
593+
end
594+
end
595+
function Base._getindex_hiprec(r::TracedStepRangeLen, i::Integer) # without rounding by T
596+
u = oftype(r.offset, i) - r.offset
597+
return r.ref + u * r.step
598+
end
599+
function Base.:(==)(r::T, s::T) where {T<:TracedStepRangeLen}
600+
return (isempty(r) & isempty(s)) |
601+
((first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s)))
602+
end
603+
604+
# TODO: if there ever comes a ReactantStepRange:
605+
# ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T}
606+
607+
function Base.:-(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L}
608+
return TracedStepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset)
609+
end
610+
611+
# TODO: promotion from StepRangeLen{T} to TracedStepRangeLen{T}?
612+
function Base.promote_rule(
613+
::Type{TracedStepRangeLen{T1,R1,S1,L1}}, ::Type{TracedStepRangeLen{T2,R2,S2,L2}}
614+
) where {T1,T2,R1,R2,S1,S2,L1,L2}
615+
R, S, L = promote_type(R1, R2), promote_type(S1, S2), promote_type(L1, L2)
616+
return Base.el_same(
617+
promote_type(T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L}
618+
)
619+
end
620+
TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L} = r
621+
function TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen) where {T,R,S,L}
622+
return TracedStepRangeLen{T,R,S,L}(
623+
convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset)
624+
)
625+
end
626+
function TracedStepRangeLen{T}(r::TracedStepRangeLen) where {T}
627+
return TracedStepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset)
628+
end
629+
function Base.promote_rule(
630+
a::Type{TracedStepRangeLen{T,R,S,L}}, ::Type{OR}
631+
) where {T,R,S,L,OR<:AbstractRange}
632+
return promote_rule(a, TracedStepRangeLen{eltype(OR),eltype(OR),eltype(OR),Int})
633+
end
634+
function TracedStepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L}
635+
return TracedStepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r))
636+
end
637+
function TracedStepRangeLen{T}(r::AbstractRange) where {T}
638+
return TracedStepRangeLen(T(first(r)), T(step(r)), length(r))
639+
end
640+
TracedStepRangeLen(r::AbstractRange) = TracedStepRangeLen{eltype(r)}(r)
641+
642+
function Base.promote_rule(
643+
::Type{LinRange{A,L}}, b::Type{TracedStepRangeLen{T2,R2,S2,L2}}
644+
) where {A,L,T2,R2,S2,L2}
645+
return promote_rule(TracedStepRangeLen{A,A,A,L}, b)
646+
end
647+
648+
function Base._reverse(r::TracedStepRangeLen, ::Colon)
649+
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
650+
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
651+
# `r.offset`
652+
offset = isempty(r) ? r.offset : length(r) - r.offset + 1
653+
return typeof(r)(r.ref, negate(r.step), length(r), offset)
654+
end
655+
656+
# TODO: +, - for TracedStepRangeLen (see Base._define_range_op)
657+
658+
function (::Type{T})(x::TwicePrecision) where {T<:Reactant.TracedRNumber}
659+
return (T(x.hi) + T(x.lo))::T
660+
end
661+
662+
function (::Type{T})(x::TwicePrecision) where {T<:Reactant.ConcreteRNumber}
663+
return Reactant.ConcreteRNumber(T(x.hi) - T(x.lo))::T
664+
end
665+
666+
Base.nbitslen(r::TracedStepRangeLen) = Base.nbitslen(eltype(r), length(r), r.offset)
667+
function TracedStepRangeLen(
668+
ref::TwicePrecision{T}, step::TwicePrecision{T}, len, offset=1
669+
) where {T}
670+
return TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}(ref, step, len, offset)
671+
end
672+
function Base.step(r::TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}) where {T}
673+
return T(r.step)
674+
end
675+
427676
# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
428677
function Base.unsafe_getindex(
429-
r::Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
678+
r::Union{
679+
Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
680+
TracedStepRangeLen{
681+
T,<:Base.TwicePrecision,<:Base.TwicePrecision,<:Base.TwicePrecision
682+
},
683+
},
430684
i::TracedRNumber{<:Integer},
431685
) where {T}
432686
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@@ -449,7 +703,9 @@ function Base.unsafe_getindex(
449703
end
450704

451705
function Base.searchsortedfirst(
452-
a::AbstractRange{<:Real}, x::TracedRNumber{<:Real}, o::Base.DirectOrdering
706+
a::AbstractRange{<:Union{Real,TracedRNumber}},
707+
x::TracedRNumber{<:Real},
708+
o::Base.DirectOrdering,
453709
)::TracedRNumber{keytype(a)}
454710

455711
# require_one_based_indexing(a)
@@ -460,7 +716,7 @@ function Base.searchsortedfirst(
460716
!Base.Order.lt(o, f, x),
461717
1,
462718
ifelse(
463-
h == 0 || Base.Order.lt(o, l, x),
719+
(h == 0) | Base.Order.lt(o, l, x),
464720
length(a) + 1,
465721
ifelse(Base.Order.lt(o, a[n], x), n + 1, n),
466722
),

0 commit comments

Comments
 (0)