|
| 1 | +# # Polyhedral QP layer |
| 2 | + |
| 3 | +#md # [](@__REPO_ROOT_URL__/docs/src/examples/polyhedral_project.jl) |
| 4 | + |
| 5 | +# We use DiffOpt to define a custom network layer which, given an input matrix `y`, |
| 6 | +# computes its projection onto a polytope defined by a fixed number of inequalities: |
| 7 | +# `a_i^T x ≥ b_i`. |
| 8 | +# A neural network is created using Flux.jl and trained on the MNIST dataset, |
| 9 | +# integrating this quadratic optimization layer. |
| 10 | +# |
| 11 | +# The QP is solved in the forward pass, and its DiffOpt derivative is used in the backward pass expressed with `ChainRulesCore.rrule`. |
| 12 | + |
| 13 | +# This example is similar to the custom ReLU layer, except that the layer is parameterized |
| 14 | +# by the hyperplanes `(w,b)` and not a simple stateless function. |
| 15 | +# This also means that `ChainRulesCore.rrule` must return the derivatives of the output with respect to the |
| 16 | +# layer parameters to allow for backpropagation. |
| 17 | + |
| 18 | +using JuMP |
| 19 | +import DiffOpt |
| 20 | +import Ipopt |
| 21 | +import ChainRulesCore |
| 22 | +import Flux |
| 23 | +import MLDatasets |
| 24 | +import Statistics |
| 25 | +using Base.Iterators: repeated |
| 26 | +using LinearAlgebra |
| 27 | +using Random |
| 28 | + |
| 29 | +Random.seed!(42) |
| 30 | + |
| 31 | +# ## The Polytope representation and its derivative |
| 32 | + |
| 33 | +struct Polytope{N} |
| 34 | + w::NTuple{N, Matrix{Float64}} |
| 35 | + b::Vector{Float64} |
| 36 | +end |
| 37 | + |
| 38 | +Polytope(w::NTuple{N}) where {N} = Polytope{N}(w, randn(N)) |
| 39 | + |
| 40 | +# We define a "call" operation on the polytope, making it a so-called functor. |
| 41 | +# Calling the polytope with a matrix `y` operates an Euclidean projection of this matrix onto the polytope. |
| 42 | +function (polytope::Polytope)(y::AbstractMatrix; model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))) |
| 43 | + N, M = size(y) |
| 44 | + empty!(model) |
| 45 | + set_silent(model) |
| 46 | + @variable(model, x[1:N, 1:M]) |
| 47 | + @constraint(model, greater_than_cons[idx in 1:length(polytope.w)], dot(polytope.w[idx], x) ≥ polytope.b[idx]) |
| 48 | + @objective(model, Min, dot(x - y, x - y)) |
| 49 | + optimize!(model) |
| 50 | + return JuMP.value.(x) |
| 51 | +end |
| 52 | + |
| 53 | +# The `@functor` macro from Flux implements auxiliary functions for collecting the parameters of |
| 54 | +# our custom layer and operating backpropagation. |
| 55 | +Flux.@functor Polytope |
| 56 | + |
| 57 | +# Define the reverse differentiation rule, for the function we defined above. |
| 58 | +# Flux uses ChainRules primitives to implement reverse-mode differentiation of the whole network. |
| 59 | +# To learn the current layer (the polytope the layer contains), |
| 60 | +# the gradient is computed with respect to the `Polytope` fields in a ChainRulesCore.Tangent type |
| 61 | +# which is used to represent derivatives with respect to structs. |
| 62 | +# For more details about backpropagation, visit [Introduction, ChainRulesCore.jl](https://juliadiff.org/ChainRulesCore.jl/dev/). |
| 63 | + |
| 64 | +function ChainRulesCore.rrule(polytope::Polytope, y::AbstractMatrix) |
| 65 | + model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer)) |
| 66 | + xv = polytope(y; model = model) |
| 67 | + function pullback_matrix_projection(dl_dx) |
| 68 | + dl_dx = ChainRulesCore.unthunk(dl_dx) |
| 69 | + ## `dl_dy` is the derivative of `l` wrt `y` |
| 70 | + x = model[:x] |
| 71 | + ## grad wrt input parameters |
| 72 | + dl_dy = zeros(size(dl_dx)) |
| 73 | + ## grad wrt layer parameters |
| 74 | + dl_dw = zero.(polytope.w) |
| 75 | + dl_db = zero(polytope.b) |
| 76 | + ## set sensitivities |
| 77 | + MOI.set.(model, DiffOpt.BackwardInVariablePrimal(), x, dl_dx) |
| 78 | + ## compute grad |
| 79 | + DiffOpt.backward(model) |
| 80 | + ## compute gradient wrt objective function parameter y |
| 81 | + obj_expr = MOI.get(model, DiffOpt.BackwardOutObjective()) |
| 82 | + dl_dy .= -2 * JuMP.coefficient.(obj_expr, x) |
| 83 | + greater_than_cons = model[:greater_than_cons] |
| 84 | + for idx in eachindex(dl_dw) |
| 85 | + cons_expr = MOI.get(model, DiffOpt.BackwardOutConstraint(), greater_than_cons[idx]) |
| 86 | + dl_db[idx] = -JuMP.constant(cons_expr) |
| 87 | + dl_dw[idx] .= JuMP.coefficient.(cons_expr, x) |
| 88 | + end |
| 89 | + dself = ChainRulesCore.Tangent{typeof(polytope)}(; w = dl_dw, b = dl_db) |
| 90 | + return (dself, dl_dy) |
| 91 | + end |
| 92 | + return xv, pullback_matrix_projection |
| 93 | +end |
| 94 | + |
| 95 | +# ## Prepare data |
| 96 | +N = 500 |
| 97 | +imgs = MLDatasets.MNIST.traintensor(1:N) |
| 98 | +labels = MLDatasets.MNIST.trainlabels(1:N); |
| 99 | + |
| 100 | +# Preprocessing |
| 101 | +train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) ## stack all the images |
| 102 | +train_Y = Flux.onehotbatch(labels, 0:9); |
| 103 | + |
| 104 | +test_imgs = MLDatasets.MNIST.testtensor(1:N) |
| 105 | +test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), N)) |
| 106 | +test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9); |
| 107 | + |
| 108 | +# ## Define the Network |
| 109 | + |
| 110 | +inner = 20 |
| 111 | + |
| 112 | +m = Flux.Chain( |
| 113 | + Flux.Dense(784, inner), ## 784 being image linear dimension (28 x 28) |
| 114 | + Polytope((randn(inner, N), randn(inner, N), randn(inner, N))), |
| 115 | + Flux.Dense(inner, 10), ## 10 being the number of outcomes (0 to 9) |
| 116 | + Flux.softmax, |
| 117 | +) |
| 118 | + |
| 119 | +# Define input data |
| 120 | +# The original data is repeated `epochs` times because `Flux.train!` only |
| 121 | +# loops through the data set once |
| 122 | + |
| 123 | +epochs = 50 |
| 124 | + |
| 125 | +dataset = repeated((train_X, train_Y), epochs); |
| 126 | + |
| 127 | +# Parameters for the network training |
| 128 | + |
| 129 | +# training loss function, Flux optimizer |
| 130 | +custom_loss(x, y) = Flux.crossentropy(m(x), y) |
| 131 | +opt = Flux.ADAM() |
| 132 | +evalcb = () -> @show(custom_loss(train_X, train_Y)) |
| 133 | + |
| 134 | +# Train to optimize network parameters |
| 135 | + |
| 136 | +Flux.train!(custom_loss, Flux.params(m), dataset, opt, cb = Flux.throttle(evalcb, 5)); |
| 137 | + |
| 138 | +# Although our custom implementation takes time, it is able to reach similar |
| 139 | +# accuracy as the usual ReLU function implementation. |
| 140 | + |
| 141 | +# Average of correct guesses |
| 142 | +accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y)); |
| 143 | + |
| 144 | +# Training accuracy |
| 145 | + |
| 146 | +accuracy(train_X, train_Y) |
| 147 | + |
| 148 | +# Test accuracy |
| 149 | + |
| 150 | +accuracy(test_X, test_Y) |
| 151 | + |
| 152 | +# Note that the accuracy is low due to simplified training. |
| 153 | +# It is possible to increase the number of samples `N`, |
| 154 | +# the number of epochs `epoch` and the connectivity `inner`. |
0 commit comments