Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2b79107
adds `ProjectTo` for `DiagonalTensorMap`
ebelnikola Jan 19, 2025
f275175
adds an `rrule` for `DiagonalTensorMap` constructor
ebelnikola Jan 19, 2025
2cf4cb4
Corrects bug in the DiagonalTensorMap rrule, adds
ebelnikola Jan 20, 2025
c43535f
@test missing in the constructor test added...
ebelnikola Jan 20, 2025
4653113
wait no, @test did not belong there
ebelnikola Jan 20, 2025
c81e43c
Update ext/TensorKitChainRulesCoreExt/utility.jl
ebelnikola Jan 20, 2025
41d1113
mixed type tests for ProjectTo
ebelnikola Jan 20, 2025
53b399c
+ rrule test on complex tensors.
ebelnikola Jan 20, 2025
8b51b8b
correct data length for DiagonalTensor in tests
ebelnikola Jan 20, 2025
2905c0b
correct data length in DiagonalTensorMap for random tnagents
ebelnikola Jan 20, 2025
79a38b6
Comment on the test failure
ebelnikola Jan 20, 2025
689b4ea
Jutho's corrections
ebelnikola Jan 23, 2025
e1fe3be
Add `DiagonalTensorMap(::AbstractTensorMap)`
lkdvos Jan 23, 2025
ccb7c93
Specialize `to_vec(::DiagonalTensorMap)`
lkdvos Jan 23, 2025
f8ca1fd
Add rrules matrix functions
lkdvos Jan 23, 2025
98dea08
Add tests AD of matrixfunctions
lkdvos Jan 23, 2025
58904f7
Merge branch 'master' into ProjectTo-for-DiagonalTensorMap
lkdvos Feb 5, 2025
5eec371
Remove duplicate methods
lkdvos Feb 6, 2025
9e834c4
disable broken tests
lkdvos Feb 6, 2025
a5ba340
Fix CI check
lkdvos Feb 6, 2025
3e45ff2
Adapt rrules for constructors and getproperty to include qdims
lkdvos Feb 7, 2025
f9ed03e
exchange sqrt and invsqrt in hope of fixing without thinking
lkdvos Feb 7, 2025
b195929
Merge branch 'master' into ProjectTo-for-DiagonalTensorMap
lkdvos Feb 7, 2025
643aba7
Actually think to fix the problem
lkdvos Feb 7, 2025
06d3784
Simplify positive data generation
lkdvos Feb 7, 2025
2732f66
simplify CI detection
lkdvos Feb 7, 2025
455bcc2
Merge branch 'master' into ProjectTo-for-DiagonalTensorMap
lkdvos Feb 7, 2025
8c08866
Fix bad merge
lkdvos Feb 7, 2025
618aaad
uncomment non-ad tests
Jutho Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/KrylovKit
Submodule KrylovKit added at 8bccac
74 changes: 73 additions & 1 deletion ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,86 @@
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...)
function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs...)
function TensorMap_pullback(Δt)
∂d = convert(Array, unthunk(Δt))
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
return TensorMap(d, args...; kwargs...), TensorMap_pullback
end

# these are not the conversion to/from array, but actually take in data parameters
# -- as a result, requires quantum dimensions to keep inner product the same:
# ⟨Δdata, ∂data⟩ = ⟨Δtensor, ∂tensor⟩ = ∑_c d_c ⟨Δtensor_c, ∂tensor_c⟩
# ⟹ Δdata = d_c Δtensor_c
function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector,
V::TensorMapSpace) where {T}
t = TensorMap{T}(data, V)
P = ProjectTo(data)
function TensorMap_pullback(Δt_)
Δt = copy(unthunk(Δt_))
for (c, b) in blocks(Δt)
scale!(b, dim(c))
end
∂data = P(Δt.data)
return NoTangent(), ∂data, NoTangent()
end
return t, TensorMap_pullback
end

