Skip to content

Conversation

@ebelnikola
Copy link
Contributor

This PR adds rrule for the DiagonalTensorMap constructor and defines ProjectTo for DiagonalTensorMap.

@codecov
Copy link

codecov bot commented Jan 20, 2025

Codecov Report

Attention: Patch coverage is 91.54930% with 6 lines in your changes missing coverage. Please review.

Project coverage is 82.44%. Comparing base (33f10bc) to head (618aaad).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
ext/TensorKitChainRulesCoreExt/constructors.jl 87.23% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #208      +/-   ##
==========================================
+ Coverage   82.29%   82.44%   +0.14%     
==========================================
  Files          43       43              
  Lines        5467     5536      +69     
==========================================
+ Hits         4499     4564      +65     
- Misses        968      972       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking the time to add this, this is really appreciated!
I left a few comments in the code, but overall looks great!

Comment on lines 17 to 22
D = DiagonalTensorMap(d, args...; kwargs...)
project_D = ProjectTo(D)
function DiagonalTensorMap_pullback(Δt)
∂d = project_D(unthunk(Δt)).data
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation slightly surprises me, I would have expected the projection to be based off d. In the end, these things probably boil down to the same thing?

Suggested change
D = DiagonalTensorMap(d, args...; kwargs...)
project_D = ProjectTo(D)
function DiagonalTensorMap_pullback(Δt)
∂d = project_D(unthunk(Δt)).data
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
project_d = ProjectTo(d)
function DiagonalTensorMap_pullback(Δt)
∂d = project_d(unthunk(Δt).data)
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
D = DiagonalTensorMap(d, args...; kwargs...)

Copy link
Contributor Author

@ebelnikola ebelnikola Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me this is not the same. The issue here, is that sometimes \Delta t may be of some non-diagonal type. I expect that in this situation your version will return incorrect tangent (with a lot of zeros from off-diagonal parts of \Delta t). Am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, I think I missed this.
Somehow, I was expecting the input to the pullback to always be a DiagonalTensorMap, since this really is a constructor, and if not, there's probably a projector missing in some other rrule...
I also found some comments in the rrules for Diagonal where a similar discussion is taking place.
Long story short though, it seems like their solution is more similar to yours, so that's definitely okay for me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also undecided between the two options. I agree that in most cases, e.g. tensor contractions, there should already have been a ProjectTo(D) being called on the adjoint variable Δt before it enters this rrule, and therefore not explicitly having project_d(Δt) in this rule might serve as a way to find missing projectors elsewhere.

On the other hand, from a user perspective, I can also see the advantage of just having this in project_D in here for safety.

While in principle that is independent of still having to call project_d = ProjectTo(d) on the output ∂d, it is probably true that project_d will not be doing anything (will act as the identity) if we already have projected Δt to a diagonal tensor with the right scalar type and storage type.

@lkdvos
Copy link
Member

lkdvos commented Jan 21, 2025

I'll have a look tomorrow to help you out, the interplay with the symmetries is somewhat expected.
Did you have a specific use-case for this by the way?
The easiest way to circumvent this is to not define the rrule for this constructor, but instead define it for a conversion between a non-symmetric Diagonal and a symmetric DiagonalTensorMap, just like the TensorMap implementation. (Note that that one is also not defined in terms of the internal data representation, but instead with the general array). For abelian symmetries there should not be any difference, but for non-abelian ones there is, and then it becomes slightly confusing to implement this correctly.

@ebelnikola
Copy link
Contributor Author

ebelnikola commented Jan 21, 2025

Thank you!

Yes, I have a use-case. I needed to differentiate through a tensor RG algorithm, and this was the last missing detail. For me, everything works perfectly already with what is provided in this PR. However, I also wanted to contribute something useful.

I believe the problem lies somewhere in the tests, not in the implementation. I checked the pullback result by hand, and it is correct. I noticed that test_rrule(DiagonalTensorMap, T.data, T.domain) passes if one removes the sqrtdim(c) factor and its inverse from the to_vec function in TensorKitFiniteDifferencesExt.jl. However, this doesn't make a lot of sense, as I understand you wanted to have isometry there.

As for the Diagonal implementation: The problem here is that there's no constructor for DiagonalTensorMap that takes Diagonal. Also, Zygote asked me for an rrule for this method.

@lkdvos
Copy link
Member

lkdvos commented Jan 22, 2025

I didn't find time today, but it's still on my TODO.
The problem is indeed the tests, in the sense that while computing derivatives, there is a bit of a choice of metric: you can either consider a euclidean metric on the parameters, or use the metric from the actual tensors. Usually the latter makes more sense, since it makes it such that working with symmetric tensors is equivalent to working without the symmetry, but here the rrule implements the former, because the constructor is working explicitly with the stored data. (Which I think is in fact correct)
However, the FiniteDifferences support is set up in such a way to work with the latter, hence the discrepancy between the two.

