@@ -13,6 +13,7 @@ import DiffOpt
13
13
import Ipopt
14
14
import ChainRulesCore
15
15
import Flux
16
+ import MLDatasets
16
17
import Statistics
17
18
import Base. Iterators: repeated
18
19
using LinearAlgebra
@@ -22,67 +23,46 @@ using LinearAlgebra
22
23
# Define a relu through an optimization problem solved by a quadratic solver.
23
24
# Return the solution of the problem.
24
25
function matrix_relu (
25
- y:: AbstractArray{T} ;
26
+ y:: Matrix ;
26
27
model = Model (() -> DiffOpt. diff_optimizer (Ipopt. Optimizer))
27
- ) where T
28
- _x = zeros (size (y))
29
- N = length (y[:, 1 ])
28
+ )
29
+ N, M = size (y)
30
30
empty! (model)
31
31
set_silent (model)
32
- @variable (model, x[1 : N] >= 0 )
33
- for i in 1 : size (y, 2 )
34
- @objective (
35
- model,
36
- Min,
37
- dot (x, x) - 2 dot (y[:, i], x)
38
- )
39
- optimize! (model)
40
- _x[:, i] = value .(x)
41
- end
42
- return _x
32
+ @variable (model, x[1 : N, 1 : M] >= 0 )
33
+ @objective (model, Min, x[:]' x[:] - 2 y[:]' x[:])
34
+ optimize! (model)
35
+ return value .(x)
43
36
end
44
37
45
38
46
39
# Define the backward differentiation rule, for the function we defined above.
47
- function ChainRulesCore. rrule (
48
- :: typeof (matrix_relu),
49
- y:: AbstractArray{T} ;
40
+ function ChainRulesCore. rrule (:: typeof (matrix_relu), y:: Matrix{T} ) where T
50
41
model = Model (() -> DiffOpt. diff_optimizer (Ipopt. Optimizer))
51
- ) where T
52
42
pv = matrix_relu (y, model = model)
53
43
function pullback_matrix_relu (dl_dx)
54
44
# # some value from the backpropagation (e.g., loss) is denoted by `l`
55
45
# # so `dl_dy` is the derivative of `l` wrt `y`
56
46
x = model[:x ] # # load decision variable `x` into scope
57
47
dl_dy = zeros (T, size (dl_dx))
58
- dl_dq = zeros (T, size (dl_dx)) # # for step-by-step explanation
59
- for i in 1 : size (y, 2 )
60
- # # set sensitivities
61
- MOI. set .(
62
- model,
63
- DiffOpt. BackwardInVariablePrimal (),
64
- x,
65
- dl_dx[:, i]
66
- )
67
- # # compute grad
68
- DiffOpt. backward (model)
69
- # # return gradient wrt objective function parameters
70
- obj_exp = MOI. get (
71
- model,
72
- DiffOpt. BackwardOutObjective ()
73
- )
74
- dl_dq[:, i] = JuMP. coefficient .(obj_exp, x) # # coeff of `x` in q'x = -2y'x
75
- dq_dy = - 2 # # dq/dy = -2
76
- dl_dy[:, i] = dl_dq[:, i] * dq_dy
77
- end
48
+ dl_dq = zeros (T, size (dl_dx))
49
+ # # set sensitivities
50
+ MOI. set .(model, DiffOpt. BackwardInVariablePrimal (), x[:], dl_dx[:])
51
+ # # compute grad
52
+ DiffOpt. backward (model)
53
+ # # return gradient wrt objective function parameters
54
+ obj_exp = MOI. get (model, DiffOpt. BackwardOutObjective ())
55
+ # # coeff of `x` in q'x = -2y'x
56
+ dl_dq[:] .= JuMP. coefficient .(obj_exp, x[:])
57
+ dq_dy = - 2 # # dq/dy = -2
58
+ dl_dy[:] .= dl_dq[:] * dq_dy
78
59
return (ChainRulesCore. NoTangent (), dl_dy,)
79
60
end
80
61
return pv, pullback_matrix_relu
81
62
end
82
63
83
64
# For more details about backpropagation, visit [Introduction, ChainRulesCore.jl](https://juliadiff.org/ChainRulesCore.jl/dev/).
84
65
# ## prepare data
85
- import MLDatasets
86
66
N = 1000
87
67
imgs = MLDatasets. MNIST. traintensor (1 : N)
88
68
labels = MLDatasets. MNIST. trainlabels (1 : N);
@@ -99,7 +79,7 @@ test_Y = Flux.onehotbatch(MLDatasets.MNIST.testlabels(1:N), 0:9);
99
79
100
80
# Network structure
101
81
102
- inner = 15
82
+ inner = 10
103
83
104
84
m = Flux. Chain (
105
85
Flux. Dense (784 , inner), # 784 being image linear dimension (28 x 28)
@@ -112,7 +92,8 @@ m = Flux.Chain(
112
92
# The original data is repeated `epochs` times because `Flux.train!` only
113
93
# loops through the data set once
114
94
115
- epochs = 5
95
+ epochs = 50 # ~1 minute (i7 8th gen with 16gb RAM)
96
+ # # epochs = 100 # leads to 77.8% in about 2 minutes
116
97
117
98
dataset = repeated ((train_X, train_Y), epochs);
118
99
0 commit comments