Skip to content
Open
52 changes: 38 additions & 14 deletions src/DeepSplitting.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# import package
#using DifferentialEquations
#using Flux

Base.copy(t::Tuple) = t # required for below
function Base.copy(opt::O) where O<:Flux.Optimise.AbstractOptimiser
return O([copy(getfield(opt,f)) for f in fieldnames(typeof(opt))]...)
Expand Down Expand Up @@ -149,22 +153,41 @@ function solve(
u = splitting_model(y0, y1, z, t)
return sum(u.^2) / batch_size
end

# calculating SDE trajectories
function sde_loop!(y0, y1, dWall)
randn!(dWall) # points normally distributed for brownian motion
x0_sample!(y1) # points for initial conditions
for i in 1:size(dWall,3)
t = ts[N + 1 - i]
dW = @view dWall[:,:,i]
y0 .= y1
y1 .= y0 .+ μ(y0,p,t) .* dt .+ σ(y0,p,t) .* sqrt(dt) .* dW
if !isnothing(neumann_bc)
y1 .= _reflect(y0, y1, neumann_bc[1], neumann_bc[2])
end
end


# calculating the SDE trajectories - use the SDESolver - it works
#function sde_loop!(y0,y1,dWall)
# x0_sample!(y1) #initial condition
# randn!(dWall) #points normally distributed for brownian motion
# for i in 1:size(dWall,3)
# t = ts[N + 1 - i] #this is dt
# dW = @view dWall[:,:,i]
# y0 .= y1
# #y1 .= y0 .+ μ(y0,p,t) .* dt .+ σ(y0,p,t) .* sqrt(dt) .* dW
# prob = SDEProblem(y0 .+ μ(y0,p,t) .* dt,σ(y0,p,t) .* sqrt(dt),x0_sample!(y1),t)
# sol = solve(prob,EM(),dt=dt)
# if !isnothing(neumann_bc)
# y1 .= _reflect(y0, y1, neumann_bc[1], neumann_bc[2])
# end
# end
#end

# calculating the SDE trajectories - use the SDESolver
function sde_loop!(y0,y1,dWall)
x0_sample!(y1) #initial condition
randn!(dWall) #points normally distributed for brownian motion
t = ts[N+1:-1:1] #this is dt
y0 .= y1
y1 .= y0 .+ μ(y0,p,t) .* dt .+ σ(y0,p,t) .* sqrt(dt) .* dW
prob = SDEProblem(y0 .+ μ(y0,p,t) .* dt,σ(y0,p,t) .* sqrt(dt),x0_sample!(y1),t)
ensembleprob = EnsembleProblem(prob)
Comment on lines +180 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not written as functions, for this won't run. Was this tested?

sol = solve(ensembleprob, EnsembleSerial(), trajectories = 3)
if !isnothing(neumann_bc)
y1 .= _reflect(y0, y1, neumann_bc[1], neumann_bc[2])
end
end


for net in 1:N
# preallocate dWall
dWall = similar(x0, d, batch_size, N + 1 - net) # for SDE
Expand All @@ -174,6 +197,7 @@ function solve(
# first of maxiters used for first nn, second used for the other nn
_maxiters = length(maxiters) > 1 ? maxiters[min(net,2)] : maxiters[]

#modifying the sde_loop by replacing with StochasticDiffEq
for λ in λs
opt_net = copy(opt) # starting with a new optimiser state at each time step
opt_net.eta = λ
Expand Down