This is more or less why I asked in what context you are encountering this: if you really have a diagonal array that you wish to convert into a tensormap, to then work with tensormaps, we should probably define the appropriate converters for that, which include the correct scalings. If not, I would be interested in the stacktrace, or a MWE, such that I might be able to spot where this is coming from.

Maybe we should also consider changing/expanding on the definitions of diag and diagm here, since that could be what you need (and should already be AD-compatible)?

@ebelnikola
Copy link
Contributor Author

This is more or less why I asked in what context you are encountering this: if you really have a diagonal array that you wish to convert into a tensormap, to then work with tensormaps, we should probably define the appropriate converters for that, which include the correct scalings. If not, I would be interested in the stacktrace, or a MWE, such that I might be able to spot where this is coming from.

Hmmm, in fact, now I cannot reproduce the problem I had.

Here is an example of the code where I encountered a problem initially:

using TensorKit
using Zygote
data=rand(8)
T=DiagonalTensorMap(data, Z2Space(4,4))
f(S)=tr(sqrt(S))
gradient(f,T)

This returns:

ERROR: Need an adjoint for constructor DiagonalTensorMap{Float64, GradedSpace{Z2Irrep, Tuple{Int64, Int64}}, Vector{Float64}}. Gradient is of type TensorMap{Float64, GradedSpace{Z2Irrep, Tuple{Int64, Int64}}, 1, 1, Vector{Float64}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{DiagonalTensorMap{…}, Nothing, false})(Δ::TensorMap{Float64, GradedSpace{…}, 1, 1, Vector{…}})
    @ Zygote ~/.julia/packages/Zygote/D9opX/src/lib/lib.jl:334
  [3] (::Zygote.var"#2229#back#333"{Zygote.Jnew{…}})(Δ::TensorMap{Float64, GradedSpace{…}, 1, 1, Vector{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [4] DiagonalTensorMap
    @ ~/.julia/dev/TensorKit/src/tensors/diagonal.jl:20 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Any})(Δ::TensorMap{Float64, GradedSpace{…}, 1, 1, Vector{…}})
    @ Zygote ~/.julia/packages/Zygote/D9opX/src/compiler/interface2.jl:0
  [6] DiagonalTensorMap
    @ ~/.julia/dev/TensorKit/src/tensors/diagonal.jl:48 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::TensorMap{Float64, GradedSpace{…}, 1, 1, Vector{…}})
    @ Zygote ~/.julia/packages/Zygote/D9opX/src/compiler/interface2.jl:0
  [8] DiagonalTensorMap
    @ ~/.julia/dev/TensorKit/src/tensors/diagonal.jl:52 [inlined]
  [9] sqrt
    @ ~/.julia/dev/TensorKit/src/tensors/diagonal.jl:311 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TensorMap{Float64, GradedSpace{…}, 1, 1, Vector{…}})
    @ Zygote ~/.julia/packages/Zygote/D9opX/src/compiler/interface2.jl:0
 [11] f
    @ ~/Codes/TensorRG/tmp2_local.jl:21 [inlined]
 [12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/D9opX/src/compiler/interface.jl:97
 [13] gradient(f::Function, args::DiagonalTensorMap{Float64, GradedSpace{Z2Irrep, Tuple{Int64, Int64}}, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/D9opX/src/compiler/interface.jl:154
 [14] top-level scope
    @ ~/Codes/TensorRG/tmp2_local.jl:23
Some type information was truncated. Use `show(err)` to see complete types.

To fix this, I defined an rrule for sqrt. A few days ago, even after this, Zygote asked for the constructor's adjoint. However, when I am repeating this now, everything works fine without it. So yes, it appears now that there was no particular global problem solved by this PR except that ProjectTo had unexpected behavior before.

@lkdvos
Copy link
Member

lkdvos commented Jan 23, 2025

Ah, I see now. The original problem comes from our specializations of matrix functions like sqrt for DiagonalTensorMap, which unpack the data, apply the function, and then repack it. The repacking step indeed requires a rrule, since that is the constructor of DiagonalTensorMap, and it should not include the dimensions, as implemented right here.

The biggest annoyance is using ChainRulesTestUtils.jl, since going from tensors to vectors in that context is understood as "desymmetrizing", ie weighing by quantum dimensions. This would be resolved if we would test sqrt instead, since there the unpacking-repacking are combined and would cancel each other, again making everything consistent.

As a separate note, the unpacking step might actually also require such a rrule, since accessing fields yields strange tangent types, which we are avoiding everywhere since our rrules will not be able to handle them.

Finally, we could also define the rrules for the matrix functions directly, but again ChainRulesTestUtils is being annoying: sqrt does weird things:

julia> sqrt(rand(3,3))
3×3 Matrix{ComplexF64}:
 0.251394+0.530824im  0.644466-0.46924im   0.245255+0.000886708im
  0.28562+0.198812im  0.729492-0.175747im  0.278139+0.000332103im
 0.652492-1.06654im   0.235591+0.942803im  0.368666-0.00178158im

julia> sqrt(rand(3,3))
3×3 Matrix{Float64}:
  0.976625  0.341331  -0.00378496
 -0.117352  0.820173   0.654396
  0.520547  0.424884   0.495172
julia> sqrt(Diagonal(randn(3)))
3×3 Diagonal{Float64, Vector{Float64}}:
 0.547541            
          0.484821   
                   0.111613

julia> sqrt(Diagonal(randn(3)))
ERROR: DomainError with -0.31158918553265935:

And since the finite differences tests will generate jacobians for the number of real parameters, this yields jacobians with weird sizes.

I'm honestly not sure what the best approach would be

@lkdvos
Copy link
Member

lkdvos commented Jan 23, 2025

In principle, this should now no longer require the rrule for the DiagonalTensorMap constructor. I'm okay with leaving that in, since I am pretty sure it is correct, but I don't think battling ChainRulesTestUtils.jl to make that behave is worth it.

The comment about accessing the fields of tensors still remains, in the sense that we don't really support AD for that, but at least now these functions that access them have custom rules defined.

For sqrt(::DiagonalTensorMap), I think it would also be possible to use a forwards rule instead, but I dont really know what the performance implications are.

@lkdvos
Copy link
Member

lkdvos commented Jan 31, 2025

It seems like the failing nightly tests are due to this commit, where now additionally everything has become even more problematic:

v1:

julia> d = Diagonal([1., -1.])
2×2 Diagonal{Float64, Vector{Float64}}:
 1.0    
     -1.0

julia> sqrt(d)
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).

julia> sqrt(Array(d))
2×2 Matrix{ComplexF64}:
 1.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+1.0im

julia> sqrt(Array(d) + [0. 0.1; 0. 0.]) # array but not strictly diagonal
2×2 Matrix{ComplexF64}:
 1.0+0.0im  0.05-0.05im
 0.0+0.0im   0.0+1.0im

nightly:

julia> using LinearAlgebra

julia> d = Diagonal([1., -1.])
2×2 Diagonal{Float64, Vector{Float64}}:
 1.0    
     -1.0

julia> sqrt(d)
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).

julia> sqrt(Array(d))
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).

