diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 662889e49e..908bf429ef 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -200,7 +200,7 @@ for jlop in ( :(Base.:^), :(Base.:(==)), ), - T in (AbstractConcreteNumber, AbstractConcreteArray{<:Any,0}) + T in (AbstractConcreteNumber, AbstractConcreteArray{<:Number,0}) @eval begin $(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 916d0d13ee..565d158122 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -23,6 +23,33 @@ Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...) Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear() +Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T) + +const ArrayTypesAlias = ( + :(TracedRArray), + :(SubArray{<:TracedRNumber,<:Any,<:TracedRArray}), + :(Base.ReshapedArray{<:TracedRNumber,<:Any,<:TracedRArray}), + :(SubArray{<:TracedRNumber,<:Any,<:Base.ReshapedArray{<:TracedRNumber,<:Any,<:TracedRArray}}), + :(Base.ReshapedArray{<:TracedRNumber,<:Any,<:SubArray{<:TracedRNumber,<:Any,<:TracedRArray}}), +) +for ArrayType1 in ArrayTypesAlias + for ArrayType2 in ArrayTypesAlias + @eval Base.mightalias(::$ArrayType1, ::$ArrayType2) = false + end +end + +# Base.mightalias(::TracedRArray, ::TracedRArray) = false +# Base.mightalias( +# ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray}, +# ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray}, +# ) = false +# Base.mightalias( +# ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray}, ::TracedRArray +# ) = false +# Base.mightalias( +# ::TracedRArray, ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray} +# ) = false + # This is required otherwise we will copy a tracedrarray each time # we use it Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x) @@ -265,8 +292,15 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC return print(io, "TracedRArray{", T, ",", N, "N}(", X.paths, ", size=", size(X), ")") end -function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} - return @opcall transpose(materialize_traced_array(A), Int64[perm...]) +for ArrayType in ( + :(AnyTracedRArray{T,N}), + :(TracedRArray{T,N}), + :(SubArray{<:TracedRNumber{T},N,<:TracedRArray}), + :(Base.ReshapedArray{<:TracedRNumber{T},N,<:TracedRArray}) + ) + @eval function Base.permutedims(A::$ArrayType, perm) where {T,N} + return @opcall transpose(materialize_traced_array(A), Int64[perm...]) + end end for (jlop, hloop, hlocomp, merge) in diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..ab92991786 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -2,7 +2,7 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end abstract type AbstractConcreteNumber{T} <: RNumber{T} end -abstract type RArray{T,N} <: AbstractArray{T,N} end +abstract type RArray{T,N} <: DenseArray{T,N} end abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end @@ -52,6 +52,11 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end +Base.elsize(::Type{TracedRNumber{T}}) where {T} = sizeof(T) +Base.elsize(::Type{RNumber{T}}) where {T} = sizeof(T) +Base.elsize(::Type{<:AbstractConcreteNumber{T}}) where {T} = sizeof(T) +Base.elsize(::Type{<:AbstractConcreteArray{T}}) where {T} = sizeof(T) + function repath(x::TracedRNumber{T}, paths) where {T} return TracedRNumber{T}(paths, x.mlir_data) end