Skip to content

Commit aa7e827

Browse files
authored
Polytope layer (#209)
* svm layer * new draft * test * fixed * more details * rem useless file * rem dep
1 parent 57f01d1 commit aa7e827

File tree

3 files changed

+157
-3
lines changed

3 files changed

+157
-3
lines changed

docs/src/examples/custom-relu.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9);
8282
inner = 10
8383

8484
m = Flux.Chain(
85-
Flux.Dense(784, inner), #784 being image linear dimension (28 x 28)
85+
Flux.Dense(784, inner), ## 784 being image linear dimension (28 x 28)
8686
matrix_relu,
87-
Flux.Dense(inner, 10), # 10 beinf the number of outcomes (0 to 9)
87+
Flux.Dense(inner, 10), ## 10 being the number of outcomes (0 to 9)
8888
Flux.softmax,
8989
)
9090

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# # Polyhedral QP layer
2+
3+
#md # [![](https://img.shields.io/badge/show-github-579ACA.svg)](@__REPO_ROOT_URL__/docs/src/examples/polyhedral_project.jl)
4+
5+
# We use DiffOpt to define a custom network layer which, given an input matrix `y`,
6+
# computes its projection onto a polytope defined by a fixed number of inequalities:
7+
# `a_i^T x ≥ b_i`.
8+
# A neural network is created using Flux.jl and trained on the MNIST dataset,
9+
# integrating this quadratic optimization layer.
10+
#
11+
# The QP is solved in the forward pass, and its DiffOpt derivative is used in the backward pass expressed with `ChainRulesCore.rrule`.
12+
13+
# This example is similar to the custom ReLU layer, except that the layer is parameterized
14+
# by the hyperplanes `(w,b)` and not a simple stateless function.
15+
# This also means that `ChainRulesCore.rrule` must return the derivatives of the output with respect to the
16+
# layer parameters to allow for backpropagation.
17+
18+
using JuMP
19+
import DiffOpt
20+
import Ipopt
21+
import ChainRulesCore
22+
import Flux
23+
import MLDatasets
24+
import Statistics
25+
using Base.Iterators: repeated
26+
using LinearAlgebra
27+
using Random
28+
29+
Random.seed!(42)
30+
31+
# ## The Polytope representation and its derivative
32+
33+
struct Polytope{N}
34+
w::NTuple{N, Matrix{Float64}}
35+
b::Vector{Float64}
36+
end
37+
38+
Polytope(w::NTuple{N}) where {N} = Polytope{N}(w, randn(N))
39+
40+
# We define a "call" operation on the polytope, making it a so-called functor.
41+
# Calling the polytope with a matrix `y` operates an Euclidean projection of this matrix onto the polytope.
42+
function (polytope::Polytope)(y::AbstractMatrix; model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer)))
43+
N, M = size(y)
44+
empty!(model)
45+
set_silent(model)
46+
@variable(model, x[1:N, 1:M])
47+
@constraint(model, greater_than_cons[idx in 1:length(polytope.w)], dot(polytope.w[idx], x) polytope.b[idx])
48+
@objective(model, Min, dot(x - y, x - y))
49+
optimize!(model)
50+
return JuMP.value.(x)
51+
end
52+
53+
# The `@functor` macro from Flux implements auxiliary functions for collecting the parameters of
54+
# our custom layer and operating backpropagation.
55+
Flux.@functor Polytope
56+
57+
# Define the reverse differentiation rule, for the function we defined above.
58+
# Flux uses ChainRules primitives to implement reverse-mode differentiation of the whole network.
59+
# To learn the current layer (the polytope the layer contains),
60+
# the gradient is computed with respect to the `Polytope` fields in a ChainRulesCore.Tangent type
61+
# which is used to represent derivatives with respect to structs.
62+
# For more details about backpropagation, visit [Introduction, ChainRulesCore.jl](https://juliadiff.org/ChainRulesCore.jl/dev/).
63+
64+
function ChainRulesCore.rrule(polytope::Polytope, y::AbstractMatrix)
65+
model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
66+
xv = polytope(y; model = model)
67+
function pullback_matrix_projection(dl_dx)
68+
dl_dx = ChainRulesCore.unthunk(dl_dx)
69+
## `dl_dy` is the derivative of `l` wrt `y`
70+
x = model[:x]
71+
## grad wrt input parameters
72+
dl_dy = zeros(size(dl_dx))
73+
## grad wrt layer parameters
74+
dl_dw = zero.(polytope.w)
75+
dl_db = zero(polytope.b)
76+
## set sensitivities
77+
MOI.set.(model, DiffOpt.BackwardInVariablePrimal(), x, dl_dx)
78+
## compute grad
79+
DiffOpt.backward(model)
80+
## compute gradient wrt objective function parameter y
81+
obj_expr = MOI.get(model, DiffOpt.BackwardOutObjective())
82+
dl_dy .= -2 * JuMP.coefficient.(obj_expr, x)
83+
greater_than_cons = model[:greater_than_cons]
84+
for idx in eachindex(dl_dw)
85+
cons_expr = MOI.get(model, DiffOpt.BackwardOutConstraint(), greater_than_cons[idx])
86+
dl_db[idx] = -JuMP.constant(cons_expr)
87+
dl_dw[idx] .= JuMP.coefficient.(cons_expr, x)
88+
end
89+
dself = ChainRulesCore.Tangent{typeof(polytope)}(; w = dl_dw, b = dl_db)
90+
return (dself, dl_dy)
91+
end
92+
return xv, pullback_matrix_projection
93+
end
94+
95+
# ## Prepare data
96+
N = 500
97+
imgs = MLDatasets.MNIST.traintensor(1:N)
98+
labels = MLDatasets.MNIST.trainlabels(1:N);
99+
100+
# Preprocessing
101+
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) ## stack all the images
102+
train_Y = Flux.onehotbatch(labels, 0:9);
103+
104+
test_imgs = MLDatasets.MNIST.testtensor(1:N)
105+
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), N))
106+
test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9);
107+
108+
# ## Define the Network
109+
110+
inner = 20
111+
112+
m = Flux.Chain(
113+
Flux.Dense(784, inner), ## 784 being image linear dimension (28 x 28)
114+
Polytope((randn(inner, N), randn(inner, N), randn(inner, N))),
115+
Flux.Dense(inner, 10), ## 10 being the number of outcomes (0 to 9)
116+
Flux.softmax,
117+
)
118+
119+
# Define input data
120+
# The original data is repeated `epochs` times because `Flux.train!` only
121+
# loops through the data set once
122+
123+
epochs = 50
124+
125+
dataset = repeated((train_X, train_Y), epochs);
126+
127+
# Parameters for the network training
128+
129+
# training loss function, Flux optimizer
130+
custom_loss(x, y) = Flux.crossentropy(m(x), y)
131+
opt = Flux.ADAM()
132+
evalcb = () -> @show(custom_loss(train_X, train_Y))
133+
134+
# Train to optimize network parameters
135+
136+
Flux.train!(custom_loss, Flux.params(m), dataset, opt, cb = Flux.throttle(evalcb, 5));
137+
138+
# Although our custom implementation takes time, it is able to reach similar
139+
# accuracy as the usual ReLU function implementation.
140+
141+
# Average of correct guesses
142+
accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));
143+
144+
# Training accuracy
145+
146+
accuracy(train_X, train_Y)
147+
148+
# Test accuracy
149+
150+
accuracy(test_X, test_Y)
151+
152+
# Note that the accuracy is low due to simplified training.
153+
# It is possible to increase the number of samples `N`,
154+
# the number of epochs `epoch` and the connectivity `inner`.

src/diff_opt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ For instance, to get the tangent of the objective function corresponding to
134134
the tangent given to `BackwardInVariablePrimal`, do the
135135
following:
136136
```julia
137-
func = MOI.get(model, DiffOpt.BackwardOutObjective)
137+
func = MOI.get(model, DiffOpt.BackwardOutObjective())
138138
```
139139
Then, to get the sensitivity of the linear term with variable `x`, do
140140
```julia

0 commit comments

Comments
 (0)