Skip to content

Commit abc207d

Browse files
fix: fix getindex edge cases
1 parent 5621372 commit abc207d

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/types.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,13 +3004,25 @@ Base.@propagate_inbounds function _getindex(arr::BasicSymbolic{T}, idxs::Union{B
30043004
sh = shape(arr)
30053005
type = promote_symtype(getindex, symtype(arr), symtype.(idxs)...)
30063006
newshape = promote_shape(getindex, sh, shape.(idxs)...)
3007-
for (oldidx, idx) in zip(Iterators.drop(args, 1), idxs)
3008-
if idx isa Colon
3007+
idxs_i = 1
3008+
for oldidx in Iterators.drop(args, 1)
3009+
oldidx_sh = shape(oldidx)
3010+
if !_is_array_shape(oldidx_sh)
3011+
push!(newargs, oldidx)
3012+
continue
3013+
end
3014+
idx = idxs[idxs_i]
3015+
idxs_i += 1
3016+
# special case when `oldidx` is `Colon()`
3017+
if length(oldidx_sh) == 1 && oldidx_sh[1] == 1:0
3018+
push!(newargs, Const{T}(idx))
3019+
elseif idx isa Colon
30093020
push!(newargs, oldidx)
30103021
else
30113022
push!(newargs, Const{T}(unwrap_const(oldidx)[idx]))
30123023
end
30133024
end
3025+
@assert idxs_i == length(idxs) + 1
30143026
return BSImpl.Term{T}(f, newargs; type, shape = newshape)
30153027
end
30163028
_ => nothing
@@ -3121,6 +3133,15 @@ end
31213133
function Base.getindex(x::AbstractArray, i1, idx::BasicSymbolic{T}, idxs...) where {T}
31223134
getindex(Const{T}(x), i1, idx, idxs...)
31233135
end
3136+
function Base.getindex(x::AbstractArray, i1::BasicSymbolic{T}, idx::BasicSymbolic{T}, idxs...) where {T}
3137+
getindex(Const{T}(x), i1, idx, idxs...)
3138+
end
31243139
function Base.getindex(x::AbstractArray, i1, i2, idx::BasicSymbolic{T}, idxs...) where {T}
31253140
getindex(Const{T}(x), i1, i2, idx, idxs...)
31263141
end
3142+
function Base.getindex(x::AbstractArray, i1, i2::BasicSymbolic{T}, idx::BasicSymbolic{T}, idxs...) where {T}
3143+
getindex(Const{T}(x), i1, i2, idx, idxs...)
3144+
end
3145+
function Base.getindex(x::AbstractArray, i1::BasicSymbolic{T}, i2::BasicSymbolic{T}, idx::BasicSymbolic{T}, idxs...) where {T}
3146+
getindex(Const{T}(x), i1, i2, idx, idxs...)
3147+
end

0 commit comments

Comments
 (0)