1
1
"""
2
2
build_nn_function(eqs, nn, soutput)
3
3
4
- Build an executable function that can also depend on an output. It is then called with:
4
+ Build an executable function that can also depend on an output. The resulting `built_function` is then called with:
5
5
```julia
6
6
built_function(input, output, ps)
7
7
```
@@ -16,14 +16,32 @@ function build_nn_function(eqs, nn::AbstractSymbolicNeuralNetwork, soutput)
16
16
build_nn_function (eqs, params (nn), nn. input, soutput)
17
17
end
18
18
19
- function build_nn_function (eq:: EqT , sparams:: NeuralNetworkParameters , sinput:: Symbolics.Arr , soutput:: Symbolics.Arr )
19
+ function build_nn_function (eq:: EqT , sparams:: NeuralNetworkParameters , sinput:: Symbolics.Arr , soutput:: Symbolics.Arr ; reduce = hcat)
20
+ @assert ( (reduce == hcat) || (reduce == + ) ) " Keyword reduce either has to be + or hcat!"
20
21
gen_fun = _build_nn_function (eq, sparams, sinput, soutput)
21
- gen_fun_returned (input, output, ps) = mapreduce (k -> gen_fun (input, output, ps, k), + , axes (input, 2 ))
22
- gen_fun_returned (input:: AT , output:: AT , ps) where {AT <: Union{AbstractVector, Symbolics.Arr} } = gen_fun_returned (reshape (input, length (input), 1 ), reshape (output, length (output), 1 ), ps)
23
- gen_fun_returned (input:: AT , output:: AT , ps) where {T, AT <: AbstractArray{T, 3} } = gen_fun_returned (reshape (input, size (input, 1 ), size (input, 2 ) * size (input, 3 )), reshape (output, size (output, 1 ), size (output, 2 ) * size (output, 3 )), ps)
22
+ gen_fun_returned (input, output, ps) = mapreduce (k -> gen_fun (input, output, ps, k), reduce, axes (input, 2 ))
23
+ function gen_fun_returned (x:: AT , y:: AT , ps) where {AT <: Union{AbstractVector, Symbolics.Arr} }
24
+ output_not_reshaped = gen_fun_returned (reshape (x, length (x), 1 ), reshape (y, length (y), 1 ), ps)
25
+ # for vectors we do not reshape, as the output may be a matrix
26
+ output_not_reshaped
27
+ end
28
+ # check this! (definitely not correct in all cases!)
29
+ function gen_fun_returned (x:: AT , y:: AT , ps) where {AT <: AbstractArray{<:Number, 3} }
30
+ output_not_reshaped = gen_fun_returned (reshape (x, size (x, 1 ), size (x, 2 ) * size (x, 3 )), reshape (y, size (y, 1 ), size (y, 2 ) * size (y, 3 )), ps)
31
+ # if arrays are added together then don't reshape!
32
+ optional_reshape (output_not_reshaped, reduce, x)
33
+ end
24
34
gen_fun_returned
25
35
end
26
36
37
+ function optional_reshape (output_not_reshaped:: AbstractVecOrMat , :: typeof (+ ), :: AbstractArray{<:Number, 3} )
38
+ output_not_reshaped
39
+ end
40
+
41
+ function optional_reshape (output_not_reshaped:: AbstractVecOrMat , :: typeof (hcat), input:: AbstractArray{<:Number, 3} )
42
+ reshape (output_not_reshaped, size (output_not_reshaped, 1 ), size (input, 2 ), size (input, 3 ))
43
+ end
44
+
27
45
"""
28
46
_build_nn_function(eq, params, sinput, soutput)
29
47
0 commit comments