julia> sqrt(Array(d) + [0. 0.1; 0. 0.]) # array but not strictly diagonal
2×2 Matrix{ComplexF64}:
 1.0+0.0im  0.05-0.05im
 0.0+0.0im   0.0+1.0im

@lkdvos
Copy link
Member

lkdvos commented Jan 31, 2025

@lkdvos lkdvos force-pushed the ProjectTo-for-DiagonalTensorMap branch from dc03d28 to 58904f7 Compare February 5, 2025 19:06
test/ad.jl Outdated
d = DiagonalTensorMap{T}(undef, V[1])
randn!(d.data)
if T <: Real
d.data .= abs.(d.data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is only necessary for f === sqrt? Alternatively, you could write

logdata = randn(T, rdim(V[1])
d = DiagonalTensorMap(exp.(logdata), V[1])

This will create a purely positive DiagonalTensorMap in the real case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that there literally is a randexp function to do that in one go... Thanks for the suggestion, I'll change that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think randexp works with complex numbers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But for complex numbers I don't need them to be positive, so I can just use randn!. I can definitely change it if you want, anything goes I guess?

include("planar.jl")
if !(Sys.isapple()) # TODO: remove once we know why this is so slow on macOS
include("ad.jl")
# TODO: remove once we know AD is slow on macOS CI
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very strange construct. Why not

if !(Sys.isapple() && get(ENV, "CI", false))
    include("ad.jl")
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mostly taken from here.
I think the biggest problem is that the environment variables are strings, so they really become "true" and "false", which makes for an annoying interface.
I don't think the construction above works, because "true" would not evaluate as a bool...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, parse(Bool, get(ENV, "CI", "false"))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this, I tried get(ENV, "CI", "false") == "true". Does that work?

@lkdvos
Copy link
Member

lkdvos commented Feb 7, 2025

Seems like this finally worked. Ready to go for me!

@lkdvos lkdvos requested a review from Jutho February 7, 2025 23:05
@Jutho
Copy link
Member

Jutho commented Feb 7, 2025

Looks good to me as well. I will merge tomorrow morning if the tests have successfully completed. Are we good to tag a patch update?

@Jutho Jutho merged commit 33bca87 into QuantumKitHub:master Feb 8, 2025
10 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants