From e4925ffd91fbd2e2fd3a41ce824113c82e692f8a Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Wed, 22 Jan 2020 23:35:50 -0500 Subject: [PATCH 1/2] Fail if a GlobalRef is being assigned to a Dual number. Co-authored-by: "Yingbo Ma" Co-authored-by: "Shashi Gowda" --- src/custom_dispatch.jl | 26 +++++++++++++++++++++++++- test/api.jl | 7 +++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/custom_dispatch.jl b/src/custom_dispatch.jl index 9849e6a..a0eb95f 100644 --- a/src/custom_dispatch.jl +++ b/src/custom_dispatch.jl @@ -71,6 +71,20 @@ function newslot!(ir) Core.SlotNumber(length(ir.slotnames)) end +struct DifferentiationFailure <: Exception + msg::String +end + +Base.showerror(io::IO, d::DifferentiationFailure) = print(io, "Differentiation failure: ", d.msg) + +_fail_if_dual(::Dual) = throw(DifferentiationFailure("Differentiated variable being set to a global variable")) +_fail_if_dual(x) = x + +function checked_global_set(stmt) + stmt.args[2] = Expr(:call, _fail_if_dual, stmt.args[2]) + [stmt] +end + function rewrite_ir(ctx, ref) # turn # f(x...) @@ -88,10 +102,20 @@ function rewrite_ir(ctx, ref) (stmt, i) -> Base.Meta.isexpr(stmt, :call) ? 8 : nothing, (stmt, i) -> (s = newslot!(ir); rewrite_call(ctx, stmt, s, i))) + # Sometimes IR has y = f(x) as one statement, handle that case: Cassette.insert_statements!(ir.code, ir.codelocs, - (stmt, i) -> Base.Meta.isexpr(stmt, :(=)) && stmt.args[1] isa Core.SlotNumber && stmt.args[2] isa Expr && stmt.args[2].head == :call ? 8 : nothing, + (stmt, i) -> Base.Meta.isexpr(stmt, :(=)) && + stmt.args[1] isa Core.SlotNumber && + stmt.args[2] isa Expr && + stmt.args[2].head == :call ? 8 : nothing, (stmt, i) -> rewrite_call(ctx, stmt.args[2], stmt.args[1], i)) + # Error if a global variable is set to a Dual number + Cassette.insert_statements!(ir.code, ir.codelocs, + (stmt, i) -> Base.Meta.isexpr(stmt, :(=)) && + stmt.args[1] isa GlobalRef ? 1 : nothing, + (stmt, i) -> checked_global_set(stmt)) + ir.ssavaluetypes = length(ir.code) # Core.Compiler.validate_code(ir) diff --git a/test/api.jl b/test/api.jl index ac82137..a1f58e9 100644 --- a/test/api.jl +++ b/test/api.jl @@ -125,3 +125,10 @@ end end @test all(iszero, DI(vars_arg -> DI(theta_arg -> fun(vars_arg, theta_arg))(ones(5)))(ones(6))) end + +@testset "Differentiation Failure" begin + # issue #31 + y = 1 + _global_access(x) = global y = x + @test_throws ForwardDiff2.DifferentiationFailure D(_global_access)(1) * 1 +end From 1d1598deb7cd0b2234a27c769a6912dc6ef6c80e Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 23 Jan 2020 00:30:25 -0500 Subject: [PATCH 2/2] Fail if a closed variable is set to a Dual Co-authored-by: "Yingbo Ma" Co-authored-by: "Shashi Gowda" --- src/dual_context.jl | 13 ++++++++++++- test/api.jl | 22 +++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/dual_context.jl b/src/dual_context.jl index ffd5db3..b1b8e29 100644 --- a/src/dual_context.jl +++ b/src/dual_context.jl @@ -148,7 +148,7 @@ for pred in BINARY_PREDICATES end end -#### recursion early termination condition +#### recursion early termination conditions @inline Cassette.overdub(ctx::TaggedCtx, f::Core.Builtin, args...) = f(args...) @inline Cassette.overdub(ctx::TaggedCtx{T}, f::Union{typeof(value),typeof(partials)}, d::Dual{T}) where {T<:Tag{Nothing}} = f(d) @inline Cassette.overdub(ctx::TaggedCtx{T}, f::typeof(allpartials), d::AbstractArray{<:Dual{T}}) where {T<:Tag{Nothing}} = f(d) @@ -158,3 +158,14 @@ end ##### Inference Hacks @inline isinteresting(ctx::TaggedCtx, f::Union{typeof(Base.print_to_string),typeof(hash)}, args...) = false @noinline Cassette.overdub(ctx::TaggedCtx, f::Union{typeof(Base.print_to_string),typeof(hash),typeof(Core.throw)}, args...) = f(args...) + +##### Errors instead of silent incorrectness + +@inline function Cassette.overdub(ctx::TaggedCtx{T}, f::typeof(Core.setfield!), + thing::S, field, d::Dual) where {T,S} + DT = Dual{Tag{T}} + if isdefined(thing, field) && getfield(thing, field) isa DT && d isa DT + return f(thing, field, d) + end + throw(DifferentiationFailure("Setting a closed variable to a differentiated value is not allowed")) +end diff --git a/test/api.jl b/test/api.jl index a1f58e9..4f672d0 100644 --- a/test/api.jl +++ b/test/api.jl @@ -126,9 +126,29 @@ end @test all(iszero, DI(vars_arg -> DI(theta_arg -> fun(vars_arg, theta_arg))(ones(5)))(ones(6))) end -@testset "Differentiation Failure" begin +@testset "Dual leakage fixes" begin # issue #31 y = 1 _global_access(x) = global y = x @test_throws ForwardDiff2.DifferentiationFailure D(_global_access)(1) * 1 + + # issue #30 (part 2) + @test_throws ForwardDiff2.DifferentiationFailure D(x -> D(y -> x = y, x)*x, 1) + + mutable struct MyStruct{T} + x::T + end + + @test D() do x + s = MyStruct(x) + s.x +=1 + s.x *= 2x + end(1) * 1 === 6 + + # wut + #@test D(y -> (D() do x + # s = MyStruct(x) + # s.x += 1 + # s.x *= 2x + #end)(y) * 10)(1) * 1 === N end