diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 2733b52222e37..9b524d759c0e7 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1385,6 +1385,29 @@ julia> parent(V) """ parent(a::AbstractArray) = a +""" + parenttype(A) + +Returns the parent array that type `T` wraps. + +Return the underlying "parent array” type. This type directly corresponds to the instance +returned by `parent(A)`. Therefore, new types that define a `parent` method should also +define a corresponding `parenttype(::Type{<:MyArray}) = ParentType` method. + +See also: [`parent`](@ref) + +# Examples +```jldoctest +julia> A = [1 2; 3 4]; + +julia> parenttype(view(A, 1:2, :)) <: typeof(A) +true +``` +""" +parenttype(x) = parenttype(typeof(x)) +parenttype(::Type{T}) where {T} = T + + ## rudimentary aliasing detection ## """ Base.unalias(dest, A) diff --git a/base/exports.jl b/base/exports.jl index 84c53ca405e7d..ff6c7efb009e1 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -413,6 +413,7 @@ export ones, parent, parentindices, + parenttype, partialsort, partialsort!, partialsortperm, diff --git a/base/indices.jl b/base/indices.jl index 28028f23c72a3..a3840a3349f5f 100644 --- a/base/indices.jl +++ b/base/indices.jl @@ -351,6 +351,8 @@ struct Slice{T<:AbstractUnitRange} <: AbstractUnitRange{Int} indices::T end Slice(S::Slice) = S +parenttype(::Type{Slice{T}}) where {T} = T +parent(S::Slice) = S.indices axes(S::Slice) = (IdentityUnitRange(S.indices),) axes1(S::Slice) = IdentityUnitRange(S.indices) axes(S::Slice{<:OneTo}) = (S.indices,) diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index ea966c44efc38..8ca8b8adf49bb 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -45,6 +45,7 @@ function PermutedDimsArray(data::AbstractArray{T,N}, perm) where {T,N} PermutedDimsArray{T,N,(perm...,),(iperm...,),typeof(data)}(data) end +Base.parenttype(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A Base.parent(A::PermutedDimsArray) = A.parent Base.size(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} = genperm(size(parent(A)), perm) Base.axes(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} = genperm(axes(parent(A)), perm) diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index ad1e8b26c4461..e5e35b94c0382 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -279,6 +279,7 @@ eachindex(style::IndexSCartesian2, A::AbstractArray) = eachindex(style, parent(A ## AbstractArray interface +parenttype(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A parent(a::ReinterpretArray) = a.parent dataids(a::ReinterpretArray) = dataids(a.parent) unaliascopy(a::NonReshapedReinterpretArray{T}) where {T} = reinterpret(T, unaliascopy(a.parent)) diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index cabe3c9d10a58..beff83ce4ed69 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -207,6 +207,7 @@ end size(A::ReshapedArray) = A.dims similar(A::ReshapedArray, eltype::Type, dims::Dims) = similar(parent(A), eltype, dims) IndexStyle(::Type{<:ReshapedArrayLF}) = IndexLinear() +parenttype(::Type{<:Base.ReshapedArray{T,N,P}}) where {T,N,P} = P parent(A::ReshapedArray) = A.parent parentindices(A::ReshapedArray) = map(oneto, size(parent(A))) reinterpret(::Type{T}, A::ReshapedArray, dims::Dims) where {T} = reinterpret(T, parent(A), dims) diff --git a/base/subarray.jl b/base/subarray.jl index ff2408bb48534..8616ffaf3d0fb 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -75,6 +75,7 @@ function Base.copy(V::SubArray) return x end +parenttype(::Type{<:SubArray{T,N,P}}) where {T,N,P} = P parent(V::SubArray) = V.parent parentindices(V::SubArray) = V.indices diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index f5903f380ee53..636e0add9a92e 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -215,6 +215,7 @@ AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrapperop(A)(AbstractVector{T # sundry basic definitions parent(A::AdjOrTrans) = A.parent +Base.parenttype(::Type{AdjOrTrans{T,S}}) where {T,S} = S vec(v::TransposeAbsVec{<:Number}) = parent(v) vec(v::AdjointAbsVec{<:Real}) = parent(v) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 8a5a04a8523d7..31cbba4b5b35a 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -125,6 +125,7 @@ function Base.replace_in_print_matrix(A::Diagonal,i::Integer,j::Integer,s::Abstr i==j ? s : Base.replace_with_centered_mark(s) end +Base.parenttype(::Type{<:Diagonal{T,V}}) where {T,V} = V parent(D::Diagonal) = D.diag ishermitian(D::Diagonal{<:Real}) = true diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 7347dd6f78639..4fd968929441b 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -272,6 +272,7 @@ function Matrix(A::Hermitian) end Array(A::Union{Symmetric,Hermitian}) = convert(Matrix, A) +Base.parenttype(::Type{<:HermOrSym{T,S}}) where {T,S} = S parent(A::HermOrSym) = A.data Symmetric{T,S}(A::Symmetric{T,S}) where {T,S<:AbstractMatrix{T}} = A Symmetric{T,S}(A::Symmetric) where {T,S<:AbstractMatrix{T}} = Symmetric{T,S}(convert(S,A.data),A.uplo) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index bd0566a11b3f2..f31078cc67cd4 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -157,6 +157,7 @@ imag(A::UnitLowerTriangular) = LowerTriangular(tril!(imag(A.data),-1)) imag(A::UnitUpperTriangular) = UpperTriangular(triu!(imag(A.data),1)) Array(A::AbstractTriangular) = Matrix(A) +Base.parenttype(::Type{<:AbstractTriangular{T,S}}) where {T,S} = S parent(A::AbstractTriangular) = A.data # then handle all methods that requires specific handling of upper/lower and unit diagonal diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 8226ddf004a72..1aa230bfb73c0 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -288,6 +288,11 @@ end @test parent(Adjoint(intmat)) === intmat @test parent(Transpose(intvec)) === intvec @test parent(Transpose(intmat)) === intmat + + @test parenttype(Adjoint(intvec)) <: typeof(intvec) + @test parenttype(Adjoint(intmat)) <: typeof(intmat) + @test parenttype(Transpose(intvec)) <: typeof(intvec) + @test parenttype(Transpose(intmat)) <: typeof(intmat) end @testset "Adjoint and Transpose vector vec methods" begin diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 0b8c9f9872594..708b45a34f327 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -47,6 +47,7 @@ Random.seed!(1) @test Array(imag(D)) == imag(DM) @test parent(D) == dd + @test parenttype(D) == typeof(dd) @test D[1,1] == dd[1] @test D[1,2] == 0 diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 169dfb0071718..a48e94dfdec3a 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -98,6 +98,12 @@ end @testset "parent" begin @test asym === parent(Symmetric(asym)) @test aherm === parent(Hermitian(aherm)) + @test asym === parent(Symmetric(asym)) + @test aherm === parent(Hermitian(aherm)) + @test typeof(asym) === parenttype(Symmetric(asym)) + @test typeof(aherm) === parenttype(Hermitian(aherm)) + @test typeof(asym) === parenttype(Symmetric(asym)) + @test typeof(aherm) === parenttype(Hermitian(aherm)) end # Unary minus for Symmetric/Hermitian matrices @testset "Unary minus for Symmetric/Hermitian matrices" begin diff --git a/test/reinterpretarray.jl b/test/reinterpretarray.jl index a8cadd83c3dec..b382d825fdb3a 100644 --- a/test/reinterpretarray.jl +++ b/test/reinterpretarray.jl @@ -314,6 +314,7 @@ end # avoid nesting @test parent(reinterpret(eltype(A), reinterpret(eltype(B), A))) === A +@test parenttype(reinterpret(eltype(A), reinterpret(eltype(B), A))) <: typeof(A) # Test 0-dimensional Arrays A = zeros(UInt32) diff --git a/test/subarray.jl b/test/subarray.jl index cc8aab94e4c42..cac2078a8579c 100644 --- a/test/subarray.jl +++ b/test/subarray.jl @@ -321,6 +321,7 @@ end sA = view(A, 2:2, 1:5, :) @test @inferred(strides(sA)) == (1, 3, 15) @test parent(sA) == A + @test parenttype(sA) <: typeof(A) @test parentindices(sA) == (2:2, 1:5, Base.Slice(1:8)) @test size(sA) == (1, 5, 8) @test axes(sA) === (Base.OneTo(1), Base.OneTo(5), Base.OneTo(8))