Skip to content

Commit 57f01d1

Browse files
Tutorials (#206)
* align svm example with paper * simplify expression * more efficient relu Co-authored-by: Mathieu Besançon <[email protected]>
1 parent 26c5b16 commit 57f01d1

File tree

2 files changed

+30
-66
lines changed

2 files changed

+30
-66
lines changed

docs/src/examples/custom-relu.jl

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import DiffOpt
1313
import Ipopt
1414
import ChainRulesCore
1515
import Flux
16+
import MLDatasets
1617
import Statistics
1718
import Base.Iterators: repeated
1819
using LinearAlgebra
@@ -22,67 +23,46 @@ using LinearAlgebra
2223
# Define a relu through an optimization problem solved by a quadratic solver.
2324
# Return the solution of the problem.
2425
function matrix_relu(
25-
y::AbstractArray{T};
26+
y::Matrix;
2627
model = Model(() -> DiffOpt.diff_optimizer(Ipopt.Optimizer))
27-
) where T
28-
_x = zeros(size(y))
29-
N = length(y[:, 1])
28+
)
29+
N, M = size(y)
3030
empty!(model)
3131
set_silent(model)
32-
@variable(model, x[1:N] >= 0)
33-
for i in 1:size(y, 2)
34-
@objective(
35-
model,
36-
Min,
37-
dot(x, x) -2dot(y[:, i], x)
38-
)
39-
optimize!(model)
40-
_x[:, i] = value.(x)
41-
end
42-
return _x
32+
@variable(model, x[1:N, 1:M] >= 0)
33+
@objective(model, Min, x[:]'x[:] -2y[:]'x[:])
34+
optimize!(model)
35+
return value.(x)
4336
end
4437

4538

4639
# Define the backward differentiation rule, for the function we defined above.
47-
function ChainRulesCore.rrule(
48-
::typeof(matrix_relu),
49-
y::AbstractArray{T};
40+
function ChainRulesCore.rrule(::typeof(matrix_relu), y::Matrix{T}) where T
5041
model = Model(() -> DiffOpt.diff_optimizer(Ipopt.Optimizer))
51-
) where T
5242
pv = matrix_relu(y, model = model)
5343
function pullback_matrix_relu(dl_dx)
5444
## some value from the backpropagation (e.g., loss) is denoted by `l`
5545
## so `dl_dy` is the derivative of `l` wrt `y`
5646
x = model[:x] ## load decision variable `x` into scope
5747
dl_dy = zeros(T, size(dl_dx))
58-
dl_dq = zeros(T, size(dl_dx)) ## for step-by-step explanation
59-
for i in 1:size(y, 2)
60-
## set sensitivities
61-
MOI.set.(
62-
model,
63-
DiffOpt.BackwardInVariablePrimal(),
64-
x,
65-
dl_dx[:, i]
66-
)
67-
## compute grad
68-
DiffOpt.backward(model)
69-
## return gradient wrt objective function parameters
70-
obj_exp = MOI.get(
71-
model,
72-
DiffOpt.BackwardOutObjective()
73-
)
74-
dl_dq[:, i] = JuMP.coefficient.(obj_exp, x) ## coeff of `x` in q'x = -2y'x
75-
dq_dy = -2 ## dq/dy = -2
76-
dl_dy[:, i] = dl_dq[:, i] * dq_dy
77-
end
48+
dl_dq = zeros(T, size(dl_dx))
49+
## set sensitivities
50+
MOI.set.(model, DiffOpt.BackwardInVariablePrimal(), x[:], dl_dx[:])
51+
## compute grad
52+
DiffOpt.backward(model)
53+
## return gradient wrt objective function parameters
54+
obj_exp = MOI.get(model, DiffOpt.BackwardOutObjective())
55+
## coeff of `x` in q'x = -2y'x
56+
dl_dq[:] .= JuMP.coefficient.(obj_exp, x[:])
57+
dq_dy = -2 ## dq/dy = -2
58+
dl_dy[:] .= dl_dq[:] * dq_dy
7859
return (ChainRulesCore.NoTangent(), dl_dy,)
7960
end
8061
return pv, pullback_matrix_relu
8162
end
8263

8364
# For more details about backpropagation, visit [Introduction, ChainRulesCore.jl](https://juliadiff.org/ChainRulesCore.jl/dev/).
8465
# ## prepare data
85-
import MLDatasets
8666
N = 1000
8767
imgs = MLDatasets.MNIST.traintensor(1:N)
8868
labels = MLDatasets.MNIST.trainlabels(1:N);
@@ -99,7 +79,7 @@ test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9);
9979

10080
# Network structure
10181

102-
inner = 15
82+
inner = 10
10383

10484
m = Flux.Chain(
10585
Flux.Dense(784, inner), #784 being image linear dimension (28 x 28)
@@ -112,7 +92,8 @@ m = Flux.Chain(
11292
# The original data is repeated `epochs` times because `Flux.train!` only
11393
# loops through the data set once
11494

115-
epochs = 5
95+
epochs = 50 # ~1 minute (i7 8th gen with 16gb RAM)
96+
## epochs = 100 # leads to 77.8% in about 2 minutes
11697

11798
dataset = repeated((train_X, train_Y), epochs);
11899

docs/src/examples/sensitivity-analysis-svm.jl

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,15 @@ MOI.set(model, MOI.Silent(), true)
5757

5858
# Add the constraints.
5959

60-
@constraint(
61-
model,
62-
cons[i in 1:N],
60+
@constraint(model, con[i in 1:N],
6361
y[i] * (dot(X[i, :], w) + b) >= 1 - ξ[i]
6462
);
6563

6664

6765
# Define the objective and solve
6866

69-
@objective(
70-
model,
71-
Min,
72-
λ * dot(w, w) + sum(ξ),
67+
@objective(model,
68+
Min, λ * dot(w, w) + sum(ξ),
7369
)
7470

7571
optimize!(model)
@@ -109,27 +105,14 @@ for i in 1:N
109105
for j in 1:N
110106
if i == j
111107
## we consider identical perturbations on all x_i coordinates
112-
MOI.set(
113-
model,
114-
DiffOpt.ForwardInConstraint(),
115-
cons[j],
116-
y[j] * sum(w),
117-
)
108+
MOI.set(model, DiffOpt.ForwardInConstraint(), con[j], y[j] * sum(w))
118109
else
119-
MOI.set(model, DiffOpt.ForwardInConstraint(), cons[j], 0.0)
110+
MOI.set(model, DiffOpt.ForwardInConstraint(), con[j], 0.0)
120111
end
121112
end
122113
DiffOpt.forward(model)
123-
dw = MOI.get.(
124-
model,
125-
DiffOpt.ForwardOutVariablePrimal(),
126-
w,
127-
)
128-
db = MOI.get(
129-
model,
130-
DiffOpt.ForwardOutVariablePrimal(),
131-
b,
132-
)
114+
dw = MOI.get.(model, DiffOpt.ForwardOutVariablePrimal(), w)
115+
db = MOI.get(model, DiffOpt.ForwardOutVariablePrimal(), b)
133116
∇[i] = norm(dw) + norm(db)
134117
end
135118

0 commit comments

Comments
 (0)