function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, args...;
kwargs...)
D = DiagonalTensorMap(data, args...; kwargs...)
P = ProjectTo(data)
function DiagonalTensorMap_pullback(Δt_)
# unclear if we're allowed to modify/take ownership of the input
Δt = copy(unthunk(Δt_))
for (c, b) in blocks(Δt)
scale!(b, dim(c))
end
∂data = P(Δt.data)
return NoTangent(), ∂data, NoTangent()
end
return D, DiagonalTensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Symbol)
if prop === :data
function getdata_pullback(Δdata)
# unclear if we're allowed to modify/take ownership of the input
t′ = typeof(t)(copy(unthunk(Δdata)), t.space)
for (c, b) in blocks(t′)
scale!(b, inv(dim(c)))
end
return NoTangent(), t′, NoTangent()
end
return t.data, getdata_pullback
elseif prop === :space
return t.space, Returns((NoTangent(), ZeroTangent(), NoTangent()))

Check warning on line 62 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
else
throw(ArgumentError("unknown property $prop"))

Check warning on line 64 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L64

Added line #L64 was not covered by tests
end
end

function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap,
prop::Symbol)
if prop === :data
function getdata_pullback(Δdata)
# unclear if we're allowed to modify/take ownership of the input
t′ = typeof(t)(copy(unthunk(Δdata)), t.domain)
for (c, b) in blocks(t′)
scale!(b, inv(dim(c)))
end
return NoTangent(), t′, NoTangent()
end
return t.data, getdata_pullback
elseif prop === :domain
return t.domain, Returns((NoTangent(), ZeroTangent(), NoTangent()))

Check warning on line 81 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L80-L81

Added lines #L80 - L81 were not covered by tests
else
throw(ArgumentError("unknown property $prop"))

Check warning on line 83 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L83

Added line #L83 was not covered by tests
end
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
Expand Down
21 changes: 20 additions & 1 deletion ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ function ChainRulesCore.rrule(::typeof(imag), a::AbstractTensorMap)
return a_imag, imag_pullback
end

function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorMap)
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(exp),
A::AbstractTensorMap)
domain(A) == codomain(A) ||
error("Exponential of a tensor only exist when domain == codomain.")
P_A = ProjectTo(A)
Expand All @@ -133,3 +134,21 @@ function ChainRulesCore.rrule(cfg::RuleConfig, ::typeof(exp), A::AbstractTensorM
end
return C, exp_pullback
end

# define rrules for matrix functions for DiagonalTensorMap, since they access data directly.
for f in
(:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt,
:log, :asin, :acos, :acosh, :atanh, :acoth)
f_pullback = Symbol(f, :_pullback)
@eval function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof($f),
t::DiagonalTensorMap)
P = ProjectTo(t) # unsure if this is necessary, should already be in pullback
d, pullback = rrule_via_ad(cfg, broadcast, $f, t.data)
function $f_pullback(Δd_)
Δd = P(unthunk(Δd_))
_, _, ∂data = pullback(Δd.data)
return NoTangent(), DiagonalTensorMap(∂data, t.domain)
end
return DiagonalTensorMap(d, t.domain), $f_pullback
end
end
12 changes: 12 additions & 0 deletions ext/TensorKitChainRulesCoreExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{<:Any,S,N
end
return y
end

function (::ProjectTo{DiagonalTensorMap{T,S,A}})(x::AbstractTensorMap) where {T,S,A}
x isa DiagonalTensorMap{T,S,A} && return x
V = space(x, 1)
space(x) == (V ← V) || throw(SpaceMismatch())
y = DiagonalTensorMap{T,S,A}(undef, V)
for (c, b) in blocks(y)
p = ProjectTo(b)
b .= p(block(x, c))
end
return y
end
8 changes: 8 additions & 0 deletions ext/TensorKitFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ function FiniteDifferences.to_vec(t::AbstractTensorMap)
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))

function FiniteDifferences.to_vec(t::DiagonalTensorMap)
x_vec, back = to_vec(TensorMap(t))
function DiagonalTensorMap_from_vec(x_vec)
return DiagonalTensorMap(back(x_vec))
end
return x_vec, DiagonalTensorMap_from_vec
end

