Skip to content

Commit 74d765a

Browse files
jumerckxgiordanowsmoses
authored
TracedUnitRange (#1004)
* copy-paste definitions from range.jl * add tracing methods * fixes * test indexing with Julia integer * fixes and test for tracedrnumber getindex * formatting * only convert tracedranges if the type eltype actually changes * Define `unitrange_last` with `ifelse` * Define rounding method for integers * fix creation of traced values * disambiguate getindex * RNumber should've been ReactantPrimitive * `Adapt.parent_type` * fix * same fix for TraceStepRangeLen * more getindex * fix for 1.10 * Formatting --------- Co-authored-by: Mosè Giordano <[email protected]> Co-authored-by: William S. Moses <[email protected]>
1 parent bd00a29 commit 74d765a

File tree

4 files changed

+191
-71
lines changed

4 files changed

+191
-71
lines changed

src/TracedRNumber.jl

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,14 @@ function Base.float(x::TracedRNumber{T}) where {T}
437437
return TracedUtils.promote_to(TracedRNumber{float(T)}, x)
438438
end
439439

440-
using Reactant: ReactantFloat
440+
using Reactant: ReactantFloat, ReactantInt
441441

442442
Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A)
443+
Base.round(A::TracedRNumber{<:ReactantInt}) = A
443444
Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A)
445+
Base.floor(A::TracedRNumber{<:ReactantInt}) = A
444446
Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A)
447+
Base.ceil(A::TracedRNumber{<:ReactantInt}) = A
445448

