From 2b7910761a869238fcb6e4f851eee0c59609fb0c Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Sun, 19 Jan 2025 21:17:30 +0100 Subject: [PATCH 01/26] adds `ProjectTo` for `DiagonalTensorMap` --- ext/TensorKitChainRulesCoreExt/utility.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index 5bdd4e4a0..acfd829ab 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -29,3 +29,15 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N end return y end + +function (::ProjectTo{T1})(x::T2) where {S,NumType,StorType, + T1<:DiagonalTensorMap{NumType,S,StorType}, + T2<:AbstractTensorMap{<:Any,S,1,1}} + T1 === T2 && return x + y = DiagonalTensorMap{NumType,S,StorType}(undef, space(x, 1)) + for (c, b) in blocks(y) + p = ProjectTo(b) + b .= p(block(x, c)) + end + return y +end From f27517516feee55f576417b97638e10e9cde4636 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Sun, 19 Jan 2025 21:41:03 +0100 Subject: [PATCH 02/26] adds an `rrule` for `DiagonalTensorMap` constructor --- ext/TensorKitChainRulesCoreExt/constructors.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index f48ed4393..af1f1ed92 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -12,6 +12,16 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg return TensorMap(d, args...; kwargs...), TensorMap_pullback end +function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; kwargs...) + D=TensorMap(d, args...; kwargs...) + project_D=ProjectTo(D) + function DiagonalTensorMap_pullback(Δt) + ∂d = project_D(unthunk(Δt)).data + return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... + end + return D, DiagonalTensorMap_pullback +end + function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) copy_pullback(Δt) = NoTangent(), Δt return copy(t), copy_pullback From 2cf4cb4e9fd65564a514fc2bb5c5ccdd7cd92ee5 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 12:28:58 +0100 Subject: [PATCH 03/26] Corrects bug in the DiagonalTensorMap rrule, adds tests for the new code, adds a proper generator of random tangents for DiagonalTensorMap --- ext/TensorKitChainRulesCoreExt/constructors.jl | 7 ++++--- test/ad.jl | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index af1f1ed92..caa588561 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -12,9 +12,10 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg return TensorMap(d, args...; kwargs...), TensorMap_pullback end -function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; kwargs...) - D=TensorMap(d, args...; kwargs...) - project_D=ProjectTo(D) +function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; + kwargs...) + D = DiagonalTensorMap(d, args...; kwargs...) + project_D = ProjectTo(D) function DiagonalTensorMap_pullback(Δt) ∂d = project_D(unthunk(Δt)).data return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... diff --git a/test/ad.jl b/test/ad.jl index a684c4f83..5d77b5be2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -15,6 +15,9 @@ end function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap) return randn!(similar(x)) end +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap) + return DiagonalTensorMap(randn(eltype(x), dim(x.domain)), x.domain) +end ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent() function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, expected::AbstractTensorMap, msg=""; kwargs...) @@ -144,6 +147,20 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), fkwargs=(; tol=Inf)) end + @timedtestset "Basic utility (DiagonalTensor)" begin + for NumType in [Float64, ComplexF64] + for v in V + T1 = DiagonalTensorMap(randn(NumType, dim(v)), v) + T2 = TensorMap(T1) + + P1 = ProjectTo(T1) + @test P1(T2) == T1 + + test_rrule(DiagonalTensorMap, T1.data, T1.domain) + end + end + end + @timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = randn(T, space(A)) From c43535f20dc3b82fbc3e229a95ebacd04d1cd125 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 12:32:10 +0100 Subject: [PATCH 04/26] @test missing in the constructor test added... --- test/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ad.jl b/test/ad.jl index 5d77b5be2..ae59ceb01 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -156,7 +156,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), P1 = ProjectTo(T1) @test P1(T2) == T1 - test_rrule(DiagonalTensorMap, T1.data, T1.domain) + @test test_rrule(DiagonalTensorMap, T1.data, T1.domain) end end end From 46531132ec2b23061c7412844e9b0ab388adef3a Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 14:10:45 +0100 Subject: [PATCH 05/26] wait no, @test did not belong there --- test/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ad.jl b/test/ad.jl index ae59ceb01..5d77b5be2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -156,7 +156,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), P1 = ProjectTo(T1) @test P1(T2) == T1 - @test test_rrule(DiagonalTensorMap, T1.data, T1.domain) + test_rrule(DiagonalTensorMap, T1.data, T1.domain) end end end From c81e43ce1c32662c8e5802630cf8971f6e5b9577 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 14:51:24 +0100 Subject: [PATCH 06/26] Update ext/TensorKitChainRulesCoreExt/utility.jl Co-authored-by: Lukas Devos --- ext/TensorKitChainRulesCoreExt/utility.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index acfd829ab..af9b7ca5b 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -30,11 +30,9 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N return y end -function (::ProjectTo{T1})(x::T2) where {S,NumType,StorType, - T1<:DiagonalTensorMap{NumType,S,StorType}, - T2<:AbstractTensorMap{<:Any,S,1,1}} +function (::ProjectTo{T1})(x::T2) where {T1<:DiagonalTensorMap,T2<:AbstractTensorMap} T1 === T2 && return x - y = DiagonalTensorMap{NumType,S,StorType}(undef, space(x, 1)) + y = DiagonalTensorMap{scalartype(T1),spacetype(T1),storagetype(T1)}(undef, space(x, 1)) for (c, b) in blocks(y) p = ProjectTo(b) b .= p(block(x, c)) From 41d11134fe6c21b3bfe5c7b93750bff26964479c Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 16:35:47 +0100 Subject: [PATCH 07/26] mixed type tests for ProjectTo --- test/ad.jl | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 5d77b5be2..85f4ecee8 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -148,16 +148,33 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), end @timedtestset "Basic utility (DiagonalTensor)" begin - for NumType in [Float64, ComplexF64] - for v in V - T1 = DiagonalTensorMap(randn(NumType, dim(v)), v) - T2 = TensorMap(T1) - - P1 = ProjectTo(T1) - @test P1(T2) == T1 - - test_rrule(DiagonalTensorMap, T1.data, T1.domain) - end + for v in V + D1 = DiagonalTensorMap(randn(dim(v)), v) + D2 = DiagonalTensorMap(randn(dim(v)), v) + D = D1 + im * D2 + T1 = TensorMap(D1) + T2 = TensorMap(D2) + T = T1 + im * T2 + + # real -> real + P1 = ProjectTo(D1) + @test P1(D1) == D1 + @test P1(T1) == D1 + + # complex -> complex + P2 = ProjectTo(D) + @test P2(D) == D + @test P2(T) == D + + # real -> complex + @test P2(D1) == D1 + 0 * im * D1 + @test P2(T1) == D1 + 0 * im * D1 + + # complex -> real + @test P1(D) == D1 + @test P1(T) == D1 + + test_rrule(DiagonalTensorMap, D1.data, D1.domain) end end From 53b399c58495e5f80a957b906624ebdf81e2194b Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 16:36:51 +0100 Subject: [PATCH 08/26] + rrule test on complex tensors. --- test/ad.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/ad.jl b/test/ad.jl index 85f4ecee8..71348713c 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -175,6 +175,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @test P1(T) == D1 test_rrule(DiagonalTensorMap, D1.data, D1.domain) + test_rrule(DiagonalTensorMap, D.data, D.domain) end end From 8b51b8bf56863362f8f77810b9e14c4c5bca5f9f Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 23:30:13 +0100 Subject: [PATCH 09/26] correct data length for DiagonalTensor in tests --- test/ad.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 71348713c..d4985969d 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -149,8 +149,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @timedtestset "Basic utility (DiagonalTensor)" begin for v in V - D1 = DiagonalTensorMap(randn(dim(v)), v) - D2 = DiagonalTensorMap(randn(dim(v)), v) + comp_num = sum(values(v.dims)) + D1 = DiagonalTensorMap(randn(comp_num), v) + D2 = DiagonalTensorMap(randn(comp_num), v) D = D1 + im * D2 T1 = TensorMap(D1) T2 = TensorMap(D2) From 2905c0baa3fb14e750a14ca7c440ae6f5ff96626 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Mon, 20 Jan 2025 23:36:45 +0100 Subject: [PATCH 10/26] correct data length in DiagonalTensorMap for random tnagents --- test/ad.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/ad.jl b/test/ad.jl index d4985969d..93fbdd821 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -16,7 +16,8 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap return randn!(similar(x)) end function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap) - return DiagonalTensorMap(randn(eltype(x), dim(x.domain)), x.domain) + S = x.domain + return DiagonalTensorMap(randn(eltype(x), sum(values(S.dims))), S) end ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent() function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, From 79a38b6af4092caa7a0587b41d29ea3c000113ef Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Tue, 21 Jan 2025 00:40:35 +0100 Subject: [PATCH 11/26] Comment on the test failure --- test/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 93fbdd821..8346d68d2 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -176,8 +176,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @test P1(D) == D1 @test P1(T) == D1 - test_rrule(DiagonalTensorMap, D1.data, D1.domain) - test_rrule(DiagonalTensorMap, D.data, D.domain) + test_rrule(DiagonalTensorMap, D1.data, D1.domain) # test fails when dimension dim of the representation is not 1. + test_rrule(DiagonalTensorMap, D.data, D.domain) # the finite diff result is larger than the exact result exactly in dim times. It should be something with how diagonal tensor transforms into a vector. end end From 689b4eaa84eb7cca2d9777b83bab58e4de06ab65 Mon Sep 17 00:00:00 2001 From: ebelnikola Date: Thu, 23 Jan 2025 12:17:12 +0100 Subject: [PATCH 12/26] Jutho's corrections --- ext/TensorKitChainRulesCoreExt/utility.jl | 8 +++++--- test/ad.jl | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index af9b7ca5b..fa1733c92 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -30,9 +30,11 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N return y end -function (::ProjectTo{T1})(x::T2) where {T1<:DiagonalTensorMap,T2<:AbstractTensorMap} - T1 === T2 && return x - y = DiagonalTensorMap{scalartype(T1),spacetype(T1),storagetype(T1)}(undef, space(x, 1)) +function (::ProjectTo{DiagonalTensorMap{T,S,A}})(x::AbstractTensorMap) where {T,S,A} + x isa DiagonalTensorMap{T,S,A} && return x + V = space(x, 1) + space(x) == (V ← V) || throw(SpaceMismatch()) + y = DiagonalTensorMap{T,S,A}(undef, V) for (c, b) in blocks(y) p = ProjectTo(b) b .= p(block(x, c)) diff --git a/test/ad.jl b/test/ad.jl index 8346d68d2..f8a606ace 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -16,8 +16,8 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap return randn!(similar(x)) end function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap) - S = x.domain - return DiagonalTensorMap(randn(eltype(x), sum(values(S.dims))), S) + V = x.domain + return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V) end ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent() function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, @@ -150,9 +150,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @timedtestset "Basic utility (DiagonalTensor)" begin for v in V - comp_num = sum(values(v.dims)) - D1 = DiagonalTensorMap(randn(comp_num), v) - D2 = DiagonalTensorMap(randn(comp_num), v) + rdim = reduceddim(v) + D1 = DiagonalTensorMap(randn(rdim), v) + D2 = DiagonalTensorMap(randn(rdim), v) D = D1 + im * D2 T1 = TensorMap(D1) T2 = TensorMap(D2) From e1fe3be3de4540f06ad8750141f35c8c13d6b55e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Jan 2025 14:24:31 -0500 Subject: [PATCH 13/26] Add `DiagonalTensorMap(::AbstractTensorMap)` --- src/tensors/diagonal.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index ebae85ccc..9d3b4d9d8 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -73,6 +73,16 @@ end TensorMap(d::DiagonalTensorMap) = copy!(similar(d), d) Base.convert(::Type{TensorMap}, d::DiagonalTensorMap) = TensorMap(d) +# similar to Diagonal: simply take diagonal +function DiagonalTensorMap(t::AbstractTensorMap) + numin(t) == numout(t) == 1 && domain(t) == codomain(t) || throw(SpaceMismatch()) + d = DiagonalTensorMap{scalartype(t)}(undef, space(t, 1)) + for (c, b) in blocks(t) + copy!(block(d, c), Diagonal(b)) + end + return d +end + function Base.convert(::Type{DiagonalTensorMap{T,S,A}}, d::DiagonalTensorMap{T,S,A}) where {T,S,A} return d From ccb7c93285c95c50275c233bf6752994c238faf2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Jan 2025 14:24:47 -0500 Subject: [PATCH 14/26] Specialize `to_vec(::DiagonalTensorMap)` --- ext/TensorKitFiniteDifferencesExt.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ext/TensorKitFiniteDifferencesExt.jl b/ext/TensorKitFiniteDifferencesExt.jl index f62642560..63c9711e3 100644 --- a/ext/TensorKitFiniteDifferencesExt.jl +++ b/ext/TensorKitFiniteDifferencesExt.jl @@ -23,6 +23,14 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap) end FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t)) +function FiniteDifferences.to_vec(t::DiagonalTensorMap) + x_vec, back = to_vec(TensorMap(t)) + function DiagonalTensorMap_from_vec(x_vec) + return DiagonalTensorMap(back(x_vec)) + end + return x_vec, DiagonalTensorMap_from_vec +end + end # TODO: Investigate why the approach below doesn't work From f8ca1fdec58303e1e21c1bd8fb7b4cf2ac16e4a6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Jan 2025 14:25:41 -0500 Subject: [PATCH 15/26] Add rrules matrix functions --- ext/TensorKitChainRulesCoreExt/linalg.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index c13694673..c76ac7390 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -106,3 +106,21 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap) end return a_imag, imag_pullback end + +# define rrules for matrix functions for DiagonalTensorMap, since they access data directly. +for f in + (:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt, + :log, :asin, :acos, :acosh, :atanh, :acoth) + f_pullback = Symbol(f, :_pullback) + @eval function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($f), + t::DiagonalTensorMap) + P = ProjectTo(t) # unsure if this is necessary, should already be in pullback + d, pullback = rrule_via_ad(cfg, broadcast, $f, t.data) + function $f_pullback(Δd_) + Δd = P(unthunk(Δd_)) + _, _, ∂data = pullback(Δd.data) + return NoTangent(), DiagonalTensorMap(∂data, t.domain) + end + return DiagonalTensorMap(d, t.domain), $f_pullback + end +end From 98dea08c2e7f80bdbbbaf48c694a0055f646a7e4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Jan 2025 14:41:10 -0500 Subject: [PATCH 16/26] Add tests AD of matrixfunctions --- test/ad.jl | 22 ++++++++++++++++++++++ test/runtests.jl | 8 ++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index f8a606ace..1f8f59d75 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -3,6 +3,7 @@ using ChainRulesTestUtils using FiniteDifferences: FiniteDifferences using Random using LinearAlgebra +using Zygote const _repartition = @static if isdefined(Base, :get_extension) Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition @@ -220,6 +221,27 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(LinearAlgebra.dot, A, B) end + @timedtestset "Matrix functions ($T)" for T in (Float64, ComplexF64) + for f in (sqrt, exp) + check_inferred = false # !(T <: Real) # not type-stable for real functions + t1 = randn(T, V[1] ← V[1]) + t2 = randn(T, V[2] ← V[2]) + d = DiagonalTensorMap{T}(undef, V[1]) + randn!(d.data) + if T <: Real + d.data .= abs.(d.data) + end + d2 = DiagonalTensorMap{T}(undef, V[1]) + randn!(d2.data) + if T <: Real + d2.data .= abs.(d2.data) + end + test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) + test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) + test_rrule(f, d; check_inferred, output_tangent=d2) + end + end + @timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) atol = precision(T) rtol = precision(T) diff --git a/test/runtests.jl b/test/runtests.jl index 1f06191cc..080c01d82 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,9 +60,13 @@ include("spaces.jl") include("tensors.jl") include("diagonal.jl") include("planar.jl") -if !(Sys.isapple()) # TODO: remove once we know why this is so slow on macOS - include("ad.jl") +# TODO: remove once we know AD is slow on macOS CI +test_ad = try + !(Sys.isapple() && ENV["CI"] == true) +catch + true end +test_ad && include("ad.jl") include("bugfixes.jl") Tf = time() printstyled("Finished all tests in ", From 5eec371f183504ceb24dcfebd8c9e465766ee613 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Feb 2025 07:03:29 -0500 Subject: [PATCH 17/26] Remove duplicate methods --- src/tensors/diagonal.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 59a11e88d..a13918d03 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -87,20 +87,6 @@ end TensorMap(d::DiagonalTensorMap) = copy!(similar(d), d) Base.convert(::Type{TensorMap}, d::DiagonalTensorMap) = TensorMap(d) -# similar to Diagonal: simply take diagonal -function DiagonalTensorMap(t::AbstractTensorMap) - numin(t) == numout(t) == 1 && domain(t) == codomain(t) || throw(SpaceMismatch()) - d = DiagonalTensorMap{scalartype(t)}(undef, space(t, 1)) - for (c, b) in blocks(t) - copy!(block(d, c), Diagonal(b)) - end - return d -end - -function Base.convert(::Type{DiagonalTensorMap{T,S,A}}, - d::DiagonalTensorMap{T,S,A}) where {T,S,A} - return d -end function Base.convert(D::Type{<:DiagonalTensorMap}, d::DiagonalTensorMap) return (d isa D) ? d : DiagonalTensorMap(convert(storagetype(D), d.data), d.domain) end From 9e834c4f66ab1725fccc177e5414429535fd59df Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Feb 2025 07:10:06 -0500 Subject: [PATCH 18/26] disable broken tests --- test/ad.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index c6bac27cd..a5247d041 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -177,8 +177,12 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @test P1(D) == D1 @test P1(T) == D1 - test_rrule(DiagonalTensorMap, D1.data, D1.domain) # test fails when dimension dim of the representation is not 1. - test_rrule(DiagonalTensorMap, D.data, D.domain) # the finite diff result is larger than the exact result exactly in dim times. It should be something with how diagonal tensor transforms into a vector. + # These tests fail because here the data vector are the actual parameters, + # not a vectorized version of the tensor. (off by a quantum dimension factor) + if FusionStyle(sectortype(D)) == UniqueFusion() + test_rrule(DiagonalTensorMap, D1.data, D1.domain) + test_rrule(DiagonalTensorMap, D.data, D.domain) + end end end From a5ba3406bcb1fb764835f3eed660f7621c9f43a7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Feb 2025 15:07:32 -0500 Subject: [PATCH 19/26] Fix CI check --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 080c01d82..9ad5ee921 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,7 +62,7 @@ include("diagonal.jl") include("planar.jl") # TODO: remove once we know AD is slow on macOS CI test_ad = try - !(Sys.isapple() && ENV["CI"] == true) + !(Sys.isapple() && ENV["CI"] == "true") catch true end From 3e45ff28a25837b6b4169ae90955171c1e2ab0ab Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Feb 2025 20:05:22 -0500 Subject: [PATCH 20/26] Adapt rrules for constructors and getproperty to include qdims --- .../constructors.jl | 73 +++++++++++++++++-- test/ad.jl | 15 ++-- 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index caa588561..08c1e7f6b 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -4,7 +4,7 @@ @non_differentiable TensorKit.isometry(args...) @non_differentiable TensorKit.unitary(args...) -function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...) +function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs...) function TensorMap_pullback(Δt) ∂d = convert(Array, unthunk(Δt)) return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... @@ -12,17 +12,76 @@ function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwarg return TensorMap(d, args...; kwargs...), TensorMap_pullback end -function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, d::DenseVector, args...; +# these are not the conversion from array, but actually take in data parameters +# -- as a result, requires quantum dimensions +function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector, + V::TensorMapSpace) where {T} + t = TensorMap{T}(data, V) + P = ProjectTo(data) + function TensorMap_pullback(Δt_) + Δt = copy(unthunk(Δt_)) + for (c, b) in blocks(Δt) + scale!(b, TensorKit.sqrtdim(c)) + end + ∂data = P(Δt.data) + return NoTangent(), ∂data, NoTangent() + end + return t, TensorMap_pullback +end + +function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, args...; kwargs...) - D = DiagonalTensorMap(d, args...; kwargs...) - project_D = ProjectTo(D) - function DiagonalTensorMap_pullback(Δt) - ∂d = project_D(unthunk(Δt)).data - return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... + D = DiagonalTensorMap(data, args...; kwargs...) + P = ProjectTo(data) + function DiagonalTensorMap_pullback(Δt_) + # unclear if we're allowed to modify/take ownership of the input + Δt = copy(unthunk(Δt_)) + for (c, b) in blocks(Δt) + scale!(b, TensorKit.sqrtdim(c)) + end + ∂data = P(Δt.data) + return NoTangent(), ∂data, NoTangent() end return D, DiagonalTensorMap_pullback end +function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Symbol) + if prop === :data + function getdata_pullback(Δdata) + # unclear if we're allowed to modify/take ownership of the input + t′ = typeof(t)(copy(unthunk(Δdata)), t.space) + for (c, b) in blocks(t′) + scale!(b, TensorKit.invsqrtdim(c)) + end + return NoTangent(), t′, NoTangent() + end + return t.data, getdata_pullback + elseif prop === :space + return t.space, Returns((NoTangent(), ZeroTangent(), NoTangent())) + else + throw(ArgumentError("unknown property $prop")) + end +end + +function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap, + prop::Symbol) + if prop === :data + function getdata_pullback(Δdata) + # unclear if we're allowed to modify/take ownership of the input + t′ = typeof(t)(copy(unthunk(Δdata)), t.domain) + for (c, b) in blocks(t′) + scale!(b, TensorKit.invsqrtdim(c)) + end + return NoTangent(), t′, NoTangent() + end + return t.data, getdata_pullback + elseif prop === :domain + return t.domain, Returns((NoTangent(), ZeroTangent(), NoTangent())) + else + throw(ArgumentError("unknown property $prop")) + end +end + function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) copy_pullback(Δt) = NoTangent(), Δt return copy(t), copy_pullback diff --git a/test/ad.jl b/test/ad.jl index a5247d041..9a01e7acb 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -147,6 +147,11 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(convert, Array, T1) test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); fkwargs=(; tol=Inf)) + + test_rrule(Base.getproperty, T1, :data) + test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space) + test_rrule(Base.getproperty, T2, :data) + test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space) end @timedtestset "Basic utility (DiagonalTensor)" begin @@ -177,12 +182,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @test P1(D) == D1 @test P1(T) == D1 - # These tests fail because here the data vector are the actual parameters, - # not a vectorized version of the tensor. (off by a quantum dimension factor) - if FusionStyle(sectortype(D)) == UniqueFusion() - test_rrule(DiagonalTensorMap, D1.data, D1.domain) - test_rrule(DiagonalTensorMap, D.data, D.domain) - end + test_rrule(DiagonalTensorMap, D1.data, D1.domain) + test_rrule(DiagonalTensorMap, D.data, D.domain) + test_rrule(Base.getproperty, D, :data) + test_rrule(Base.getproperty, D1, :data) end end From f9ed03ecc6e6c7c6d206b85763ca240ac8bbda02 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 6 Feb 2025 20:28:31 -0500 Subject: [PATCH 21/26] exchange sqrt and invsqrt in hope of fixing without thinking --- ext/TensorKitChainRulesCoreExt/constructors.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index 08c1e7f6b..d09763882 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -21,7 +21,7 @@ function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector, function TensorMap_pullback(Δt_) Δt = copy(unthunk(Δt_)) for (c, b) in blocks(Δt) - scale!(b, TensorKit.sqrtdim(c)) + scale!(b, TensorKit.invsqrtdim(c)) end ∂data = P(Δt.data) return NoTangent(), ∂data, NoTangent() @@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, ar # unclear if we're allowed to modify/take ownership of the input Δt = copy(unthunk(Δt_)) for (c, b) in blocks(Δt) - scale!(b, TensorKit.sqrtdim(c)) + scale!(b, TensorKit.invsqrtdim(c)) end ∂data = P(Δt.data) return NoTangent(), ∂data, NoTangent() @@ -51,7 +51,7 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Sy # unclear if we're allowed to modify/take ownership of the input t′ = typeof(t)(copy(unthunk(Δdata)), t.space) for (c, b) in blocks(t′) - scale!(b, TensorKit.invsqrtdim(c)) + scale!(b, TensorKit.sqrtdim(c)) end return NoTangent(), t′, NoTangent() end @@ -70,7 +70,7 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap, # unclear if we're allowed to modify/take ownership of the input t′ = typeof(t)(copy(unthunk(Δdata)), t.domain) for (c, b) in blocks(t′) - scale!(b, TensorKit.invsqrtdim(c)) + scale!(b, TensorKit.sqrtdim(c)) end return NoTangent(), t′, NoTangent() end From 643aba7b21953438dffcf9273f7e3999c10e33fc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 7 Feb 2025 09:28:45 -0500 Subject: [PATCH 22/26] Actually think to fix the problem --- ext/TensorKitChainRulesCoreExt/constructors.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl index d09763882..4b9dcd6b1 100644 --- a/ext/TensorKitChainRulesCoreExt/constructors.jl +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -12,8 +12,10 @@ function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs. return TensorMap(d, args...; kwargs...), TensorMap_pullback end -# these are not the conversion from array, but actually take in data parameters -# -- as a result, requires quantum dimensions +# these are not the conversion to/from array, but actually take in data parameters +# -- as a result, requires quantum dimensions to keep inner product the same: +# ⟨Δdata, ∂data⟩ = ⟨Δtensor, ∂tensor⟩ = ∑_c d_c ⟨Δtensor_c, ∂tensor_c⟩ +# ⟹ Δdata = d_c Δtensor_c function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector, V::TensorMapSpace) where {T} t = TensorMap{T}(data, V) @@ -21,7 +23,7 @@ function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector, function TensorMap_pullback(Δt_) Δt = copy(unthunk(Δt_)) for (c, b) in blocks(Δt) - scale!(b, TensorKit.invsqrtdim(c)) + scale!(b, dim(c)) end ∂data = P(Δt.data) return NoTangent(), ∂data, NoTangent() @@ -37,7 +39,7 @@ function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, ar # unclear if we're allowed to modify/take ownership of the input Δt = copy(unthunk(Δt_)) for (c, b) in blocks(Δt) - scale!(b, TensorKit.invsqrtdim(c)) + scale!(b, dim(c)) end ∂data = P(Δt.data) return NoTangent(), ∂data, NoTangent() @@ -51,7 +53,7 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Sy # unclear if we're allowed to modify/take ownership of the input t′ = typeof(t)(copy(unthunk(Δdata)), t.space) for (c, b) in blocks(t′) - scale!(b, TensorKit.sqrtdim(c)) + scale!(b, inv(dim(c))) end return NoTangent(), t′, NoTangent() end @@ -70,7 +72,7 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap, # unclear if we're allowed to modify/take ownership of the input t′ = typeof(t)(copy(unthunk(Δdata)), t.domain) for (c, b) in blocks(t′) - scale!(b, TensorKit.sqrtdim(c)) + scale!(b, inv(dim(c))) end return NoTangent(), t′, NoTangent() end From 06d3784bd430b7fb377423dfc47e831e695db9f9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 7 Feb 2025 16:30:14 -0500 Subject: [PATCH 23/26] Simplify positive data generation --- test/ad.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 9a01e7acb..f60a5464c 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -235,15 +235,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), t1 = randn(T, V[1] ← V[1]) t2 = randn(T, V[2] ← V[2]) d = DiagonalTensorMap{T}(undef, V[1]) - randn!(d.data) - if T <: Real - d.data .= abs.(d.data) - end + (T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data) d2 = DiagonalTensorMap{T}(undef, V[1]) - randn!(d2.data) - if T <: Real - d2.data .= abs.(d2.data) - end + (T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data) test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred) test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred) test_rrule(f, d; check_inferred, output_tangent=d2) From 2732f66d289238d4289900972d13fd119e03fce1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 7 Feb 2025 16:30:30 -0500 Subject: [PATCH 24/26] simplify CI detection --- test/runtests.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9ad5ee921..d0cd9945b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,12 +61,9 @@ include("tensors.jl") include("diagonal.jl") include("planar.jl") # TODO: remove once we know AD is slow on macOS CI -test_ad = try - !(Sys.isapple() && ENV["CI"] == "true") -catch - true +if !(Sys.isapple() && get(ENV, "CI", "false") == "true") + include("ad.jl") end -test_ad && include("ad.jl") include("bugfixes.jl") Tf = time() printstyled("Finished all tests in ", From 8c08866c4242eb076fecf0868ee440d82c3b74a5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 7 Feb 2025 16:51:08 -0500 Subject: [PATCH 25/26] Fix bad merge --- test/ad.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index e123bf6aa..044e89534 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -153,9 +153,10 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), if symmetricbraiding test_rrule(TensorKit.permutedcopy_oftype, T1, ComplexF64, ((3, 1), (2, 4))) - test_rrule(convert, Array, T1) - test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); - fkwargs=(; tol=Inf)) + test_rrule(convert, Array, T1) + test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); + fkwargs=(; tol=Inf)) + end test_rrule(Base.getproperty, T1, :data) test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space) From 618aaad4931cc93be12a5050914326aa49195265 Mon Sep 17 00:00:00 2001 From: Jutho Date: Sat, 8 Feb 2025 00:18:29 +0100 Subject: [PATCH 26/26] uncomment non-ad tests --- test/runtests.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ddef75243..d0cd9945b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,11 +55,11 @@ sectorlist = (Z2Irrep, Z3Irrep, Z4Irrep, Z3Irrep ⊠ Z4Irrep, Z2Irrep ⊠ FibonacciAnyon ⊠ FibonacciAnyon) Ti = time() -# include("fusiontrees.jl") -# include("spaces.jl") -# include("tensors.jl") -# include("diagonal.jl") -# include("planar.jl") +include("fusiontrees.jl") +include("spaces.jl") +include("tensors.jl") +include("diagonal.jl") +include("planar.jl") # TODO: remove once we know AD is slow on macOS CI if !(Sys.isapple() && get(ENV, "CI", "false") == "true") include("ad.jl")