Skip to content

Commit 72c1e7f

Browse files
committed
add length type parameter to StepRangeLen
Also be more careful about using additive identity instead of multiplicative, and be more consistent about types in a few places. Fixes #41517
1 parent 2893de7 commit 72c1e7f

File tree

4 files changed

+149
-116
lines changed

4 files changed

+149
-116
lines changed

base/broadcast.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,19 +1121,20 @@ end
11211121

11221122
## scalar-range broadcast operations ##
11231123
# DefaultArrayStyle and \ are not available at the time of range.jl
1124-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange) = r
1125-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen) = r
1126-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange) = r
1124+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange) = r
11271125

1128-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
1126+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange) = range(-first(r), step=-step(r), length=length(r))
1127+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), -last(r), step=-step(r))
11291128
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset)
11301129
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r))
11311130

1132-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r))
1133-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r))
11341131
# For #18336 we need to prevent promotion of the step type:
11351132
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r))
11361133
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r))
1134+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange, x::Real) = range(first(r) + x, last(r) + x, step=step(r))
1135+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::Real) = range(x + first(r), x + last(r), step=step(r))
1136+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, last(r) + x)
1137+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), x + last(r))
11371138
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T =
11381139
StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset)
11391140
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T =
@@ -1142,9 +1143,11 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRa
11421143
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r))
11431144
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2
11441145

1145-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r))
1146-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r))
1147-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r))
1146+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r) - x, step=step(r), length=length(r))
1147+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x - first(r), step=-step(r), length=length(r))
1148+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange, x::Real) = range(first(r) - x, last(r) - x, step=step(r))
1149+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Real, r::OrdinalRange) = range(x - first(r), x - last(r), step=-step(r))
1150+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Real) = range(first(r) - x, last(r) - x)
11481151
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T =
11491152
StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset)
11501153
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T =

base/range.jl

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
_colon(::Ordered, ::Any, start::T, step, stop::T) where {T} = StepRange(start, step, stop)
2525
# for T<:Union{Float16,Float32,Float64} see twiceprecision.jl
2626
_colon(::Ordered, ::ArithmeticRounds, start::T, step, stop::T) where {T} =
27-
StepRangeLen(start, step, floor(Int, (stop-start)/step)+1)
27+
StepRangeLen(start, step, floor(Integer, (stop-start)/step)+1)
2828
_colon(::Any, ::Any, start::T, step, stop::T) where {T} =
29-
StepRangeLen(start, step, floor(Int, (stop-start)/step)+1)
29+
StepRangeLen(start, step, floor(Integer, (stop-start)/step)+1)
3030