end

# TODO: Investigate why the approach below doesn't work
Expand Down
60 changes: 60 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ChainRulesTestUtils
using FiniteDifferences: FiniteDifferences
using Random
using LinearAlgebra
using Zygote

const _repartition = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition
Expand All @@ -15,6 +16,10 @@ end
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
return randn!(similar(x))
end
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
V = x.domain
return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V)
end
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
expected::AbstractTensorMap, msg=""; kwargs...)
Expand Down Expand Up @@ -152,6 +157,46 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1);
fkwargs=(; tol=Inf))
end

test_rrule(Base.getproperty, T1, :data)
test_rrule(TensorMap{scalartype(T1)}, T1.data, T1.space)
test_rrule(Base.getproperty, T2, :data)
test_rrule(TensorMap{scalartype(T2)}, T2.data, T2.space)
end

@timedtestset "Basic utility (DiagonalTensor)" begin
for v in V
rdim = reduceddim(v)
D1 = DiagonalTensorMap(randn(rdim), v)
D2 = DiagonalTensorMap(randn(rdim), v)
D = D1 + im * D2
T1 = TensorMap(D1)
T2 = TensorMap(D2)
T = T1 + im * T2

# real -> real
P1 = ProjectTo(D1)
@test P1(D1) == D1
@test P1(T1) == D1

# complex -> complex
P2 = ProjectTo(D)
@test P2(D) == D
@test P2(T) == D

# real -> complex
@test P2(D1) == D1 + 0 * im * D1
@test P2(T1) == D1 + 0 * im * D1

# complex -> real
@test P1(D) == D1
@test P1(T) == D1

test_rrule(DiagonalTensorMap, D1.data, D1.domain)
test_rrule(DiagonalTensorMap, D.data, D.domain)
test_rrule(Base.getproperty, D, :data)
test_rrule(Base.getproperty, D1, :data)
end
end

@timedtestset "Basic Linear Algebra with scalartype $T" for T in eltypes
Expand Down Expand Up @@ -196,6 +241,21 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(LinearAlgebra.dot, A, B)
end

@timedtestset "Matrix functions ($T)" for T in eltypes
for f in (sqrt, exp)
check_inferred = false # !(T <: Real) # not type-stable for real functions
t1 = randn(T, V[1] ← V[1])
t2 = randn(T, V[2] ← V[2])
d = DiagonalTensorMap{T}(undef, V[1])
(T <: Real && f === sqrt) ? randexp!(d.data) : randn!(d.data)
d2 = DiagonalTensorMap{T}(undef, V[1])
(T <: Real && f === sqrt) ? randexp!(d2.data) : randn!(d2.data)
test_rrule(f, t1; rrule_f=Zygote.rrule_via_ad, check_inferred)
test_rrule(f, t2; rrule_f=Zygote.rrule_via_ad, check_inferred)
test_rrule(f, d; check_inferred, output_tangent=d2)
end
end

symmetricbraiding &&
@timedtestset "TensorOperations with scalartype $T" for T in eltypes
atol = precision(T)
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ include("spaces.jl")
include("tensors.jl")
include("diagonal.jl")
include("planar.jl")
if !(Sys.isapple()) # TODO: remove once we know why this is so slow on macOS
# TODO: remove once we know AD is slow on macOS CI
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very strange construct. Why not

if !(Sys.isapple() && get(ENV, "CI", false))
    include("ad.jl")
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mostly taken from here.
I think the biggest problem is that the environment variables are strings, so they really become "true" and "false", which makes for an annoying interface.
I don't think the construction above works, because "true" would not evaluate as a bool...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, parse(Bool, get(ENV, "CI", "false"))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this, I tried get(ENV, "CI", "false") == "true". Does that work?

if !(Sys.isapple() && get(ENV, "CI", "false") == "true")
include("ad.jl")
end
include("bugfixes.jl")
Expand Down
Loading