-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Problem: The current implementation only solves the full multistage problem in one go. While the existing code contains a simplified manual rrule
for a per-stage get_next_state
function, it doesn’t compute accurate sensitivities. The new DiffOpt API (PR #281 for implicit differentiation and PR #303 for objective sensitivity) allows us to obtain exact gradients with respect to parameter variables.
Proposed solution:
-
Only allow parameters to be MOI.Parameters: e.g there should be only the possibility of what is currently under
:Param
:
Line 1 in 5f413f5
function variable_to_parameter(model::JuMP.Model, variable::JuMP.VariableRef; initial_value=0.0, deficit=nothing, param_type=:Param) -
Per-stage solve: Make sure the function
simulate_stage
is well implemented and that it solves a single stage, with parameters (incoming state, realized uncertainty, target state) exposed via DiffOpt’s parameter interface. -
Fix get_next_state rrule: Use DiffOpt’s reverse model differentiation API inside the
_pullback
. This ensures the gradient of the realized state with respect to the target and incoming state is exact. -
Use objective sensitivity: With DiffOpt’s forthcoming objective-sensitivity API, compute the derivative of the stage objective with respect to the parameter variables (target state, penalty weight) directly. i.e change line
DecisionRules.jl/src/simulate_multistage.jl
Line 203 in 5f413f5
return MOI.get(JuMP.owner_model(v), POI.ParameterDual(), v) - now should be just:
dual(v)
- Testing & examples: Provide a test comparing the new stagewise gradient computation to the existing dual-based method on small instances (e.g., the battery example). Include a demonstration of training a policy using the stagewise approach.