Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/custom_dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ function newslot!(ir)
Core.SlotNumber(length(ir.slotnames))
end

struct DifferentiationFailure <: Exception
msg::String
end

Base.showerror(io::IO, d::DifferentiationFailure) = print(io, "Differentiation failure: ", d.msg)

_fail_if_dual(::Dual) = throw(DifferentiationFailure("Differentiated variable being set to a global variable"))
_fail_if_dual(x) = x

function checked_global_set(stmt)
stmt.args[2] = Expr(:call, _fail_if_dual, stmt.args[2])
[stmt]
end

function rewrite_ir(ctx, ref)
# turn
# f(x...)
Expand All @@ -88,10 +102,20 @@ function rewrite_ir(ctx, ref)
(stmt, i) -> Base.Meta.isexpr(stmt, :call) ? 8 : nothing,
(stmt, i) -> (s = newslot!(ir); rewrite_call(ctx, stmt, s, i)))

# Sometimes IR has y = f(x) as one statement, handle that case:
Cassette.insert_statements!(ir.code, ir.codelocs,
(stmt, i) -> Base.Meta.isexpr(stmt, :(=)) && stmt.args[1] isa Core.SlotNumber && stmt.args[2] isa Expr && stmt.args[2].head == :call ? 8 : nothing,
(stmt, i) -> Base.Meta.isexpr(stmt, :(=)) &&
stmt.args[1] isa Core.SlotNumber &&
stmt.args[2] isa Expr &&
stmt.args[2].head == :call ? 8 : nothing,
(stmt, i) -> rewrite_call(ctx, stmt.args[2], stmt.args[1], i))

# Error if a global variable is set to a Dual number
Cassette.insert_statements!(ir.code, ir.codelocs,
(stmt, i) -> Base.Meta.isexpr(stmt, :(=)) &&
stmt.args[1] isa GlobalRef ? 1 : nothing,
(stmt, i) -> checked_global_set(stmt))

ir.ssavaluetypes = length(ir.code)

# Core.Compiler.validate_code(ir)
Expand Down
13 changes: 12 additions & 1 deletion src/dual_context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ for pred in BINARY_PREDICATES
end
end

#### recursion early termination condition
#### recursion early termination conditions
@inline Cassette.overdub(ctx::TaggedCtx, f::Core.Builtin, args...) = f(args...)
@inline Cassette.overdub(ctx::TaggedCtx{T}, f::Union{typeof(value),typeof(partials)}, d::Dual{T}) where {T<:Tag{Nothing}} = f(d)
@inline Cassette.overdub(ctx::TaggedCtx{T}, f::typeof(allpartials), d::AbstractArray{<:Dual{T}}) where {T<:Tag{Nothing}} = f(d)
Expand All @@ -158,3 +158,14 @@ end
##### Inference Hacks
@inline isinteresting(ctx::TaggedCtx, f::Union{typeof(Base.print_to_string),typeof(hash)}, args...) = false
@noinline Cassette.overdub(ctx::TaggedCtx, f::Union{typeof(Base.print_to_string),typeof(hash),typeof(Core.throw)}, args...) = f(args...)

##### Errors instead of silent incorrectness

@inline function Cassette.overdub(ctx::TaggedCtx{T}, f::typeof(Core.setfield!),
thing::S, field, d::Dual) where {T,S}
DT = Dual{Tag{T}}
if isdefined(thing, field) && getfield(thing, field) isa DT && d isa DT
return f(thing, field, d)
end
throw(DifferentiationFailure("Setting a closed variable to a differentiated value is not allowed"))
end
27 changes: 27 additions & 0 deletions test/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,30 @@ end
end
@test all(iszero, DI(vars_arg -> DI(theta_arg -> fun(vars_arg, theta_arg))(ones(5)))(ones(6)))
end

@testset "Dual leakage fixes" begin
# issue #31
y = 1
_global_access(x) = global y = x
@test_throws ForwardDiff2.DifferentiationFailure D(_global_access)(1) * 1

# issue #30 (part 2)
@test_throws ForwardDiff2.DifferentiationFailure D(x -> D(y -> x = y, x)*x, 1)

mutable struct MyStruct{T}
x::T
end

@test D() do x
s = MyStruct(x)
s.x +=1
s.x *= 2x
end(1) * 1 === 6

# wut
#@test D(y -> (D() do x
# s = MyStruct(x)
# s.x += 1
# s.x *= 2x
#end)(y) * 10)(1) * 1 === N
end