Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)
if frule_call.rt !== Const(nothing)
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(primal_call.rt, primal_call.exct, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
else
return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
end
else
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end
Expand Down
7 changes: 0 additions & 7 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ end
Base.lastindex(x::Core.Compiler.InstructionStream) =
Core.Compiler.length(x)

# Solves an error after https://github.com/JuliaLang/julia/pull/46961
# as does https://github.com/FluxML/IRTools.jl/pull/101
if isdefined(Core.Compiler, :CallInfo)
Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo()
end


"""
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)

Expand Down
8 changes: 4 additions & 4 deletions src/stage1/recurse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Core.IR
using Core.Compiler:
Argument, BasicBlock, CFG, CodeInfo, GotoIfNot, GotoNode, IRCode, IncrementalCompact,
Instruction, MethodInstance, NewInstruction, NewvarNode, OldSSAValue, PhiNode,
ReturnNode, SSAValue, SlotNumber, StmtRange,
BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction,
NoCallInfo, OldSSAValue, StmtRange,
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
insert_node_here!, effect_free_and_nothrow, non_dce_finish!, quoted, retrieve_code_info,
Expand Down Expand Up @@ -266,7 +266,7 @@ function optic_transform!(ci, mi, nargs, N)

meta = Expr[]
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
Any[nothing for i = 1:length(code)],
CallInfo[NoCallInfo() for i = 1:length(code)],
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
Any[Any for i = 1:2], meta, sptypes(sparams))

Expand Down
171 changes: 133 additions & 38 deletions src/stage2/abstractinterpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using .CC: Const, isconstType, argtypes_to_type, tuple_tfunc, Const,
getfield_tfunc, _methods_by_ftype, VarTable, nfields_tfunc,
ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method,
PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc,
StmtInfo
StmtInfo, NoCallInfo
using Core: PartialStruct
using Base.Meta

Expand Down Expand Up @@ -41,7 +41,11 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
else
rt2 = obtype
end
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt2, call.exct, call.effects, RecurseInfo(call.info))
else
return CallMeta(rt2, call.effects, RecurseInfo(call.info))
end
end

# Check if there is a rrule for this function
Expand All @@ -56,7 +60,12 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
end
call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1)
if call.rt != Const(nothing)
return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info))
newrt = getfield_tfunc(call.rt, Const(1))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
else
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
end
end
end
end
Expand All @@ -74,26 +83,39 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
return ret
end

function abstract_accum(interp::AbstractInterpreter, args::Vector{Any}, sv::InferenceState)
args = filter(x->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), args)
function abstract_accum(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
argtypes = filter(@nospecialize(x)->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), argtypes)

if length(args) == 0
return CallMeta(ZeroTangent, Effects(), nothing)
if length(argtypes) == 0
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(ZeroTangent, Any, Effects(), NoCallInfo())
else
return CallMeta(ZeroTangent, Effects(), NoCallInfo())
end
end

if length(args) == 1
return CallMeta(args[1], Effects(), nothing)
if length(argtypes) == 1
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(argtypes[1], Any, Effects(), NoCallInfo())
else
return CallMeta(argtypes[1], Effects(), NoCallInfo())
end
end

rtype = reduce(tmerge, args)
rtype = reduce(tmerge, argtypes)
if widenconst(rtype) <: Tuple
targs = Any[]
for i = 1:nfields_tfunc(rtype).val
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in args], sv).rt)
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in argtypes], sv).rt)
end
rt = tuple_tfunc(targs)
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), NoCallInfo())
else
return CallMeta(rt, Effects(), NoCallInfo())
end
return CallMeta(tuple_tfunc(targs), nothing)
end
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), args...],
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), argtypes...],
sv::InferenceState)
return call
end
Expand Down Expand Up @@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
ft = argextype(inst.args[1], primal, primal.sptypes)
f = singleton_type(ft)
if isa(f, Core.Builtin)
call = CallMeta(backwards_tfunc(f, primal, inst, Δ), nothing)
rt = backwards_tfunc(f, primal, inst, Δ)
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), NoCallInfo())
else
call = CallMeta(rt, Effects(), NoCallInfo())
end
else
bail!(inst)
continue
Expand All @@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
arg = getfield_tfunc(Δ, Const(1))
call = abstract_call(interp, nothing, Any[clos, arg], sv)
# No derivative wrt the functor
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
rt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), ReifyInfo(call.info))
else
call = CallMeta(rt, Effects(), ReifyInfo(call.info))
end
else
(level, close) = derive_closure_type(call_info)
call = abstract_call(change_level(interp, level), ArgInfo(nothing, Any[close, Δ]), sv)
Expand All @@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp

if isa(info, UnionSplitApplyCallInfo)
argts = Any[argextype(inst.args[i], primal, primal.sptypes) for i = 4:length(inst.args)]
call = CallMeta(repackage_apply_rt(info, call.rt, argts),
UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
rt = repackage_apply_rt(info, call.rt, argts)
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(rt, Any, Effects(), newinfo)
else
call = CallMeta(rt, Effects(), newinfo)
end
end

if isa(call_info, ReifyInfo)
new_rt = tuple_tfunc(Any[derive_closure_type(call.info)[2]; call.rt])
call = CallMeta(new_rt, RecurseInfo(call.info))
newinfo = RecurseInfo(call.info)
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(new_rt, Any, Effects(), newinfo)
else
call = CallMeta(new_rt, Effects(), newinfo)
end
end

if call.rt === Union{}
Expand Down Expand Up @@ -312,15 +354,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
accum_call = abstract_accum(interp, this_arg_typs, sv)
if accum_call.rt == Union{}
@show accum_call.rt
return CallMeta(Union{}, false)
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(Union{}, Any, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Effects(), NoCallInfo())
end
end
push!(arg_accums, accum_call)
tup_push!(tup_elemns, accum_call.rt)
end
end

rt = tuple_tfunc(Any[tup_elemns...])
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))
else
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
end
end

function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospecialize(cc_Δ), sv::InferenceState)
Expand Down Expand Up @@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe

if isa(inst, ReturnNode)
rt = accum_arg(inst.val)
return CallMeta(rt, CompClosInfo(cc, ssa_infos))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))
else
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
end
end

args = Any[]
Expand Down Expand Up @@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
arg = getfield_tfunc(Δ, Const(2))
call = abstract_call(interp, nothing, Any[clos, arg], sv)
# No derivative wrt the functor
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
newrt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(newrt, Any, Effects(), ReifyInfo(call.info))
else
call = CallMeta(newrt, Effects(), ReifyInfo(call.info))
end
#error()
else
(level, clos) = derive_closure_type(call_info)
Expand All @@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe

if isa(call_info, ReifyInfo)
new_rt = tuple_tfunc(Any[call.rt; derive_closure_type(call.info)[2]])
call = CallMeta(new_rt, RecurseInfo())
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(new_rt, Any, Effects(), RecurseInfo())
else
call = CallMeta(new_rt, Effects(), RecurseInfo())
end
end

if isa(info, UnionSplitApplyCallInfo)
call = CallMeta(call.rt, UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
@static if VERSION ≥ v"1.11.0-DEV.945"
call = CallMeta(call.rt, call.exct, Effects(), newinfo)
else
call = CallMeta(call.rt, Effects(), newinfo)
end
end

accums[i] = call.rt
Expand All @@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos
end

function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecialize(Δ), sv::InferenceState)
@show ("enter", pc)

if pc.seq == 1
call = abstract_call(change_level(interp, pc.order), nothing, Any[pc.dual, Δ], sv)
rt = call.rt
@show (pc, Δ, rt)
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(call.rt, call.exct, Effects(), newinfo)
else
return CallMeta(call.rt, Effects(), newinfo)
end
elseif pc.seq == 2
ni = change_level(interp, pc.order)
mi′ = specialize_method(pc.info_below.results.matches[1], true)
Expand All @@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
call = infer_comp_closure(ni, cc, Δ, sv)
rt = getfield_tfunc(call.rt, Const(2))
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif pc.seq == 3
ni = change_level(interp, pc.order)
mi′ = specialize_method(pc.info_carried.info.results.matches[1], true)
Expand All @@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 0
info = pc.info_below
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)

# Add back gradient w.r.t. rrule
Δ = tuple_tfunc(Any[NoTangent, tuple_type_fields(Δ)...])
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
rt = getfield_tfunc(call.rt, Const(1))
@show (pc, Δ, rt)
return CallMeta(rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 1
info = pc.info_carried
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[pc.dual, Δ])], sv)
rt = call.rt
@show (pc, Δ, rt)
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 2
info = pc.info_below
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
rt = getfield_tfunc(call.rt, Const(2))
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
elseif mod(pc.seq, 4) == 3
info = pc.info_carried
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
@show (pc, Δ, rt)
return CallMeta(rt,
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
@static if VERSION ≥ v"1.11.0-DEV.945"
return CallMeta(rt, Any, Effects(), newinfo)
else
return CallMeta(rt, Effects(), newinfo)
end
end
error()
end
Expand All @@ -556,8 +652,7 @@ function CC.abstract_call_opaque_closure(interp::ADInterpreter,
if isa(closure.source, AbstractCompClosure)
(;argtypes) = arginfo
if length(argtypes) !== 2
error()
return CallMeta(Union{}, false)
error("bad argtypes")
end
return infer_comp_closure(interp, closure.source, argtypes[2], sv)
elseif isa(closure.source, PrimClosure)
Expand Down