From 902ced95330b80d2a08ae3d42a1a1ecb346d8618 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 1 May 2025 22:30:37 -0500 Subject: [PATCH 01/87] generate --- src/Compiler.jl | 43 ++++++++++ src/Interpreter.jl | 107 ++++++++++++++++++++++++ src/Overlay.jl | 12 +++ src/mlir/Dialects/Enzyme.jl | 162 ++++++++++++++++++++++++++++++++++++ test/probprog.jl | 32 +++++++ 5 files changed, 356 insertions(+) create mode 100644 test/probprog.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index db02746a07..bff1046c9d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1489,6 +1489,49 @@ function compile_mlir!( ), "after_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + if raise_first + [ + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + jit, + ] + else + [ + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + jit, + ] + end, + ',', + ), + "probprog", + ) elseif optimize === :canonicalize run_pass_pipeline!(mod, "canonicalize", "canonicalize") elseif optimize === :just_batch diff --git a/src/Interpreter.jl b/src/Interpreter.jl index ee299ca4c1..46c1f675e5 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -539,3 +539,110 @@ function overload_autodiff( end end end + +function overload_generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function overload_sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix = gensym("samplearg") + resprefix = gensym("sampleresult") + resargprefix = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end diff --git a/src/Overlay.jl b/src/Overlay.jl index c97a06664d..cfef42541f 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -21,6 +21,18 @@ end return overload_autodiff(rmode, f, rt, args...) end +@reactant_overlay @noinline function Enzyme.generate( + f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + return overload_generate(f, args...) +end + +@reactant_overlay @noinline function Enzyme.sample( + f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + return overload_sample(f, args...) +end + # Random.jl overlays @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index e4306b06a1..f558ee0468 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -151,6 +151,33 @@ function fwddiff( ) end +""" +`generate` + +Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. +""" +function generate( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.generate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function genericAdjoint( inputs::Vector{Value}, outputs::Vector{Value}; @@ -323,4 +350,139 @@ function set(gradient::Value, value::Value; location=Location()) ) end +""" +`simulate` + +Simulate a probabilistic function to generate execution trace +by replacing all SampleOps with distribution calls and inserting +sampled values into the choice map. +""" +function simulate( + inputs::Vector{Value}; newTrace::IR.Type, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[newTrace,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.simulate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`trace` + +Execute a probabilistic function specified by a symbol reference using the provided arguments, +and a set of constraints on the sampled variables (if provided). Return the execution trace +(if provided) and the log-likelihood of the execution trace. +""" +function trace( + inputs::Vector{Value}, + oldTrace=nothing::Union{Nothing,Value}; + constraints=nothing::Union{Nothing,Value}, + newTrace::IR.Type, + weights::Vector{IR.Type}, + fn, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[newTrace, weights...] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(oldTrace) && push!(operands, oldTrace) + !isnothing(constraints) && push!(operands, constraints) + push!( + attributes, + operandsegmentsizes([ + length(inputs), if (oldTrace == nothing) + 0 + elseif 1(constraints == nothing) + 0 + else + 1 + end + ]), + ) + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.trace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`addSampleToTrace` + +Add a sampled value into the execution trace. +""" +function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[trace, sample] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.addSampleToTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`insertChoiceToMap` + +Insert a constraint on a sampled variable into the choice map. +""" +function insertChoiceToMap( + choiceMap::Value, + choice::Value; + newChoiceMap::IR.Type, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[newChoiceMap,] + operands = Value[choiceMap, choice] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.insertChoiceToMap", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + end # enzyme diff --git a/test/probprog.jl b/test/probprog.jl new file mode 100644 index 0000000000..e3f64faf30 --- /dev/null +++ b/test/probprog.jl @@ -0,0 +1,32 @@ +using Enzyme, Reactant, Test, Random, StableRNGs, Statistics + +normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) + +function model(mean, stddev) + s = Enzyme.sample(normal, StableRNG(0), mean, stddev) + t = Enzyme.sample(normal, StableRNG(0), s, stddev) + return t +end + +@testset "ProbProg" begin + @testset "normal_hlo" begin + hlo = @code_hlo Enzyme.generate( + model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + @test contains(repr(hlo), "enzyme.generate") + @test contains(repr(hlo), "enzyme.sample") + # println(hlo) + + lowered = Reactant.Compiler.run_pass_pipeline_on_source(repr(hlo), "probprog") + println(lowered) + end + + @testset "normal_generate" begin + X = Array( + @jit optimize = :probprog Enzyme.generate( + model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + ) + @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 + end +end From e2c77e402f41fd39084933bef7fac89e08eeee01 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 2 May 2025 15:40:51 -0500 Subject: [PATCH 02/87] refactor --- src/Interpreter.jl | 107 ----------------------------------------- src/Overlay.jl | 12 ----- src/ProbProg.jl | 115 +++++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 1 + test/probprog.jl | 11 +++-- test/runtests.jl | 1 + 6 files changed, 123 insertions(+), 124 deletions(-) create mode 100644 src/ProbProg.jl diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 46c1f675e5..ee299ca4c1 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -539,110 +539,3 @@ function overload_autodiff( end end end - -function overload_generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) - - residx = 1 - for a in linear_results - resv = MLIR.IR.result(gen_op, residx) - residx += 1 - for path in a.paths - if length(path) == 0 - continue - end - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - end - end - - return result -end - -function overload_sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix = gensym("samplearg") - resprefix = gensym("sampleresult") - resargprefix = gensym("sampleresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - idx -= fnwrap ? 1 : 0 - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - sym = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - - sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) - - ridx = 1 - for a in linear_results - val = MLIR.IR.result(sample_op, ridx) - ridx += 1 - - for path in a.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], val) - elseif path[1] == argprefix - idx = path[2]::Int - (fnwrap ? 1 : 0) - TracedUtils.set!(args[idx], path[3:end], val) - end - end - end - - return result -end diff --git a/src/Overlay.jl b/src/Overlay.jl index cfef42541f..c97a06664d 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -21,18 +21,6 @@ end return overload_autodiff(rmode, f, rt, args...) end -@reactant_overlay @noinline function Enzyme.generate( - f::Function, args::Vararg{Any,Nargs} -) where {Nargs} - return overload_generate(f, args...) -end - -@reactant_overlay @noinline function Enzyme.sample( - f::Function, args::Vararg{Any,Nargs} -) where {Nargs} - return overload_sample(f, args...) -end - # Random.jl overlays @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) diff --git a/src/ProbProg.jl b/src/ProbProg.jl new file mode 100644 index 0000000000..b80fb2f628 --- /dev/null +++ b/src/ProbProg.jl @@ -0,0 +1,115 @@ +module ProbProg + +using ..Reactant: Reactant, XLA, MLIR, TracedUtils +using ReactantCore: ReactantCore + +using Enzyme + +function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix = gensym("samplearg") + resprefix = gensym("sampleresult") + resargprefix = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end + +end \ No newline at end of file diff --git a/src/Reactant.jl b/src/Reactant.jl index 090a8d6b90..d9f5d908b8 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -174,6 +174,7 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +include("ProbProg.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/test/probprog.jl b/test/probprog.jl index e3f64faf30..a493fcee4b 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -1,16 +1,17 @@ -using Enzyme, Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) function model(mean, stddev) - s = Enzyme.sample(normal, StableRNG(0), mean, stddev) - t = Enzyme.sample(normal, StableRNG(0), s, stddev) + s = ProbProg.sample(normal, StableRNG(0), mean, stddev) + t = ProbProg.sample(normal, StableRNG(0), s, stddev) return t end @testset "ProbProg" begin @testset "normal_hlo" begin - hlo = @code_hlo Enzyme.generate( + hlo = @code_hlo ProbProg.generate( model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) @test contains(repr(hlo), "enzyme.generate") @@ -23,7 +24,7 @@ end @testset "normal_generate" begin X = Array( - @jit optimize = :probprog Enzyme.generate( + @jit optimize = :probprog ProbProg.generate( model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) ) diff --git a/test/runtests.jl b/test/runtests.jl index b93fb9ae20..383aa44cf1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") @safetestset "Autodiff" include("autodiff.jl") + @safetestset "ProbProg" include("probprog.jl") @safetestset "Complex" include("complex.jl") @safetestset "Broadcast" include("bcast.jl") @safetestset "Struct" include("struct.jl") From d611ae4f818ec0aee692b71805a5b6041583d96b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 6 May 2025 20:14:21 -0500 Subject: [PATCH 03/87] add probprog pass to :all --- src/Compiler.jl | 45 ++------------------------------------------- 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 84db740901..e2c9fd93c7 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1285,6 +1285,7 @@ function compile_mlir!( raise_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1299,6 +1300,7 @@ function compile_mlir!( opt_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1506,49 +1508,6 @@ function compile_mlir!( ), "after_enzyme", ) - elseif optimize === :probprog - run_pass_pipeline!( - mod, - join( - if raise_first - [ - opt_passes, - kern, - raise_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - "probprog", - enzyme_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - opt_passes2, - jit, - ] - else - [ - opt_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - "probprog", - enzyme_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - opt_passes2, - kern, - raise_passes, - jit, - ] - end, - ',', - ), - "probprog", - ) elseif optimize === :canonicalize run_pass_pipeline!(mod, "mark-func-memory-effects,canonicalize", "canonicalize") elseif optimize === :just_batch From 3672d83caa1b53d66bb640cfee6901ece906b89a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 6 May 2025 20:14:28 -0500 Subject: [PATCH 04/87] improve test --- test/probprog.jl | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/probprog.jl b/test/probprog.jl index a493fcee4b..6272ec3312 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -3,29 +3,39 @@ using Reactant: ProbProg normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) -function model(mean, stddev) - s = ProbProg.sample(normal, StableRNG(0), mean, stddev) - t = ProbProg.sample(normal, StableRNG(0), s, stddev) +function model(rng, mean, stddev) + s = ProbProg.sample(normal, rng, mean, stddev) + t = ProbProg.sample(normal, rng, s, stddev) return t end @testset "ProbProg" begin @testset "normal_hlo" begin - hlo = @code_hlo ProbProg.generate( - model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + rng = StableRNG(0) + before = @code_hlo optimize = :none ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) - @test contains(repr(hlo), "enzyme.generate") - @test contains(repr(hlo), "enzyme.sample") - # println(hlo) + @test contains(repr(before), "enzyme.generate") + @test contains(repr(before), "enzyme.sample") - lowered = Reactant.Compiler.run_pass_pipeline_on_source(repr(hlo), "probprog") - println(lowered) + # println("Before") + # println(repr(before)) + + after = @code_hlo optimize = :all ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + @test !contains(repr(after), "enzyme.generate") + @test !contains(repr(after), "enzyme.sample") + + # println("After") + # println(repr(after)) end @testset "normal_generate" begin + rng = StableRNG(1) X = Array( - @jit optimize = :probprog ProbProg.generate( - model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + @jit optimize = :all ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) ) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 From b70843e34d96bff7fdfb6a4e6c83f19e60c7d9b9 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 8 May 2025 12:22:05 -0500 Subject: [PATCH 05/87] only probprog opt mode --- src/Compiler.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index e2c9fd93c7..690c964a01 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1424,6 +1424,22 @@ function compile_mlir!( ), "only_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + [ + "mark-func-memory-effects", + "enzyme-batch", + "probprog", + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ], + ',', + ), + "probprog", + ) elseif optimize === :only_enzyme run_pass_pipeline!( mod, From 597fa89b0d4d009dfe0a463e63eb693f7105acd0 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 8 May 2025 12:22:15 -0500 Subject: [PATCH 06/87] fix up test --- test/probprog.jl | 60 +++++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/test/probprog.jl b/test/probprog.jl index 6272ec3312..b3cfe75970 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -1,43 +1,55 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) +normal(rng, μ, σ) = μ .+ σ .* randn(rng, 10000) -function model(rng, mean, stddev) - s = ProbProg.sample(normal, rng, mean, stddev) - t = ProbProg.sample(normal, rng, s, stddev) - return t +function generate_model(seed, μ, σ) + function model(seed, μ, σ) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ) + t = ProbProg.sample(normal, rng, s, σ) + return t + end + + return ProbProg.generate(model, seed, μ, σ) end @testset "ProbProg" begin + @testset "normal_deterministic" begin + seed1 = Reactant.to_rarray(UInt64[1, 4]) + seed2 = Reactant.to_rarray(UInt64[1, 4]) + μ1 = Reactant.ConcreteRArray(0.0) + μ2 = Reactant.ConcreteRArray(1000.0) + σ1 = Reactant.ConcreteRArray(1.0) + σ2 = Reactant.ConcreteRArray(1.0) + model_compiled = @compile generate_model(seed1, μ1, σ1) + + @test Array(model_compiled(seed1, μ1, σ1)) ≈ Array(model_compiled(seed1, μ1, σ1)) + @test mean(Array(model_compiled(seed1, μ1, σ1))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + @test !(all( + Array(model_compiled(seed1, μ1, σ1)) .≈ Array(model_compiled(seed2, μ2, σ2)) + )) + end @testset "normal_hlo" begin - rng = StableRNG(0) - before = @code_hlo optimize = :none ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + before = @code_hlo optimize = :none generate_model(seed, μ, σ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - # println("Before") - # println(repr(before)) - - after = @code_hlo optimize = :all ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") - - # println("After") - # println(repr(after)) end @testset "normal_generate" begin - rng = StableRNG(1) - X = Array( - @jit optimize = :all ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) - ) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From e6c2c0a2a37dec3be89051798ce8335896c99f5b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 12 May 2025 09:38:53 -0500 Subject: [PATCH 07/87] move --- test/{probprog.jl => probprog/generate.jl} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/{probprog.jl => probprog/generate.jl} (98%) diff --git a/test/probprog.jl b/test/probprog/generate.jl similarity index 98% rename from test/probprog.jl rename to test/probprog/generate.jl index b3cfe75970..5a488479d4 100644 --- a/test/probprog.jl +++ b/test/probprog/generate.jl @@ -15,7 +15,7 @@ function generate_model(seed, μ, σ) return ProbProg.generate(model, seed, μ, σ) end -@testset "ProbProg" begin +@testset "Generate" begin @testset "normal_deterministic" begin seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) From 9b9395e361ea22db9e730c84a4a53da295335025 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 12 May 2025 16:10:31 -0500 Subject: [PATCH 08/87] simplify --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index b80fb2f628..c73fceb6ef 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -11,7 +11,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + f, args, (), string(f), false; argprefix, resprefix, resargprefix ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -69,7 +69,7 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + f, args, (), string(f), false; argprefix, resprefix, resargprefix ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped From b3ba4779d709620ff345793659dac33a1ede1361 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 14 May 2025 16:52:47 -0500 Subject: [PATCH 09/87] fix up --- src/ProbProg.jl | 6 +- src/mlir/Dialects/Enzyme.jl | 453 ++++++++++++------------------------ test/runtests.jl | 1 - 3 files changed, 158 insertions(+), 302 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index c73fceb6ef..68d0ca3a3f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -64,9 +64,9 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix = gensym("samplearg") - resprefix = gensym("sampleresult") - resargprefix = gensym("sampleresarg") + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( f, args, (), string(f), false; argprefix, resprefix, resargprefix diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index f558ee0468..3863cc567c 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -1,18 +1,10 @@ module enzyme using ...IR -import ...IR: - NamedAttribute, - Value, - Location, - Block, - Region, - Attribute, - create_operation, - context, - IndexType +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API + """ `addTo` @@ -20,75 +12,49 @@ TODO """ function addTo(values::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[values...,] + operands = Value[values..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.addTo", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addTo", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function autodiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.autodiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.autodiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function batch( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), namedattribute("batch_shape", batch_shape) - ] - - return create_operation( - "enzyme.batch", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] + + create_operation( + "enzyme.batch", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -101,53 +67,34 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ function broadcast(input::Value; output::IR.Type, shape, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[input,] + op_ty_results = IR.Type[output, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("shape", shape),] - - return create_operation( - "enzyme.broadcast", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("shape", shape), ] + + create_operation( + "enzyme.broadcast", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function fwddiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.fwddiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.fwddiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -156,197 +103,151 @@ end Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. """ -function generate( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] +function generate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.generate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.generate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function genericAdjoint( - inputs::Vector{Value}, - outputs::Vector{Value}; - result_tensors::Vector{IR.Type}, - indexing_maps, - iterator_types, - doc=nothing, - library_call=nothing, - region::Region, - location=Location(), -) - op_ty_results = IR.Type[result_tensors...,] - operands = Value[inputs..., outputs...] - owned_regions = Region[region,] + +function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) + op_ty_results = IR.Type[result_tensors..., ] + operands = Value[inputs..., outputs..., ] + owned_regions = Region[region, ] successors = Block[] - attributes = NamedAttribute[ - namedattribute("indexing_maps", indexing_maps), - namedattribute("iterator_types", iterator_types), - ] - push!(attributes, operandsegmentsizes([length(inputs), length(outputs)])) + attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) - !isnothing(library_call) && - push!(attributes, namedattribute("library_call", library_call)) - - return create_operation( - "enzyme.genericAdjoint", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) + + create_operation( + "enzyme.genericAdjoint", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] - operands = Value[gradient,] + op_ty_results = IR.Type[result_0, ] + operands = Value[gradient, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.get", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.get", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result_0, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.init", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.init", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function placeholder(; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] + op_ty_results = IR.Type[output, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.placeholder", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.placeholder", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function pop(cache::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[cache,] + op_ty_results = IR.Type[output, ] + operands = Value[cache, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.pop", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.pop", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function push(cache::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[cache, value] + operands = Value[cache, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.push", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.push", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function sample( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function sample(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.sample", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.sample", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[gradient, value] + operands = Value[gradient, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.set", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.set", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -357,25 +258,19 @@ Simulate a probabilistic function to generate execution trace by replacing all SampleOps with distribution calls and inserting sampled values into the choice map. """ -function simulate( - inputs::Vector{Value}; newTrace::IR.Type, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[newTrace,] - operands = Value[inputs...,] +function simulate(inputs::Vector{Value}; trace::IR.Type, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[trace, ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.simulate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.simulate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -386,46 +281,22 @@ Execute a probabilistic function specified by a symbol reference using the provi and a set of constraints on the sampled variables (if provided). Return the execution trace (if provided) and the log-likelihood of the execution trace. """ -function trace( - inputs::Vector{Value}, - oldTrace=nothing::Union{Nothing,Value}; - constraints=nothing::Union{Nothing,Value}, - newTrace::IR.Type, - weights::Vector{IR.Type}, - fn, - name=nothing, - location=Location(), -) - op_ty_results = IR.Type[newTrace, weights...] - operands = Value[inputs...,] +function trace(inputs::Vector{Value}, oldTrace=nothing::Union{Nothing, Value}; constraints=nothing::Union{Nothing, Value}, newTrace::IR.Type, weights::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[newTrace, weights..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(oldTrace) && push!(operands, oldTrace) !isnothing(constraints) && push!(operands, constraints) - push!( - attributes, - operandsegmentsizes([ - length(inputs), if (oldTrace == nothing) - 0 - elseif 1(constraints == nothing) - 0 - else - 1 - end - ]), - ) + push!(attributes, operandsegmentsizes([length(inputs), (oldTrace==nothing) ? 0 : 1(constraints==nothing) ? 0 : 1])) !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.trace", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.trace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -436,21 +307,17 @@ Add a sampled value into the execution trace. """ function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) op_ty_results = IR.Type[] - operands = Value[trace, sample] + operands = Value[trace, sample, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.addSampleToTrace", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addSampleToTrace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -459,29 +326,19 @@ end Insert a constraint on a sampled variable into the choice map. """ -function insertChoiceToMap( - choiceMap::Value, - choice::Value; - newChoiceMap::IR.Type, - name=nothing, - location=Location(), -) - op_ty_results = IR.Type[newChoiceMap,] - operands = Value[choiceMap, choice] +function insertChoiceToMap(choiceMap::Value, choice::Value; newChoiceMap::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[newChoiceMap, ] + operands = Value[choiceMap, choice, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.insertChoiceToMap", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.insertChoiceToMap", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end diff --git a/test/runtests.jl b/test/runtests.jl index 489731eff5..a52159b4a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,6 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") @safetestset "Autodiff" include("autodiff.jl") - @safetestset "ProbProg" include("probprog.jl") @safetestset "Complex" include("complex.jl") @safetestset "Broadcast" include("bcast.jl") @safetestset "Struct" include("struct.jl") From 47e9fe312e2a3e0de63a7e243bd82490a099ab26 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 15 May 2025 18:10:34 -0500 Subject: [PATCH 10/87] saving changes --- src/ProbProg.jl | 22 +++++++++++++++++++--- src/mlir/Dialects/Enzyme.jl | 8 ++++---- test/probprog/generate.jl | 33 +++++++++++++++++++-------------- 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 68d0ca3a3f..afa3a0f5b0 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -5,13 +5,21 @@ using ReactantCore: ReactantCore using Enzyme -function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f), false; argprefix, resprefix, resargprefix + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -69,7 +77,15 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f), false; argprefix, resprefix, resargprefix + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 3863cc567c..54065e0136 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -258,8 +258,8 @@ Simulate a probabilistic function to generate execution trace by replacing all SampleOps with distribution calls and inserting sampled values into the choice map. """ -function simulate(inputs::Vector{Value}; trace::IR.Type, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[trace, ] +function simulate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] @@ -326,8 +326,8 @@ end Insert a constraint on a sampled variable into the choice map. """ -function insertChoiceToMap(choiceMap::Value, choice::Value; newChoiceMap::IR.Type, name=nothing, location=Location()) - op_ty_results = IR.Type[newChoiceMap, ] +function insertChoiceToMap(choiceMap::Value, choice::Value; outputs::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs, ] operands = Value[choiceMap, choice, ] owned_regions = Region[] successors = Block[] diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 5a488479d4..8f0ddfcaa0 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -1,55 +1,60 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -normal(rng, μ, σ) = μ .+ σ .* randn(rng, 10000) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function generate_model(seed, μ, σ) - function model(seed, μ, σ) +function generate_model(seed, μ, σ, shape) + function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ) - t = ProbProg.sample(normal, rng, s, σ) + s = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, s, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ) + return ProbProg.generate(model, seed, μ, σ, shape) end @testset "Generate" begin @testset "normal_deterministic" begin + shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) μ1 = Reactant.ConcreteRArray(0.0) μ2 = Reactant.ConcreteRArray(1000.0) σ1 = Reactant.ConcreteRArray(1.0) σ2 = Reactant.ConcreteRArray(1.0) - model_compiled = @compile generate_model(seed1, μ1, σ1) - @test Array(model_compiled(seed1, μ1, σ1)) ≈ Array(model_compiled(seed1, μ1, σ1)) - @test mean(Array(model_compiled(seed1, μ1, σ1))) ≈ 0.0 atol = 0.05 rtol = 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + model_compiled = @compile generate_model(seed1, μ1, σ1, shape) + + @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) + @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = 0.05 @test !(all( - Array(model_compiled(seed1, μ1, σ1)) .≈ Array(model_compiled(seed2, μ2, σ2)) + Array(model_compiled(seed1, μ1, σ1, shape)) .≈ Array(model_compiled(seed2, μ2, σ2, shape)) )) end @testset "normal_hlo" begin + shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) - before = @code_hlo optimize = :none generate_model(seed, μ, σ) + + before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog generate_model(seed, μ, σ) + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @testset "normal_generate" begin + shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) - X = Array(@jit optimize = :probprog generate_model(seed, μ, σ)) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From a6fcca3dcb18c751ee954b58f4ae0eadbd892c76 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:37:46 -0500 Subject: [PATCH 11/87] fix sample op --- src/ProbProg.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index afa3a0f5b0..23be48f1d4 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -16,7 +16,7 @@ using Enzyme (), string(f), false; - args_in_result=:result, + args_in_result=:result_and_mutated, argprefix, resprefix, resargprefix, @@ -71,7 +71,7 @@ using Enzyme return result end -function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") @@ -82,7 +82,7 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} (), string(f), false; - args_in_result=:result, + args_in_result=:result_and_mutated, argprefix, resprefix, resargprefix, From e51e04bb5646f1833a622814045c34d255729f6d Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:45:08 -0500 Subject: [PATCH 12/87] save tests --- test/probprog/generate.jl | 6 ++--- test/probprog/sample.jl | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 test/probprog/sample.jl diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 8f0ddfcaa0..cc73d15b12 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -7,8 +7,8 @@ function generate_model(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape) - t = ProbProg.sample(normal, rng, s, σ, shape) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + t = ProbProg.sample!(normal, rng, s, σ, shape) return t end @@ -25,7 +25,7 @@ end σ1 = Reactant.ConcreteRArray(1.0) σ2 = Reactant.ConcreteRArray(1.0) - model_compiled = @compile generate_model(seed1, μ1, σ1, shape) + model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..93d411f9a5 --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,50 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg + +@noinline normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function sample1(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + return s + end + + return ProbProg.generate(model, seed, μ, σ, shape) +end + +function sample2(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + t = ProbProg.sample!(normal, rng, μ, σ, shape) + return t + end + + return ProbProg.generate(model, seed, μ, σ, shape) +end + +@testset "test" begin + @testset "sample_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + before = @code_hlo optimize = false sample2(seed, μ, σ, shape) + @test contains(repr(before), "enzyme.sample") + after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "sample_normal" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape)) + Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape)) + @test !all(X .≈ Y) + end +end From ce68f6a3da18d8da60bc01a971f03e5512b81b3f Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:45:47 -0500 Subject: [PATCH 13/87] temporarily removing probprog pass from :all as MLIR pass is not merged yet --- src/Compiler.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 74f3cd829d..4cee6f2446 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1290,7 +1290,6 @@ function compile_mlir!( raise_passes, "enzyme-batch", opt_passes2, - "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1306,7 +1305,6 @@ function compile_mlir!( opt_passes, "enzyme-batch", opt_passes2, - "probprog", enzyme_pass, opt_passes2, "canonicalize", From d31bba636aeb7e1537c1dc90d985553dda84e1a6 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:53:48 -0500 Subject: [PATCH 14/87] undo enzyme binding change --- src/mlir/Dialects/Enzyme.jl | 425 +++++++++++++++++------------------- 1 file changed, 203 insertions(+), 222 deletions(-) diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 54065e0136..e4306b06a1 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -1,10 +1,18 @@ module enzyme using ...IR -import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API - """ `addTo` @@ -12,49 +20,75 @@ TODO """ function addTo(values::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[values..., ] + operands = Value[values...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.addTo", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.addTo", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] +function autodiff( + inputs::Vector{Value}; + outputs::Vector{IR.Type}, + fn, + activity, + ret_activity, + width=nothing, + location=Location(), +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + attributes = NamedAttribute[ + namedattribute("fn", fn), + namedattribute("activity", activity), + namedattribute("ret_activity", ret_activity), + ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - create_operation( - "enzyme.autodiff", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.autodiff", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] +function batch( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] - - create_operation( - "enzyme.batch", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[ + namedattribute("fn", fn), namedattribute("batch_shape", batch_shape) + ] + + return create_operation( + "enzyme.batch", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -67,278 +101,225 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ function broadcast(input::Value; output::IR.Type, shape, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[input, ] + op_ty_results = IR.Type[output,] + operands = Value[input,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("shape", shape), ] - - create_operation( - "enzyme.broadcast", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[namedattribute("shape", shape),] + + return create_operation( + "enzyme.broadcast", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] +function fwddiff( + inputs::Vector{Value}; + outputs::Vector{IR.Type}, + fn, + activity, + ret_activity, + width=nothing, + location=Location(), +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] + attributes = NamedAttribute[ + namedattribute("fn", fn), + namedattribute("activity", activity), + namedattribute("ret_activity", ret_activity), + ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - create_operation( - "enzyme.fwddiff", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end -""" -`generate` - -Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. -""" -function generate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.generate", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.fwddiff", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) - op_ty_results = IR.Type[result_tensors..., ] - operands = Value[inputs..., outputs..., ] - owned_regions = Region[region, ] +function genericAdjoint( + inputs::Vector{Value}, + outputs::Vector{Value}; + result_tensors::Vector{IR.Type}, + indexing_maps, + iterator_types, + doc=nothing, + library_call=nothing, + region::Region, + location=Location(), +) + op_ty_results = IR.Type[result_tensors...,] + operands = Value[inputs..., outputs...] + owned_regions = Region[region,] successors = Block[] - attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] - push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) + attributes = NamedAttribute[ + namedattribute("indexing_maps", indexing_maps), + namedattribute("iterator_types", iterator_types), + ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs)])) !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) - !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) - - create_operation( - "enzyme.genericAdjoint", location; - operands, owned_regions, successors, attributes, + !isnothing(library_call) && + push!(attributes, namedattribute("library_call", library_call)) + + return create_operation( + "enzyme.genericAdjoint", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0, ] - operands = Value[gradient, ] + op_ty_results = IR.Type[result_0,] + operands = Value[gradient,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.get", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.get", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0, ] + op_ty_results = IR.Type[result_0,] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.init", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.init", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function placeholder(; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output, ] + op_ty_results = IR.Type[output,] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.placeholder", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.placeholder", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function pop(cache::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[cache, ] + op_ty_results = IR.Type[output,] + operands = Value[cache,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.pop", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzyme.pop", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function push(cache::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[cache, value, ] + operands = Value[cache, value] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzyme.push", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -function sample(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.sample", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.push", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function set(gradient::Value, value::Value; location=Location()) - op_ty_results = IR.Type[] - operands = Value[gradient, value, ] +function sample( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[] - - create_operation( - "enzyme.set", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -""" -`simulate` - -Simulate a probabilistic function to generate execution trace -by replacing all SampleOps with distribution calls and inserting -sampled values into the choice map. -""" -function simulate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] + attributes = NamedAttribute[namedattribute("fn", fn),] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.simulate", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end -""" -`trace` - -Execute a probabilistic function specified by a symbol reference using the provided arguments, -and a set of constraints on the sampled variables (if provided). Return the execution trace -(if provided) and the log-likelihood of the execution trace. -""" -function trace(inputs::Vector{Value}, oldTrace=nothing::Union{Nothing, Value}; constraints=nothing::Union{Nothing, Value}, newTrace::IR.Type, weights::Vector{IR.Type}, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[newTrace, weights..., ] - operands = Value[inputs..., ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(oldTrace) && push!(operands, oldTrace) - !isnothing(constraints) && push!(operands, constraints) - push!(attributes, operandsegmentsizes([length(inputs), (oldTrace==nothing) ? 0 : 1(constraints==nothing) ? 0 : 1])) - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.trace", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.sample", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end -""" -`addSampleToTrace` - -Add a sampled value into the execution trace. -""" -function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) +function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[trace, sample, ] + operands = Value[gradient, value] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.addSampleToTrace", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -""" -`insertChoiceToMap` -Insert a constraint on a sampled variable into the choice map. -""" -function insertChoiceToMap(choiceMap::Value, choice::Value; outputs::IR.Type, name=nothing, location=Location()) - op_ty_results = IR.Type[outputs, ] - operands = Value[choiceMap, choice, ] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[] - !isnothing(name) && push!(attributes, namedattribute("name", name)) - - create_operation( - "enzyme.insertChoiceToMap", location; - operands, owned_regions, successors, attributes, + return create_operation( + "enzyme.set", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end From 573fa021e147f6ef2eac49e46b01edb8baaadd45 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 16:58:21 -0500 Subject: [PATCH 15/87] format --- test/probprog/generate.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index cc73d15b12..47f7beff45 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -27,11 +27,15 @@ end model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) - @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) - @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ + Array(model_compiled(seed1, μ1, σ1, shape)) + @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = + 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = + 0.05 @test !(all( - Array(model_compiled(seed1, μ1, σ1, shape)) .≈ Array(model_compiled(seed2, μ2, σ2, shape)) + Array(model_compiled(seed1, μ1, σ1, shape)) .≈ + Array(model_compiled(seed2, μ2, σ2, shape)), )) end @testset "normal_hlo" begin From 0264a3dc40ae1be90460db8e5cae7bdd3633f381 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 17:01:42 -0500 Subject: [PATCH 16/87] format --- src/ProbProg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 23be48f1d4..a5301a999f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -128,4 +128,4 @@ function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} return result end -end \ No newline at end of file +end From 2e18bdf8d878212510c4a2b70202183e1085b126 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 18:50:58 -0500 Subject: [PATCH 17/87] improve --- src/ProbProg.jl | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index a5301a999f..22824bee7b 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -44,14 +44,10 @@ using Enzyme gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) - residx = 1 - for a in linear_results - resv = MLIR.IR.result(gen_op, residx) - residx += 1 - for path in a.paths - if length(path) == 0 - continue - end + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(gen_op, i) + for path in res.paths + isempty(path) && continue if path[1] == resprefix TracedUtils.set!(result, path[2:end], resv) elseif path[1] == argprefix @@ -109,18 +105,23 @@ function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) - ridx = 1 - for a in linear_results - val = MLIR.IR.result(sample_op, ridx) - ridx += 1 + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(sample_op, i) - for path in a.paths + for path in res.paths isempty(path) && continue if path[1] == resprefix - TracedUtils.set!(result, path[2:end], val) + TracedUtils.set!(result, path[2:end], resv) elseif path[1] == argprefix - idx = path[2]::Int - (fnwrap ? 1 : 0) - TracedUtils.set!(args[idx], path[3:end], val) + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end end end end From 1f1997976987733303c33e841099f91a376493ef Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 20 May 2025 18:51:34 -0500 Subject: [PATCH 18/87] improve --- src/ProbProg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 22824bee7b..14bca0db4a 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -67,7 +67,7 @@ using Enzyme return result end -function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") From 096d790abbd32fa235f7f123ee5aaed3e2b2f352 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 22 May 2025 17:22:36 -0500 Subject: [PATCH 19/87] get rid of result_and_mutated too --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 14bca0db4a..8cfa5bec35 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -16,7 +16,7 @@ using Enzyme (), string(f), false; - args_in_result=:result_and_mutated, + args_in_result=:all, argprefix, resprefix, resargprefix, @@ -78,7 +78,7 @@ end (), string(f), false; - args_in_result=:result_and_mutated, + args_in_result=:all, argprefix, resprefix, resargprefix, From 9ac653555b69a695a33421e331d0471e266999fb Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 00:10:15 -0500 Subject: [PATCH 20/87] working trace object pointer hacks + tests --- src/Compiler.jl | 3 + src/ProbProg.jl | 149 +++++++++++++++++++++++++++++++++++++- test/probprog/simulate.jl | 46 ++++++++++++ 3 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 test/probprog/simulate.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index 35bf0c4e9b..6a2becb4e2 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1488,6 +1488,7 @@ function compile_mlir!( blas_int_width = sizeof(BLAS.BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ blas_int_width=$blas_int_width}" + lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}" if optimize === :all run_pass_pipeline!( @@ -1651,6 +1652,8 @@ function compile_mlir!( "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", + lower_enzyme_probprog_pass, + jit ], ',', ), diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 8cfa5bec35..7f83ac96c3 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,10 +1,50 @@ module ProbProg -using ..Reactant: Reactant, XLA, MLIR, TracedUtils +using ..Reactant: Reactant, XLA, MLIR, TracedUtils, TracedRArray, ConcretePJRTArray using ReactantCore: ReactantCore +using Libdl: Libdl using Enzyme +const Trace = Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) + +function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) + trace_ptr = unsafe_load(trace_ptr_ptr) + @assert reinterpret(UInt64, trace_ptr) == 42 + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(Trace)) + + return nothing +end + +function addSampleToTraceLowered( + trace_ptr_ptr::Ptr{Ptr{Cvoid}}, + symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, + sample_ptr_ptr::Ptr{Cvoid}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) + + trace[symbol] = 888 + + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + add_sample_to_trace_ptr = @cfunction( + addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end + @noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -67,7 +107,9 @@ using Enzyme return result end -@noinline function sample!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function sample!( + f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") +) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") @@ -83,7 +125,7 @@ end resprefix, resargprefix, ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res + (; result, linear_args, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f @@ -103,7 +145,17 @@ end sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + symbol_ptr = pointer_from_objref(symbol) + symbol_addr = reinterpret(UInt64, symbol_ptr) + + addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr]) + + sample_op = MLIR.Dialects.enzyme.sample( + MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=addr_attr), 1), + batch_inputs; + outputs=out_tys, + fn=fn_attr, + ) for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(sample_op, i) @@ -129,4 +181,93 @@ end return result end +@noinline function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("simulatearg") + resprefix::Symbol = gensym("simulateresult") + resargprefix::Symbol = gensym("simulateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, + args, + (), + string(f), + false; + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = MLIR.IR.Type[] + supress_rest = false + for res in linear_results + if TracedUtils.has_idx(res, resprefix) && !supress_rest + push!(out_tys, MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64))) + supress_rest = true + else + # push!(out_tys, MLIR.IR.type(TracedUtils.get_mlir_data(res))) + end + end + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + simulate_op = MLIR.Dialects.enzyme.simulate(batch_inputs; outputs=out_tys, fn=fname) + + result = nothing + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(simulate_op, i) + + if TracedUtils.has_idx(res, resprefix) + # casted = MLIR.IR.result( + # MLIR.Dialects.builtin.unrealized_conversion_cast( + # resv; to=MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64)) + # ), + # 1, + # ) + # result = TracedRArray(casted) + result = TracedRArray(resv) + break + # continue + end + + # for path in res.paths + # isempty(path) && continue + # if path[1] == argprefix + # idx = path[2]::Int + # if idx == 1 && fnwrap + # TracedUtils.set!(f, path[3:end], resv) + # else + # if fnwrap + # idx -= 1 + # end + # TracedUtils.set!(args[idx], path[3:end], resv) + # end + # end + # end + end + + return result +end + +function getTrace(t::ConcretePJRTArray) + return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1])) +end + end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..de1abc05df --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,46 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg +using Libdl: Libdl + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function simulate_model(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape) + t = ProbProg.sample!(normal, rng, s, σ, shape) + return t + end + + return ProbProg.simulate(model, seed, μ, σ, shape) +end + + +@testset "Simulate" begin + @testset "normal_hlo" begin + shape = (10000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + + before = @code_hlo optimize = :no_enzyme simulate_model(seed, μ, σ, shape) + @test contains(repr(before), "enzyme.simulate") + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog simulate_model(seed, μ, σ, shape) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.sample") + @test contains(repr(after), "enzyme_probprog_add_sample_to_trace") + @test contains(repr(after), "enzyme_probprog_init_trace") + end + + @testset "normal_simulate" begin + shape = (10000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape)) + @test X[:_integrity_check] == 0x123456789abcdef + end +end From b24766f0ccd570e443a92b4e09f901860806e2ec Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 01:38:31 -0500 Subject: [PATCH 21/87] Assuming scalar samples for now; simple Bayesian linear regression test --- src/ProbProg.jl | 4 ++-- test/probprog/blr.jl | 28 ++++++++++++++++++++++++++++ test/probprog/simulate.jl | 4 ++-- 3 files changed, 32 insertions(+), 4 deletions(-) create mode 100644 test/probprog/blr.jl diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 7f83ac96c3..834e09910f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -20,12 +20,12 @@ end function addSampleToTraceLowered( trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, - sample_ptr_ptr::Ptr{Cvoid}, + sample_ptr::Ptr{Cvoid}, ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - trace[symbol] = 888 + trace[symbol] = unsafe_load(reinterpret(Ptr{Float64}, sample_ptr)) return nothing end diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl new file mode 100644 index 0000000000..619af0b5c5 --- /dev/null +++ b/test/probprog/blr.jl @@ -0,0 +1,28 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg +using Libdl: Libdl + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function blr(seed, xs) + function model(seed, xs) + rng = Random.default_rng() + Random.seed!(rng, seed) + slope = ProbProg.sample!(normal, rng, 0, 2, (1,); symbol=:slope) + intercept = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:intercept) + for (i, x) in enumerate(xs) + ProbProg.sample!(normal, rng, slope * x + intercept, 1, (1,); symbol=Symbol("y-$i")) + end + return intercept + end + + return ProbProg.simulate(model, seed, xs) +end + +@testset "BLR" begin + xs = [1, 2, 3, 4, 5] + seed = Reactant.to_rarray(UInt64[1, 4]) + X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, xs)) + @test X[:_integrity_check] == 0x123456789abcdef + @show X +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index de1abc05df..80c820a861 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -8,8 +8,8 @@ function simulate_model(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) - t = ProbProg.sample!(normal, rng, s, σ, shape) + s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol = :s) + t = ProbProg.sample!(normal, rng, s, σ, shape; symbol = :t) return t end From 3c52b39a913f2f529eaff0d7e88ef4d4e6736062 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 17:14:16 -0500 Subject: [PATCH 22/87] exclamation mark --- src/ProbProg.jl | 4 ++-- test/probprog/generate.jl | 2 +- test/probprog/sample.jl | 4 ++-- test/probprog/simulate.jl | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 834e09910f..5fdde5489b 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -45,7 +45,7 @@ function __init__() return nothing end -@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") @@ -181,7 +181,7 @@ end return result end -@noinline function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function simulate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") resargprefix::Symbol = gensym("simulateresarg") diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 47f7beff45..64297a6ce2 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -12,7 +12,7 @@ function generate_model(seed, μ, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ, shape) + return ProbProg.generate!(model, seed, μ, σ, shape) end @testset "Generate" begin diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 93d411f9a5..8d79488566 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -11,7 +11,7 @@ function sample1(seed, μ, σ, shape) return s end - return ProbProg.generate(model, seed, μ, σ, shape) + return ProbProg.generate!(model, seed, μ, σ, shape) end function sample2(seed, μ, σ, shape) @@ -23,7 +23,7 @@ function sample2(seed, μ, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ, shape) + return ProbProg.generate!(model, seed, μ, σ, shape) end @testset "test" begin diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 80c820a861..5ec3f1d031 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -13,7 +13,7 @@ function simulate_model(seed, μ, σ, shape) return t end - return ProbProg.simulate(model, seed, μ, σ, shape) + return ProbProg.simulate!(model, seed, μ, σ, shape) end From af3d055e2166666d711f570aba418d5883c35ff1 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 23:38:59 -0500 Subject: [PATCH 23/87] sample metadata --- src/ProbProg.jl | 70 +++++++++++++++++++++++++++++++++++---- test/probprog/simulate.jl | 3 +- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 5fdde5489b..d374e2efa4 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -6,7 +6,20 @@ using Libdl: Libdl using Enzyme -const Trace = Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) +struct SampleMetadata + shape::NTuple{N,Int} where {N} + element_type::Type + is_scalar::Bool + + function SampleMetadata( + shape::NTuple{N,Int}, element_type::Type, is_scalar::Bool + ) where {N} + return new(shape, element_type, is_scalar) + end +end + +const SAMPLE_METADATA_CACHE = IdDict{Symbol,SampleMetadata}() +const Trace = IdDict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) trace_ptr = unsafe_load(trace_ptr_ptr) @@ -18,14 +31,28 @@ function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) end function addSampleToTraceLowered( - trace_ptr_ptr::Ptr{Ptr{Cvoid}}, - symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, - sample_ptr::Ptr{Cvoid}, + trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, sample_ptr::Ptr{Cvoid} ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - trace[symbol] = unsafe_load(reinterpret(Ptr{Float64}, sample_ptr)) + @assert haskey(SAMPLE_METADATA_CACHE, symbol) "Symbol $symbol not found in metadata cache" + + metadata = SAMPLE_METADATA_CACHE[symbol] + shape = metadata.shape + element_type = metadata.element_type + is_scalar = metadata.is_scalar + + if is_scalar + value = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) + else + value = unsafe_wrap( + Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + ) + value = reshape(value, shape) # TODO: GC'd? + end + + trace[symbol] = value return nothing end @@ -145,9 +172,22 @@ end sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + if !isempty(linear_results) + sample_result = linear_results[1] # TODO: consider multiple results + sample_mlir_data = TracedUtils.get_mlir_data(sample_result) + @assert sample_mlir_data isa MLIR.IR.Value "Sample $sample_result is not a MLIR.IR.Value" + + sample_type = MLIR.IR.type(sample_mlir_data) + sample_shape = size(sample_type) + sample_element_type = MLIR.IR.julia_type(eltype(sample_type)) + + SAMPLE_METADATA_CACHE[symbol] = SampleMetadata( + sample_shape, sample_element_type, length(sample_shape) == 0 + ) + end + symbol_ptr = pointer_from_objref(symbol) symbol_addr = reinterpret(UInt64, symbol_ptr) - addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr]) sample_op = MLIR.Dialects.enzyme.sample( @@ -270,4 +310,22 @@ function getTrace(t::ConcretePJRTArray) return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1])) end +function print_trace(trace::IdDict) + println("Probabilistic Program Trace:") + for (symbol, sample) in trace + symbol == :_integrity_check && continue + metadata = SAMPLE_METADATA_CACHE[symbol] + + println(" $symbol:") + println(" Sample: $(sample)") + println(" Shape: $(metadata.shape)") + println(" Element Type: $(metadata.element_type)") + end +end + +function clear_sample_metadata_cache!() + empty!(SAMPLE_METADATA_CACHE) + return nothing +end + end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 5ec3f1d031..94fde7e55a 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -36,11 +36,12 @@ end end @testset "normal_simulate" begin - shape = (10000,) + shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape)) @test X[:_integrity_check] == 0x123456789abcdef + ProbProg.print_trace(X) end end From 6c7ffa3e4e692de353f9848789dd759f55dd8dda Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 23:54:42 -0500 Subject: [PATCH 24/87] fix up copy --- src/ProbProg.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index d374e2efa4..0c9a95b9bd 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -44,16 +44,15 @@ function addSampleToTraceLowered( is_scalar = metadata.is_scalar if is_scalar - value = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) + trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) else - value = unsafe_wrap( - Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + trace[symbol] = Base.deepcopy( + unsafe_wrap( + Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + ), ) - value = reshape(value, shape) # TODO: GC'd? end - trace[symbol] = value - return nothing end From 4e017d0476f33d41ac61e96490f2fce2a76ec6b2 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 5 Jun 2025 23:58:33 -0500 Subject: [PATCH 25/87] fix up copy --- src/ProbProg.jl | 9 +++++++-- test/probprog/simulate.jl | 7 +++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 0c9a95b9bd..80c934f95e 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -47,8 +47,13 @@ function addSampleToTraceLowered( trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) else trace[symbol] = Base.deepcopy( - unsafe_wrap( - Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape) + reshape( + unsafe_wrap( + Array{element_type}, + reinterpret(Ptr{element_type}, sample_ptr), + prod(shape), + ), + shape, ), ) end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 94fde7e55a..97910443e2 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -8,15 +8,14 @@ function simulate_model(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol = :s) - t = ProbProg.sample!(normal, rng, s, σ, shape; symbol = :t) + s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) return t end return ProbProg.simulate!(model, seed, μ, σ, shape) end - @testset "Simulate" begin @testset "normal_hlo" begin shape = (10000,) @@ -36,7 +35,7 @@ end end @testset "normal_simulate" begin - shape = (10,) + shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) From e53fc7cf58d9d93047c055a9d6adc9e7c32f0487 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 6 Jun 2025 00:20:39 -0500 Subject: [PATCH 26/87] working vectorized blr test --- src/ProbProg.jl | 1 + test/probprog/blr.jl | 38 +++++++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 80c934f95e..46d9b4f786 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -152,6 +152,7 @@ end string(f), false; args_in_result=:all, + do_transpose=false, # TODO: double check transpose argprefix, resprefix, resargprefix, diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 619af0b5c5..4f0836b76c 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -3,26 +3,38 @@ using Reactant: ProbProg using Libdl: Libdl normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) -function blr(seed, xs) - function model(seed, xs) +function blr(seed, N, K) + function model(seed, N, K) rng = Random.default_rng() Random.seed!(rng, seed) - slope = ProbProg.sample!(normal, rng, 0, 2, (1,); symbol=:slope) - intercept = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:intercept) - for (i, x) in enumerate(xs) - ProbProg.sample!(normal, rng, slope * x + intercept, 1, (1,); symbol=Symbol("y-$i")) - end - return intercept + + # α ~ Normal(0, 10, size = 1) + α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) + + # β ~ Normal(0, 2.5, size = K) + β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) + + # X ~ Normal(0, 10, size = (N, K)) + X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose + + # μ = α .+ X * β + μ = α .+ X * β + + ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) + + return μ end - return ProbProg.simulate(model, seed, xs) + return ProbProg.simulate!(model, seed, N, K) end @testset "BLR" begin - xs = [1, 2, 3, 4, 5] + N = 5 # number of observations + K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, xs)) - @test X[:_integrity_check] == 0x123456789abcdef - @show X + + X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K)) + ProbProg.print_trace(X) end From 1dbf5c73702b068d2ad338021f5d37912351c3b4 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 11 Jun 2025 17:35:40 -0500 Subject: [PATCH 27/87] fix test warning --- test/probprog/generate.jl | 16 ++++++++-------- test/probprog/sample.jl | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 64297a6ce2..605b375805 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -20,10 +20,10 @@ end shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) - μ1 = Reactant.ConcreteRArray(0.0) - μ2 = Reactant.ConcreteRArray(1000.0) - σ1 = Reactant.ConcreteRArray(1.0) - σ2 = Reactant.ConcreteRArray(1.0) + μ1 = Reactant.ConcreteRNumber(0.0) + μ2 = Reactant.ConcreteRNumber(1000.0) + σ1 = Reactant.ConcreteRNumber(1.0) + σ2 = Reactant.ConcreteRNumber(1.0) model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) @@ -41,8 +41,8 @@ end @testset "normal_hlo" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) @test contains(repr(before), "enzyme.generate") @@ -56,8 +56,8 @@ end @testset "normal_generate" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 8d79488566..9c711241d8 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -30,8 +30,8 @@ end @testset "sample_hlo" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = false sample2(seed, μ, σ, shape) @test contains(repr(before), "enzyme.sample") after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape) @@ -41,8 +41,8 @@ end @testset "sample_normal" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape)) Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape)) @test !all(X .≈ Y) From dd9dcabe5bd683b384dc6bf40b6520d9c99fae18 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 11 Jun 2025 17:36:15 -0500 Subject: [PATCH 28/87] hacks to temporarily remove world age issue in tests --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 46d9b4f786..5b0d33a88f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -81,7 +81,7 @@ end resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") - mlir_fn_res = TracedUtils.make_mlir_fn( + mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, f, args, (), @@ -145,7 +145,7 @@ end resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") - mlir_fn_res = TracedUtils.make_mlir_fn( + mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, f, args, (), From a34472613d541527cb54b126704bd4345f533e89 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 12 Jun 2025 17:20:15 -0500 Subject: [PATCH 29/87] partial refactoring --- src/ProbProg.jl | 150 +++++++++++++++++++----------------------------- 1 file changed, 58 insertions(+), 92 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 5b0d33a88f..d9f0672071 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,9 +1,6 @@ module ProbProg -using ..Reactant: Reactant, XLA, MLIR, TracedUtils, TracedRArray, ConcretePJRTArray -using ReactantCore: ReactantCore -using Libdl: Libdl - +using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray using Enzyme struct SampleMetadata @@ -18,16 +15,10 @@ struct SampleMetadata end end -const SAMPLE_METADATA_CACHE = IdDict{Symbol,SampleMetadata}() -const Trace = IdDict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) - -function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}}) - trace_ptr = unsafe_load(trace_ptr_ptr) - @assert reinterpret(UInt64, trace_ptr) == 42 - - unsafe_store!(trace_ptr_ptr, pointer_from_objref(Trace)) +const SAMPLE_METADATA_CACHE = Dict{Symbol,SampleMetadata}() - return nothing +function createTrace() + return Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) end function addSampleToTraceLowered( @@ -46,7 +37,7 @@ function addSampleToTraceLowered( if is_scalar trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) else - trace[symbol] = Base.deepcopy( + trace[symbol] = copy( reshape( unsafe_wrap( Array{element_type}, @@ -62,10 +53,6 @@ function addSampleToTraceLowered( end function __init__() - init_trace_ptr = @cfunction(initTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}},)) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} - )::Cvoid add_sample_to_trace_ptr = @cfunction( addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}) ) @@ -81,7 +68,8 @@ end resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") - mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, f, args, (), @@ -139,13 +127,17 @@ end end @noinline function sample!( - f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + trace::Union{Dict,Nothing}=nothing, ) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") - mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn, + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, f, args, (), @@ -191,47 +183,44 @@ end ) end - symbol_ptr = pointer_from_objref(symbol) - symbol_addr = reinterpret(UInt64, symbol_ptr) - addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr]) + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) sample_op = MLIR.Dialects.enzyme.sample( - MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=addr_attr), 1), - batch_inputs; - outputs=out_tys, - fn=fn_attr, + batch_inputs; outputs=out_tys, fn=fn_attr, symbol=symbol_addr ) for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(sample_op, i) - - for path in res.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + else + if fnwrap + idx -= 1 end + TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) end + else + TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) end end return result end -@noinline function simulate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function simulate!( + f::Function, args::Vararg{Any,Nargs}; trace::Dict +) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") resargprefix::Symbol = gensym("simulateresarg") - mlir_fn_res = TracedUtils.make_mlir_fn( + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, f, args, (), @@ -242,10 +231,14 @@ end resprefix, resargprefix, ) - (; linear_args, linear_results) = mlir_fn_res + (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) @@ -259,63 +252,36 @@ end end end - out_tys = MLIR.IR.Type[] - supress_rest = false - for res in linear_results - if TracedUtils.has_idx(res, resprefix) && !supress_rest - push!(out_tys, MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64))) - supress_rest = true - else - # push!(out_tys, MLIR.IR.type(TracedUtils.get_mlir_data(res))) - end - end + trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - simulate_op = MLIR.Dialects.enzyme.simulate(batch_inputs; outputs=out_tys, fn=fname) + simulate_op = MLIR.Dialects.enzyme.simulate( + batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + ) - result = nothing for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(simulate_op, i) - if TracedUtils.has_idx(res, resprefix) - # casted = MLIR.IR.result( - # MLIR.Dialects.builtin.unrealized_conversion_cast( - # resv; to=MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64)) - # ), - # 1, - # ) - # result = TracedRArray(casted) - result = TracedRArray(resv) - break - # continue + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + end + else + TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) end - - # for path in res.paths - # isempty(path) && continue - # if path[1] == argprefix - # idx = path[2]::Int - # if idx == 1 && fnwrap - # TracedUtils.set!(f, path[3:end], resv) - # else - # if fnwrap - # idx -= 1 - # end - # TracedUtils.set!(args[idx], path[3:end], resv) - # end - # end - # end end - return result -end - -function getTrace(t::ConcretePJRTArray) - return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1])) + return trace, result end -function print_trace(trace::IdDict) +function print_trace(trace::Dict) println("Probabilistic Program Trace:") for (symbol, sample) in trace symbol == :_integrity_check && continue From ef2e77064851226259ab813e3131a3e9b119e82c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 14 Jun 2025 16:51:59 -0500 Subject: [PATCH 30/87] fixed tracing infra --- src/Compiler.jl | 92 ++++++++++++++++++++++---- src/ProbProg.jl | 133 ++++++++++++++++---------------------- test/probprog/simulate.jl | 54 +++++++++------- 3 files changed, 165 insertions(+), 114 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index dc92821853..a4e2ff44a5 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1740,21 +1740,91 @@ function compile_mlir!( ), "only_enzyme", ) + elseif optimize === :probprog_no_lowering + run_pass_pipeline!( + mod, + join( + if raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + ] + end, + ",", + ), + "probprog_no_lowering", + ) elseif optimize === :probprog run_pass_pipeline!( mod, join( - [ - "mark-func-memory-effects", - "enzyme-batch", - "probprog", - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - lower_enzyme_probprog_pass, - jit - ], - ',', + if raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + end, + ",", ), "probprog", ) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index d9f0672071..74c014de82 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -3,50 +3,38 @@ module ProbProg using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray using Enzyme -struct SampleMetadata - shape::NTuple{N,Int} where {N} - element_type::Type - is_scalar::Bool - - function SampleMetadata( - shape::NTuple{N,Int}, element_type::Type, is_scalar::Bool - ) where {N} - return new(shape, element_type, is_scalar) - end -end - -const SAMPLE_METADATA_CACHE = Dict{Symbol,SampleMetadata}() - function createTrace() - return Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef) + return Dict{Symbol,Any}() end function addSampleToTraceLowered( - trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, sample_ptr::Ptr{Cvoid} + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr::Ptr{Any}, + num_dims_ptr::Ptr{Int64}, + shape_array_ptr::Ptr{Int64}, + datatype_width_ptr::Ptr{Int64}, ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - @assert haskey(SAMPLE_METADATA_CACHE, symbol) "Symbol $symbol not found in metadata cache" + num_dims = unsafe_load(num_dims_ptr) + shape_array = unsafe_wrap(Array, shape_array_ptr, num_dims) + datatype_width = unsafe_load(datatype_width_ptr) - metadata = SAMPLE_METADATA_CACHE[symbol] - shape = metadata.shape - element_type = metadata.element_type - is_scalar = metadata.is_scalar + julia_type = if datatype_width == 32 + Float32 + elseif datatype_width == 64 + Float64 + else + error("Unsupported datatype width: $datatype_width") + end - if is_scalar - trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr)) + typed_ptr = Ptr{julia_type}(sample_ptr) + if num_dims == 0 + trace[symbol] = unsafe_load(typed_ptr) else - trace[symbol] = copy( - reshape( - unsafe_wrap( - Array{element_type}, - reinterpret(Ptr{element_type}, sample_ptr), - prod(shape), - ), - shape, - ), - ) + trace[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) end return nothing @@ -54,7 +42,9 @@ end function __init__() add_sample_to_trace_ptr = @cfunction( - addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}) + addSampleToTraceLowered, + Cvoid, + (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Any}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}) ) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} @@ -105,21 +95,21 @@ end for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(gen_op, i) - for path in res.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + else + if fnwrap + idx -= 1 end + TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) end + else + TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) end end @@ -127,10 +117,7 @@ end end @noinline function sample!( - f::Function, - args::Vararg{Any,Nargs}; - symbol::Symbol=gensym("sample"), - trace::Union{Dict,Nothing}=nothing, + f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") ) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") @@ -169,24 +156,21 @@ end sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - if !isempty(linear_results) - sample_result = linear_results[1] # TODO: consider multiple results - sample_mlir_data = TracedUtils.get_mlir_data(sample_result) - @assert sample_mlir_data isa MLIR.IR.Value "Sample $sample_result is not a MLIR.IR.Value" - - sample_type = MLIR.IR.type(sample_mlir_data) - sample_shape = size(sample_type) - sample_element_type = MLIR.IR.julia_type(eltype(sample_type)) - - SAMPLE_METADATA_CACHE[symbol] = SampleMetadata( - sample_shape, sample_element_type, length(sample_shape) == 0 - ) + traced_output_indices = Int[] + for (i, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end end symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) sample_op = MLIR.Dialects.enzyme.sample( - batch_inputs; outputs=out_tys, fn=fn_attr, symbol=symbol_addr + batch_inputs; + outputs=out_tys, + fn=fn_attr, + symbol=symbol_addr, + traced_output_indices=traced_output_indices, ) for (i, res) in enumerate(linear_results) @@ -213,7 +197,7 @@ end end @noinline function simulate!( - f::Function, args::Vararg{Any,Nargs}; trace::Dict + f::Function, args::Vararg{Any,Nargs}; trace::Dict{Symbol,Any} ) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") @@ -278,25 +262,16 @@ end end end - return trace, result + return result end -function print_trace(trace::Dict) - println("Probabilistic Program Trace:") +function print_trace(trace::Dict{Symbol,Any}) + println("### Probabilistic Program Trace ###") for (symbol, sample) in trace - symbol == :_integrity_check && continue - metadata = SAMPLE_METADATA_CACHE[symbol] - println(" $symbol:") println(" Sample: $(sample)") - println(" Shape: $(metadata.shape)") - println(" Element Type: $(metadata.element_type)") end + println("### End of Trace ###") end -function clear_sample_metadata_cache!() - empty!(SAMPLE_METADATA_CACHE) - return nothing -end - -end +end \ No newline at end of file diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 97910443e2..59bbfe0509 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -1,46 +1,52 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -using Libdl: Libdl -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +@testset "Simulate" begin + normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function simulate_model(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) - t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) - return t - end + function simulate_model(trace, seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) + return t + end - return ProbProg.simulate!(model, seed, μ, σ, shape) -end - -@testset "Simulate" begin + result = ProbProg.simulate!(model, seed, μ, σ, shape; trace) + return result + end @testset "normal_hlo" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace = ProbProg.createTrace() - before = @code_hlo optimize = :no_enzyme simulate_model(seed, μ, σ, shape) + before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape) @test contains(repr(before), "enzyme.simulate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog simulate_model(seed, μ, σ, shape) + after = @code_hlo optimize = :probprog simulate_model(trace, seed, μ, σ, shape) @test !contains(repr(after), "enzyme.simulate") @test !contains(repr(after), "enzyme.sample") @test contains(repr(after), "enzyme_probprog_add_sample_to_trace") - @test contains(repr(after), "enzyme_probprog_init_trace") end @testset "normal_simulate" begin shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRArray(0.0) - σ = Reactant.ConcreteRArray(1.0) - X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape)) - @test X[:_integrity_check] == 0x123456789abcdef - ProbProg.print_trace(X) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace = ProbProg.createTrace() + + result = Array( + @jit optimize = :probprog sync = true simulate_model(trace, seed, μ, σ, shape) + ) + + ProbProg.print_trace(trace) + @test size(result) == shape end end From 46e0f6b9f6f47ee69df4148bc665c4ea998e388c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 16 Jun 2025 17:54:15 -0500 Subject: [PATCH 31/87] transpose fix up --- src/ProbProg.jl | 33 ++++++++++++++++++--------------- test/probprog/generate.jl | 22 +++++++++++++++++++--- test/probprog/sample.jl | 2 +- test/probprog/simulate.jl | 29 +++++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 74c014de82..bfc6b7d942 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -27,7 +27,8 @@ function addSampleToTraceLowered( elseif datatype_width == 64 Float64 else - error("Unsupported datatype width: $datatype_width") + @ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid + return nothing end typed_ptr = Ptr{julia_type}(sample_ptr) @@ -65,6 +66,7 @@ end (), string(f), false; + do_transpose=false, args_in_result=:all, argprefix, resprefix, @@ -97,19 +99,19 @@ end resv = MLIR.IR.result(gen_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(result, path[2:end], resv) elseif TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(f, path[3:end], resv) else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(args[idx], path[3:end], resv) end else - TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) + TracedUtils.set!(res, (), resv) end end @@ -130,8 +132,8 @@ end (), string(f), false; + do_transpose=false, args_in_result=:all, - do_transpose=false, # TODO: double check transpose argprefix, resprefix, resargprefix, @@ -177,19 +179,19 @@ end resv = MLIR.IR.result(sample_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(result, path[2:end], resv) elseif TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(f, path[3:end], resv) else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(args[idx], path[3:end], resv) end else - TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) + TracedUtils.set!(res, (), resv) end end @@ -210,6 +212,7 @@ end (), string(f), false; + do_transpose=false, args_in_result=:all, argprefix, resprefix, @@ -246,19 +249,19 @@ end resv = MLIR.IR.result(simulate_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(result, path[2:end], resv) elseif TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(f, path[3:end], resv) else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv)) + TracedUtils.set!(args[idx], path[3:end], resv) end else - TracedUtils.set!(res, (), TracedUtils.transpose_val(resv)) + TracedUtils.set!(res, (), resv) end end @@ -271,7 +274,7 @@ function print_trace(trace::Dict{Symbol,Any}) println(" $symbol:") println(" Sample: $(sample)") end - println("### End of Trace ###") + return println("### End of Trace ###") end end \ No newline at end of file diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 605b375805..cefd648a93 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -16,7 +16,7 @@ function generate_model(seed, μ, σ, shape) end @testset "Generate" begin - @testset "normal_deterministic" begin + @testset "deterministic" begin shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) @@ -38,7 +38,7 @@ end Array(model_compiled(seed2, μ2, σ2, shape)), )) end - @testset "normal_hlo" begin + @testset "hlo" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) @@ -53,7 +53,7 @@ end @test !contains(repr(after), "enzyme.sample") end - @testset "normal_generate" begin + @testset "normal" begin shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) @@ -61,4 +61,20 @@ end X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end + + @testset "correctness" begin + op(x, y) = x * y' + + function fake_model(x, y) + return ProbProg.sample!(op, x, y) + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test Array(@jit optimize = :probprog ProbProg.generate!(fake_model, x_ra, y_ra)) == + op(x, y) + end end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 9c711241d8..aabf476f94 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -18,7 +18,7 @@ function sample2(seed, μ, σ, shape) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) + _ = ProbProg.sample!(normal, rng, μ, σ, shape) t = ProbProg.sample!(normal, rng, μ, σ, shape) return t end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 59bbfe0509..7b44e8bcd9 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -42,11 +42,32 @@ using Reactant: ProbProg trace = ProbProg.createTrace() - result = Array( - @jit optimize = :probprog sync = true simulate_model(trace, seed, μ, σ, shape) - ) + result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape)) - ProbProg.print_trace(trace) @test size(result) == shape + @test haskey(trace, :s) + @test haskey(trace, :t) + @test size(trace[:s]) == shape + @test size(trace[:t]) == shape + end + + @testset "correctness" begin + op(x, y) = x * y' + function fake_model(x, y) + return ProbProg.sample!(op, x, y; symbol=:matmul) + end + + trace = ProbProg.createTrace() + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test Array( + @jit optimize = :probprog ProbProg.simulate!(fake_model, x_ra, y_ra; trace) + ) == op(x, y) + + @test haskey(trace, :matmul) + @test trace[:matmul] == op(x, y) end end From 1c5297cab8f1e09ea6861fe54397997581d8fc51 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 17 Jun 2025 13:12:45 -0500 Subject: [PATCH 32/87] minor changes --- src/ProbProg.jl | 10 ++++---- test/probprog/blr.jl | 49 +++++++++++++++++++++------------------ test/probprog/simulate.jl | 6 ++--- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index bfc6b7d942..f876d43122 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -3,10 +3,6 @@ module ProbProg using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray using Enzyme -function createTrace() - return Dict{Symbol,Any}() -end - function addSampleToTraceLowered( trace_ptr_ptr::Ptr{Ptr{Any}}, symbol_ptr_ptr::Ptr{Ptr{Any}}, @@ -26,6 +22,8 @@ function addSampleToTraceLowered( Float32 elseif datatype_width == 64 Float64 + elseif datatype_width == 1 + Bool else @ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid return nothing @@ -268,6 +266,10 @@ end return result end +function create_trace() + return Dict{Symbol,Any}() +end + function print_trace(trace::Dict{Symbol,Any}) println("### Probabilistic Program Trace ###") for (symbol, sample) in trace diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 4f0836b76c..3e0e040963 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -1,33 +1,33 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random using Reactant: ProbProg -using Libdl: Libdl -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +function normal(rng, μ, σ, shape) + return μ .+ σ .* randn(rng, shape) +end -function blr(seed, N, K) - function model(seed, N, K) - rng = Random.default_rng() - Random.seed!(rng, seed) +function bernoulli_logit(rng, logit, shape) + return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +end - # α ~ Normal(0, 10, size = 1) - α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) +function blr(seed, N, K) + rng = Random.default_rng() + Random.seed!(rng, seed) - # β ~ Normal(0, 2.5, size = K) - β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) + # α ~ Normal(0, 10, size = 1) + α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) - # X ~ Normal(0, 10, size = (N, K)) - X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose + # β ~ Normal(0, 2.5, size = K) + β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) - # μ = α .+ X * β - μ = α .+ X * β + # X ~ Normal(0, 10, size = (N, K)) + X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) - ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) + # μ = α .+ X * β + μ = α .+ X * β - return μ - end + Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) - return ProbProg.simulate!(model, seed, N, K) + return Y end @testset "BLR" begin @@ -35,6 +35,11 @@ end K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K)) - ProbProg.print_trace(X) + trace = ProbProg.create_trace() + + @test size( + Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace)) + ) == (N,) + + ProbProg.print_trace(trace) end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 7b44e8bcd9..403505e1b0 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -22,7 +22,7 @@ using Reactant: ProbProg μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.createTrace() + trace = ProbProg.create_trace() before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape) @test contains(repr(before), "enzyme.simulate") @@ -40,7 +40,7 @@ using Reactant: ProbProg μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.createTrace() + trace = ProbProg.create_trace() result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape)) @@ -57,7 +57,7 @@ using Reactant: ProbProg return ProbProg.sample!(op, x, y; symbol=:matmul) end - trace = ProbProg.createTrace() + trace = ProbProg.create_trace() x = reshape(collect(Float64, 1:12), (4, 3)) y = reshape(collect(Float64, 1:12), (4, 3)) x_ra = Reactant.to_rarray(x) From d707053c7f553828dea03f458584b1e5181fe356 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 17 Jun 2025 16:01:00 -0500 Subject: [PATCH 33/87] reorder --- src/ProbProg.jl | 91 ++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index f876d43122..230ce3c2af 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -52,10 +52,12 @@ function __init__() return nothing end -@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") +@noinline function sample!( + f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") +) where {Nargs} + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, @@ -70,31 +72,45 @@ end resprefix, resargprefix, ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res + (; result, linear_args, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap TracedUtils.push_val!(batch_inputs, f, path[3:end]) else - if fnwrap - idx -= 1 - end + idx -= fnwrap ? 1 : 0 TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) end end - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + traced_output_indices = Int[] for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(gen_op, i) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end + end + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + + sample_op = MLIR.Dialects.enzyme.sample( + batch_inputs; + outputs=out_tys, + fn=fn_attr, + symbol=symbol_addr, + traced_output_indices=traced_output_indices, + ) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(sample_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -116,12 +132,10 @@ end return result end -@noinline function sample!( - f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") -) where {Nargs} - argprefix::Symbol = gensym("samplearg") - resprefix::Symbol = gensym("sampleresult") - resargprefix::Symbol = gensym("sampleresarg") +@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, @@ -136,45 +150,31 @@ end resprefix, resargprefix, ) - (; result, linear_args, linear_results) = mlir_fn_res + (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap TracedUtils.push_val!(batch_inputs, f, path[3:end]) else - idx -= fnwrap ? 1 : 0 + if fnwrap + idx -= 1 + end TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) end end - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - sym = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - - traced_output_indices = Int[] - for (i, res) in enumerate(linear_results) - if TracedUtils.has_idx(res, resprefix) - push!(traced_output_indices, i - 1) - end - end - - symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) - - sample_op = MLIR.Dialects.enzyme.sample( - batch_inputs; - outputs=out_tys, - fn=fn_attr, - symbol=symbol_addr, - traced_output_indices=traced_output_indices, - ) + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(sample_op, i) + resv = MLIR.IR.result(gen_op, i) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -278,5 +278,4 @@ function print_trace(trace::Dict{Symbol,Any}) end return println("### End of Trace ###") end - -end \ No newline at end of file +end From 91a0850ed245da954b3c294475418149ed14476a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 20 Jun 2025 14:17:03 -0500 Subject: [PATCH 34/87] API change --- src/ProbProg.jl | 50 ++++++++++++++++++----------- src/Reactant.jl | 2 +- test/probprog/blr.jl | 16 ++++------ test/probprog/generate.jl | 36 +++++++++++---------- test/probprog/sample.jl | 44 +++++++++++-------------- test/probprog/simulate.jl | 67 +++++++++++++++++---------------------- 6 files changed, 106 insertions(+), 109 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 230ce3c2af..0dc2bdeffb 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,8 +1,18 @@ module ProbProg -using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray +using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber +using ..Compiler: @jit using Enzyme +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + + function ProbProgTrace() + return new(Dict{Symbol,Any}(), nothing) + end +end + function addSampleToTraceLowered( trace_ptr_ptr::Ptr{Ptr{Any}}, symbol_ptr_ptr::Ptr{Ptr{Any}}, @@ -31,9 +41,9 @@ function addSampleToTraceLowered( typed_ptr = Ptr{julia_type}(sample_ptr) if num_dims == 0 - trace[symbol] = unsafe_load(typed_ptr) + trace.choices[symbol] = unsafe_load(typed_ptr) else - trace[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) + trace.choices[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) end return nothing @@ -52,7 +62,7 @@ function __init__() return nothing end -@noinline function sample!( +function sample( f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") ) where {Nargs} argprefix::Symbol = gensym("samplearg") @@ -132,7 +142,12 @@ end return result end -@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog generate_internal(f, args...) + return res isa AbstractConcreteArray ? Array(res) : res +end + +function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") @@ -196,8 +211,18 @@ end return result end -@noinline function simulate!( - f::Function, args::Vararg{Any,Nargs}; trace::Dict{Symbol,Any} +function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = ProbProgTrace() + + res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace) + + trace.retval = res isa AbstractConcreteArray ? Array(res) : res + + return trace +end + +function simulate_internal( + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace ) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") @@ -266,16 +291,5 @@ end return result end -function create_trace() - return Dict{Symbol,Any}() -end -function print_trace(trace::Dict{Symbol,Any}) - println("### Probabilistic Program Trace ###") - for (symbol, sample) in trace - println(" $symbol:") - println(" Sample: $(sample)") - end - return println("### End of Trace ###") -end end diff --git a/src/Reactant.jl b/src/Reactant.jl index f0b6c044f1..48874e8f95 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -176,7 +176,6 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") -include("ProbProg.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} @@ -189,6 +188,7 @@ export OptimizeCommunicationOptions include("Compiler.jl") include("Overlay.jl") +include("ProbProg.jl") using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile export ConcreteRArray, diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 3e0e040963..7c53aaafd7 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -14,18 +14,18 @@ function blr(seed, N, K) Random.seed!(rng, seed) # α ~ Normal(0, 10, size = 1) - α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) + α = ProbProg.sample(normal, rng, 0, 10, (1,); symbol=:α) # β ~ Normal(0, 2.5, size = K) - β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) + β = ProbProg.sample(normal, rng, 0, 2.5, (K,); symbol=:β) # X ~ Normal(0, 10, size = (N, K)) - X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) + X = ProbProg.sample(normal, rng, 0, 10, (N, K); symbol=:X) # μ = α .+ X * β μ = α .+ X * β - Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) + Y = ProbProg.sample(bernoulli_logit, rng, μ, (N,); symbol=:Y) return Y end @@ -35,11 +35,9 @@ end K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - trace = ProbProg.create_trace() + trace = ProbProg.simulate(blr, seed, N, K) - @test size( - Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace)) - ) == (N,) + @test size(Array(trace.retval)) == (N,) - ProbProg.print_trace(trace) + println(trace) end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index cefd648a93..9c93c7a6f5 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -1,18 +1,14 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random, Statistics using Reactant: ProbProg normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function generate_model(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) - t = ProbProg.sample!(normal, rng, s, σ, shape) - return t - end - - return ProbProg.generate!(model, seed, μ, σ, shape) +function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, s, σ, shape) + return t end @testset "Generate" begin @@ -25,6 +21,9 @@ end σ1 = Reactant.ConcreteRNumber(1.0) σ2 = Reactant.ConcreteRNumber(1.0) + generate_model(seed, μ, σ, shape) = + ProbProg.generate_internal(model, seed, μ, σ, shape) + model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ @@ -44,11 +43,15 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) + before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( + model, seed, μ, σ, shape + ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.generate_internal( + model, seed, μ, σ, shape + ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @@ -58,7 +61,7 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) + X = ProbProg.generate(model, seed, μ, σ, shape) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end @@ -66,7 +69,7 @@ end op(x, y) = x * y' function fake_model(x, y) - return ProbProg.sample!(op, x, y) + return ProbProg.sample(op, x, y) end x = reshape(collect(Float64, 1:12), (4, 3)) @@ -74,7 +77,6 @@ end x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) - @test Array(@jit optimize = :probprog ProbProg.generate!(fake_model, x_ra, y_ra)) == - op(x, y) + @test ProbProg.generate(fake_model, x_ra, y_ra) == op(x, y) end end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index aabf476f94..9541b2feb8 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -1,29 +1,21 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random using Reactant: ProbProg -@noinline normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function sample1(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape) - return s - end - - return ProbProg.generate!(model, seed, μ, σ, shape) +function one_sample(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape) + return s end -function sample2(seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - _ = ProbProg.sample!(normal, rng, μ, σ, shape) - t = ProbProg.sample!(normal, rng, μ, σ, shape) - return t - end - - return ProbProg.generate!(model, seed, μ, σ, shape) +function two_samples(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + _ = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, μ, σ, shape) + return t end @testset "test" begin @@ -32,19 +24,19 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false sample2(seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.generate_internal(one_sample, seed, μ, σ, shape) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.generate_internal(two_samples, seed, μ, σ, shape) @test !contains(repr(after), "enzyme.sample") end - @testset "sample_normal" begin + @testset "rng_state" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape)) - Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape)) + X = ProbProg.generate(one_sample, seed, μ, σ, shape) + Y = ProbProg.generate(two_samples, seed, μ, σ, shape) @test !all(X .≈ Y) end end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 403505e1b0..a97fc5ae8d 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -1,37 +1,32 @@ -using Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random using Reactant: ProbProg -@testset "Simulate" begin - normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) - function simulate_model(trace, seed, μ, σ, shape) - function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s) - t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t) - return t - end +function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t) + return t +end - result = ProbProg.simulate!(model, seed, μ, σ, shape; trace) - return result - end - @testset "normal_hlo" begin - shape = (10000,) +@testset "Simulate" begin + @testset "simulate_hlo" begin + shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.create_trace() - - before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.simulate_internal( + model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + ) @test contains(repr(before), "enzyme.simulate") - @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog simulate_model(trace, seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.simulate_internal( + model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + ) @test !contains(repr(after), "enzyme.simulate") - @test !contains(repr(after), "enzyme.sample") - @test contains(repr(after), "enzyme_probprog_add_sample_to_trace") end @testset "normal_simulate" begin @@ -40,34 +35,30 @@ using Reactant: ProbProg μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.create_trace() - - result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape)) + trace = ProbProg.simulate(model, seed, μ, σ, shape) - @test size(result) == shape - @test haskey(trace, :s) - @test haskey(trace, :t) - @test size(trace[:s]) == shape - @test size(trace[:t]) == shape + @test size(trace.retval) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s]) == shape + @test size(trace.choices[:t]) == shape end @testset "correctness" begin op(x, y) = x * y' function fake_model(x, y) - return ProbProg.sample!(op, x, y; symbol=:matmul) + return ProbProg.sample(op, x, y; symbol=:matmul) end - trace = ProbProg.create_trace() x = reshape(collect(Float64, 1:12), (4, 3)) y = reshape(collect(Float64, 1:12), (4, 3)) x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) - @test Array( - @jit optimize = :probprog ProbProg.simulate!(fake_model, x_ra, y_ra; trace) - ) == op(x, y) + trace = ProbProg.simulate(fake_model, x_ra, y_ra) - @test haskey(trace, :matmul) - @test trace[:matmul] == op(x, y) + @test Array(trace.retval) == op(x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul] == op(x, y) end end From 561b051b6c0801160a5e8a5df824547924171651 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 20 Jun 2025 14:19:51 -0500 Subject: [PATCH 35/87] better print --- src/ProbProg.jl | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 0dc2bdeffb..3373831a53 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -291,5 +291,71 @@ function simulate_internal( return result end +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end end From 99d7608c10cb14d3c033864f2996c6320a21458a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Jun 2025 17:47:19 -0500 Subject: [PATCH 36/87] unconstrained real generate op --- src/ProbProg.jl | 109 +++++++++++++++++++++++++++++++++----- test/probprog/generate.jl | 57 ++++---------------- 2 files changed, 105 insertions(+), 61 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 3373831a53..1b7986d3ab 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -1,15 +1,22 @@ module ProbProg -using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber +using ..Reactant: + MLIR, + TracedUtils, + AbstractConcreteArray, + AbstractConcreteNumber, + AbstractRNG, + TracedRArray using ..Compiler: @jit using Enzyme mutable struct ProbProgTrace choices::Dict{Symbol,Any} retval::Any + weight::Any function ProbProgTrace() - return new(Dict{Symbol,Any}(), nothing) + return new(Dict{Symbol,Any}(), nothing, nothing) end end @@ -63,7 +70,10 @@ function __init__() end function sample( - f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample") + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, ) where {Nargs} argprefix::Symbol = gensym("samplearg") resprefix::Symbol = gensym("sampleresult") @@ -102,6 +112,7 @@ function sample( sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + # Specify which outputs to add to the trace. traced_output_indices = Int[] for (i, res) in enumerate(linear_results) if TracedUtils.has_idx(res, resprefix) @@ -109,13 +120,60 @@ function sample( end end + # Specify which inputs to pass to logpdf. + traced_input_indices = Int[] + for (i, a) in enumerate(linear_args) + idx, _ = TracedUtils.get_argidx(a, argprefix) + if fnwrap && idx == 1 # TODO: add test for fnwrap + continue + end + + if fnwrap + idx -= 1 + end + + if !(args[idx] isa AbstractRNG) + push!(traced_input_indices, i - 1) + end + end + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + # Construct MLIR attribute if Julia logpdf function is provided. + logpdf_attr = nothing + if logpdf !== nothing + # Just to get static information about the sample. TODO: kwargs? + example_sample = f(args...) + + # Remove AbstractRNG from `f`'s argument list if present, assuming that + # logpdf parameters follows `(sample, args...)` convention. + logpdf_args = (example_sample,) + if !isempty(args) && args[1] isa AbstractRNG + logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs? + end + + logpdf_mlir = invokelatest( + TracedUtils.make_mlir_fn, + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:all, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + sample_op = MLIR.Dialects.enzyme.sample( batch_inputs; outputs=out_tys, fn=fn_attr, + logpdf=logpdf_attr, symbol=symbol_addr, + traced_input_indices=traced_input_indices, traced_output_indices=traced_output_indices, ) @@ -143,11 +201,19 @@ function sample( end function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - res = @jit optimize = :probprog generate_internal(f, args...) - return res isa AbstractConcreteArray ? Array(res) : res + trace = ProbProgTrace() + + weight, res = @jit optimize = :probprog generate_internal(f, args...; trace) + + trace.retval = res isa AbstractConcreteArray ? Array(res) : res + trace.weight = Array(weight)[1] + + return trace, trace.weight end -function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function generate_internal( + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace +) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") @@ -169,7 +235,8 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + f_out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + out_tys = [MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)); f_out_tys] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) @@ -186,10 +253,17 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end end - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) + + # Output: (weight, f's outputs...) + gen_op = MLIR.Dialects.enzyme.generate( + batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + ) + + weight = TracedRArray(MLIR.IR.result(gen_op, 1)) for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(gen_op, i) + resv = MLIR.IR.result(gen_op, i + 1) # to skip weight if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -208,7 +282,7 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end end - return result + return weight, result end function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} @@ -299,7 +373,6 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) LAST = '\u2514' indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) @@ -320,6 +393,10 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) n += 1 end + if trace.weight !== nothing + n += 1 + end + cur = 1 if trace.retval !== nothing @@ -328,6 +405,12 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) cur += 1 end + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + for (key, value) in sorted_choices print(io, indent_vert_str) print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") @@ -337,7 +420,7 @@ end function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) println(io, "ProbProgTrace:") - if isempty(trace.choices) && trace.retval === nothing + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing println(io, " (empty)") else _show_pretty(io, trace, 0, ()) @@ -350,7 +433,7 @@ function Base.show(io::IO, trace::ProbProgTrace) has_retval = trace.retval !== nothing print(io, "ProbProgTrace($(choices_count) choices") if has_retval - print(io, ", retval=$(trace.retval)") + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") end print(io, ")") else diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 9c93c7a6f5..8c1a8917a4 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -2,81 +2,42 @@ using Reactant, Test, Random, Statistics using Reactant: ProbProg normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape) - t = ProbProg.sample(normal, rng, s, σ, shape) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf) return t end @testset "Generate" begin - @testset "deterministic" begin - shape = (10000,) - seed1 = Reactant.to_rarray(UInt64[1, 4]) - seed2 = Reactant.to_rarray(UInt64[1, 4]) - μ1 = Reactant.ConcreteRNumber(0.0) - μ2 = Reactant.ConcreteRNumber(1000.0) - σ1 = Reactant.ConcreteRNumber(1.0) - σ2 = Reactant.ConcreteRNumber(1.0) - - generate_model(seed, μ, σ, shape) = - ProbProg.generate_internal(model, seed, μ, σ, shape) - - model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape) - - @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ - Array(model_compiled(seed1, μ1, σ1, shape)) - @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = - 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = - 0.05 - @test !(all( - Array(model_compiled(seed1, μ1, σ1, shape)) .≈ - Array(model_compiled(seed2, μ2, σ2, shape)), - )) - end @testset "hlo" begin - shape = (10000,) + shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( - model, seed, μ, σ, shape + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") after = @code_hlo optimize = :probprog ProbProg.generate_internal( - model, seed, μ, σ, shape + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @testset "normal" begin - shape = (10000,) + shape = (1000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = ProbProg.generate(model, seed, μ, σ, shape) - @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 - end - - @testset "correctness" begin - op(x, y) = x * y' - - function fake_model(x, y) - return ProbProg.sample(op, x, y) - end - - x = reshape(collect(Float64, 1:12), (4, 3)) - y = reshape(collect(Float64, 1:12), (4, 3)) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) - - @test ProbProg.generate(fake_model, x_ra, y_ra) == op(x, y) + trace, weight = ProbProg.generate(model, seed, μ, σ, shape) + @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From b13f8bf58a700876a76d47b28069e5acfeab7346 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 25 Jun 2025 17:47:41 -0500 Subject: [PATCH 37/87] probprog postpasses --- src/Compiler.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index a4e2ff44a5..bd496ba0f1 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1183,6 +1183,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1753,7 +1754,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", @@ -1767,7 +1768,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", @@ -1794,7 +1795,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", @@ -1811,7 +1812,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - "probprog", + probprog_pass, opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", From 6e4dc0c4d56efe81f984ce3a42e5e56076f5814b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 13:21:24 -0500 Subject: [PATCH 38/87] bug fix for alising outputs --- src/ProbProg.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 1b7986d3ab..18de59758a 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -139,6 +139,25 @@ function sample( symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + # (out_idx1, in_idx1, out_idx2, in_idx2, ...) + alias_pairs = Int64[] + for (out_idx, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, argprefix) + in_idx = nothing + for (i, arg) in enumerate(linear_args) + if TracedUtils.has_idx(arg, argprefix) && + TracedUtils.get_idx(arg, argprefix) == TracedUtils.get_idx(res, argprefix) + in_idx = i - 1 + break + end + end + @assert in_idx !== nothing "Unable to find operand for aliased result" + push!(alias_pairs, out_idx - 1) + push!(alias_pairs, in_idx) + end + end + alias_attr = MLIR.IR.DenseArrayAttribute(alias_pairs) + # Construct MLIR attribute if Julia logpdf function is provided. logpdf_attr = nothing if logpdf !== nothing @@ -175,6 +194,8 @@ function sample( symbol=symbol_addr, traced_input_indices=traced_input_indices, traced_output_indices=traced_output_indices, + alias_map=alias_attr, + name=Base.String(symbol), ) for (i, res) in enumerate(linear_results) From 5b5c1d15938b33336428b65f3750e63064f661e3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 13:26:48 -0500 Subject: [PATCH 39/87] generate op with constraints --- deps/ReactantExtra/API.cpp | 14 ++++++++++ src/ProbProg.jl | 52 +++++++++++++++++++++++++++++++++----- test/probprog/generate.jl | 19 ++++++++++++++ 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 911b21ae77..9658522757 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -353,6 +353,20 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } +extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet( + MlirContext ctx, uint64_t symbol, MlirAttribute values) { + mlir::Attribute vals = unwrap(values); + auto arr = llvm::dyn_cast(vals); + if (!arr) { + ReactantThrowError( + "enzymeConstraintAttrGet: `values` must be an ArrayAttr"); + return MlirAttribute{nullptr}; + } + mlir::Attribute attr = + mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr); + return wrap(attr); +} + // Create profiler session and start profiling extern "C" tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 18de59758a..6e161676f1 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -146,7 +146,8 @@ function sample( in_idx = nothing for (i, arg) in enumerate(linear_args) if TracedUtils.has_idx(arg, argprefix) && - TracedUtils.get_idx(arg, argprefix) == TracedUtils.get_idx(res, argprefix) + TracedUtils.get_idx(arg, argprefix) == + TracedUtils.get_idx(res, argprefix) in_idx = i - 1 break end @@ -221,10 +222,12 @@ function sample( return result end -function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} trace = ProbProgTrace() - weight, res = @jit optimize = :probprog generate_internal(f, args...; trace) + weight, res = @jit sync = true optimize = :probprog generate_internal( + f, args...; trace, constraints + ) trace.retval = res isa AbstractConcreteArray ? Array(res) : res trace.weight = Array(weight)[1] @@ -233,7 +236,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function generate_internal( - f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace, constraints=nothing ) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -276,9 +279,46 @@ function generate_internal( trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) - # Output: (weight, f's outputs...) + constraints_attr = nothing + if constraints !== nothing && !isempty(constraints) + constraint_attrs = MLIR.IR.Attribute[] + + for (sym, constraint) in constraints + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + + if !(constraint isa AbstractArray) + error( + "Constraints must be an array (one element per traced output) of arrays" + ) + end + + sym_constraint_attrs = MLIR.IR.Attribute[] + for oc in constraint + if !(oc isa AbstractArray) + error("Per-output constraints must be arrays") + end + + push!(sym_constraint_attrs, MLIR.IR.DenseElementsAttribute(oc)) + end + + cattr_ptr = @ccall MLIR.API.mlir_c.enzymeConstraintAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, + sym_addr::UInt64, + MLIR.IR.Attribute(sym_constraint_attrs)::MLIR.API.MlirAttribute, + )::MLIR.API.MlirAttribute + + push!(constraint_attrs, MLIR.IR.Attribute(cattr_ptr)) + end + + constraints_attr = MLIR.IR.Attribute(constraint_attrs) + end + gen_op = MLIR.Dialects.enzyme.generate( - batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + batch_inputs; + outputs=out_tys, + fn=fname, + trace=trace_addr, + constraints=constraints_attr, ) weight = TracedRArray(MLIR.IR.result(gen_op, 1)) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 8c1a8917a4..5ed4f662fc 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -40,4 +40,23 @@ end trace, weight = ProbProg.generate(model, seed, μ, σ, shape) @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 end + + @testset "constraints" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + s_constraint = fill(0.1, shape) + constraints = Dict(:s => [s_constraint]) + + trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraints) + + @test trace.choices[:s] == s_constraint + + expected_weight = + normal_logpdf(s_constraint, 0.0, 1.0, shape) + + normal_logpdf(trace.choices[:t], s_constraint, 1.0, shape) + @test weight ≈ expected_weight atol = 1e-6 + end end From 1ad167a3d3943672cc43bcc3c2a470d7ac25eb0e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 14:58:13 -0500 Subject: [PATCH 40/87] untraced call --- src/ProbProg.jl | 69 +++++++++++++++++++++++++++++++++++++++++ test/probprog/sample.jl | 8 ++--- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 6e161676f1..262482690c 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -222,6 +222,75 @@ function sample( return result end +function call(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog call_internal(f, args...) + return res isa AbstractConcreteArray ? Array(res) : res +end + +function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("callarg") + resprefix::Symbol = gensym("callresult") + resargprefix::Symbol = gensym("callresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fname) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(call_op, i) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + return result +end + function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} trace = ProbProgTrace() diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 9541b2feb8..904d2a7ccd 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -24,9 +24,9 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.generate_internal(one_sample, seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.call_internal(one_sample, seed, μ, σ, shape) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.generate_internal(two_samples, seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.call_internal(two_samples, seed, μ, σ, shape) @test !contains(repr(after), "enzyme.sample") end @@ -35,8 +35,8 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = ProbProg.generate(one_sample, seed, μ, σ, shape) - Y = ProbProg.generate(two_samples, seed, μ, σ, shape) + X = ProbProg.call(one_sample, seed, μ, σ, shape) + Y = ProbProg.call(two_samples, seed, μ, σ, shape) @test !all(X .≈ Y) end end From 8f66b5f8813120d6d58b1ba9e9138ac68ef86eb7 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 26 Jun 2025 17:19:02 -0500 Subject: [PATCH 41/87] working metropolis hastings (with hacks) --- src/ProbProg.jl | 46 ++++++++++++++++-- test/probprog/linear_regression.jl | 77 ++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 test/probprog/linear_regression.jl diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 262482690c..c7a04abe24 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -14,10 +14,14 @@ mutable struct ProbProgTrace choices::Dict{Symbol,Any} retval::Any weight::Any + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} - function ProbProgTrace() - return new(Dict{Symbol,Any}(), nothing, nothing) + function ProbProgTrace(fn::Function, args::Tuple) + return new(Dict{Symbol,Any}(), nothing, nothing, fn, args) end + + ProbProgTrace() = new(Dict{Symbol,Any}(), nothing, nothing, nothing, ()) end function addSampleToTraceLowered( @@ -292,7 +296,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} - trace = ProbProgTrace() + trace = ProbProgTrace(f, (args...,)) weight, res = @jit sync = true optimize = :probprog generate_internal( f, args...; trace, constraints @@ -416,7 +420,7 @@ function generate_internal( end function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - trace = ProbProgTrace() + trace = ProbProgTrace(f, (args...,)) res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace) @@ -571,4 +575,38 @@ function Base.show(io::IO, trace::ProbProgTrace) end end +struct Selection + symbols::Vector{Symbol} +end + +select(symbol::Symbol) = Selection([symbol]) + +choicemap() = Dict{Symbol,Any}() +get_choices(trace::ProbProgTrace) = trace.choices + +function metropolis_hastings(trace::ProbProgTrace, sel::Selection) + if trace.fn === nothing + error("MH requires a trace with fn and args recorded") + end + + constraints = Dict{Symbol,Any}() + for (sym, val) in trace.choices + sym in sel.symbols && continue + constraints[sym] = [val] + end + + new_trace, _ = generate(trace.fn, trace.args...; constraints) + rng_state = new_trace.retval[1] # TODO: this is a temporary hack + + log_alpha = new_trace.weight - trace.weight + + if log(rand()) < log_alpha + new_trace.args = (rng_state, new_trace.args[2:end]...) + return (new_trace, true) + else + trace.args = (rng_state, trace.args[2:end]...) + return (trace, false) + end +end + end diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl new file mode 100644 index 0000000000..095e4d8aac --- /dev/null +++ b/test/probprog/linear_regression.jl @@ -0,0 +1,77 @@ +using Reactant, Test, Random +using Reactant: ProbProg + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function my_model(seed, xs) + rng = Random.default_rng() + Random.seed!(rng, seed) + + slope = ProbProg.sample( + normal, rng, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + intercept = ProbProg.sample( + normal, rng, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + ys = ProbProg.sample( + normal, + rng, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return rng.seed, ys +end + +function my_inference_program(xs, ys, num_iters) + xs_r = Reactant.to_rarray(xs) + + constraints = ProbProg.choicemap() + constraints[:ys] = [ys] + + seed = Reactant.to_rarray(UInt64[1, 4]) + + trace, _ = ProbProg.generate(my_model, seed, xs_r; constraints) + trace.args = (trace.retval[1], trace.args[2:end]...) # TODO: this is a temporary hack + + for i in 1:num_iters + trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope)) + trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept)) + choices = ProbProg.get_choices(trace) + @show i, choices[:slope], choices[:intercept] + end + + choices = ProbProg.get_choices(trace) + return (choices[:slope], choices[:intercept]) +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace = ProbProg.simulate(my_model, seed, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + + slope, intercept = my_inference_program(xs, ys, 1000) + + @show slope, intercept + end +end \ No newline at end of file From 850e3c4e42a5e7601ff0b49312861c7030370f16 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 13:49:05 -0500 Subject: [PATCH 42/87] set julia rng --- test/probprog/linear_regression.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index 095e4d8aac..f246fead50 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -55,6 +55,7 @@ end @testset "linear_regression" begin @testset "simulate" begin seed = Reactant.to_rarray(UInt64[1, 4]) + Random.seed!(42) # For Julia side RNG xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] xs_r = Reactant.to_rarray(xs) @@ -74,4 +75,4 @@ end @show slope, intercept end -end \ No newline at end of file +end From e1b3bcb2d0e02c15ea1fec2fd3882db7a48efbbf Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 13:49:27 -0500 Subject: [PATCH 43/87] remove print --- test/probprog/blr.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 7c53aaafd7..615edb842d 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -38,6 +38,4 @@ end trace = ProbProg.simulate(blr, seed, N, K) @test size(Array(trace.retval)) == (N,) - - println(trace) end From 659b9637fee1d1da31c0c6c04d24bf5e7f09328a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:12:19 -0500 Subject: [PATCH 44/87] less iterations. hiding prints --- test/probprog/linear_regression.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index f246fead50..a0efed9416 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -45,7 +45,7 @@ function my_inference_program(xs, ys, num_iters) trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope)) trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept)) choices = ProbProg.get_choices(trace) - @show i, choices[:slope], choices[:intercept] + # @show i, choices[:slope], choices[:intercept] end choices = ProbProg.get_choices(trace) @@ -71,8 +71,8 @@ end xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] - slope, intercept = my_inference_program(xs, ys, 1000) + slope, intercept = my_inference_program(xs, ys, 5) - @show slope, intercept + # @show slope, intercept end end From 537de49a6c205b9c8d14c7d5d84aae96a246e8e8 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:12:45 -0500 Subject: [PATCH 45/87] add probprog test group --- test/runtests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..e7998129f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,4 +60,12 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg BLR" include("probprog/blr.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg Linear Regression" include("probprog/linear_regression.jl") + end end From 8260fee3d8ab2cc451c246e7dee5b2b1e9391e1e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:37:21 -0500 Subject: [PATCH 46/87] format --- test/probprog/sample.jl | 8 ++++++-- test/probprog/simulate.jl | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 904d2a7ccd..ef212a63bf 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -24,9 +24,13 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.call_internal(one_sample, seed, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.call_internal( + one_sample, seed, μ, σ, shape + ) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.call_internal(two_samples, seed, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.call_internal( + two_samples, seed, μ, σ, shape + ) @test !contains(repr(after), "enzyme.sample") end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index a97fc5ae8d..3fbdfdd1ad 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -19,12 +19,12 @@ end σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = false ProbProg.simulate_internal( - model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test contains(repr(before), "enzyme.simulate") after = @code_hlo optimize = :probprog ProbProg.simulate_internal( - model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace() + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() ) @test !contains(repr(after), "enzyme.simulate") end From 0f9416668baf4d9b5002aaf2e263474037e75f34 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 27 Jun 2025 14:46:34 -0500 Subject: [PATCH 47/87] add probprog compile opt --- src/CompileOptions.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 30dfda915f..9b01785c11 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -221,6 +221,8 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, + :probprog_no_lowering, ] end From a05d2c2cb6af266755901cb5436696735676d049 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 3 Jul 2025 18:06:14 -0500 Subject: [PATCH 48/87] pass all args even when w/o rng --- src/ProbProg.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index c7a04abe24..52f9386e26 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -171,9 +171,11 @@ function sample( # Remove AbstractRNG from `f`'s argument list if present, assuming that # logpdf parameters follows `(sample, args...)` convention. - logpdf_args = (example_sample,) + logpdf_args = nothing if !isempty(args) && args[1] isa AbstractRNG logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs? + else + logpdf_args = (example_sample, args...) end logpdf_mlir = invokelatest( From f40960f894f63e9169a7c1e71911db578a4b0bba Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 3 Jul 2025 18:07:17 -0500 Subject: [PATCH 49/87] updated probprog frontend for refactored simulate op --- src/ProbProg.jl | 368 +++++++++++++++++--------------------- test/probprog/simulate.jl | 54 +++++- 2 files changed, 212 insertions(+), 210 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 52f9386e26..365d189f4c 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -6,70 +6,148 @@ using ..Reactant: AbstractConcreteArray, AbstractConcreteNumber, AbstractRNG, - TracedRArray + TracedRArray, + TracedRNumber using ..Compiler: @jit using Enzyme +using Base: ReentrantLock mutable struct ProbProgTrace + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} choices::Dict{Symbol,Any} retval::Any weight::Any - fn::Union{Nothing,Function} - args::Union{Nothing,Tuple} + subtraces::Dict{Symbol,Any} function ProbProgTrace(fn::Function, args::Tuple) - return new(Dict{Symbol,Any}(), nothing, nothing, fn, args) + return new(fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) + end + + function ProbProgTrace() + return new(nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) end +end + +const _trace_ref_lock = ReentrantLock() +const _trace_refs = Vector{Any}() - ProbProgTrace() = new(Dict{Symbol,Any}(), nothing, nothing, nothing, ()) +function _keepalive!(tr::ProbProgTrace) + lock(_trace_ref_lock) + try + push!(_trace_refs, tr) + finally + unlock(_trace_ref_lock) + end + return tr end -function addSampleToTraceLowered( +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + return nothing +end + +function addSampleToTrace( trace_ptr_ptr::Ptr{Ptr{Any}}, symbol_ptr_ptr::Ptr{Ptr{Any}}, - sample_ptr::Ptr{Any}, - num_dims_ptr::Ptr{Int64}, - shape_array_ptr::Ptr{Int64}, - datatype_width_ptr::Ptr{Int64}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, ) - trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) - symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) - - num_dims = unsafe_load(num_dims_ptr) - shape_array = unsafe_wrap(Array, shape_array_ptr, num_dims) - datatype_width = unsafe_load(datatype_width_ptr) - - julia_type = if datatype_width == 32 - Float32 - elseif datatype_width == 64 - Float64 - elseif datatype_width == 1 - Bool - else - @ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid - return nothing - end + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end - typed_ptr = Ptr{julia_type}(sample_ptr) - if num_dims == 0 - trace.choices[symbol] = unsafe_load(typed_ptr) - else - trace.choices[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + val = unsafe_load(Ptr{julia_type}(sample_ptr)) + trace.choices[symbol] = val + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + trace.choices[symbol] = copy( + unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + ) + end end return nothing end +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + return nothing +end + function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + add_sample_to_trace_ptr = @cfunction( - addSampleToTraceLowered, + addSampleToTrace, Cvoid, - (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Any}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}) + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) ) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} )::Cvoid + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + return nothing end @@ -142,6 +220,9 @@ function sample( end symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 + )::MLIR.IR.Attribute # (out_idx1, in_idx1, out_idx2, in_idx2, ...) alias_pairs = Int64[] @@ -198,7 +279,7 @@ function sample( outputs=out_tys, fn=fn_attr, logpdf=logpdf_attr, - symbol=symbol_addr, + symbol=symbol_attr, traced_input_indices=traced_input_indices, traced_output_indices=traced_output_indices, alias_map=alias_attr, @@ -297,143 +378,27 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} return result end -function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} - trace = ProbProgTrace(f, (args...,)) - - weight, res = @jit sync = true optimize = :probprog generate_internal( - f, args...; trace, constraints - ) - - trace.retval = res isa AbstractConcreteArray ? Array(res) : res - trace.weight = Array(weight)[1] - - return trace, trace.weight -end - -function generate_internal( - f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace, constraints=nothing -) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") - - mlir_fn_res = invokelatest( - TracedUtils.make_mlir_fn, - f, - args, - (), - string(f), - false; - do_transpose=false, - args_in_result=:all, - argprefix, - resprefix, - resargprefix, - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - f_out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - out_tys = [MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)); f_out_tys] - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) - - constraints_attr = nothing - if constraints !== nothing && !isempty(constraints) - constraint_attrs = MLIR.IR.Attribute[] - - for (sym, constraint) in constraints - sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) - - if !(constraint isa AbstractArray) - error( - "Constraints must be an array (one element per traced output) of arrays" - ) - end - - sym_constraint_attrs = MLIR.IR.Attribute[] - for oc in constraint - if !(oc isa AbstractArray) - error("Per-output constraints must be arrays") - end - - push!(sym_constraint_attrs, MLIR.IR.DenseElementsAttribute(oc)) - end - - cattr_ptr = @ccall MLIR.API.mlir_c.enzymeConstraintAttrGet( - MLIR.IR.context()::MLIR.API.MlirContext, - sym_addr::UInt64, - MLIR.IR.Attribute(sym_constraint_attrs)::MLIR.API.MlirAttribute, - )::MLIR.API.MlirAttribute +function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + old_gc_state = GC.enable(false) - push!(constraint_attrs, MLIR.IR.Attribute(cattr_ptr)) - end + trace = nothing + weight = nothing + res = nothing - constraints_attr = MLIR.IR.Attribute(constraint_attrs) + try + trace, weight, res = @jit optimize = :probprog simulate_internal(f, args...) + finally + GC.enable(old_gc_state) end - gen_op = MLIR.Dialects.enzyme.generate( - batch_inputs; - outputs=out_tys, - fn=fname, - trace=trace_addr, - constraints=constraints_attr, - ) - - weight = TracedRArray(MLIR.IR.result(gen_op, 1)) - - for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(gen_op, i + 1) # to skip weight - if TracedUtils.has_idx(res, resprefix) - path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], resv) - elseif TracedUtils.has_idx(res, argprefix) - idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - else - TracedUtils.set!(res, (), resv) - end - end - - return weight, result -end - -function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - trace = ProbProgTrace(f, (args...,)) - - res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace) - + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) trace.retval = res isa AbstractConcreteArray ? Array(res) : res + trace.weight = Array(weight)[1] - return trace + return trace, trace.weight end -function simulate_internal( - f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace -) where {Nargs} +function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") resargprefix::Symbol = gensym("simulateresarg") @@ -472,14 +437,16 @@ function simulate_internal( end end - trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) - + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) simulate_op = MLIR.Dialects.enzyme.simulate( - batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + batch_inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fname ) for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(simulate_op, i) + resv = MLIR.IR.result(simulate_op, i + 2) if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) @@ -498,7 +465,26 @@ function simulate_internal( end end - return result + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + weight = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 2)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), weight, ()) + + return trace, weight, result end # Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 @@ -552,6 +538,18 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") cur += 1 end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end end function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) @@ -577,38 +575,6 @@ function Base.show(io::IO, trace::ProbProgTrace) end end -struct Selection - symbols::Vector{Symbol} -end - -select(symbol::Symbol) = Selection([symbol]) - -choicemap() = Dict{Symbol,Any}() get_choices(trace::ProbProgTrace) = trace.choices -function metropolis_hastings(trace::ProbProgTrace, sel::Selection) - if trace.fn === nothing - error("MH requires a trace with fn and args recorded") - end - - constraints = Dict{Symbol,Any}() - for (sym, val) in trace.choices - sym in sel.symbols && continue - constraints[sym] = [val] - end - - new_trace, _ = generate(trace.fn, trace.args...; constraints) - rng_state = new_trace.retval[1] # TODO: this is a temporary hack - - log_alpha = new_trace.weight - trace.weight - - if log(rand()) < log_alpha - new_trace.args = (rng_state, new_trace.args[2:end]...) - return (new_trace, true) - else - trace.args = (rng_state, trace.args[2:end]...) - return (trace, false) - end -end - end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 3fbdfdd1ad..90874c9675 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -2,29 +2,44 @@ using Reactant, Test, Random using Reactant: ProbProg normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function product_two_normals(rng, μ, σ, shape) + a = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + b = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + return a .* b +end function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s) - t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function model2(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(product_two_normals, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample(product_two_normals, rng, s, σ, shape; symbol=:t) return t end @testset "Simulate" begin - @testset "simulate_hlo" begin + @testset "hlo" begin shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = false ProbProg.simulate_internal( - model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + model, seed, μ, σ, shape ) @test contains(repr(before), "enzyme.simulate") after = @code_hlo optimize = :probprog ProbProg.simulate_internal( - model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + model, seed, μ, σ, shape ) @test !contains(repr(after), "enzyme.simulate") end @@ -35,19 +50,21 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace = ProbProg.simulate(model, seed, μ, σ, shape) + trace, weight = ProbProg.simulate(model, seed, μ, σ, shape) @test size(trace.retval) == shape @test haskey(trace.choices, :s) @test haskey(trace.choices, :t) @test size(trace.choices[:s]) == shape @test size(trace.choices[:t]) == shape + @test trace.weight isa Float64 end - @testset "correctness" begin + @testset "simple_fake" begin op(x, y) = x * y' + logpdf(res, _, _) = sum(res) function fake_model(x, y) - return ProbProg.sample(op, x, y; symbol=:matmul) + return ProbProg.sample(op, x, y; symbol=:matmul, logpdf=logpdf) end x = reshape(collect(Float64, 1:12), (4, 3)) @@ -55,10 +72,29 @@ end x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) - trace = ProbProg.simulate(fake_model, x_ra, y_ra) + trace, weight = ProbProg.simulate(fake_model, x_ra, y_ra) @test Array(trace.retval) == op(x, y) @test haskey(trace.choices, :matmul) @test trace.choices[:matmul] == op(x, y) + @test trace.weight == logpdf(op(x, y), x, y) + end + + @testset "submodel_fake" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate(model2, seed, μ, σ, shape) + + println(trace) + + @test size(trace.retval) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s]) == shape + @test size(trace.choices[:t]) == shape + @test trace.weight isa Float64 end end From f6ee8494cf07e2afe5382569ae8850860a7fa2ef Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 3 Jul 2025 18:07:34 -0500 Subject: [PATCH 50/87] probprog attr mlir api --- deps/ReactantExtra/API.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index af7abcbf6d..8177dc3b76 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -364,6 +364,16 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } +extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { + return wrap(mlir::enzyme::TraceType::get(unwrap(ctx))); +} + +extern "C" MLIR_CAPI_EXPORTED MlirAttribute +enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { + mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol); + return wrap(attr); +} + extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet( MlirContext ctx, uint64_t symbol, MlirAttribute values) { mlir::Attribute vals = unwrap(values); From 38e33dec6edb86cf16bfb968cbe955d23450f3b7 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 4 Jul 2025 00:59:55 -0500 Subject: [PATCH 51/87] adding cfunction mapping for AddWeightToTrace and AddRetvalToTrace ops --- src/ProbProg.jl | 86 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 365d189f4c..1a50f62719 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -118,6 +118,69 @@ function addSubtrace( return nothing end +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = length(vals) == 1 ? vals[1] : vals + return nothing +end + function __init__() init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( @@ -148,6 +211,27 @@ function __init__() :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} )::Cvoid + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + return nothing end @@ -392,8 +476,6 @@ function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) - trace.retval = res isa AbstractConcreteArray ? Array(res) : res - trace.weight = Array(weight)[1] return trace, trace.weight end From 127126d29e1d693919ca9fcf0cfb6fee51ddd702 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 4 Jul 2025 01:32:41 -0500 Subject: [PATCH 52/87] adding traced_output_indices attr to simulate op --- src/ProbProg.jl | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 1a50f62719..1c24919964 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -422,7 +422,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) batch_inputs = MLIR.IR.Value[] for a in linear_args @@ -437,7 +437,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end end - call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fname) + call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fn_attr) for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(call_op, i) @@ -504,7 +504,15 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + # Specify which outputs to add to the trace. + traced_output_indices = Int[] + for (i, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end + end batch_inputs = MLIR.IR.Value[] for a in linear_args @@ -524,7 +532,12 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} )::MLIR.IR.Type weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) simulate_op = MLIR.Dialects.enzyme.simulate( - batch_inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fname + batch_inputs; + trace=trace_ty, + weight=weight_ty, + outputs=out_tys, + fn=fn_attr, + traced_output_indices=traced_output_indices, ) for (i, res) in enumerate(linear_results) From 3d66c7a6528b67da9b76104abefef36f7c3d4fbb Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 4 Jul 2025 16:13:12 -0500 Subject: [PATCH 53/87] update tests --- test/probprog/simulate.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 90874c9675..894ac55697 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -87,14 +87,24 @@ end σ = Reactant.ConcreteRNumber(1.0) trace, weight = ProbProg.simulate(model2, seed, μ, σ, shape) - - println(trace) @test size(trace.retval) == shape + + @test length(trace.choices) == 2 @test haskey(trace.choices, :s) @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + @test size(trace.choices[:s]) == shape @test size(trace.choices[:t]) == shape + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight end end From 15854835464d13b9cc4136853869670f51a52f2f Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 7 Jul 2025 22:09:17 -0500 Subject: [PATCH 54/87] refactored generate op --- deps/ReactantExtra/API.cpp | 19 +-- src/ProbProg.jl | 316 ++++++++++++++++++++++++++++++++++--- test/probprog/generate.jl | 19 ++- 3 files changed, 306 insertions(+), 48 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8177dc3b76..e1883aff96 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -368,26 +368,17 @@ extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { return wrap(mlir::enzyme::TraceType::get(unwrap(ctx))); } +extern "C" MLIR_CAPI_EXPORTED MlirType +enzymeConstraintTypeGet(MlirContext ctx) { + return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx))); +} + extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol); return wrap(attr); } -extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet( - MlirContext ctx, uint64_t symbol, MlirAttribute values) { - mlir::Attribute vals = unwrap(values); - auto arr = llvm::dyn_cast(vals); - if (!arr) { - ReactantThrowError( - "enzymeConstraintAttrGet: `values` must be an ArrayAttr"); - return MlirAttribute{nullptr}; - } - mlir::Attribute attr = - mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr); - return wrap(attr); -} - // Create profiler session and start profiling extern "C" tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 1c24919964..147133689b 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -7,7 +7,9 @@ using ..Reactant: AbstractConcreteNumber, AbstractRNG, TracedRArray, - TracedRNumber + TracedRNumber, + ConcreteRNumber, + Ops using ..Compiler: @jit using Enzyme using Base: ReentrantLock @@ -29,6 +31,8 @@ mutable struct ProbProgTrace end end +const Constraint = Dict{Symbol,Any} + const _trace_ref_lock = ReentrantLock() const _trace_refs = Vector{Any}() @@ -54,20 +58,21 @@ function addSampleToTrace( trace_ptr_ptr::Ptr{Ptr{Any}}, symbol_ptr_ptr::Ptr{Ptr{Any}}, sample_ptr_array::Ptr{Ptr{Any}}, - num_samples_ptr::Ptr{UInt64}, + num_outputs_ptr::Ptr{UInt64}, ndims_array::Ptr{UInt64}, shape_ptr_array::Ptr{Ptr{UInt64}}, width_array::Ptr{UInt64}, ) trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol - num_samples = unsafe_load(num_samples_ptr) - ndims_array = unsafe_wrap(Array, ndims_array, num_samples) - width_array = unsafe_wrap(Array, width_array, num_samples) - shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) - sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) - - for i in 1:num_samples + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + tostore = Any[] + for i in 1:num_outputs ndims = ndims_array[i] width = width_array[i] shape_ptr = shape_ptr_array[i] @@ -92,15 +97,17 @@ function addSampleToTrace( if ndims == 0 val = unsafe_load(Ptr{julia_type}(sample_ptr)) - trace.choices[symbol] = val + push!(tostore, val) else shape = unsafe_wrap(Array, shape_ptr, ndims) - trace.choices[symbol] = copy( - unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + push!( + tostore, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape))) ) end end + trace.choices[symbol] = tuple(tostore...) + return nothing end @@ -181,6 +188,94 @@ function addRetvalToTrace( return nothing end +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, symbol, nothing) + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + if size(dest) != size(tostore[i]) + if length(size(dest)) != length(size(tostore[i])) + @ccall printf( + "Shape size mismatch in constrained sample: %zd != %zd\n"::Cstring, + length(size(dest))::Csize_t, + length(size(tostore[i]))::Csize_t, + )::Cvoid + return nothing + end + for i in 1:length(size(dest)) + d = size(dest)[i] + t = size(tostore[i])[i] + if d != t + @ccall printf( + "Shape mismatch in `%zd`th dimension of constrained sample: %zd != %zd\n"::Cstring, + i::Csize_t, + size(dest)[i]::Csize_t, + size(tostore[i])[i]::Csize_t, + )::Cvoid + return nothing + end + end + end + + dest .= tostore[i] + end + end + + return nothing +end + function __init__() init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( @@ -232,6 +327,24 @@ function __init__() :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} )::Cvoid + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + return nothing end @@ -262,14 +375,14 @@ function sample( fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - batch_inputs = MLIR.IR.Value[] + inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) + TracedUtils.push_val!(inputs, f, path[3:end]) else idx -= fnwrap ? 1 : 0 - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + TracedUtils.push_val!(inputs, args[idx], path[3:end]) end end @@ -359,7 +472,7 @@ function sample( end sample_op = MLIR.Dialects.enzyme.sample( - batch_inputs; + inputs; outputs=out_tys, fn=fn_attr, logpdf=logpdf_attr, @@ -424,20 +537,20 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - batch_inputs = MLIR.IR.Value[] + inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) + TracedUtils.push_val!(inputs, f, path[3:end]) else if fnwrap idx -= 1 end - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + TracedUtils.push_val!(inputs, args[idx], path[3:end]) end end - call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fn_attr) + call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(call_op, i) @@ -514,16 +627,16 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end end - batch_inputs = MLIR.IR.Value[] + inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) + TracedUtils.push_val!(inputs, f, path[3:end]) else if fnwrap idx -= 1 end - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + TracedUtils.push_val!(inputs, args[idx], path[3:end]) end end @@ -531,8 +644,9 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.IR.Type weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + simulate_op = MLIR.Dialects.enzyme.simulate( - batch_inputs; + inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, @@ -582,6 +696,160 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} return trace, weight, result end +function generate( + f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}() +) where {Nargs} + old_gc_state = GC.enable(false) + + trace = nothing + weight = nothing + res = nothing + + try + trace, weight, res = @jit optimize = :probprog generate_internal( + f, args...; constraint + ) + finally + GC.enable(old_gc_state) + end + + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + + return trace, trace.weight +end + +function generate_internal( + f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}() +) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + # Specify which outputs to add to the trace. + traced_output_indices = Int[] + for (i, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end + end + + inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(inputs, args[idx], path[3:end]) + end + end + + constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + constraint_addr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + constraint_mlir_val = TracedUtils.get_mlir_data(Ops.constant(constraint_addr)) + + constraint_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [constraint_mlir_val]; outputs=[constraint_ty] + ), + 1, + ) + + constrained_symbols_attr = MLIR.IR.Attribute[] + for sym in keys(constraint) + addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + constrained_symbols_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, addr::UInt64 + )::MLIR.IR.Attribute + ) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + generate_op = MLIR.Dialects.enzyme.generate( + inputs, + constraint_val; + trace=trace_ty, + weight=weight_ty, + outputs=out_tys, + fn=fn_attr, + constrained_symbols=MLIR.IR.Attribute(constrained_symbols_attr), + traced_output_indices, + ) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(generate_op, i + 2) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + weight = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 2)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), weight, ()) + + return trace, weight, result +end + # Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) VERT = '\u2502' diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 5ed4f662fc..01b09bb5d1 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -20,19 +20,19 @@ end σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( - model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + model, seed, μ, σ, shape ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") after = @code_hlo optimize = :probprog ProbProg.generate_internal( - model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + model, seed, μ, σ, shape ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end - @testset "normal" begin + @testset "unconstrained" begin shape = (1000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) @@ -41,22 +41,21 @@ end @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 end - @testset "constraints" begin + @testset "constrained" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - s_constraint = fill(0.1, shape) - constraints = Dict(:s => [s_constraint]) + constraint = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) - trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraints) + trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraint) - @test trace.choices[:s] == s_constraint + @test trace.choices[:s] == constraint[:s] expected_weight = - normal_logpdf(s_constraint, 0.0, 1.0, shape) + - normal_logpdf(trace.choices[:t], s_constraint, 1.0, shape) + normal_logpdf(constraint[:s][1], 0.0, 1.0, shape) + + normal_logpdf(trace.choices[:t][1], constraint[:s][1], 1.0, shape) @test weight ≈ expected_weight atol = 1e-6 end end From 34f35c43b0e7addda80916e7b67b6be6c903021e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 7 Jul 2025 22:52:00 -0500 Subject: [PATCH 55/87] @compile for generate op --- src/ProbProg.jl | 23 ++++++++++++++--------- test/probprog/generate.jl | 19 ------------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 147133689b..5169e319d7 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -10,7 +10,7 @@ using ..Reactant: TracedRNumber, ConcreteRNumber, Ops -using ..Compiler: @jit +using ..Compiler: @jit, @compile using Enzyme using Base: ReentrantLock @@ -705,10 +705,15 @@ function generate( weight = nothing res = nothing + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + + function wrapper_fn(constraint_ptr, args...) + return generate_internal(f, args...; constraint_ptr, constraint) + end + try - trace, weight, res = @jit optimize = :probprog generate_internal( - f, args...; constraint - ) + compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr, args...) + trace, weight, res = compiled_fn(constraint_ptr, args...) finally GC.enable(old_gc_state) end @@ -719,7 +724,10 @@ function generate( end function generate_internal( - f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}() + f::Function, + args::Vararg{Any,Nargs}; + constraint_ptr::TracedRNumber, + constraint::Constraint=Dict{Symbol,Any}(), ) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -771,12 +779,9 @@ function generate_internal( MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.IR.Type - constraint_addr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) - constraint_mlir_val = TracedUtils.get_mlir_data(Ops.constant(constraint_addr)) - constraint_val = MLIR.IR.result( MLIR.Dialects.builtin.unrealized_conversion_cast( - [constraint_mlir_val]; outputs=[constraint_ty] + [TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty] ), 1, ) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 01b09bb5d1..77d1f2d5b0 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -13,25 +13,6 @@ function model(seed, μ, σ, shape) end @testset "Generate" begin - @testset "hlo" begin - shape = (10,) - seed = Reactant.to_rarray(UInt64[1, 4]) - μ = Reactant.ConcreteRNumber(0.0) - σ = Reactant.ConcreteRNumber(1.0) - - before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( - model, seed, μ, σ, shape - ) - @test contains(repr(before), "enzyme.generate") - @test contains(repr(before), "enzyme.sample") - - after = @code_hlo optimize = :probprog ProbProg.generate_internal( - model, seed, μ, σ, shape - ) - @test !contains(repr(after), "enzyme.generate") - @test !contains(repr(after), "enzyme.sample") - end - @testset "unconstrained" begin shape = (1000,) seed = Reactant.to_rarray(UInt64[1, 4]) From f4a6415f63524a3eeb12e25d4b75c7df07599c66 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 7 Jul 2025 23:24:53 -0500 Subject: [PATCH 56/87] improve api --- src/ProbProg.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 5169e319d7..e7051298d6 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -706,9 +706,10 @@ function generate( res = nothing constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + constrained_symbols = collect(keys(constraint)) function wrapper_fn(constraint_ptr, args...) - return generate_internal(f, args...; constraint_ptr, constraint) + return generate_internal(f, args...; constraint_ptr, constrained_symbols) end try @@ -727,7 +728,7 @@ function generate_internal( f::Function, args::Vararg{Any,Nargs}; constraint_ptr::TracedRNumber, - constraint::Constraint=Dict{Symbol,Any}(), + constrained_symbols::Vector{Symbol}, ) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -787,7 +788,7 @@ function generate_internal( ) constrained_symbols_attr = MLIR.IR.Attribute[] - for sym in keys(constraint) + for sym in constrained_symbols addr = reinterpret(UInt64, pointer_from_objref(sym)) push!( constrained_symbols_attr, From b92a7335b406bd6c69a2069d03a4148964517be3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 7 Jul 2025 23:25:04 -0500 Subject: [PATCH 57/87] compiled generate test --- test/probprog/generate.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 77d1f2d5b0..bf51cdc270 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -39,4 +39,38 @@ end normal_logpdf(trace.choices[:t][1], constraint[:s][1], 1.0, shape) @test weight ≈ expected_weight atol = 1e-6 end + + @testset "compiled" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) + + constrained_symbols = collect(keys(constraint1)) # This doesn't change + + constraint_ptr1 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint1)) + ) + + wrapper_fn(constraint_ptr, seed, μ, σ) = ProbProg.generate_internal( + model, seed, μ, σ, shape; constraint_ptr, constrained_symbols + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, seed, μ, σ) + + trace1, weight = compiled_fn(constraint_ptr1, seed, μ, σ) + trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) + + constraint2 = Dict{Symbol,Any}(:s => (fill(0.2, shape),)) + constraint_ptr2 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint2)) + ) + + trace2, _ = compiled_fn(constraint_ptr2, seed, μ, σ) + trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1])) + + @test trace1.choices[:s] != trace2.choices[:s] + end end From f4c4a887bafea17a3de8e63c60c88daca8939215 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 14 Jul 2025 23:49:07 -0500 Subject: [PATCH 58/87] save gc change --- src/ProbProg.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index e7051298d6..5c88f2d129 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -699,8 +699,6 @@ end function generate( f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}() ) where {Nargs} - old_gc_state = GC.enable(false) - trace = nothing weight = nothing res = nothing @@ -712,8 +710,10 @@ function generate( return generate_internal(f, args...; constraint_ptr, constrained_symbols) end + compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr, args...) + + old_gc_state = GC.enable(false) try - compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr, args...) trace, weight, res = compiled_fn(constraint_ptr, args...) finally GC.enable(old_gc_state) From d1be27c7f069f5a1daa052cab62c334d71f39019 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 16 Jul 2025 17:26:40 -0500 Subject: [PATCH 59/87] enforcing calling convention (rng being the 0th operand) for sample & untraced call ops --- src/ProbProg.jl | 92 ++++++++++++++++++++++++++++++++--------- test/probprog/sample.jl | 72 +++++++++++++++++++++++--------- 2 files changed, 126 insertions(+), 38 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 5c88f2d129..ebeab14091 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -349,6 +349,23 @@ function __init__() end function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + res = sample_internal(rng, f, args...; symbol, logpdf) + + @assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG" + + res = res[2:end] + + return length(res) == 1 ? res[1] : res +end + +function sample_internal( + rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample"), @@ -358,15 +375,22 @@ function sample( resprefix::Symbol = gensym("sampleresult") resargprefix::Symbol = gensym("sampleresarg") + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + args = (rng, args...) + mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, - f, + wrapper_fn, args, (), string(f), false; do_transpose=false, - args_in_result=:all, + args_in_result=:result, argprefix, resprefix, resargprefix, @@ -378,10 +402,13 @@ function sample( inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap + if idx == 2 && fnwrap TracedUtils.push_val!(inputs, f, path[3:end]) else - idx -= fnwrap ? 1 : 0 + if fnwrap && idx > 1 + idx -= 1 + end + TracedUtils.push_val!(inputs, args[idx], path[3:end]) end end @@ -464,7 +491,7 @@ function sample( string(logpdf), false; do_transpose=false, - args_in_result=:all, + args_in_result=:result, ) logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") @@ -485,20 +512,25 @@ function sample( for (i, res) in enumerate(linear_results) resv = MLIR.IR.result(sample_op, i) + if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) - elseif TracedUtils.has_idx(res, argprefix) + end + + if TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 1 && fnwrap + if fnwrap && idx == 2 TracedUtils.set!(f, path[3:end], resv) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.set!(args[idx], path[3:end], resv) end - else + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) TracedUtils.set!(res, (), resv) end end @@ -506,25 +538,41 @@ function sample( return result end -function call(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - res = @jit optimize = :probprog call_internal(f, args...) - return res isa AbstractConcreteArray ? Array(res) : res +function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog call_internal(rng, f, args...) + + @assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG" + + res = map(res[2:end]) do r + r isa AbstractConcreteArray ? Array(r) : r + end + + @show res + + return length(res) == 1 ? res[1] : res end -function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("callarg") resprefix::Symbol = gensym("callresult") resargprefix::Symbol = gensym("callresarg") + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + args = (rng, args...) + mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, - f, + wrapper_fn, args, (), string(f), false; do_transpose=false, - args_in_result=:all, + args_in_result=:result, argprefix, resprefix, resargprefix, @@ -533,6 +581,8 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + @show length(linear_results), linear_results + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) @@ -557,17 +607,21 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) - elseif TracedUtils.has_idx(res, argprefix) + end + + if TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 1 && fnwrap + if fnwrap && idx == 2 TracedUtils.set!(f, path[3:end], resv) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.set!(args[idx], path[3:end], resv) end - else + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) TracedUtils.set!(res, (), resv) end end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index ef212a63bf..28a4ab6ee9 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -1,46 +1,80 @@ using Reactant, Test, Random -using Reactant: ProbProg +using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function one_sample(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape) +function one_sample(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape) return s end -function two_samples(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - _ = ProbProg.sample(normal, rng, μ, σ, shape) - t = ProbProg.sample(normal, rng, μ, σ, shape) +function two_samples(rng, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape) + t = ProbProg.sample(rng, normal, μ, σ, shape) + return t +end + +function compose(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape) + t = ProbProg.sample(rng, normal, s, σ, shape) return t end @testset "test" begin - @testset "sample_hlo" begin + @testset "normal_hlo" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.call_internal( - one_sample, seed, μ, σ, shape - ) + + code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "two_samples_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "compose" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.call(rng, compose, μ, σ, shape) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.call_internal( - two_samples, seed, μ, σ, shape - ) + + after = @code_hlo optimize = :probprog ProbProg.call(rng, compose, μ, σ, shape) @test !contains(repr(after), "enzyme.sample") end @testset "rng_state" begin shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - X = ProbProg.call(one_sample, seed, μ, σ, shape) - Y = ProbProg.call(two_samples, seed, μ, σ, shape) + + rng1 = ReactantRNG(copy(seed)) + + X = ProbProg.call(rng1, one_sample, μ, σ, shape) + @test !all(rng1.seed .== seed) + + rng2 = ReactantRNG(copy(seed)) + Y = ProbProg.call(rng2, two_samples, μ, σ, shape) + + @test !all(rng2.seed .== seed) + @test !all(rng2.seed .== rng1.seed) + @test !all(X .≈ Y) end end From b66681311815246a89dff686690f43e37fb499d4 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 17 Jul 2025 18:33:20 -0500 Subject: [PATCH 60/87] enforcing calling convention (rng being 0th operand) for simulate/generate ops --- src/ProbProg.jl | 179 ++++++++++++++------------------------ test/probprog/generate.jl | 33 +++---- test/probprog/simulate.jl | 73 +++++++++------- 3 files changed, 124 insertions(+), 161 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index ebeab14091..2ffd333859 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -71,7 +71,7 @@ function addSampleToTrace( shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) - tostore = Any[] + vals = Any[] for i in 1:num_outputs ndims = ndims_array[i] width = width_array[i] @@ -96,17 +96,14 @@ function addSampleToTrace( end if ndims == 0 - val = unsafe_load(Ptr{julia_type}(sample_ptr)) - push!(tostore, val) + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) else shape = unsafe_wrap(Array, shape_ptr, ndims) - push!( - tostore, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape))) - ) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) end end - trace.choices[symbol] = tuple(tostore...) + trace.choices[symbol] = tuple(vals...) return nothing end @@ -184,7 +181,8 @@ function addRetvalToTrace( end end - trace.retval = length(vals) == 1 ? vals[1] : vals + trace.retval = tuple(vals...) + return nothing end @@ -418,56 +416,11 @@ function sample_internal( sym = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - # Specify which outputs to add to the trace. - traced_output_indices = Int[] - for (i, res) in enumerate(linear_results) - if TracedUtils.has_idx(res, resprefix) - push!(traced_output_indices, i - 1) - end - end - - # Specify which inputs to pass to logpdf. - traced_input_indices = Int[] - for (i, a) in enumerate(linear_args) - idx, _ = TracedUtils.get_argidx(a, argprefix) - if fnwrap && idx == 1 # TODO: add test for fnwrap - continue - end - - if fnwrap - idx -= 1 - end - - if !(args[idx] isa AbstractRNG) - push!(traced_input_indices, i - 1) - end - end - symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 )::MLIR.IR.Attribute - # (out_idx1, in_idx1, out_idx2, in_idx2, ...) - alias_pairs = Int64[] - for (out_idx, res) in enumerate(linear_results) - if TracedUtils.has_idx(res, argprefix) - in_idx = nothing - for (i, arg) in enumerate(linear_args) - if TracedUtils.has_idx(arg, argprefix) && - TracedUtils.get_idx(arg, argprefix) == - TracedUtils.get_idx(res, argprefix) - in_idx = i - 1 - break - end - end - @assert in_idx !== nothing "Unable to find operand for aliased result" - push!(alias_pairs, out_idx - 1) - push!(alias_pairs, in_idx) - end - end - alias_attr = MLIR.IR.DenseArrayAttribute(alias_pairs) - # Construct MLIR attribute if Julia logpdf function is provided. logpdf_attr = nothing if logpdf !== nothing @@ -504,9 +457,6 @@ function sample_internal( fn=fn_attr, logpdf=logpdf_attr, symbol=symbol_attr, - traced_input_indices=traced_input_indices, - traced_output_indices=traced_output_indices, - alias_map=alias_attr, name=Base.String(symbol), ) @@ -547,8 +497,6 @@ function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nar r isa AbstractConcreteArray ? Array(r) : r end - @show res - return length(res) == 1 ? res[1] : res end @@ -581,8 +529,6 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - @show length(linear_results), linear_results - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) @@ -590,10 +536,10 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap + if idx == 2 && fnwrap TracedUtils.push_val!(inputs, f, path[3:end]) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.push_val!(inputs, args[idx], path[3:end]) @@ -629,15 +575,14 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w return result end -function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - old_gc_state = GC.enable(false) - +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} trace = nothing - weight = nothing - res = nothing + compiled_fn = @compile optimize = :probprog simulate_internal(rng, f, args...) + + old_gc_state = GC.enable(false) try - trace, weight, res = @jit optimize = :probprog simulate_internal(f, args...) + trace, _, _ = compiled_fn(rng, f, args...) finally GC.enable(old_gc_state) end @@ -647,20 +592,29 @@ function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} return trace, trace.weight end -function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +function simulate_internal( + rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs} +) where {Nargs} argprefix::Symbol = gensym("simulatearg") resprefix::Symbol = gensym("simulateresult") resargprefix::Symbol = gensym("simulateresarg") + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + args = (rng, args...) + mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, - f, + wrapper_fn, args, (), string(f), false; do_transpose=false, - args_in_result=:all, + args_in_result=:result, argprefix, resprefix, resargprefix, @@ -673,21 +627,13 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - # Specify which outputs to add to the trace. - traced_output_indices = Int[] - for (i, res) in enumerate(linear_results) - if TracedUtils.has_idx(res, resprefix) - push!(traced_output_indices, i - 1) - end - end - inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap + if idx == 2 && fnwrap TracedUtils.push_val!(inputs, f, path[3:end]) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.push_val!(inputs, args[idx], path[3:end]) @@ -700,12 +646,7 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) simulate_op = MLIR.Dialects.enzyme.simulate( - inputs; - trace=trace_ty, - weight=weight_ty, - outputs=out_tys, - fn=fn_attr, - traced_output_indices=traced_output_indices, + inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr ) for (i, res) in enumerate(linear_results) @@ -713,17 +654,21 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) - elseif TracedUtils.has_idx(res, argprefix) + end + + if TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 1 && fnwrap + if idx == 2 && fnwrap TracedUtils.set!(f, path[3:end], resv) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.set!(args[idx], path[3:end], resv) end - else + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) TracedUtils.set!(res, (), resv) end end @@ -751,24 +696,25 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function generate( - f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}() + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint::Constraint=Dict{Symbol,Any}(), ) where {Nargs} trace = nothing - weight = nothing - res = nothing constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) constrained_symbols = collect(keys(constraint)) - function wrapper_fn(constraint_ptr, args...) - return generate_internal(f, args...; constraint_ptr, constrained_symbols) + function wrapper_fn(rng, constraint_ptr, args...) + return generate_internal(rng, f, args...; constraint_ptr, constrained_symbols) end - compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr, args...) + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) old_gc_state = GC.enable(false) try - trace, weight, res = compiled_fn(constraint_ptr, args...) + trace, _, _ = compiled_fn(rng, constraint_ptr, args...) finally GC.enable(old_gc_state) end @@ -779,6 +725,7 @@ function generate( end function generate_internal( + rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}; constraint_ptr::TracedRNumber, @@ -788,15 +735,22 @@ function generate_internal( resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + args = (rng, args...) + mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, - f, + wrapper_fn, args, (), string(f), false; do_transpose=false, - args_in_result=:all, + args_in_result=:result, argprefix, resprefix, resargprefix, @@ -809,21 +763,13 @@ function generate_internal( fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - # Specify which outputs to add to the trace. - traced_output_indices = Int[] - for (i, res) in enumerate(linear_results) - if TracedUtils.has_idx(res, resprefix) - push!(traced_output_indices, i - 1) - end - end - inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap + if idx == 2 && fnwrap TracedUtils.push_val!(inputs, f, path[3:end]) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.push_val!(inputs, args[idx], path[3:end]) @@ -865,7 +811,6 @@ function generate_internal( outputs=out_tys, fn=fn_attr, constrained_symbols=MLIR.IR.Attribute(constrained_symbols_attr), - traced_output_indices, ) for (i, res) in enumerate(linear_results) @@ -873,17 +818,21 @@ function generate_internal( if TracedUtils.has_idx(res, resprefix) path = TracedUtils.get_idx(res, resprefix) TracedUtils.set!(result, path[2:end], resv) - elseif TracedUtils.has_idx(res, argprefix) + end + + if TracedUtils.has_idx(res, argprefix) idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 1 && fnwrap + if idx == 2 && fnwrap TracedUtils.set!(f, path[3:end], resv) else - if fnwrap + if fnwrap && idx > 2 idx -= 1 end TracedUtils.set!(args[idx], path[3:end], resv) end - else + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) TracedUtils.set!(res, (), resv) end end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index bf51cdc270..bed55ab634 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -1,14 +1,12 @@ using Reactant, Test, Random, Statistics -using Reactant: ProbProg +using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) -function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) - t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf) +function model(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) return t end @@ -16,23 +14,25 @@ end @testset "unconstrained" begin shape = (1000,) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.generate(model, seed, μ, σ, shape) - @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 + trace, weight = ProbProg.generate(rng, model, μ, σ, shape) + @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 end @testset "constrained" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) constraint = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) - trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraint) + trace, weight = ProbProg.generate(rng, model, μ, σ, shape; constraint) - @test trace.choices[:s] == constraint[:s] + @test trace.choices[:s][1] == constraint[:s][1] expected_weight = normal_logpdf(constraint[:s][1], 0.0, 1.0, shape) + @@ -43,6 +43,7 @@ end @testset "compiled" begin shape = (10,) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) @@ -54,13 +55,13 @@ end reinterpret(UInt64, pointer_from_objref(constraint1)) ) - wrapper_fn(constraint_ptr, seed, μ, σ) = ProbProg.generate_internal( - model, seed, μ, σ, shape; constraint_ptr, constrained_symbols + wrapper_fn(constraint_ptr, rng, μ, σ) = ProbProg.generate_internal( + rng, model, μ, σ, shape; constraint_ptr, constrained_symbols ) - compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, seed, μ, σ) + compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, rng, μ, σ) - trace1, weight = compiled_fn(constraint_ptr1, seed, μ, σ) + trace1, weight = compiled_fn(constraint_ptr1, rng, μ, σ) trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) constraint2 = Dict{Symbol,Any}(:s => (fill(0.2, shape),)) @@ -68,9 +69,9 @@ end reinterpret(UInt64, pointer_from_objref(constraint2)) ) - trace2, _ = compiled_fn(constraint_ptr2, seed, μ, σ) + trace2, _ = compiled_fn(constraint_ptr2, rng, μ, σ) trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1])) - @test trace1.choices[:s] != trace2.choices[:s] + @test trace1.choices[:s][1] != trace2.choices[:s][1] end end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 894ac55697..6f36004a4c 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -1,28 +1,24 @@ using Reactant, Test, Random -using Reactant: ProbProg +using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) function product_two_normals(rng, μ, σ, shape) - a = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) - b = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) return a .* b end -function model(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) - t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf) +function model(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) return t end -function model2(seed, μ, σ, shape) - rng = Random.default_rng() - Random.seed!(rng, seed) - s = ProbProg.sample(product_two_normals, rng, μ, σ, shape; symbol=:s) - t = ProbProg.sample(product_two_normals, rng, s, σ, shape; symbol=:t) +function model2(rng, μ, σ, shape) + s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) return t end @@ -30,65 +26,82 @@ end @testset "hlo" begin shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) before = @code_hlo optimize = false ProbProg.simulate_internal( - model, seed, μ, σ, shape + rng, model, μ, σ, shape ) @test contains(repr(before), "enzyme.simulate") + unlowered = @code_hlo optimize = :probprog_no_lowering ProbProg.simulate_internal( + rng, model, μ, σ, shape + ) + @test !contains(repr(unlowered), "enzyme.simulate") + @test contains(repr(unlowered), "enzyme.addSampleToTrace") + @test contains(repr(unlowered), "enzyme.addWeightToTrace") + @test contains(repr(unlowered), "enzyme.addRetvalToTrace") + after = @code_hlo optimize = :probprog ProbProg.simulate_internal( - model, seed, μ, σ, shape + rng, model, μ, σ, shape ) @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.addSampleToTrace") + @test !contains(repr(after), "enzyme.addWeightToTrace") + @test !contains(repr(after), "enzyme.addRetvalToTrace") end @testset "normal_simulate" begin shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.simulate(model, seed, μ, σ, shape) + trace, weight = ProbProg.simulate(rng, model, μ, σ, shape) + println(trace) - @test size(trace.retval) == shape + @test size(trace.retval[1]) == shape @test haskey(trace.choices, :s) @test haskey(trace.choices, :t) - @test size(trace.choices[:s]) == shape - @test size(trace.choices[:t]) == shape + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape @test trace.weight isa Float64 end @testset "simple_fake" begin - op(x, y) = x * y' + op(_, x, y) = x * y' logpdf(res, _, _) = sum(res) - function fake_model(x, y) - return ProbProg.sample(op, x, y; symbol=:matmul, logpdf=logpdf) + function fake_model(rng, x, y) + return ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) end x = reshape(collect(Float64, 1:12), (4, 3)) y = reshape(collect(Float64, 1:12), (4, 3)) x_ra = Reactant.to_rarray(x) y_ra = Reactant.to_rarray(y) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) - trace, weight = ProbProg.simulate(fake_model, x_ra, y_ra) + trace, weight = ProbProg.simulate(rng, fake_model, x_ra, y_ra) - @test Array(trace.retval) == op(x, y) + @test Array(trace.retval[1]) == op(rng, x, y) @test haskey(trace.choices, :matmul) - @test trace.choices[:matmul] == op(x, y) - @test trace.weight == logpdf(op(x, y), x, y) + @test trace.choices[:matmul][1] == op(rng, x, y) + @test trace.weight == logpdf(op(rng, x, y), x, y) end @testset "submodel_fake" begin shape = (3, 3, 3) seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.simulate(model2, seed, μ, σ, shape) + trace, weight = ProbProg.simulate(rng, model2, μ, σ, shape) - @test size(trace.retval) == shape + @test size(trace.retval[1]) == shape @test length(trace.choices) == 2 @test haskey(trace.choices, :s) @@ -100,8 +113,8 @@ end @test haskey(trace.subtraces[:t].choices, :a) @test haskey(trace.subtraces[:t].choices, :b) - @test size(trace.choices[:s]) == shape - @test size(trace.choices[:t]) == shape + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape @test trace.weight isa Float64 From c57a1e4114bdc3b86f9f5c217a03d4aa13aa7697 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 17 Jul 2025 21:18:14 -0500 Subject: [PATCH 61/87] clean up --- src/ProbProg.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 2ffd333859..cb478d4ec4 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -491,8 +491,6 @@ end function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} res = @jit optimize = :probprog call_internal(rng, f, args...) - @assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG" - res = map(res[2:end]) do r r isa AbstractConcreteArray ? Array(r) : r end From 2b81db9a4db23a475071f1738defe268ace7012e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 17 Jul 2025 22:38:59 -0500 Subject: [PATCH 62/87] refactored mh inference steps with new calling convention enforced --- src/ProbProg.jl | 98 ++++++++++++++++++++++++++++-- test/probprog/linear_regression.jl | 49 ++++++++------- 2 files changed, 121 insertions(+), 26 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index cb478d4ec4..b1162fcb8a 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -13,6 +13,7 @@ using ..Reactant: using ..Compiler: @jit, @compile using Enzyme using Base: ReentrantLock +using Random mutable struct ProbProgTrace fn::Union{Nothing,Function} @@ -21,13 +22,18 @@ mutable struct ProbProgTrace retval::Any weight::Any subtraces::Dict{Symbol,Any} + rng::Union{Nothing,AbstractRNG} function ProbProgTrace(fn::Function, args::Tuple) - return new(fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) + return new( + fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing + ) end function ProbProgTrace() - return new(nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) + return new( + nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing + ) end end @@ -587,6 +593,10 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + trace.fn = f + trace.args = args + trace.rng = rng + return trace, trace.weight end @@ -702,7 +712,7 @@ function generate( trace = nothing constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) - constrained_symbols = collect(keys(constraint)) + constrained_symbols = Set(keys(constraint)) function wrapper_fn(rng, constraint_ptr, args...) return generate_internal(rng, f, args...; constraint_ptr, constrained_symbols) @@ -719,6 +729,10 @@ function generate( trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + trace.fn = f + trace.args = args + trace.rng = rng + return trace, trace.weight end @@ -727,7 +741,7 @@ function generate_internal( f::Function, args::Vararg{Any,Nargs}; constraint_ptr::TracedRNumber, - constrained_symbols::Vector{Symbol}, + constrained_symbols::Set{Symbol}, ) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -947,4 +961,80 @@ end get_choices(trace::ProbProgTrace) = trace.choices +const Selection = Set{Symbol} +select(syms::Symbol...) = Set(syms) +choicemap() = Constraint() +const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} + +function metropolis_hastings( + trace::ProbProgTrace, + sel::Selection; + compiled_cache::Union{Nothing,CompiledFnCache}=nothing, +) + if trace.fn === nothing || trace.rng === nothing + error("MH requires a trace with fn and rng recorded (use generate to create trace)") + end + + constraints = Dict{Symbol,Any}() + constrained_symbols = Set{Symbol}() + + for (sym, val) in trace.choices + if !(sym in sel) + constraints[sym] = val + push!(constrained_symbols, sym) + end + end + + cache_key = (typeof(trace.fn), constrained_symbols) + + compiled_fn = nothing + if compiled_cache !== nothing + compiled_fn = get(compiled_cache, cache_key, nothing) + end + + if compiled_fn === nothing + function wrapper_fn(rng, constraint_ptr, args...) + return generate_internal( + rng, trace.fn, args...; constraint_ptr, constrained_symbols + ) + end + + constraint_ptr = ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraints)) + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn( + trace.rng, constraint_ptr, trace.args... + ) + + if compiled_cache !== nothing + compiled_cache[cache_key] = compiled_fn + end + end + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraints))) + + old_gc_state = GC.enable(false) + new_trace_ptr = nothing + try + new_trace_ptr, _, _ = compiled_fn(trace.rng, constraint_ptr, trace.args...) + finally + GC.enable(old_gc_state) + end + + new_trace = unsafe_pointer_to_objref(Ptr{Any}(Array(new_trace_ptr)[1])) + + new_trace.fn = trace.fn + new_trace.args = trace.args + new_trace.rng = trace.rng + + log_alpha = new_trace.weight - trace.weight + + if log(rand()) < log_alpha + return (new_trace, true) + else + return (trace, false) + end +end + end diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index a0efed9416..6cf629cd31 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -1,25 +1,22 @@ using Reactant, Test, Random -using Reactant: ProbProg +using Reactant: ProbProg, ReactantRNG # Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) -function my_model(seed, xs) - rng = Random.default_rng() - Random.seed!(rng, seed) - +function my_model(rng, xs) slope = ProbProg.sample( - normal, rng, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf ) intercept = ProbProg.sample( - normal, rng, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf ) ys = ProbProg.sample( - normal, rng, + normal, slope .* xs .+ intercept, 1.0, (length(xs),); @@ -27,40 +24,44 @@ function my_model(seed, xs) logpdf=normal_logpdf, ) - return rng.seed, ys + return ys end function my_inference_program(xs, ys, num_iters) xs_r = Reactant.to_rarray(xs) - constraints = ProbProg.choicemap() - constraints[:ys] = [ys] + constraint = ProbProg.choicemap() + constraint[:ys] = [ys] seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) - trace, _ = ProbProg.generate(my_model, seed, xs_r; constraints) - trace.args = (trace.retval[1], trace.args[2:end]...) # TODO: this is a temporary hack + trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint) + + compiled_cache = ProbProg.CompiledFnCache() for i in 1:num_iters - trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope)) - trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept)) - choices = ProbProg.get_choices(trace) - # @show i, choices[:slope], choices[:intercept] + trace, _ = ProbProg.metropolis_hastings( + trace, ProbProg.select(:slope); compiled_cache + ) + trace, _ = ProbProg.metropolis_hastings( + trace, ProbProg.select(:intercept); compiled_cache + ) end choices = ProbProg.get_choices(trace) - return (choices[:slope], choices[:intercept]) + return (Array(choices[:slope][1])[1], Array(choices[:intercept][1])[1]) end @testset "linear_regression" begin @testset "simulate" begin seed = Reactant.to_rarray(UInt64[1, 4]) - Random.seed!(42) # For Julia side RNG + rng = ReactantRNG(seed) xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] xs_r = Reactant.to_rarray(xs) - trace = ProbProg.simulate(my_model, seed, xs_r) + trace, _ = ProbProg.simulate(rng, my_model, xs_r) @test haskey(trace.choices, :slope) @test haskey(trace.choices, :intercept) @@ -68,11 +69,15 @@ end end @testset "inference" begin + Random.seed!(1) # For Julia side RNG xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] - slope, intercept = my_inference_program(xs, ys, 5) + slope, intercept = my_inference_program(xs, ys, 10000) + + @show slope, intercept - # @show slope, intercept + @test slope ≈ -2.0 rtol = 0.05 + @test intercept ≈ 10.0 rtol = 0.05 end end From e647b0d5ea44ca2605dd6e7e1c40ca904d6bd95a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 20 Jul 2025 17:56:38 -0500 Subject: [PATCH 63/87] improve --- src/ProbProg.jl | 5 +++++ test/probprog/linear_regression.jl | 20 +++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index b1162fcb8a..162f299143 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -966,6 +966,11 @@ select(syms::Symbol...) = Set(syms) choicemap() = Constraint() const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} +function with_compiled_cache(f) + cache = CompiledFnCache() + return f(cache) +end + function metropolis_hastings( trace::ProbProgTrace, sel::Selection; diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index 6cf629cd31..1930fdc47a 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -38,15 +38,17 @@ function my_inference_program(xs, ys, num_iters) trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint) - compiled_cache = ProbProg.CompiledFnCache() - - for i in 1:num_iters - trace, _ = ProbProg.metropolis_hastings( - trace, ProbProg.select(:slope); compiled_cache - ) - trace, _ = ProbProg.metropolis_hastings( - trace, ProbProg.select(:intercept); compiled_cache - ) + trace = ProbProg.with_compiled_cache() do cache + local t = trace + for _ in 1:num_iters + t, _ = ProbProg.metropolis_hastings( + t, ProbProg.select(:slope); compiled_cache=cache + ) + t, _ = ProbProg.metropolis_hastings( + t, ProbProg.select(:intercept); compiled_cache=cache + ) + end + return t end choices = ProbProg.get_choices(trace) From 94b9e3a08a8ee6cd0cbfa4b51ebbe5f7402c5df7 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 20 Jul 2025 18:39:23 -0500 Subject: [PATCH 64/87] reorganize --- src/Reactant.jl | 2 +- src/probprog/Display.jl | 87 ++++ src/probprog/FFI.jl | 301 ++++++++++++ src/probprog/Inference.jl | 73 +++ src/{ProbProg.jl => probprog/Modeling.jl} | 527 +--------------------- src/probprog/ProbProg.jl | 29 ++ src/probprog/Types.jl | 49 ++ 7 files changed, 542 insertions(+), 526 deletions(-) create mode 100644 src/probprog/Display.jl create mode 100644 src/probprog/FFI.jl create mode 100644 src/probprog/Inference.jl rename src/{ProbProg.jl => probprog/Modeling.jl} (50%) create mode 100644 src/probprog/ProbProg.jl create mode 100644 src/probprog/Types.jl diff --git a/src/Reactant.jl b/src/Reactant.jl index 61fb3bfbb2..b3fad50edc 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -189,7 +189,7 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") -include("ProbProg.jl") +include("probprog/ProbProg.jl") # Serialization include("serialization/Serialization.jl") diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl new file mode 100644 index 0000000000..8c78fc2f0e --- /dev/null +++ b/src/probprog/Display.jl @@ -0,0 +1,87 @@ +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end \ No newline at end of file diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl new file mode 100644 index 0000000000..4637ff3cac --- /dev/null +++ b/src/probprog/FFI.jl @@ -0,0 +1,301 @@ +using ..Reactant: MLIR + +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + return nothing +end + +function addSampleToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_outputs_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + vals = Any[] + for i in 1:num_outputs + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) + end + end + + trace.choices[symbol] = tuple(vals...) + + return nothing +end + +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + return nothing +end + +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = tuple(vals...) + + return nothing +end + +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, symbol, nothing) + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + if size(dest) != size(tostore[i]) + if length(size(dest)) != length(size(tostore[i])) + @ccall printf( + "Shape size mismatch in constrained sample: %zd != %zd\n"::Cstring, + length(size(dest))::Csize_t, + length(size(tostore[i]))::Csize_t, + )::Cvoid + return nothing + end + for i in 1:length(size(dest)) + d = size(dest)[i] + t = size(tostore[i])[i] + if d != t + @ccall printf( + "Shape mismatch in `%zd`th dimension of constrained sample: %zd != %zd\n"::Cstring, + i::Csize_t, + size(dest)[i]::Csize_t, + size(tostore[i])[i]::Csize_t, + )::Cvoid + return nothing + end + end + end + + dest .= tostore[i] + end + end + + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_sample_to_trace_ptr = @cfunction( + addSampleToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + + return nothing +end \ No newline at end of file diff --git a/src/probprog/Inference.jl b/src/probprog/Inference.jl new file mode 100644 index 0000000000..cd01dc65c3 --- /dev/null +++ b/src/probprog/Inference.jl @@ -0,0 +1,73 @@ +using ..Reactant: ConcreteRNumber +using ..Compiler: @compile + +function metropolis_hastings( + trace::ProbProgTrace, + sel::Selection; + compiled_cache::Union{Nothing,CompiledFnCache}=nothing, +) + if trace.fn === nothing || trace.rng === nothing + error("MH requires a trace with fn and rng recorded (use generate to create trace)") + end + + constraints = Dict{Symbol,Any}() + constrained_symbols = Set{Symbol}() + + for (sym, val) in trace.choices + if !(sym in sel) + constraints[sym] = val + push!(constrained_symbols, sym) + end + end + + cache_key = (typeof(trace.fn), constrained_symbols) + + compiled_fn = nothing + if compiled_cache !== nothing + compiled_fn = get(compiled_cache, cache_key, nothing) + end + + if compiled_fn === nothing + function wrapper_fn(rng, constraint_ptr, args...) + return generate_internal( + rng, trace.fn, args...; constraint_ptr, constrained_symbols + ) + end + + constraint_ptr = ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraints)) + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn( + trace.rng, constraint_ptr, trace.args... + ) + + if compiled_cache !== nothing + compiled_cache[cache_key] = compiled_fn + end + end + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraints))) + + old_gc_state = GC.enable(false) + new_trace_ptr = nothing + try + new_trace_ptr, _, _ = compiled_fn(trace.rng, constraint_ptr, trace.args...) + finally + GC.enable(old_gc_state) + end + + new_trace = unsafe_pointer_to_objref(Ptr{Any}(Array(new_trace_ptr)[1])) + + new_trace.fn = trace.fn + new_trace.args = trace.args + new_trace.rng = trace.rng + + log_alpha = new_trace.weight - trace.weight + + if log(rand()) < log_alpha + return (new_trace, true) + else + return (trace, false) + end +end \ No newline at end of file diff --git a/src/ProbProg.jl b/src/probprog/Modeling.jl similarity index 50% rename from src/ProbProg.jl rename to src/probprog/Modeling.jl index 162f299143..cb0fe0d0c0 100644 --- a/src/ProbProg.jl +++ b/src/probprog/Modeling.jl @@ -1,356 +1,6 @@ -module ProbProg - using ..Reactant: - MLIR, - TracedUtils, - AbstractConcreteArray, - AbstractConcreteNumber, - AbstractRNG, - TracedRArray, - TracedRNumber, - ConcreteRNumber, - Ops + MLIR, TracedUtils, AbstractRNG, AbstractConcreteArray, TracedRArray, ConcreteRNumber using ..Compiler: @jit, @compile -using Enzyme -using Base: ReentrantLock -using Random - -mutable struct ProbProgTrace - fn::Union{Nothing,Function} - args::Union{Nothing,Tuple} - choices::Dict{Symbol,Any} - retval::Any - weight::Any - subtraces::Dict{Symbol,Any} - rng::Union{Nothing,AbstractRNG} - - function ProbProgTrace(fn::Function, args::Tuple) - return new( - fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing - ) - end - - function ProbProgTrace() - return new( - nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing - ) - end -end - -const Constraint = Dict{Symbol,Any} - -const _trace_ref_lock = ReentrantLock() -const _trace_refs = Vector{Any}() - -function _keepalive!(tr::ProbProgTrace) - lock(_trace_ref_lock) - try - push!(_trace_refs, tr) - finally - unlock(_trace_ref_lock) - end - return tr -end - -function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) - tr = ProbProgTrace() - _keepalive!(tr) - - unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) - return nothing -end - -function addSampleToTrace( - trace_ptr_ptr::Ptr{Ptr{Any}}, - symbol_ptr_ptr::Ptr{Ptr{Any}}, - sample_ptr_array::Ptr{Ptr{Any}}, - num_outputs_ptr::Ptr{UInt64}, - ndims_array::Ptr{UInt64}, - shape_ptr_array::Ptr{Ptr{UInt64}}, - width_array::Ptr{UInt64}, -) - trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace - symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol - num_outputs = unsafe_load(num_outputs_ptr) - ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) - width_array = unsafe_wrap(Array, width_array, num_outputs) - shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) - sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) - - vals = Any[] - for i in 1:num_outputs - ndims = ndims_array[i] - width = width_array[i] - shape_ptr = shape_ptr_array[i] - sample_ptr = sample_ptr_array[i] - - julia_type = if width == 32 - Float32 - elseif width == 64 - Float64 - elseif width == 1 - Bool - else - nothing - end - - if julia_type === nothing - @ccall printf( - "Unsupported datatype width: %lld\n"::Cstring, width::Int64 - )::Cvoid - return nothing - end - - if ndims == 0 - push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) - else - shape = unsafe_wrap(Array, shape_ptr, ndims) - push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) - end - end - - trace.choices[symbol] = tuple(vals...) - - return nothing -end - -function addSubtrace( - trace_ptr_ptr::Ptr{Ptr{Any}}, - symbol_ptr_ptr::Ptr{Ptr{Any}}, - subtrace_ptr_ptr::Ptr{Ptr{Any}}, -) - trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace - symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol - subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace - - trace.subtraces[symbol] = subtrace - - return nothing -end - -function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) - trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace - trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) - return nothing -end - -function addRetvalToTrace( - trace_ptr_ptr::Ptr{Ptr{Any}}, - retval_ptr_array::Ptr{Ptr{Any}}, - num_results_ptr::Ptr{UInt64}, - ndims_array::Ptr{UInt64}, - shape_ptr_array::Ptr{Ptr{UInt64}}, - width_array::Ptr{UInt64}, -) - trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace - - num_results = unsafe_load(num_results_ptr) - - if num_results == 0 - return nothing - end - - ndims_array = unsafe_wrap(Array, ndims_array, num_results) - width_array = unsafe_wrap(Array, width_array, num_results) - shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) - retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) - - vals = Any[] - for i in 1:num_results - ndims = ndims_array[i] - width = width_array[i] - shape_ptr = shape_ptr_array[i] - retval_ptr = retval_ptr_array[i] - - julia_type = if width == 32 - Float32 - elseif width == 64 - Float64 - elseif width == 1 - Bool - else - nothing - end - - if julia_type === nothing - @ccall printf( - "Unsupported datatype width: %lld\n"::Cstring, width::Int64 - )::Cvoid - return nothing - end - - if ndims == 0 - push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) - else - shape = unsafe_wrap(Array, shape_ptr, ndims) - push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) - end - end - - trace.retval = tuple(vals...) - - return nothing -end - -function getSampleFromConstraint( - constraint_ptr_ptr::Ptr{Ptr{Any}}, - symbol_ptr_ptr::Ptr{Ptr{Any}}, - sample_ptr_array::Ptr{Ptr{Any}}, - num_samples_ptr::Ptr{UInt64}, - ndims_array::Ptr{UInt64}, - shape_ptr_array::Ptr{Ptr{UInt64}}, - width_array::Ptr{UInt64}, -) - constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint - symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol - num_samples = unsafe_load(num_samples_ptr) - ndims_array = unsafe_wrap(Array, ndims_array, num_samples) - width_array = unsafe_wrap(Array, width_array, num_samples) - shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) - sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) - - tostore = get(constraint, symbol, nothing) - - for i in 1:num_samples - ndims = ndims_array[i] - width = width_array[i] - shape_ptr = shape_ptr_array[i] - sample_ptr = sample_ptr_array[i] - - julia_type = if width == 32 - Float32 - elseif width == 64 - Float64 - elseif width == 1 - Bool - else - nothing - end - - if julia_type === nothing - @ccall printf( - "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t - )::Cvoid - return nothing - end - - if julia_type != eltype(tostore[i]) - @ccall printf( - "Type mismatch in constrained sample: %s != %s\n"::Cstring, - string(julia_type)::Cstring, - string(eltype(tostore[i]))::Cstring, - )::Cvoid - return nothing - end - - if ndims == 0 - unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) - else - shape = unsafe_wrap(Array, shape_ptr, ndims) - dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) - - if size(dest) != size(tostore[i]) - if length(size(dest)) != length(size(tostore[i])) - @ccall printf( - "Shape size mismatch in constrained sample: %zd != %zd\n"::Cstring, - length(size(dest))::Csize_t, - length(size(tostore[i]))::Csize_t, - )::Cvoid - return nothing - end - for i in 1:length(size(dest)) - d = size(dest)[i] - t = size(tostore[i])[i] - if d != t - @ccall printf( - "Shape mismatch in `%zd`th dimension of constrained sample: %zd != %zd\n"::Cstring, - i::Csize_t, - size(dest)[i]::Csize_t, - size(tostore[i])[i]::Csize_t, - )::Cvoid - return nothing - end - end - end - - dest .= tostore[i] - end - end - - return nothing -end - -function __init__() - init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} - )::Cvoid - - add_sample_to_trace_ptr = @cfunction( - addSampleToTrace, - Cvoid, - ( - Ptr{Ptr{Any}}, - Ptr{Ptr{Any}}, - Ptr{Ptr{Any}}, - Ptr{UInt64}, - Ptr{UInt64}, - Ptr{Ptr{UInt64}}, - Ptr{UInt64}, - ) - ) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} - )::Cvoid - - add_subtrace_ptr = @cfunction( - addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) - ) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} - )::Cvoid - - add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} - )::Cvoid - - add_retval_to_trace_ptr = @cfunction( - addRetvalToTrace, - Cvoid, - ( - Ptr{Ptr{Any}}, - Ptr{Ptr{Any}}, - Ptr{UInt64}, - Ptr{UInt64}, - Ptr{Ptr{UInt64}}, - Ptr{UInt64}, - ), - ) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} - )::Cvoid - - get_sample_from_constraint_ptr = @cfunction( - getSampleFromConstraint, - Cvoid, - ( - Ptr{Ptr{Any}}, - Ptr{Ptr{Any}}, - Ptr{Ptr{Any}}, - Ptr{UInt64}, - Ptr{UInt64}, - Ptr{Ptr{UInt64}}, - Ptr{UInt64}, - ) - ) - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( - :enzyme_probprog_get_sample_from_constraint::Cstring, - get_sample_from_constraint_ptr::Ptr{Cvoid}, - )::Cvoid - - return nothing -end function sample( rng::AbstractRNG, @@ -869,177 +519,4 @@ function generate_internal( weight = TracedRArray{Float64,0}((), weight, ()) return trace, weight, result -end - -# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 -function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - - indent_vert_str = join(indent_vert) - indent_str = join(indent) - indent_last_str = join(indent_last) - - sorted_choices = sort(collect(trace.choices); by=x -> x[1]) - n = length(sorted_choices) - - if trace.retval !== nothing - n += 1 - end - - if trace.weight !== nothing - n += 1 - end - - cur = 1 - - if trace.retval !== nothing - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") - cur += 1 - end - - if trace.weight !== nothing - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") - cur += 1 - end - - for (key, value) in sorted_choices - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - - sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) - n += length(sorted_subtraces) - - for (key, subtrace) in sorted_subtraces - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") - _show_pretty( - io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) - ) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) - println(io, "ProbProgTrace:") - if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing - println(io, " (empty)") - else - _show_pretty(io, trace, 0, ()) - end -end - -function Base.show(io::IO, trace::ProbProgTrace) - if get(io, :compact, false) - choices_count = length(trace.choices) - has_retval = trace.retval !== nothing - print(io, "ProbProgTrace($(choices_count) choices") - if has_retval - print(io, ", retval=$(trace.retval), weight=$(trace.weight)") - end - print(io, ")") - else - show(io, MIME"text/plain"(), trace) - end -end - -get_choices(trace::ProbProgTrace) = trace.choices - -const Selection = Set{Symbol} -select(syms::Symbol...) = Set(syms) -choicemap() = Constraint() -const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} - -function with_compiled_cache(f) - cache = CompiledFnCache() - return f(cache) -end - -function metropolis_hastings( - trace::ProbProgTrace, - sel::Selection; - compiled_cache::Union{Nothing,CompiledFnCache}=nothing, -) - if trace.fn === nothing || trace.rng === nothing - error("MH requires a trace with fn and rng recorded (use generate to create trace)") - end - - constraints = Dict{Symbol,Any}() - constrained_symbols = Set{Symbol}() - - for (sym, val) in trace.choices - if !(sym in sel) - constraints[sym] = val - push!(constrained_symbols, sym) - end - end - - cache_key = (typeof(trace.fn), constrained_symbols) - - compiled_fn = nothing - if compiled_cache !== nothing - compiled_fn = get(compiled_cache, cache_key, nothing) - end - - if compiled_fn === nothing - function wrapper_fn(rng, constraint_ptr, args...) - return generate_internal( - rng, trace.fn, args...; constraint_ptr, constrained_symbols - ) - end - - constraint_ptr = ConcreteRNumber( - reinterpret(UInt64, pointer_from_objref(constraints)) - ) - - compiled_fn = @compile optimize = :probprog wrapper_fn( - trace.rng, constraint_ptr, trace.args... - ) - - if compiled_cache !== nothing - compiled_cache[cache_key] = compiled_fn - end - end - - constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraints))) - - old_gc_state = GC.enable(false) - new_trace_ptr = nothing - try - new_trace_ptr, _, _ = compiled_fn(trace.rng, constraint_ptr, trace.args...) - finally - GC.enable(old_gc_state) - end - - new_trace = unsafe_pointer_to_objref(Ptr{Any}(Array(new_trace_ptr)[1])) - - new_trace.fn = trace.fn - new_trace.args = trace.args - new_trace.rng = trace.rng - - log_alpha = new_trace.weight - trace.weight - - if log(rand()) < log_alpha - return (new_trace, true) - else - return (trace, false) - end -end - -end +end \ No newline at end of file diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl new file mode 100644 index 0000000000..8a98e900b5 --- /dev/null +++ b/src/probprog/ProbProg.jl @@ -0,0 +1,29 @@ +module ProbProg + +using ..Reactant: + MLIR, + TracedUtils, + AbstractConcreteArray, + AbstractConcreteNumber, + AbstractRNG, + TracedRArray, + TracedRNumber, + ConcreteRNumber, + Ops +using ..Compiler: @jit, @compile +using Enzyme + +include("Types.jl") +include("FFI.jl") +include("Modeling.jl") +include("Inference.jl") +include("Display.jl") + +export ProbProgTrace, Constraint, Selection, CompiledFnCache +export get_choices, select, choicemap, with_compiled_cache + +export sample, call, simulate, generate + +export metropolis_hastings + +end \ No newline at end of file diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl new file mode 100644 index 0000000000..c74e92ad20 --- /dev/null +++ b/src/probprog/Types.jl @@ -0,0 +1,49 @@ +using Base: ReentrantLock + +mutable struct ProbProgTrace + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} + choices::Dict{Symbol,Any} + retval::Any + weight::Any + subtraces::Dict{Symbol,Any} + rng::Union{Nothing,AbstractRNG} + + function ProbProgTrace(fn::Function, args::Tuple) + return new( + fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing + ) + end + + function ProbProgTrace() + return new( + nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing + ) + end +end + +const Constraint = Dict{Symbol,Any} +const Selection = Set{Symbol} +const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} + +const _trace_ref_lock = ReentrantLock() +const _trace_refs = Vector{Any}() + +function _keepalive!(tr::ProbProgTrace) + lock(_trace_ref_lock) + try + push!(_trace_refs, tr) + finally + unlock(_trace_ref_lock) + end + return tr +end + +get_choices(trace::ProbProgTrace) = trace.choices +select(syms::Symbol...) = Set(syms) +choicemap() = Constraint() + +function with_compiled_cache(f) + cache = CompiledFnCache() + return f(cache) +end \ No newline at end of file From b2d583a0382e13b71478309a56e70b48bae7ad77 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 20 Jul 2025 18:41:51 -0500 Subject: [PATCH 65/87] format --- src/probprog/Display.jl | 2 +- src/probprog/FFI.jl | 2 +- src/probprog/Inference.jl | 2 +- src/probprog/Modeling.jl | 2 +- src/probprog/ProbProg.jl | 2 +- src/probprog/Types.jl | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl index 8c78fc2f0e..a81992eb71 100644 --- a/src/probprog/Display.jl +++ b/src/probprog/Display.jl @@ -84,4 +84,4 @@ function Base.show(io::IO, trace::ProbProgTrace) else show(io, MIME"text/plain"(), trace) end -end \ No newline at end of file +end diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl index 4637ff3cac..eb05cc402d 100644 --- a/src/probprog/FFI.jl +++ b/src/probprog/FFI.jl @@ -298,4 +298,4 @@ function __init__() )::Cvoid return nothing -end \ No newline at end of file +end diff --git a/src/probprog/Inference.jl b/src/probprog/Inference.jl index cd01dc65c3..6a9d4e1aa2 100644 --- a/src/probprog/Inference.jl +++ b/src/probprog/Inference.jl @@ -70,4 +70,4 @@ function metropolis_hastings( else return (trace, false) end -end \ No newline at end of file +end diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index cb0fe0d0c0..523771c1f7 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -519,4 +519,4 @@ function generate_internal( weight = TracedRArray{Float64,0}((), weight, ()) return trace, weight, result -end \ No newline at end of file +end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index 8a98e900b5..a4f55fa0dd 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -26,4 +26,4 @@ export sample, call, simulate, generate export metropolis_hastings -end \ No newline at end of file +end diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl index c74e92ad20..12a5ba8921 100644 --- a/src/probprog/Types.jl +++ b/src/probprog/Types.jl @@ -46,4 +46,4 @@ choicemap() = Constraint() function with_compiled_cache(f) cache = CompiledFnCache() return f(cache) -end \ No newline at end of file +end From 65d359583fb0d75280df28a5932dffe00e4e19c5 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 20 Jul 2025 19:53:09 -0500 Subject: [PATCH 66/87] fix up tests --- test/probprog/blr.jl | 34 +++++++++++++++++------------- test/probprog/generate.jl | 7 ++++-- test/probprog/linear_regression.jl | 5 ++++- test/probprog/simulate.jl | 5 ++++- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index 615edb842d..a4756262a5 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -1,31 +1,32 @@ using Reactant, Test, Random -using Reactant: ProbProg +using Reactant: ProbProg, ReactantRNG -function normal(rng, μ, σ, shape) - return μ .+ σ .* randn(rng, shape) -end +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function bernoulli_logit(rng, logit, shape) - return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) end -function blr(seed, N, K) - rng = Random.default_rng() - Random.seed!(rng, seed) +bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +bernoulli_logit_logpdf(x, logit, _) = sum(x .* logit .- log1p.(exp.(logit))) +# https://github.com/facebookresearch/pplbench/blob/main/pplbench/models/logistic_regression.py +function blr(rng, N, K) # α ~ Normal(0, 10, size = 1) - α = ProbProg.sample(normal, rng, 0, 10, (1,); symbol=:α) + α = ProbProg.sample(rng, normal, 0, 10, (1,); symbol=:α, logpdf=normal_logpdf) # β ~ Normal(0, 2.5, size = K) - β = ProbProg.sample(normal, rng, 0, 2.5, (K,); symbol=:β) + β = ProbProg.sample(rng, normal, 0, 2.5, (K,); symbol=:β, logpdf=normal_logpdf) # X ~ Normal(0, 10, size = (N, K)) - X = ProbProg.sample(normal, rng, 0, 10, (N, K); symbol=:X) + X = ProbProg.sample(rng, normal, 0, 10, (N, K); symbol=:X, logpdf=normal_logpdf) # μ = α .+ X * β μ = α .+ X * β - Y = ProbProg.sample(bernoulli_logit, rng, μ, (N,); symbol=:Y) + Y = ProbProg.sample( + rng, bernoulli_logit, μ, (N,); symbol=:Y, logpdf=bernoulli_logit_logpdf + ) return Y end @@ -35,7 +36,10 @@ end K = 3 # number of features seed = Reactant.to_rarray(UInt64[1, 4]) - trace = ProbProg.simulate(blr, seed, N, K) + rng = ReactantRNG(seed) + + trace, _ = ProbProg.simulate(rng, blr, N, K) + println(trace) - @test size(Array(trace.retval)) == (N,) + @test size(trace.retval[1]) == (N,) end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index bed55ab634..321eb997c5 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -2,7 +2,10 @@ using Reactant, Test, Random, Statistics using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end function model(rng, μ, σ, shape) s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) @@ -49,7 +52,7 @@ end constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) - constrained_symbols = collect(keys(constraint1)) # This doesn't change + constrained_symbols = Set(keys(constraint1)) constraint_ptr1 = Reactant.ConcreteRNumber( reinterpret(UInt64, pointer_from_objref(constraint1)) diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index 1930fdc47a..f9269b1cd1 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -4,7 +4,10 @@ using Reactant: ProbProg, ReactantRNG # Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end function my_model(rng, xs) slope = ProbProg.sample( diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 6f36004a4c..b7bca27569 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -2,7 +2,10 @@ using Reactant, Test, Random using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end function product_two_normals(rng, μ, σ, shape) a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) From a29fbed6ad0872b8163ed143cbf6f81da67377fb Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sun, 20 Jul 2025 21:53:38 -0500 Subject: [PATCH 67/87] remove redundant cast --- src/probprog/Modeling.jl | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 523771c1f7..51e450d558 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -339,16 +339,8 @@ function simulate_internal( 1, ) - weight = MLIR.IR.result( - MLIR.Dialects.builtin.unrealized_conversion_cast( - [MLIR.IR.result(simulate_op, 2)]; - outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))], - ), - 1, - ) - trace = TracedRArray{UInt64,0}((), trace, ()) - weight = TracedRArray{Float64,0}((), weight, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) return trace, weight, result end @@ -507,16 +499,8 @@ function generate_internal( 1, ) - weight = MLIR.IR.result( - MLIR.Dialects.builtin.unrealized_conversion_cast( - [MLIR.IR.result(generate_op, 2)]; - outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))], - ), - 1, - ) - trace = TracedRArray{UInt64,0}((), trace, ()) - weight = TracedRArray{Float64,0}((), weight, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) return trace, weight, result end From 87ced72bb7c984d4153054d0bc5ed701ced3c66c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 28 Jul 2025 23:27:59 -0500 Subject: [PATCH 68/87] generate op fixup: replacing constrained_symbols with constrained_addresses. Address struct are now the key of Constraint dicts --- src/probprog/FFI.jl | 47 +++++++++++++++++++++++++- src/probprog/Modeling.jl | 39 ++++++++++++++-------- src/probprog/ProbProg.jl | 2 +- src/probprog/Types.jl | 58 +++++++++++++++++++++++++++----- test/probprog/generate.jl | 69 ++++++++++++++++++++++++++++++++++----- 5 files changed, 183 insertions(+), 32 deletions(-) diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl index eb05cc402d..70fb6c0618 100644 --- a/src/probprog/FFI.jl +++ b/src/probprog/FFI.jl @@ -157,7 +157,14 @@ function getSampleFromConstraint( shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) - tostore = get(constraint, symbol, nothing) + tostore = get(constraint, Address(symbol), nothing) + + if tostore === nothing + @ccall printf( + "No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + return nothing + end for i in 1:num_samples ndims = ndims_array[i] @@ -228,6 +235,37 @@ function getSampleFromConstraint( return nothing end +function getSubconstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subconstraint_ptr_ptr::Ptr{Ptr{Any}}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subconstraint = Constraint() + + for (key, value) in constraint + if key.path[1] == symbol + @assert isa(key, Address) "Expected Address type for constraint key" + @assert length(key.path) > 1 "Expected composite address with length > 1" + tail_address = Address(key.path[2:end]) + subconstraint[tail_address] = value + end + end + + if isempty(subconstraint) + @ccall printf( + "No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + return nothing + end + + _keepalive!(subconstraint) + unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint)) + return nothing +end + function __init__() init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( @@ -297,5 +335,12 @@ function __init__() get_sample_from_constraint_ptr::Ptr{Cvoid}, )::Cvoid + get_subconstraint_ptr = @cfunction( + getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid} + )::Cvoid + return nothing end diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 51e450d558..750b6340b0 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -349,15 +349,16 @@ function generate( rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}; - constraint::Constraint=Dict{Symbol,Any}(), + constraint::Constraint=Constraint(), ) where {Nargs} trace = nothing constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) - constrained_symbols = Set(keys(constraint)) + + constrained_addresses = _extract_addresses(constraint) function wrapper_fn(rng, constraint_ptr, args...) - return generate_internal(rng, f, args...; constraint_ptr, constrained_symbols) + return generate_internal(rng, f, args...; constraint_ptr, constrained_addresses) end compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) @@ -378,12 +379,18 @@ function generate( return trace, trace.weight end +function _extract_addresses(constraint::Constraint) + addresses = Set(keys(constraint)) + + return addresses +end + function generate_internal( rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}; constraint_ptr::TracedRNumber, - constrained_symbols::Set{Symbol}, + constrained_addresses::Set{Address}, ) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") @@ -441,15 +448,19 @@ function generate_internal( 1, ) - constrained_symbols_attr = MLIR.IR.Attribute[] - for sym in constrained_symbols - addr = reinterpret(UInt64, pointer_from_objref(sym)) - push!( - constrained_symbols_attr, - @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( - MLIR.IR.context()::MLIR.API.MlirContext, addr::UInt64 - )::MLIR.IR.Attribute - ) + constrained_addresses_attr = MLIR.IR.Attribute[] + for address in constrained_addresses + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr)) end trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( @@ -464,7 +475,7 @@ function generate_internal( weight=weight_ty, outputs=out_tys, fn=fn_attr, - constrained_symbols=MLIR.IR.Attribute(constrained_symbols_attr), + constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), ) for (i, res) in enumerate(linear_results) diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index a4f55fa0dd..b795677a52 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -19,7 +19,7 @@ include("Modeling.jl") include("Inference.jl") include("Display.jl") -export ProbProgTrace, Constraint, Selection, CompiledFnCache +export ProbProgTrace, Constraint, Selection, CompiledFnCache, Address export get_choices, select, choicemap, with_compiled_cache export sample, call, simulate, generate diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl index 12a5ba8921..d51f5f9f78 100644 --- a/src/probprog/Types.jl +++ b/src/probprog/Types.jl @@ -22,26 +22,68 @@ mutable struct ProbProgTrace end end -const Constraint = Dict{Symbol,Any} +struct Address + path::Vector{Symbol} + + Address(path::Vector{Symbol}) = new(path) +end +Address(sym::Symbol) = Address([sym]) +Address(syms::Symbol...) = Address([syms...]) + +Base.:(==)(a::Address, b::Address) = a.path == b.path +Base.hash(a::Address, h::UInt) = hash(a.path, h) + +mutable struct Constraint <: AbstractDict{Address,Any} + dict::Dict{Address,Any} + + function Constraint(pairs::Pair...) + dict = Dict{Address,Any}() + for pair in pairs + symbols = Symbol[] + current = pair + while isa(current, Pair) && isa(current.first, Symbol) + push!(symbols, current.first) + current = current.second + end + dict[Address(symbols...)] = current + end + return new(dict) + end + + Constraint() = new(Dict{Address,Any}()) + Constraint(d::Dict{Address,Any}) = new(d) +end + +Base.getindex(c::Constraint, k::Address) = c.dict[k] +Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v) +Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k) +Base.keys(c::Constraint) = keys(c.dict) +Base.values(c::Constraint) = values(c.dict) +Base.iterate(c::Constraint) = iterate(c.dict) +Base.iterate(c::Constraint, state) = iterate(c.dict, state) +Base.length(c::Constraint) = length(c.dict) +Base.isempty(c::Constraint) = isempty(c.dict) +Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) +Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) + const Selection = Set{Symbol} const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} -const _trace_ref_lock = ReentrantLock() -const _trace_refs = Vector{Any}() +const _probprog_ref_lock = ReentrantLock() +const _probprog_refs = IdDict() -function _keepalive!(tr::ProbProgTrace) - lock(_trace_ref_lock) +function _keepalive!(tr::Any) + lock(_probprog_ref_lock) try - push!(_trace_refs, tr) + _probprog_refs[tr] = tr finally - unlock(_trace_ref_lock) + unlock(_probprog_ref_lock) end return tr end get_choices(trace::ProbProgTrace) = trace.choices select(syms::Symbol...) = Set(syms) -choicemap() = Constraint() function with_compiled_cache(f) cache = CompiledFnCache() diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 321eb997c5..189f2a0f44 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -13,6 +13,19 @@ function model(rng, μ, σ, shape) return t end +function two_normals(rng, μ, σ, shape) + x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + return y +end + +function nested_model(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + return u +end + @testset "Generate" begin @testset "unconstrained" begin shape = (1000,) @@ -31,15 +44,55 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - constraint = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) + constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) trace, weight = ProbProg.generate(rng, model, μ, σ, shape; constraint) - @test trace.choices[:s][1] == constraint[:s][1] + @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] expected_weight = - normal_logpdf(constraint[:s][1], 0.0, 1.0, shape) + - normal_logpdf(trace.choices[:t][1], constraint[:s][1], 1.0, shape) + normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) + + normal_logpdf( + trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape + ) + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "composite addresses" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint( + :s => (fill(0.1, shape),), + :t => :x => (fill(0.2, shape),), + :u => :y => (fill(0.3, shape),), + ) + + trace, weight = ProbProg.generate(rng, nested_model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == fill(0.1, shape) + @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) + @test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape) + + s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape) + tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape) + ty_weight = normal_logpdf( + trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape + ) + ux_weight = normal_logpdf( + trace.subtraces[:u].choices[:x][1], + trace.subtraces[:t].choices[:y][1], + 1.0, + shape, + ) + uy_weight = normal_logpdf( + fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape + ) + + expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight @test weight ≈ expected_weight atol = 1e-6 end @@ -50,16 +103,16 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) + constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) - constrained_symbols = Set(keys(constraint1)) + constrained_addresses = ProbProg._extract_addresses(constraint1) constraint_ptr1 = Reactant.ConcreteRNumber( reinterpret(UInt64, pointer_from_objref(constraint1)) ) wrapper_fn(constraint_ptr, rng, μ, σ) = ProbProg.generate_internal( - rng, model, μ, σ, shape; constraint_ptr, constrained_symbols + rng, model, μ, σ, shape; constraint_ptr, constrained_addresses ) compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, rng, μ, σ) @@ -67,7 +120,7 @@ end trace1, weight = compiled_fn(constraint_ptr1, rng, μ, σ) trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) - constraint2 = Dict{Symbol,Any}(:s => (fill(0.2, shape),)) + constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) constraint_ptr2 = Reactant.ConcreteRNumber( reinterpret(UInt64, pointer_from_objref(constraint2)) ) From ebec467f5a570519af37c59e7a16a900e4423fa3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 28 Jul 2025 23:58:29 -0500 Subject: [PATCH 69/87] minor --- src/probprog/Modeling.jl | 12 +++--------- src/probprog/Types.jl | 21 ++++++++++++--------- test/probprog/generate.jl | 2 +- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 750b6340b0..bea6ecf75a 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -243,9 +243,9 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + trace.rng = rng trace.fn = f trace.args = args - trace.rng = rng return trace, trace.weight end @@ -355,7 +355,7 @@ function generate( constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) - constrained_addresses = _extract_addresses(constraint) + constrained_addresses = extract_addresses(constraint) function wrapper_fn(rng, constraint_ptr, args...) return generate_internal(rng, f, args...; constraint_ptr, constrained_addresses) @@ -372,19 +372,13 @@ function generate( trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + trace.rng = rng trace.fn = f trace.args = args - trace.rng = rng return trace, trace.weight end -function _extract_addresses(constraint::Constraint) - addresses = Set(keys(constraint)) - - return addresses -end - function generate_internal( rng::AbstractRNG, f::Function, diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl index d51f5f9f78..d0b43b186a 100644 --- a/src/probprog/Types.jl +++ b/src/probprog/Types.jl @@ -1,23 +1,23 @@ using Base: ReentrantLock mutable struct ProbProgTrace - fn::Union{Nothing,Function} - args::Union{Nothing,Tuple} choices::Dict{Symbol,Any} retval::Any weight::Any subtraces::Dict{Symbol,Any} rng::Union{Nothing,AbstractRNG} - - function ProbProgTrace(fn::Function, args::Tuple) - return new( - fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing - ) - end + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} function ProbProgTrace() return new( - nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing + Dict{Symbol,Any}(), + nothing, + nothing, + Dict{Symbol,Any}(), + nothing, + nothing, + nothing, ) end end @@ -27,6 +27,7 @@ struct Address Address(path::Vector{Symbol}) = new(path) end + Address(sym::Symbol) = Address([sym]) Address(syms::Symbol...) = Address([syms...]) @@ -66,6 +67,8 @@ Base.isempty(c::Constraint) = isempty(c.dict) Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) +extract_addresses(constraint::Constraint) = Set(keys(constraint)) + const Selection = Set{Symbol} const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 189f2a0f44..efa717733d 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -105,7 +105,7 @@ end constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) - constrained_addresses = ProbProg._extract_addresses(constraint1) + constrained_addresses = ProbProg.extract_addresses(constraint1) constraint_ptr1 = Reactant.ConcreteRNumber( reinterpret(UInt64, pointer_from_objref(constraint1)) From f771bcb6266b3d38d7e10c9125df803076e09074 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 29 Jul 2025 00:08:31 -0500 Subject: [PATCH 70/87] update legacy inference API --- src/probprog/Inference.jl | 18 +++++++++--------- src/probprog/Types.jl | 2 +- test/probprog/linear_regression.jl | 5 ++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/probprog/Inference.jl b/src/probprog/Inference.jl index 6a9d4e1aa2..6f9c85a7d9 100644 --- a/src/probprog/Inference.jl +++ b/src/probprog/Inference.jl @@ -10,17 +10,17 @@ function metropolis_hastings( error("MH requires a trace with fn and rng recorded (use generate to create trace)") end - constraints = Dict{Symbol,Any}() - constrained_symbols = Set{Symbol}() - + constraint_pairs = Pair{Symbol,Any}[] for (sym, val) in trace.choices if !(sym in sel) - constraints[sym] = val - push!(constrained_symbols, sym) + push!(constraint_pairs, sym => val) end end + constraint = Constraint(constraint_pairs...) + + constrained_addresses = extract_addresses(constraint) - cache_key = (typeof(trace.fn), constrained_symbols) + cache_key = (typeof(trace.fn), constrained_addresses) compiled_fn = nothing if compiled_cache !== nothing @@ -30,12 +30,12 @@ function metropolis_hastings( if compiled_fn === nothing function wrapper_fn(rng, constraint_ptr, args...) return generate_internal( - rng, trace.fn, args...; constraint_ptr, constrained_symbols + rng, trace.fn, args...; constraint_ptr, constrained_addresses ) end constraint_ptr = ConcreteRNumber( - reinterpret(UInt64, pointer_from_objref(constraints)) + reinterpret(UInt64, pointer_from_objref(constraint)) ) compiled_fn = @compile optimize = :probprog wrapper_fn( @@ -47,7 +47,7 @@ function metropolis_hastings( end end - constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraints))) + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) old_gc_state = GC.enable(false) new_trace_ptr = nothing diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl index d0b43b186a..51df28e176 100644 --- a/src/probprog/Types.jl +++ b/src/probprog/Types.jl @@ -70,7 +70,7 @@ Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) extract_addresses(constraint::Constraint) = Set(keys(constraint)) const Selection = Set{Symbol} -const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any} +const CompiledFnCache = Dict{Tuple{Type,Set{Address}},Any} const _probprog_ref_lock = ReentrantLock() const _probprog_refs = IdDict() diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl index f9269b1cd1..1a4ff0b181 100644 --- a/test/probprog/linear_regression.jl +++ b/test/probprog/linear_regression.jl @@ -33,13 +33,12 @@ end function my_inference_program(xs, ys, num_iters) xs_r = Reactant.to_rarray(xs) - constraint = ProbProg.choicemap() - constraint[:ys] = [ys] + observations = ProbProg.Constraint(:ys => (ys,)) seed = Reactant.to_rarray(UInt64[1, 4]) rng = ReactantRNG(seed) - trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint) + trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint=observations) trace = ProbProg.with_compiled_cache() do cache local t = trace From 190818801f1caf170550ea352ccf7cb1caec4174 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 29 Jul 2025 00:28:15 -0500 Subject: [PATCH 71/87] simplify --- src/probprog/Modeling.jl | 327 ++++++++++++--------------------------- 1 file changed, 97 insertions(+), 230 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index bea6ecf75a..a447f47621 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -2,40 +2,16 @@ using ..Reactant: MLIR, TracedUtils, AbstractRNG, AbstractConcreteArray, TracedRArray, ConcreteRNumber using ..Compiler: @jit, @compile -function sample( - rng::AbstractRNG, - f::Function, - args::Vararg{Any,Nargs}; - symbol::Symbol=gensym("sample"), - logpdf::Union{Nothing,Function}=nothing, -) where {Nargs} - res = sample_internal(rng, f, args...; symbol, logpdf) - - @assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG" - - res = res[2:end] - - return length(res) == 1 ? res[1] : res -end - -function sample_internal( - rng::AbstractRNG, - f::Function, - args::Vararg{Any,Nargs}; - symbol::Symbol=gensym("sample"), - logpdf::Union{Nothing,Function}=nothing, -) where {Nargs} - argprefix::Symbol = gensym("samplearg") - resprefix::Symbol = gensym("sampleresult") - resargprefix::Symbol = gensym("sampleresarg") +function process_mlir_function(f::Function, args::Tuple, op_name::String) + argprefix = gensym(op_name * "arg") + resprefix = gensym(op_name * "result") + resargprefix = gensym(op_name * "resarg") wrapper_fn = (all_args...) -> begin res = f(all_args...) (all_args[1], (res isa Tuple ? res : (res,))...) end - args = (rng, args...) - mlir_fn_res = invokelatest( TracedUtils.make_mlir_fn, wrapper_fn, @@ -49,10 +25,11 @@ function sample_internal( resprefix, resargprefix, ) - (; result, linear_args, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f + return mlir_fn_res, argprefix, resprefix, resargprefix +end + +function process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) inputs = MLIR.IR.Value[] for a in linear_args idx, path = TracedUtils.get_argidx(a, argprefix) @@ -62,11 +39,74 @@ function sample_internal( if fnwrap && idx > 1 idx -= 1 end - TracedUtils.push_val!(inputs, args[idx], path[3:end]) end end + return inputs +end + +function process_mlir_outputs( + op, linear_results, result, f, args, fnwrap, resprefix, argprefix, start_idx=0 +) + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(op, i + start_idx) + + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + end + + if TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if fnwrap && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap && idx > 2 + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) + TracedUtils.set!(res, (), resv) + end + end +end + +function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + res = sample_internal(rng, f, args...; symbol, logpdf) + + @assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG" + + res = res[2:end] + + return length(res) == 1 ? res[1] : res +end + +function sample_internal( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( + f, args, "sample" + ) + + (; result, linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] sym = TracedUtils.get_attribute_by_name(func2, "sym_name") @@ -116,30 +156,9 @@ function sample_internal( name=Base.String(symbol), ) - for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(sample_op, i) - - if TracedUtils.has_idx(res, resprefix) - path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], resv) - end - - if TracedUtils.has_idx(res, argprefix) - idx, path = TracedUtils.get_argidx(res, argprefix) - if fnwrap && idx == 2 - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - - if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) - TracedUtils.set!(res, (), resv) - end - end + process_mlir_outputs( + sample_op, linear_results, result, f, args, fnwrap, resprefix, argprefix + ) return result end @@ -155,76 +174,24 @@ function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nar end function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix::Symbol = gensym("callarg") - resprefix::Symbol = gensym("callresult") - resargprefix::Symbol = gensym("callresarg") - - wrapper_fn = (all_args...) -> begin - res = f(all_args...) - (all_args[1], (res isa Tuple ? res : (res,))...) - end - args = (rng, args...) + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function(f, args, "call") - mlir_fn_res = invokelatest( - TracedUtils.make_mlir_fn, - wrapper_fn, - args, - (), - string(f), - false; - do_transpose=false, - args_in_result=:result, - argprefix, - resprefix, - resargprefix, - ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 2 && fnwrap - TracedUtils.push_val!(inputs, f, path[3:end]) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.push_val!(inputs, args[idx], path[3:end]) - end - end - call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) - for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(call_op, i) - if TracedUtils.has_idx(res, resprefix) - path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], resv) - end - - if TracedUtils.has_idx(res, argprefix) - idx, path = TracedUtils.get_argidx(res, argprefix) - if fnwrap && idx == 2 - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - - if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) - TracedUtils.set!(res, (), resv) - end - end + process_mlir_outputs( + call_op, linear_results, result, f, args, fnwrap, resprefix, argprefix + ) return result end @@ -253,51 +220,21 @@ end function simulate_internal( rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs} ) where {Nargs} - argprefix::Symbol = gensym("simulatearg") - resprefix::Symbol = gensym("simulateresult") - resargprefix::Symbol = gensym("simulateresarg") - - wrapper_fn = (all_args...) -> begin - res = f(all_args...) - (all_args[1], (res isa Tuple ? res : (res,))...) - end - args = (rng, args...) - - mlir_fn_res = invokelatest( - TracedUtils.make_mlir_fn, - wrapper_fn, - args, - (), - string(f), - false; - do_transpose=false, - args_in_result=:result, - argprefix, - resprefix, - resargprefix, + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( + f, args, "simulate" ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 2 && fnwrap - TracedUtils.push_val!(inputs, f, path[3:end]) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.push_val!(inputs, args[idx], path[3:end]) - end - end - trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.IR.Type @@ -307,29 +244,9 @@ function simulate_internal( inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr ) - for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(simulate_op, i + 2) - if TracedUtils.has_idx(res, resprefix) - path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], resv) - end - - if TracedUtils.has_idx(res, argprefix) - idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 2 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - - if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) - TracedUtils.set!(res, (), resv) - end - end + process_mlir_outputs( + simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) trace = MLIR.IR.result( MLIR.Dialects.builtin.unrealized_conversion_cast( @@ -386,51 +303,21 @@ function generate_internal( constraint_ptr::TracedRNumber, constrained_addresses::Set{Address}, ) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") - - wrapper_fn = (all_args...) -> begin - res = f(all_args...) - (all_args[1], (res isa Tuple ? res : (res,))...) - end - args = (rng, args...) - - mlir_fn_res = invokelatest( - TracedUtils.make_mlir_fn, - wrapper_fn, - args, - (), - string(f), - false; - do_transpose=false, - args_in_result=:result, - argprefix, - resprefix, - resargprefix, + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( + f, args, "generate" ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 2 && fnwrap - TracedUtils.push_val!(inputs, f, path[3:end]) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.push_val!(inputs, args[idx], path[3:end]) - end - end - constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( MLIR.IR.context()::MLIR.API.MlirContext )::MLIR.IR.Type @@ -472,29 +359,9 @@ function generate_internal( constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), ) - for (i, res) in enumerate(linear_results) - resv = MLIR.IR.result(generate_op, i + 2) - if TracedUtils.has_idx(res, resprefix) - path = TracedUtils.get_idx(res, resprefix) - TracedUtils.set!(result, path[2:end], resv) - end - - if TracedUtils.has_idx(res, argprefix) - idx, path = TracedUtils.get_argidx(res, argprefix) - if idx == 2 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap && idx > 2 - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - - if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) - TracedUtils.set!(res, (), resv) - end - end + process_mlir_outputs( + generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) trace = MLIR.IR.result( MLIR.Dialects.builtin.unrealized_conversion_cast( From 0b71444219536be02cad33dde8c3f0f822d13255 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 29 Jul 2025 00:30:13 -0500 Subject: [PATCH 72/87] cleanup --- src/probprog/Modeling.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index a447f47621..1796936b51 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -83,8 +83,6 @@ function sample( ) where {Nargs} res = sample_internal(rng, f, args...; symbol, logpdf) - @assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG" - res = res[2:end] return length(res) == 1 ? res[1] : res From 9bd1dee3337b5a61b66e2f6a636bdc66c559ae4a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 31 Jul 2025 02:00:26 -0500 Subject: [PATCH 73/87] fix deadlock --- src/Types.jl | 2 ++ src/probprog/Modeling.jl | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/Types.jl b/src/Types.jl index 48221a431d..4cb153b723 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -215,6 +215,7 @@ function ConcretePJRTArray( end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) +Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) @@ -412,6 +413,7 @@ function ConcreteIFRTArray( end Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data) +Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data) XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) return XLA.device(x.data) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 1796936b51..632ae7d6fb 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -199,11 +199,13 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where compiled_fn = @compile optimize = :probprog simulate_internal(rng, f, args...) - old_gc_state = GC.enable(false) - try + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer begin trace, _, _ = compiled_fn(rng, f, args...) - finally - GC.enable(old_gc_state) + + while !isready(trace) + yield() + end end trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) @@ -278,11 +280,13 @@ function generate( compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) - old_gc_state = GC.enable(false) - try + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint begin trace, _, _ = compiled_fn(rng, constraint_ptr, args...) - finally - GC.enable(old_gc_state) + + while !isready(trace) + yield() + end end trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) From c9ff7c05a2a67ea827f75dfa4cffe0ed69e22d60 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 31 Jul 2025 03:06:38 -0500 Subject: [PATCH 74/87] fix test --- test/probprog/generate.jl | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index efa717733d..2bae4df530 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -111,13 +111,21 @@ end reinterpret(UInt64, pointer_from_objref(constraint1)) ) - wrapper_fn(constraint_ptr, rng, μ, σ) = ProbProg.generate_internal( + wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate_internal( rng, model, μ, σ, shape; constraint_ptr, constrained_addresses ) - compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, rng, μ, σ) + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) - trace1, weight = compiled_fn(constraint_ptr1, rng, μ, σ) + trace1 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint1 begin + trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) + + while !isready(trace1) + yield() + end + end trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) @@ -125,7 +133,15 @@ end reinterpret(UInt64, pointer_from_objref(constraint2)) ) - trace2, _ = compiled_fn(constraint_ptr2, rng, μ, σ) + trace2 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint2 begin + trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) + + while !isready(trace2) + yield() + end + end trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1])) @test trace1.choices[:s][1] != trace2.choices[:s][1] From 31969890fcabb34f785f710b895207317e2df8c6 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 31 Jul 2025 13:49:06 -0500 Subject: [PATCH 75/87] don't print --- test/probprog/blr.jl | 1 - test/probprog/simulate.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl index a4756262a5..1dafcce76c 100644 --- a/test/probprog/blr.jl +++ b/test/probprog/blr.jl @@ -39,7 +39,6 @@ end rng = ReactantRNG(seed) trace, _ = ProbProg.simulate(rng, blr, N, K) - println(trace) @test size(trace.retval[1]) == (N,) end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index b7bca27569..0a5a870b14 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -63,7 +63,6 @@ end σ = Reactant.ConcreteRNumber(1.0) trace, weight = ProbProg.simulate(rng, model, μ, σ, shape) - println(trace) @test size(trace.retval[1]) == shape @test haskey(trace.choices, :s) From 4afda71657c6ae2de663914b309edcfc50d56bdb Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 31 Jul 2025 13:49:28 -0500 Subject: [PATCH 76/87] clean up postpasses --- src/Compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index f655ceec5c..e1c6313f83 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1231,7 +1231,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" -const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() From f7aab5b2dfa586c9fe708c4aaab7f65a78e9d741 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 19 Sep 2025 21:47:39 -0500 Subject: [PATCH 77/87] format --- deps/ReactantExtra/API.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index b0c441d73e..a20f4cd90c 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -502,8 +502,8 @@ MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices, return client.release(); } #else - *error = "ReactantExtra was not built with GPU support"; - return nullptr; + *error = "ReactantExtra was not built with GPU support"; + return nullptr; #endif } From 834de058112c2d1f18978d05ef3a971edd37566c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 14:46:04 -0500 Subject: [PATCH 78/87] remove probprog_no_lowering --- src/CompileOptions.jl | 1 - src/Compiler.jl | 55 --------------------------------------- test/probprog/simulate.jl | 8 ------ 3 files changed, 64 deletions(-) diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index f38dbc2e57..925c357e1a 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -230,7 +230,6 @@ function CompileOptions(; :just_batch, :none, :probprog, - :probprog_no_lowering, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index 054788e42b..87b9269e20 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1880,61 +1880,6 @@ function compile_mlir!( ), "no_enzyme", ) - elseif compile_options.optimization_passes === :probprog_no_lowering - run_pass_pipeline!( - mod, - join( - if compile_options.raise_first - [ - "mark-func-memory-effects", - opt_passes, - kern, - raise_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - probprog_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - ( - if compile_options.legalize_chlo_to_stablehlo - ["func.func(chlo-legalize-to-stablehlo)"] - else - [] - end - )..., - opt_passes2, - ] - else - [ - "mark-func-memory-effects", - opt_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - probprog_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - ( - if compile_options.legalize_chlo_to_stablehlo - ["func.func(chlo-legalize-to-stablehlo)"] - else - [] - end - )..., - opt_passes2, - kern, - raise_passes, - ] - end, - ",", - ), - "probprog_no_lowering", - ) elseif compile_options.optimization_passes === :probprog run_pass_pipeline!( mod, diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 0a5a870b14..a3c024c012 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -38,14 +38,6 @@ end ) @test contains(repr(before), "enzyme.simulate") - unlowered = @code_hlo optimize = :probprog_no_lowering ProbProg.simulate_internal( - rng, model, μ, σ, shape - ) - @test !contains(repr(unlowered), "enzyme.simulate") - @test contains(repr(unlowered), "enzyme.addSampleToTrace") - @test contains(repr(unlowered), "enzyme.addWeightToTrace") - @test contains(repr(unlowered), "enzyme.addRetvalToTrace") - after = @code_hlo optimize = :probprog ProbProg.simulate_internal( rng, model, μ, σ, shape ) From 7ad561d9634e2d7c9f3393d4172f3d0baa1735dd Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 14:51:03 -0500 Subject: [PATCH 79/87] undo jll change --- deps/ReactantExtra/API.cpp | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index a20f4cd90c..d8b1e188f8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -373,21 +373,6 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } -extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { - return wrap(mlir::enzyme::TraceType::get(unwrap(ctx))); -} - -extern "C" MLIR_CAPI_EXPORTED MlirType -enzymeConstraintTypeGet(MlirContext ctx) { - return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx))); -} - -extern "C" MLIR_CAPI_EXPORTED MlirAttribute -enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { - mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol); - return wrap(attr); -} - // Create profiler session and start profiling REACTANT_ABI tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, @@ -502,8 +487,8 @@ MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices, return client.release(); } #else - *error = "ReactantExtra was not built with GPU support"; - return nullptr; + *error = "ReactantExtra was not built with GPU support"; + return nullptr; #endif } From a31f43904331b36a7464c8a879330394a7c70c71 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 14:52:07 -0500 Subject: [PATCH 80/87] undo jll change --- deps/ReactantExtra/API.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index d8b1e188f8..b0c441d73e 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -373,6 +373,21 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } +extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { + return wrap(mlir::enzyme::TraceType::get(unwrap(ctx))); +} + +extern "C" MLIR_CAPI_EXPORTED MlirType +enzymeConstraintTypeGet(MlirContext ctx) { + return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx))); +} + +extern "C" MLIR_CAPI_EXPORTED MlirAttribute +enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { + mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol); + return wrap(attr); +} + // Create profiler session and start profiling REACTANT_ABI tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, From 96569142cf418d92165f933cbf1525bd2e08f326 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 15:01:35 -0500 Subject: [PATCH 81/87] clean --- src/probprog/Inference.jl | 73 ------------------------- src/probprog/ProbProg.jl | 3 -- test/probprog/blr.jl | 44 --------------- test/probprog/linear_regression.jl | 87 ------------------------------ test/runtests.jl | 2 - 5 files changed, 209 deletions(-) delete mode 100644 src/probprog/Inference.jl delete mode 100644 test/probprog/blr.jl delete mode 100644 test/probprog/linear_regression.jl diff --git a/src/probprog/Inference.jl b/src/probprog/Inference.jl deleted file mode 100644 index 6f9c85a7d9..0000000000 --- a/src/probprog/Inference.jl +++ /dev/null @@ -1,73 +0,0 @@ -using ..Reactant: ConcreteRNumber -using ..Compiler: @compile - -function metropolis_hastings( - trace::ProbProgTrace, - sel::Selection; - compiled_cache::Union{Nothing,CompiledFnCache}=nothing, -) - if trace.fn === nothing || trace.rng === nothing - error("MH requires a trace with fn and rng recorded (use generate to create trace)") - end - - constraint_pairs = Pair{Symbol,Any}[] - for (sym, val) in trace.choices - if !(sym in sel) - push!(constraint_pairs, sym => val) - end - end - constraint = Constraint(constraint_pairs...) - - constrained_addresses = extract_addresses(constraint) - - cache_key = (typeof(trace.fn), constrained_addresses) - - compiled_fn = nothing - if compiled_cache !== nothing - compiled_fn = get(compiled_cache, cache_key, nothing) - end - - if compiled_fn === nothing - function wrapper_fn(rng, constraint_ptr, args...) - return generate_internal( - rng, trace.fn, args...; constraint_ptr, constrained_addresses - ) - end - - constraint_ptr = ConcreteRNumber( - reinterpret(UInt64, pointer_from_objref(constraint)) - ) - - compiled_fn = @compile optimize = :probprog wrapper_fn( - trace.rng, constraint_ptr, trace.args... - ) - - if compiled_cache !== nothing - compiled_cache[cache_key] = compiled_fn - end - end - - constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) - - old_gc_state = GC.enable(false) - new_trace_ptr = nothing - try - new_trace_ptr, _, _ = compiled_fn(trace.rng, constraint_ptr, trace.args...) - finally - GC.enable(old_gc_state) - end - - new_trace = unsafe_pointer_to_objref(Ptr{Any}(Array(new_trace_ptr)[1])) - - new_trace.fn = trace.fn - new_trace.args = trace.args - new_trace.rng = trace.rng - - log_alpha = new_trace.weight - trace.weight - - if log(rand()) < log_alpha - return (new_trace, true) - else - return (trace, false) - end -end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index b795677a52..86da3bb386 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -16,7 +16,6 @@ using Enzyme include("Types.jl") include("FFI.jl") include("Modeling.jl") -include("Inference.jl") include("Display.jl") export ProbProgTrace, Constraint, Selection, CompiledFnCache, Address @@ -24,6 +23,4 @@ export get_choices, select, choicemap, with_compiled_cache export sample, call, simulate, generate -export metropolis_hastings - end diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl deleted file mode 100644 index 1dafcce76c..0000000000 --- a/test/probprog/blr.jl +++ /dev/null @@ -1,44 +0,0 @@ -using Reactant, Test, Random -using Reactant: ProbProg, ReactantRNG - -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) - -function normal_logpdf(x, μ, σ, _) - return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) -end - -bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) -bernoulli_logit_logpdf(x, logit, _) = sum(x .* logit .- log1p.(exp.(logit))) - -# https://github.com/facebookresearch/pplbench/blob/main/pplbench/models/logistic_regression.py -function blr(rng, N, K) - # α ~ Normal(0, 10, size = 1) - α = ProbProg.sample(rng, normal, 0, 10, (1,); symbol=:α, logpdf=normal_logpdf) - - # β ~ Normal(0, 2.5, size = K) - β = ProbProg.sample(rng, normal, 0, 2.5, (K,); symbol=:β, logpdf=normal_logpdf) - - # X ~ Normal(0, 10, size = (N, K)) - X = ProbProg.sample(rng, normal, 0, 10, (N, K); symbol=:X, logpdf=normal_logpdf) - - # μ = α .+ X * β - μ = α .+ X * β - - Y = ProbProg.sample( - rng, bernoulli_logit, μ, (N,); symbol=:Y, logpdf=bernoulli_logit_logpdf - ) - - return Y -end - -@testset "BLR" begin - N = 5 # number of observations - K = 3 # number of features - seed = Reactant.to_rarray(UInt64[1, 4]) - - rng = ReactantRNG(seed) - - trace, _ = ProbProg.simulate(rng, blr, N, K) - - @test size(trace.retval[1]) == (N,) -end diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl deleted file mode 100644 index 1a4ff0b181..0000000000 --- a/test/probprog/linear_regression.jl +++ /dev/null @@ -1,87 +0,0 @@ -using Reactant, Test, Random -using Reactant: ProbProg, ReactantRNG - -# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ - -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) - -function normal_logpdf(x, μ, σ, _) - return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) -end - -function my_model(rng, xs) - slope = ProbProg.sample( - rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf - ) - intercept = ProbProg.sample( - rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf - ) - - ys = ProbProg.sample( - rng, - normal, - slope .* xs .+ intercept, - 1.0, - (length(xs),); - symbol=:ys, - logpdf=normal_logpdf, - ) - - return ys -end - -function my_inference_program(xs, ys, num_iters) - xs_r = Reactant.to_rarray(xs) - - observations = ProbProg.Constraint(:ys => (ys,)) - - seed = Reactant.to_rarray(UInt64[1, 4]) - rng = ReactantRNG(seed) - - trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint=observations) - - trace = ProbProg.with_compiled_cache() do cache - local t = trace - for _ in 1:num_iters - t, _ = ProbProg.metropolis_hastings( - t, ProbProg.select(:slope); compiled_cache=cache - ) - t, _ = ProbProg.metropolis_hastings( - t, ProbProg.select(:intercept); compiled_cache=cache - ) - end - return t - end - - choices = ProbProg.get_choices(trace) - return (Array(choices[:slope][1])[1], Array(choices[:intercept][1])[1]) -end - -@testset "linear_regression" begin - @testset "simulate" begin - seed = Reactant.to_rarray(UInt64[1, 4]) - rng = ReactantRNG(seed) - - xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] - xs_r = Reactant.to_rarray(xs) - - trace, _ = ProbProg.simulate(rng, my_model, xs_r) - - @test haskey(trace.choices, :slope) - @test haskey(trace.choices, :intercept) - @test haskey(trace.choices, :ys) - end - - @testset "inference" begin - Random.seed!(1) # For Julia side RNG - xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] - ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] - - slope, intercept = my_inference_program(xs, ys, 10000) - - @show slope, intercept - - @test slope ≈ -2.0 rtol = 0.05 - @test intercept ≈ 10.0 rtol = 0.05 - end -end diff --git a/test/runtests.jl b/test/runtests.jl index a796d58be6..ccb96984ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,9 +71,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" @safetestset "ProbProg Sample" include("probprog/sample.jl") - @safetestset "ProbProg BLR" include("probprog/blr.jl") @safetestset "ProbProg Simulate" include("probprog/simulate.jl") @safetestset "ProbProg Generate" include("probprog/generate.jl") - @safetestset "ProbProg Linear Regression" include("probprog/linear_regression.jl") end end From 1532269b6ec09ffbddf057a91d472093b1c36ec3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 16:48:08 -0500 Subject: [PATCH 82/87] clean up and improve --- src/probprog/Modeling.jl | 90 +++++++++++---------------------------- src/probprog/ProbProg.jl | 13 ++++-- test/probprog/generate.jl | 22 +++++----- test/probprog/sample.jl | 16 +++---- test/probprog/simulate.jl | 25 +++++------ 5 files changed, 68 insertions(+), 98 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 632ae7d6fb..f07ee78659 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -81,30 +81,14 @@ function sample( symbol::Symbol=gensym("sample"), logpdf::Union{Nothing,Function}=nothing, ) where {Nargs} - res = sample_internal(rng, f, args...; symbol, logpdf) - - res = res[2:end] - - return length(res) == 1 ? res[1] : res -end - -function sample_internal( - rng::AbstractRNG, - f::Function, - args::Vararg{Any,Nargs}; - symbol::Symbol=gensym("sample"), - logpdf::Union{Nothing,Function}=nothing, -) where {Nargs} - args = (rng, args...) - mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( - f, args, "sample" - ) + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_mlir_function(f, args_with_rng, "sample") (; result, linear_args, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + inputs = process_mlir_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] sym = TracedUtils.get_attribute_by_name(func2, "sym_name") @@ -115,23 +99,15 @@ function sample_internal( MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 )::MLIR.IR.Attribute - # Construct MLIR attribute if Julia logpdf function is provided. + # Construct logpdf attribute if `logpdf` function is provided. logpdf_attr = nothing - if logpdf !== nothing - # Just to get static information about the sample. TODO: kwargs? - example_sample = f(args...) - - # Remove AbstractRNG from `f`'s argument list if present, assuming that - # logpdf parameters follows `(sample, args...)` convention. - logpdf_args = nothing - if !isempty(args) && args[1] isa AbstractRNG - logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs? - else - logpdf_args = (example_sample, args...) - end + if logpdf isa Function + samples = f(args_with_rng...) - logpdf_mlir = invokelatest( - TracedUtils.make_mlir_fn, + # Assume that logpdf parameters follow `(sample, args...)` convention. + logpdf_args = (samples, args...) + + logpdf_mlir = TracedUtils.make_mlir_fn( logpdf, logpdf_args, (), @@ -155,31 +131,21 @@ function sample_internal( ) process_mlir_outputs( - sample_op, linear_results, result, f, args, fnwrap, resprefix, argprefix + sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix ) return result end -function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} - res = @jit optimize = :probprog call_internal(rng, f, args...) - - res = map(res[2:end]) do r - r isa AbstractConcreteArray ? Array(r) : r - end - - return length(res) == 1 ? res[1] : res -end - -function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} - args = (rng, args...) - mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function(f, args, "call") +function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_mlir_function(f, args_with_rng, "call") (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped func2 = mlir_fn_res.f - inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + inputs = process_mlir_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] fname = TracedUtils.get_attribute_by_name(func2, "sym_name") @@ -188,16 +154,17 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) process_mlir_outputs( - call_op, linear_results, result, f, args, fnwrap, resprefix, argprefix + call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix ) return result end -function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} +# Gen-like helper function. +function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} trace = nothing - compiled_fn = @compile optimize = :probprog simulate_internal(rng, f, args...) + compiled_fn = @compile optimize = :probprog simulate(rng, f, args...) seed_buffer = only(rng.seed.data).buffer GC.@preserve seed_buffer begin @@ -217,13 +184,9 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where return trace, trace.weight end -function simulate_internal( - rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs} -) where {Nargs} +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} args = (rng, args...) - mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( - f, args, "simulate" - ) + mlir_fn_res, argprefix, resprefix, _ = process_mlir_function(f, args, "simulate") (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -262,7 +225,8 @@ function simulate_internal( return trace, weight, result end -function generate( +# Gen-like helper function. +function generate_( rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}; @@ -275,7 +239,7 @@ function generate( constrained_addresses = extract_addresses(constraint) function wrapper_fn(rng, constraint_ptr, args...) - return generate_internal(rng, f, args...; constraint_ptr, constrained_addresses) + return generate(rng, f, args...; constraint_ptr, constrained_addresses) end compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) @@ -298,7 +262,7 @@ function generate( return trace, trace.weight end -function generate_internal( +function generate( rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}; @@ -306,9 +270,7 @@ function generate_internal( constrained_addresses::Set{Address}, ) where {Nargs} args = (rng, args...) - mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( - f, args, "generate" - ) + mlir_fn_res, argprefix, resprefix, _ = process_mlir_function(f, args, "generate") (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index 86da3bb386..151efc6b8c 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -18,9 +18,16 @@ include("FFI.jl") include("Modeling.jl") include("Display.jl") -export ProbProgTrace, Constraint, Selection, CompiledFnCache, Address -export get_choices, select, choicemap, with_compiled_cache +# Types. +export ProbProgTrace, Constraint, Selection, Address -export sample, call, simulate, generate +# Utility functions. +export get_choices, select, choicemap + +# Core MLIR ops. +export sample, untraced_call, simulate, generate + +# Gen-like helper functions. +export simulate_, generate_ end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 2bae4df530..fdbcf20b6d 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -8,21 +8,21 @@ function normal_logpdf(x, μ, σ, _) end function model(rng, μ, σ, shape) - s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) - t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) return t end function two_normals(rng, μ, σ, shape) - x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) - y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + _, x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + _, y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) return y end function nested_model(rng, μ, σ, shape) - s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) - t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) - u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + _, u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) return u end @@ -33,7 +33,7 @@ end rng = ReactantRNG(seed) μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.generate(rng, model, μ, σ, shape) + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape) @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 end @@ -46,7 +46,7 @@ end constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) - trace, weight = ProbProg.generate(rng, model, μ, σ, shape; constraint) + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint) @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] @@ -71,7 +71,7 @@ end :u => :y => (fill(0.3, shape),), ) - trace, weight = ProbProg.generate(rng, nested_model, μ, σ, shape; constraint) + trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint) @test trace.choices[:s][1] == fill(0.1, shape) @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) @@ -111,7 +111,7 @@ end reinterpret(UInt64, pointer_from_objref(constraint1)) ) - wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate_internal( + wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate( rng, model, μ, σ, shape; constraint_ptr, constrained_addresses ) diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index 28a4ab6ee9..a6f41d8f22 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -4,19 +4,19 @@ using Reactant: ProbProg, ReactantRNG normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) function one_sample(rng, μ, σ, shape) - s = ProbProg.sample(rng, normal, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) return s end function two_samples(rng, μ, σ, shape) _ = ProbProg.sample(rng, normal, μ, σ, shape) - t = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, μ, σ, shape) return t end function compose(rng, μ, σ, shape) - s = ProbProg.sample(rng, normal, μ, σ, shape) - t = ProbProg.sample(rng, normal, s, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, s, σ, shape) return t end @@ -50,10 +50,10 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.call(rng, compose, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.untraced_call(rng, compose, μ, σ, shape) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.call(rng, compose, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.untraced_call(rng, compose, μ, σ, shape) @test !contains(repr(after), "enzyme.sample") end @@ -66,11 +66,11 @@ end rng1 = ReactantRNG(copy(seed)) - X = ProbProg.call(rng1, one_sample, μ, σ, shape) + _, X = @jit optimize = :probprog ProbProg.untraced_call(rng1, one_sample, μ, σ, shape) @test !all(rng1.seed .== seed) rng2 = ReactantRNG(copy(seed)) - Y = ProbProg.call(rng2, two_samples, μ, σ, shape) + _, Y = @jit optimize = :probprog ProbProg.untraced_call(rng2, two_samples, μ, σ, shape) @test !all(rng2.seed .== seed) @test !all(rng2.seed .== rng1.seed) diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index a3c024c012..551e669af1 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -8,20 +8,20 @@ function normal_logpdf(x, μ, σ, _) end function product_two_normals(rng, μ, σ, shape) - a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) - b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + _, a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + _, b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) return a .* b end function model(rng, μ, σ, shape) - s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) - t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) return t end function model2(rng, μ, σ, shape) - s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) - t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) return t end @@ -33,12 +33,12 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.simulate_internal( + before = @code_hlo optimize = false ProbProg.simulate( rng, model, μ, σ, shape ) @test contains(repr(before), "enzyme.simulate") - after = @code_hlo optimize = :probprog ProbProg.simulate_internal( + after = @code_hlo optimize = :probprog ProbProg.simulate( rng, model, μ, σ, shape ) @test !contains(repr(after), "enzyme.simulate") @@ -54,7 +54,7 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.simulate(rng, model, μ, σ, shape) + trace, weight = ProbProg.simulate_(rng, model, μ, σ, shape) @test size(trace.retval[1]) == shape @test haskey(trace.choices, :s) @@ -68,7 +68,8 @@ end op(_, x, y) = x * y' logpdf(res, _, _) = sum(res) function fake_model(rng, x, y) - return ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + _, res = ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + return res end x = reshape(collect(Float64, 1:12), (4, 3)) @@ -78,7 +79,7 @@ end seed = Reactant.to_rarray(UInt64[1, 4]) rng = ReactantRNG(seed) - trace, weight = ProbProg.simulate(rng, fake_model, x_ra, y_ra) + trace, weight = ProbProg.simulate_(rng, fake_model, x_ra, y_ra) @test Array(trace.retval[1]) == op(rng, x, y) @test haskey(trace.choices, :matmul) @@ -93,7 +94,7 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - trace, weight = ProbProg.simulate(rng, model2, μ, σ, shape) + trace, weight = ProbProg.simulate_(rng, model2, μ, σ, shape) @test size(trace.retval[1]) == shape From eda74e39166110d6454cefdbab9c933cdca47ec3 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 16:54:25 -0500 Subject: [PATCH 83/87] format --- test/probprog/sample.jl | 16 ++++++++++++---- test/probprog/simulate.jl | 8 ++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl index a6f41d8f22..b7889c46dd 100644 --- a/test/probprog/sample.jl +++ b/test/probprog/sample.jl @@ -50,10 +50,14 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.untraced_call(rng, compose, μ, σ, shape) + before = @code_hlo optimize = false ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog ProbProg.untraced_call(rng, compose, μ, σ, shape) + after = @code_hlo optimize = :probprog ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) @test !contains(repr(after), "enzyme.sample") end @@ -66,11 +70,15 @@ end rng1 = ReactantRNG(copy(seed)) - _, X = @jit optimize = :probprog ProbProg.untraced_call(rng1, one_sample, μ, σ, shape) + _, X = @jit optimize = :probprog ProbProg.untraced_call( + rng1, one_sample, μ, σ, shape + ) @test !all(rng1.seed .== seed) rng2 = ReactantRNG(copy(seed)) - _, Y = @jit optimize = :probprog ProbProg.untraced_call(rng2, two_samples, μ, σ, shape) + _, Y = @jit optimize = :probprog ProbProg.untraced_call( + rng2, two_samples, μ, σ, shape + ) @test !all(rng2.seed .== seed) @test !all(rng2.seed .== rng1.seed) diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl index 551e669af1..423a818ebf 100644 --- a/test/probprog/simulate.jl +++ b/test/probprog/simulate.jl @@ -33,14 +33,10 @@ end μ = Reactant.ConcreteRNumber(0.0) σ = Reactant.ConcreteRNumber(1.0) - before = @code_hlo optimize = false ProbProg.simulate( - rng, model, μ, σ, shape - ) + before = @code_hlo optimize = false ProbProg.simulate(rng, model, μ, σ, shape) @test contains(repr(before), "enzyme.simulate") - after = @code_hlo optimize = :probprog ProbProg.simulate( - rng, model, μ, σ, shape - ) + after = @code_hlo optimize = :probprog ProbProg.simulate(rng, model, μ, σ, shape) @test !contains(repr(after), "enzyme.simulate") @test !contains(repr(after), "enzyme.addSampleToTrace") @test !contains(repr(after), "enzyme.addWeightToTrace") From db314a75f1a7b59461cab5cb9a3dd156ac6b2b21 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 16:54:35 -0500 Subject: [PATCH 84/87] remove invokelatest --- src/probprog/Modeling.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index f07ee78659..1732bab04e 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -12,8 +12,7 @@ function process_mlir_function(f::Function, args::Tuple, op_name::String) (all_args[1], (res isa Tuple ? res : (res,))...) end - mlir_fn_res = invokelatest( - TracedUtils.make_mlir_fn, + mlir_fn_res = TracedUtils.make_mlir_fn( wrapper_fn, args, (), From df764c166a7d43650f54a073db24b2424a1b240c Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 17:55:25 -0500 Subject: [PATCH 85/87] clean up --- src/probprog/ProbProg.jl | 2 +- src/probprog/Types.jl | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index 151efc6b8c..64e0519d58 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -22,7 +22,7 @@ include("Display.jl") export ProbProgTrace, Constraint, Selection, Address # Utility functions. -export get_choices, select, choicemap +export get_choices, select # Core MLIR ops. export sample, untraced_call, simulate, generate diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl index 51df28e176..bd5e50fc84 100644 --- a/src/probprog/Types.jl +++ b/src/probprog/Types.jl @@ -70,7 +70,6 @@ Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) extract_addresses(constraint::Constraint) = Set(keys(constraint)) const Selection = Set{Symbol} -const CompiledFnCache = Dict{Tuple{Type,Set{Address}},Any} const _probprog_ref_lock = ReentrantLock() const _probprog_refs = IdDict() @@ -87,8 +86,3 @@ end get_choices(trace::ProbProgTrace) = trace.choices select(syms::Symbol...) = Set(syms) - -function with_compiled_cache(f) - cache = CompiledFnCache() - return f(cache) -end From c0cc686561b56027b19c1e33e753fa32167aaae4 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 18:12:38 -0500 Subject: [PATCH 86/87] ci --- src/probprog/Modeling.jl | 2 +- src/probprog/ProbProg.jl | 11 +---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index 1732bab04e..c5826c252d 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -1,5 +1,5 @@ using ..Reactant: - MLIR, TracedUtils, AbstractRNG, AbstractConcreteArray, TracedRArray, ConcreteRNumber + MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber using ..Compiler: @jit, @compile function process_mlir_function(f::Function, args::Tuple, op_name::String) diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl index 64e0519d58..4eb7024ffe 100644 --- a/src/probprog/ProbProg.jl +++ b/src/probprog/ProbProg.jl @@ -1,17 +1,8 @@ module ProbProg using ..Reactant: - MLIR, - TracedUtils, - AbstractConcreteArray, - AbstractConcreteNumber, - AbstractRNG, - TracedRArray, - TracedRNumber, - ConcreteRNumber, - Ops + MLIR, TracedUtils, AbstractRNG, TracedRArray, TracedRNumber, ConcreteRNumber using ..Compiler: @jit, @compile -using Enzyme include("Types.jl") include("FFI.jl") From 1d3a7d82fb265c31fcf05f10b975720d2f92ae61 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 30 Sep 2025 18:14:33 -0500 Subject: [PATCH 87/87] format --- src/probprog/Modeling.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl index c5826c252d..4678864d0d 100644 --- a/src/probprog/Modeling.jl +++ b/src/probprog/Modeling.jl @@ -1,5 +1,4 @@ -using ..Reactant: - MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber +using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber using ..Compiler: @jit, @compile function process_mlir_function(f::Function, args::Tuple, op_name::String)