3131
"""
3232
(:)(start, [step], stop)
@@ -415,8 +415,9 @@ oneto(r) = OneTo(r)
415415
## Step ranges parameterized by length
416416

417417
"""
418-
StepRangeLen{T,R,S}(ref::R, step::S, len, [offset=1]) where {T,R,S}
419-
StepRangeLen( ref::R, step::S, len, [offset=1]) where { R,S}
418+
StepRangeLen( ref::R, step::S, len, [offset=1]) where { R,S}
419+
StepRangeLen{T,R,S}( ref::R, step::S, len, [offset=1]) where {T,R,S}
420+
StepRangeLen{T,R,S,L}(ref::R, step::S, len, [offset=1]) where {T,R,S,L}
420421
421422
A range `r` where `r[i]` produces values of type `T` (in the second
422423
form, `T` is deduced automatically), parameterized by a `ref`erence
@@ -426,26 +427,30 @@ value `r[1]`, but alternatively you can supply it as the value of
426427
with `TwicePrecision` this can be used to implement ranges that are
427428
free of roundoff error.
428429
"""
429-
struct StepRangeLen{T,R,S} <: AbstractRange{T}
430+
struct StepRangeLen{T,R,S,L} <: AbstractRange{T}
430431
ref::R # reference value (might be smallest-magnitude value in the range)
431432
step::S # step value
432-
len::Int # length of the range
433-
offset::Int # the index of ref
433+
len::L # length of the range
434+
offset::L # the index of ref
434435

435-
function StepRangeLen{T,R,S}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S}
436+
function StepRangeLen{T,R,S,L}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S,L}
436437
if T <: Integer && !isinteger(ref + step)
437438
throw(ArgumentError("StepRangeLen{<:Integer} cannot have non-integer step"))
438439
end
440+
len = convert(L, len)
439441
len >= 0 || throw(ArgumentError("length cannot be negative, got $len"))
440-
1 <= offset <= max(1,len) || throw(ArgumentError("StepRangeLen: offset must be in [1,$len], got $offset"))
441-
new(ref, step, len, offset)
442+
offset = convert(L, offset)
443+
1 <= offset <= max(1, len) || throw(ArgumentError("StepRangeLen: offset must be in [1,$len], got $offset"))
444+
return new(ref, step, len, offset)
442445
end
443446
end
444447

448+
StepRangeLen{T,R,S}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S} =
449+
StepRangeLen{T,R,S,promote_type(Int,typeof(len))}(ref, step, len, offset)
445450
StepRangeLen(ref::R, step::S, len::Integer, offset::Integer = 1) where {R,S} =
446-
StepRangeLen{typeof(ref+zero(step)),R,S}(ref, step, len, offset)
451+
StepRangeLen{typeof(ref+zero(step)),R,S,promote_type(Int,typeof(len))}(ref, step, len, offset)
447452
StepRangeLen{T}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S} =
448-
StepRangeLen{T,R,S}(ref, step, len, offset)
453+
StepRangeLen{T,R,S,promote_type(Int,typeof(len))}(ref, step, len, offset)
449454

450455
## range with computed step
451456

@@ -621,6 +626,7 @@ step(r::StepRangeLen) = r.step
621626
step(r::StepRangeLen{T}) where {T<:AbstractFloat} = T(r.step)
622627
step(r::LinRange) = (last(r)-first(r))/r.lendiv
623628

629+
# high-precision step
624630
step_hp(r::StepRangeLen) = r.step
625631
step_hp(r::AbstractRange) = step(r)
626632

@@ -648,7 +654,7 @@ function checked_length(r::OrdinalRange{T}) where T
648654
diff = checked_sub(stop, start)
649655
end
650656
a = Integer(div(diff, s))
651-
return checked_add(a, one(a))
657+
return checked_add(a, oneunit(a))
652658
end
653659

654660
function checked_length(r::AbstractUnitRange{T}) where T
@@ -657,7 +663,7 @@ function checked_length(r::AbstractUnitRange{T}) where T
657663
return Integer(first(r) - first(r))
658664
end
659665
a = Integer(checked_add(checked_sub(last(r), first(r))))
660-
return checked_add(a, one(a))
666+
return checked_add(a, oneunit(a))
661667
end
662668

663669
function length(r::OrdinalRange{T}) where T
@@ -675,14 +681,14 @@ function length(r::OrdinalRange{T}) where T
675681
diff = stop - start
676682
end
677683
a = Integer(div(diff, s))
678-
return a + one(a)
684+
return a + oneunit(a)
679685
end
680686

681687

682688
function length(r::AbstractUnitRange{T}) where T
683689
@_inline_meta
684690
a = Integer(last(r) - first(r)) # even when isempty, by construction (with overflow)
685-
return a + one(a)
691+
return a + oneunit(a)
686692
end
687693

688694
length(r::OneTo) = Integer(r.stop - zero(r.stop))
@@ -710,7 +716,7 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
710716
else
711717
a = div(unsigned(diff), s) % typeof(diff)
712718
end
713-
return Integer(a) + one(a)
719+
return Integer(a) + oneunit(a)
714720
end
715721
function checked_length(r::OrdinalRange{T}) where T<:bigints
716722
s = step(r)
@@ -729,7 +735,7 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
729735
else
730736
a = div(checked_sub(start, stop), -s)
731737
end
732-
return checked_add(a, one(a))
738+
return checked_add(a, oneunit(a))
733739
end
734740
end
735741

@@ -803,7 +809,13 @@ copy(r::AbstractRange) = r
803809

804810
## iteration
805811

806-
function iterate(r::Union{LinRange,StepRangeLen}, i::Int=1)
812+
function iterate(r::StepRangeLen, i::Integer=1)
813+
@_inline_meta
814+
length(r) < i && return nothing
815+
unsafe_getindex(r, i), i + 1
816+
end
817+
818+
function iterate(r::LinRange, i::Int=1)
807819
@_inline_meta
808820
length(r) < i && return nothing
809821
unsafe_getindex(r, i), i + 1
@@ -897,7 +909,7 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
897909
@boundscheck checkbounds(r, s)
898910

899911
if T === Bool
900-
range(first(s) ? first(r) : last(r), length = Int(last(s)))
912+
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
901913
else
902914
f = first(r)
903915
st = oftype(f, f + first(s)-1)
@@ -916,7 +928,7 @@ function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
916928
@boundscheck checkbounds(r, s)
917929

918930
if T === Bool
919-
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Int(last(s)))
931+
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
920932
else
921933
st = oftype(first(r), first(r) + s.start-1)
922934
return range(st, step=step(s), length=length(s))
@@ -949,24 +961,29 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
949961
@_inline_meta
950962
@boundscheck checkbounds(r, s)
951963

964+
len = length(s)
965+
sstep = step_hp(s)
966+
rstep = step_hp(r)
967+
L = typeof(len)
952968
if S === Bool
953-
if length(s) == 0
954-
return StepRangeLen{T}(first(r), step(r), 0, 1)
955-
elseif length(s) == 1
969+
rstep *= one(sstep)
970+
if len == 0
971+
return StepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
972+
elseif len == 1
956973
if first(s)
957-
return StepRangeLen{T}(first(r), step(r), 1, 1)
974+
return StepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L))
958975
else
959-
return StepRangeLen{T}(first(r), step(r), 0, 1)
976+
return StepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
960977
end
961-
else # length(s) == 2
962-
return StepRangeLen{T}(last(r), step(r), 1, 1)
978+
else # len == 2
979+
return StepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L))
963980
end
964981
else
965982
# Find closest approach to offset by s
966983
ind = LinearIndices(s)
967-
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
968-
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
969-
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
984+
offset = L(max(min(1 + round(L, (r.offset - first(s))/sstep), last(ind)), first(ind)))
985+
ref = _getindex_hiprec(r, first(s) + (offset-1)*sstep)
986+
return StepRangeLen{T}(ref, rstep*sstep, len, offset)
970987
end
971988
end
972989

@@ -1153,8 +1170,8 @@ issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
11531170
## linear operations on ranges ##
11541171

11551172
-(r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
1156-
-(r::StepRangeLen{T,R,S}) where {T,R,S} =
1157-
StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset)
1173+
-(r::StepRangeLen{T,R,S,L}) where {T,R,S,L} =
1174+
StepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset)
11581175
function -(r::LinRange)
11591176
start = -r.start
11601177
LinRange{typeof(start)}(start, -r.stop, length(r))
@@ -1206,20 +1223,20 @@ StepRange(r::AbstractUnitRange{T}) where {T} =
12061223
StepRange{T,T}(first(r), step(r), last(r))
12071224
(StepRange{T1,T2} where T1)(r::AbstractRange) where {T2} = StepRange{eltype(r),T2}(r)
12081225

1209-
promote_rule(::Type{StepRangeLen{T1,R1,S1}},::Type{StepRangeLen{T2,R2,S2}}) where {T1,T2,R1,R2,S1,S2} =
1226+
promote_rule(::Type{StepRangeLen{T1,R1,S1,L1}},::Type{StepRangeLen{T2,R2,S2,L2}}) where {T1,T2,R1,R2,S1,S2,L1,L2} =
12101227
el_same(promote_type(T1,T2),
1211-
StepRangeLen{T1,promote_type(R1,R2),promote_type(S1,S2)},
1212-
StepRangeLen{T2,promote_type(R1,R2),promote_type(S1,S2)})
1213-
StepRangeLen{T,R,S}(r::StepRangeLen{T,R,S}) where {T,R,S} = r
1214-
StepRangeLen{T,R,S}(r::StepRangeLen) where {T,R,S} =
1215-
StepRangeLen{T,R,S}(convert(R, r.ref), convert(S, r.step), length(r), r.offset)
1228+
StepRangeLen{T1,promote_type(R1,R2),promote_type(S1,S2),promote_type(L1,L2)},
1229+
StepRangeLen{T2,promote_type(R1,R2),promote_type(S1,S2),promote_type(L1,L2)})
1230+
StepRangeLen{T,R,S,L}(r::StepRangeLen{T,R,S,L}) where {T,R,S,L} = r
1231+
StepRangeLen{T,R,S,L}(r::StepRangeLen) where {T,R,S,L} =
1232+
StepRangeLen{T,R,S,L}(convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset))
12161233
StepRangeLen{T}(r::StepRangeLen) where {T} =
1217-
StepRangeLen(convert(T, r.ref), convert(T, r.step), length(r), r.offset)
1234+
StepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset)
12181235

1219-
promote_rule(a::Type{StepRangeLen{T,R,S}}, ::Type{OR}) where {T,R,S,OR<:AbstractRange} =
1220-
promote_rule(a, StepRangeLen{eltype(OR), eltype(OR), eltype(OR)})
1221-
StepRangeLen{T,R,S}(r::AbstractRange) where {T,R,S} =
1222-
StepRangeLen{T,R,S}(R(first(r)), S(step(r)), length(r))
1236+
promote_rule(a::Type{StepRangeLen{T,R,S,L}}, ::Type{OR}) where {T,R,S,L,OR<:AbstractRange} =
1237+
promote_rule(a, StepRangeLen{eltype(OR), eltype(OR), eltype(OR), Int})
1238+
StepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L} =
1239+
StepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r))
12231240
StepRangeLen{T}(r::AbstractRange) where {T} =
12241241
StepRangeLen(T(first(r)), T(step(r)), length(r))
12251242
StepRangeLen(r::AbstractRange) = StepRangeLen{eltype(r)}(r)
@@ -1233,8 +1250,8 @@ LinRange(r::AbstractRange{T}) where {T} = LinRange{T}(r)
12331250
promote_rule(a::Type{LinRange{T}}, ::Type{OR}) where {T,OR<:OrdinalRange} =
12341251
promote_rule(a, LinRange{eltype(OR)})
12351252

1236-
promote_rule(::Type{LinRange{L}}, b::Type{StepRangeLen{T,R,S}}) where {L,T,R,S} =
1237-
promote_rule(StepRangeLen{L,L,L}, b)
1253+
promote_rule(::Type{LinRange{A}}, b::Type{StepRangeLen{T,R,S,L}}) where {A,T,R,S,L} =
1254+
promote_rule(StepRangeLen{A,A,A,Int}, b)
12381255

12391256
## concatenation ##
12401257

@@ -1261,7 +1278,7 @@ function _reverse(r::StepRangeLen, ::Colon)
12611278
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
12621279
# `r.offset`
12631280
offset = isempty(r) ? r.offset : length(r)-r.offset+1
1264-
StepRangeLen(r.ref, -r.step, length(r), offset)
1281+
return typeof(r)(r.ref, -r.step, length(r), offset)
12651282
end
12661283
_reverse(r::LinRange{T}, ::Colon) where {T} = LinRange{T}(r.stop, r.start, length(r))
12671284

0 commit comments

Comments
 (0)