Skip to content

Commit 3bc350b

Browse files
refactor: move ChainRulesCore to an extension
1 parent 844f5a5 commit 3bc350b

File tree

3 files changed

+50
-33
lines changed

3 files changed

+50
-33
lines changed

Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,24 @@ repo = "https://github.com/JuliaAlgebra/MultivariatePolynomials.jl"
55
version = "0.5.12"
66

77
[deps]
8-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
98
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
1211

12+
[weakdeps]
13+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
14+
15+
[extensions]
16+
MultivariatePolynomialsChainRulesCoreExt = "ChainRulesCore"
17+
1318
[compat]
1419
ChainRulesCore = "1"
1520
DataStructures = "0.19"
1621
MutableArithmetics = "0.3, 1"
1722
julia = "1.10"
23+
24+
[extras]
25+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
26+
27+
[targets]
28+
test = ["ChainRulesCore"]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module MultivariatePolynomialsChainRulesCoreExt
2+
3+
import ChainRulesCore
4+
using MultivariatePolynomials
5+
using MultivariatePolynomials: _APL
6+
7+
ChainRulesCore.@scalar_rule +(x::_APL) true
8+
ChainRulesCore.@scalar_rule -(x::_APL) -1
9+
10+
ChainRulesCore.@scalar_rule +(x::_APL, y::_APL) (true, true)
11+
ChainRulesCore.@scalar_rule -(x::_APL, y::_APL) (true, -1)
12+
13+
function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::_APL, q::_APL)
14+
return p * q, MA.add_mul!!(p * Δq, q, Δp)
15+
end
16+
function ChainRulesCore.rrule(::typeof(*), p::_APL, q::_APL)
17+
function times_pullback2(ΔΩ̇)
18+
#ΔΩ = ChainRulesCore.unthunk(Ω̇)
19+
#return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
20+
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
21+
end
22+
return p * q, times_pullback2
23+
end
24+
25+
function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
26+
return differentiate(p, x), differentiate(Δp, x)
27+
end
28+
function pullback(Δdpdx, x)
29+
return ChainRulesCore.NoTangent(),
30+
x * differentiate(x * Δdpdx, x),
31+
ChainRulesCore.NoTangent()
32+
end
33+
function ChainRulesCore.rrule(::typeof(differentiate), p, x)
34+
dpdx = differentiate(p, x)
35+
return dpdx, Base.Fix2(pullback, x)
36+
end
37+
38+
end

src/chain_rules.jl

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +0,0 @@
1-
import ChainRulesCore
2-
3-
ChainRulesCore.@scalar_rule +(x::_APL) true
4-
ChainRulesCore.@scalar_rule -(x::_APL) -1
5-
6-
ChainRulesCore.@scalar_rule +(x::_APL, y::_APL) (true, true)
7-
ChainRulesCore.@scalar_rule -(x::_APL, y::_APL) (true, -1)
8-
9-
function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::_APL, q::_APL)
10-
return p * q, MA.add_mul!!(p * Δq, q, Δp)
11-
end
12-
function ChainRulesCore.rrule(::typeof(*), p::_APL, q::_APL)
13-
function times_pullback2(ΔΩ̇)
14-
#ΔΩ = ChainRulesCore.unthunk(Ω̇)
15-
#return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
16-
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
17-
end
18-
return p * q, times_pullback2
19-
end
20-
21-
function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
22-
return differentiate(p, x), differentiate(Δp, x)
23-
end
24-
function pullback(Δdpdx, x)
25-
return ChainRulesCore.NoTangent(),
26-
x * differentiate(x * Δdpdx, x),
27-
ChainRulesCore.NoTangent()
28-
end
29-
function ChainRulesCore.rrule(::typeof(differentiate), p, x)
30-
dpdx = differentiate(p, x)
31-
return dpdx, Base.Fix2(pullback, x)
32-
end

0 commit comments

Comments
 (0)