Skip to content

Commit bd00a29

Browse files
Disambiguate offsetarray getindex methods (#1013)
* diambiguate * format suggestion * missing an OffsetVector method * Update ReactantOffsetArraysExt.jl * Apply suggestions from code review * remove extra arg * disambiguate all * fix code style * fix code style * more fix code style * more fix code style --------- Co-authored-by: Mosè Giordano <[email protected]>
1 parent cb49c07 commit bd00a29

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

ext/ReactantOffsetArraysExt.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantOffsetArraysExt
22

33
using OffsetArrays
4-
using OffsetArrays: OffsetArray
4+
using OffsetArrays: OffsetArray, OffsetVector
55
using Reactant: Reactant, MLIR, Ops, TracedRArray, AbstractConcreteArray
66

77
Base.@nospecializeinfer function Reactant.traced_type_inner(
@@ -27,12 +27,23 @@ function Base.getindex(
2727
return getindex(parent(x), offset_indices...)
2828
end
2929
function Base.getindex(
30-
x::OffsetArray{T,N,<:AbstractConcreteArray}, args::Vararg{Int,N}
30+
x::OffsetArray{T,N,<:AbstractConcreteArray},
31+
args::Vararg{Union{Int,AbstractUnitRange{Int}},N},
3132
) where {T,N}
3233
offset_indices = [arg .- x.offsets[i] for (i, arg) in enumerate(args)]
3334
return getindex(parent(x), offset_indices...)
3435
end
3536

37+
function Base.getindex(x::OffsetVector{T,<:AbstractConcreteArray}, index::Int) where {T}
38+
return getindex(parent(x), index - x.offsets[1])
39+
end
40+
function Base.getindex(
41+
x::OffsetVector{T,<:AbstractConcreteArray}, indices::AbstractUnitRange{Int}
42+
) where {T}
43+
offset_indices = indices .- x.offsets[1]
44+
return getindex(parent(x), offset_indices)
45+
end
46+
3647
parentindex(r::OffsetArrays.IdOffsetRange, i) = i .- r.offset
3748
function Base.getindex(
3849
a::OffsetArray{<:Reactant.TracedRNumber,N}, indices::Vararg{Union{Int,AbstractArray},N}

0 commit comments

Comments
 (0)