1
1
module ReactantOffsetArraysExt
2
2
3
3
using OffsetArrays
4
- using OffsetArrays: OffsetArray
4
+ using OffsetArrays: OffsetArray, OffsetVector
5
5
using Reactant: Reactant, MLIR, Ops, TracedRArray, AbstractConcreteArray
6
6
7
7
Base. @nospecializeinfer function Reactant. traced_type_inner (
@@ -27,12 +27,23 @@ function Base.getindex(
27
27
return getindex (parent (x), offset_indices... )
28
28
end
29
29
function Base. getindex (
30
- x:: OffsetArray{T,N,<:AbstractConcreteArray} , args:: Vararg{Int,N}
30
+ x:: OffsetArray{T,N,<:AbstractConcreteArray} ,
31
+ args:: Vararg{Union{Int,AbstractUnitRange{Int}},N} ,
31
32
) where {T,N}
32
33
offset_indices = [arg .- x. offsets[i] for (i, arg) in enumerate (args)]
33
34
return getindex (parent (x), offset_indices... )
34
35
end
35
36
37
+ function Base. getindex (x:: OffsetVector{T,<:AbstractConcreteArray} , index:: Int ) where {T}
38
+ return getindex (parent (x), index - x. offsets[1 ])
39
+ end
40
+ function Base. getindex (
41
+ x:: OffsetVector{T,<:AbstractConcreteArray} , indices:: AbstractUnitRange{Int}
42
+ ) where {T}
43
+ offset_indices = indices .- x. offsets[1 ]
44
+ return getindex (parent (x), offset_indices)
45
+ end
46
+
36
47
parentindex (r:: OffsetArrays.IdOffsetRange , i) = i .- r. offset
37
48
function Base. getindex (
38
49
a:: OffsetArray{<:Reactant.TracedRNumber,N} , indices:: Vararg{Union{Int,AbstractArray},N}
0 commit comments