446449
function Base.unsafe_trunc(
447450
T::Type{<:Reactant.ReactantInt}, x::TracedRNumber{<:Reactant.ReactantFloat}
@@ -499,6 +502,81 @@ function Base.getindex(
499502
return Base.unsafe_getindex(r, i)
500503
end
501504

505+
function unitrange_last(start::Integer, stop::Integer)
506+
return ifelse(stop >= start, stop, convert(typeof(stop), start - oneunit(start - stop)))
507+
end
508+
function unitrange_last(start, stop)
509+
return ifelse(
510+
stop >= start,
511+
convert(typeof(stop), start + floor(stop - start)),
512+
convert(typeof(stop), start - oneunit(start - stop)),
513+
)
514+
end
515+
516+
struct TracedUnitRange{T} <: AbstractUnitRange{T}
517+
start::T
518+
stop::T
519+
function TracedUnitRange{T}(start::T, stop::T) where {T}
520+
return new(start, unitrange_last(start, stop))
521+
end
522+
end
523+
function Adapt.parent_type(::Type{TracedUnitRange{T}}) where {T}
524+
return TracedUnitRange{T}
525+
end
526+
function TracedUnitRange{T}(start, stop) where {T}
527+
return TracedUnitRange{T}(convert(T, start), convert(T, stop))
528+
end
529+
TracedUnitRange(start::T, stop::T) where {T} = TracedUnitRange{T}(start, stop)
530+
function TracedUnitRange(start, stop)
531+
startstop_promoted = promote(start, stop)
532+
not_sametype((start, stop), startstop_promoted)
533+
return TracedUnitRange(startstop_promoted...)
534+
end
535+
function Base._in_unit_range(
536+
v::TracedUnitRange, val, i::Union{Integer,TracedRNumber{<:Integer}}
537+
)
538+
return (i > 0) & (val <= v.stop) & (val >= v.start)
539+
end
540+
541+
function _traced_unitrange_getindex(v::TracedUnitRange{T}, i) where {T}
542+
val = convert(T, v.start + (i - oneunit(i)))
543+
# TODO: we should have error messages at some point.
544+
# @boundscheck Base._in_unit_range(v, val, i) || throw_boundserror(v, i)
545+
return val
546+
end
547+
548+
function Base._getindex(v::TracedUnitRange, i::TracedRNumber{<:Integer})
549+
return _traced_unitrange_getindex(v, i)
550+
end
551+
Base.getindex(v::TracedUnitRange, i::Integer) = _traced_unitrange_getindex(v, i)
552+
Base.getindex(r::TracedUnitRange, i::TracedRNumber) = Base._getindex(r, i)
553+
function Base.getindex(r::Base.UnitRange, i::I) where {I<:TracedRNumber{<:Integer}}
554+
val = convert(I, r.start + (i - oneunit(i)))
555+
# TODO: we should have error messages at some point.
556+
# @boundscheck Base._in_unit_range(v, val, i) || throw_boundserror(v, i)
557+
return val
558+
end
559+
560+
function Base.promote_rule(
561+
a::Type{TracedUnitRange{T1}}, b::Type{TracedUnitRange{T2}}
562+
) where {T1,T2}
563+
return el_same(promote_type(T1, T2), a, b)
564+
end
565+
TracedUnitRange{T}(r::TracedUnitRange{T}) where {T<:Real} = r
566+
TracedUnitRange{T}(r::TracedUnitRange) where {T<:Real} = TracedUnitRange{T}(r.start, r.stop)
567+
568+
function Base.promote_rule(
569+
a::Type{TracedUnitRange{T1}}, ::Type{UR}
570+
) where {T1,UR<:AbstractUnitRange}
571+
return promote_rule(a, TracedUnitRange{eltype(UR)})
572+
end
573+
function TracedUnitRange{T}(r::AbstractUnitRange) where {T<:Real}
574+
return TracedUnitRange{T}(first(r), last(r))
575+
end
576+
TracedUnitRange(r::AbstractUnitRange) = TracedUnitRange(first(r), last(r))
577+
578+
AbstractUnitRange{T}(r::TracedUnitRange) where {T} = TracedUnitRange{T}(r)
579+
502580
struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
503581
ref::R
504582
step::S

src/TracedUtils.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -312,23 +312,29 @@ function make_mlir_fn(
312312

313313
seen_results = OrderedIdDict()
314314

315-
traced_result = Reactant.make_tracer(
316-
seen_results,
317-
result,
318-
(:result,),
319-
concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath;
320-
runtime,
321-
)
322-
323-
# marks buffers to be donated
324-
for i in 1:N
325-
Reactant.make_tracer(
315+
MLIR.IR.activate!(fnbody)
316+
traced_result = try
317+
traced_result = Reactant.make_tracer(
326318
seen_results,
327-
traced_args[i],
328-
concretein ? (:resargs, i) : (),
329-
Reactant.NoStopTracedTrack;
319+
result,
320+
(:result,),
321+
concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath;
330322
runtime,
331323
)
324+
325+
# marks buffers to be donated
326+
for i in 1:N
327+
Reactant.make_tracer(
328+
seen_results,
329+
traced_args[i],
330+
concretein ? (:resargs, i) : (),
331+
Reactant.NoStopTracedTrack;
332+
runtime,
333+
)
334+
end
335+
traced_result
336+
finally
337+
MLIR.IR.deactivate!(fnbody)
332338
end
333339

334340
linear_results = Reactant.TracedType[]

src/Tracing.jl

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,33 +1868,68 @@ end
18681868
end
18691869

18701870
function Reactant.traced_type_inner(
1871-
@nospecialize(RT::Type{<:StepRangeLen}),
1871+
@nospecialize(RT::Type{<:UnitRange{<:ReactantPrimitive}}),
18721872
seen,
18731873
mode::Reactant.TraceMode,
18741874
track_numbers::Type,
18751875
sharding,
18761876
runtime,
18771877
)
1878-
if !(Number <: track_numbers)
1879-
modified_track_numbers = Number
1878+
(T,) = RT.parameters
1879+
newT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding, runtime)
1880+
if T == newT
1881+
return RT
18801882
else
1881-
modified_track_numbers = track_numbers
1883+
return TracedRNumberOverrides.TracedUnitRange{newT}
18821884
end
1885+
end
1886+
1887+
function Reactant.make_tracer(
1888+
seen,
1889+
@nospecialize(prev::UnitRange),
1890+
@nospecialize(path),
1891+
mode;
1892+
@nospecialize(sharding = Sharding.NoSharding()),
1893+
kwargs...,
1894+
)
1895+
Reactant.Sharding.is_sharded(sharding) && error("Cannot specify sharding for UnitRange")
1896+
if mode == Reactant.TracedToTypes
1897+
push!(path, Core.Typeof(prev))
1898+
make_tracer(seen, prev.start, path, mode; kwargs...)
1899+
make_tracer(seen, prev.stop, path, mode; kwargs...)
1900+
return nothing
1901+
end
1902+
newstart = Reactant.make_tracer(
1903+
seen, prev.start, Reactant.append_path(path, :start), mode; kwargs...
1904+
)
1905+
newstop = Reactant.make_tracer(
1906+
seen, prev.stop, Reactant.append_path(path, :stop), mode; kwargs...
1907+
)
1908+
if typeof(newstart) == typeof(prev.start) && typeof(newstop) == typeof(prev.stop)
1909+
return prev
1910+
else
1911+
return TracedRNumberOverrides.TracedUnitRange(newstart, newstop)
1912+
end
1913+
end
1914+
1915+
function Reactant.traced_type_inner(
1916+
@nospecialize(RT::Type{<:StepRangeLen}),
1917+
seen,
1918+
mode::Reactant.TraceMode,
1919+
track_numbers::Type,
1920+
sharding,
1921+
runtime,
1922+
)
18831923
T, R, S, L = RT.parameters
1884-
return TracedRNumberOverrides.TracedStepRangeLen{
1885-
Reactant.traced_type_inner(
1886-
T, seen, mode, modified_track_numbers, sharding, runtime
1887-
),
1888-
Reactant.traced_type_inner(
1889-
R, seen, mode, modified_track_numbers, sharding, runtime
1890-
),
1891-
Reactant.traced_type_inner(
1892-
S, seen, mode, modified_track_numbers, sharding, runtime
1893-
),
1894-
Reactant.traced_type_inner(
1895-
L, seen, mode, modified_track_numbers, sharding, runtime
1896-
),
1897-
}
1924+
newT = Reactant.traced_type_inner(T, seen, mode, track_numbers, sharding, runtime)
1925+
newR = Reactant.traced_type_inner(R, seen, mode, track_numbers, sharding, runtime)
1926+
newS = Reactant.traced_type_inner(S, seen, mode, track_numbers, sharding, runtime)
1927+
newL = Reactant.traced_type_inner(L, seen, mode, track_numbers, sharding, runtime)
1928+
if T == newT && R == newR && S == newS && L == newL
1929+
return RT
1930+
else
1931+
return TracedRNumberOverrides.TracedStepRangeLen{newT,newR,newS,newL}
1932+
end
18981933
end
18991934

19001935
function Reactant.make_tracer(
@@ -1909,44 +1944,30 @@ function Reactant.make_tracer(
19091944
error("Cannot specify sharding for StepRangeLen")
19101945
if mode == Reactant.TracedToTypes
19111946
push!(path, Core.Typeof(prev))
1912-
make_tracer(seen, prev.ref, path, mode; kwargs...)
1913-
make_tracer(seen, prev.step, path, mode; kwargs...)
1914-
make_tracer(seen, prev.len, path, mode; kwargs...)
1915-
make_tracer(seen, prev.offset, path, mode; kwargs...)
1947+
make_tracer(seen, prev.ref, path, mode; sharding, kwargs...)
1948+
make_tracer(seen, prev.step, path, mode; sharding, kwargs...)
1949+
make_tracer(seen, prev.len, path, mode; sharding, kwargs...)
1950+
make_tracer(seen, prev.offset, path, mode; sharding, kwargs...)
19161951
return nothing
19171952
end
1918-
return TracedRNumberOverrides.TracedStepRangeLen(
1919-
Reactant.make_tracer(
1920-
seen,
1921-
prev.ref,
1922-
Reactant.append_path(path, :ref),
1923-
mode;
1924-
kwargs...,
1925-
track_numbers=Number,
1926-
),
1927-
Reactant.make_tracer(
1928-
seen,
1929-
prev.step,
1930-
Reactant.append_path(path, :step),
1931-
mode;
1932-
kwargs...,
1933-
track_numbers=Number,
1934-
),
1935-
Reactant.make_tracer(
1936-
seen,
1937-
prev.len,
1938-
Reactant.append_path(path, :len),
1939-
mode;
1940-
kwargs...,
1941-
track_numbers=Number,
1942-
),
1943-
Reactant.make_tracer(
1944-
seen,
1945-
prev.offset,
1946-
Reactant.append_path(path, :offset),
1947-
mode;
1948-
kwargs...,
1949-
track_numbers=Number,
1950-
),
1953+
newref = Reactant.make_tracer(
1954+
seen, prev.ref, Reactant.append_path(path, :ref), mode; sharding, kwargs...
1955+
)
1956+
newstep = Reactant.make_tracer(
1957+
seen, prev.step, Reactant.append_path(path, :step), mode; sharding, kwargs...
1958+
)
1959+
newlen = Reactant.make_tracer(
1960+
seen, prev.len, Reactant.append_path(path, :len), mode; sharding, kwargs...
1961+
)
1962+
newoffset = Reactant.make_tracer(
1963+
seen, prev.offset, Reactant.append_path(path, :offset), mode; sharding, kwargs...
19511964
)
1965+
if typeof(newref) == typeof(prev.ref) &&
1966+
typeof(newstep) == typeof(prev.step) &&
1967+
typeof(newlen) == typeof(prev.len) &&
1968+
typeof(newoffset) == typeof(prev.offset)
1969+
return prev
1970+
else
1971+
return TracedRNumberOverrides.TracedStepRangeLen(newref, newstep, newlen, newoffset)
1972+
end
19521973
end

test/basic.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,20 +993,35 @@ end
993993

994994
@testset "Fractional index" begin
995995
times = 0:0.01:4.5
996+
@test times isa Base.StepRangeLen
996997
res = @jit fractional_idx(times, ConcreteRNumber(2.143))
997998
@test res[1] == 0.29999999999997334
998999
@test res[2] == 215
9991000
@test res[3] == 216
10001001
end
10011002

10021003
@testset "Traced fractional index" begin
1003-
times = Reactant.to_rarray(0:0.01:4.5)
1004+
times = Reactant.to_rarray(0:0.01:4.5; track_numbers=Number)
1005+
@test times isa Reactant.TracedRNumberOverrides.TracedStepRangeLen
10041006
res = @jit fractional_idx(times, ConcreteRNumber(2.143))
10051007
@test res[1] == 0.29999999999997334
10061008
@test res[2] == 215
10071009
@test res[3] == 216
10081010
end
10091011

1012+
function unitrange_test(r, i)
1013+
return r[i]
1014+
end
1015+
@testset "Unitrange" begin
1016+
x = 2:10
1017+
@test (@jit unitrange_test(x, 3)) == 4
1018+
@test (@jit unitrange_test(x, Reactant.ConcreteRNumber(4))) == 5
1019+
1020+
x = Reactant.to_rarray(2:10; track_numbers=Number)
1021+
@test (@jit unitrange_test(x, 3)) == 4
1022+
@test (@jit unitrange_test(x, Reactant.ConcreteRNumber(4))) == 5
1023+
end
1024+
10101025
mulpi(x) = π * x
10111026

10121027
@testset "Irrational promotion" begin

0 commit comments

Comments
 (0)