Skip to content
1 change: 1 addition & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:probprog,
]
end

Expand Down
66 changes: 66 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ end
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}"
const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}"

function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
pm = MLIR.IR.PassManager()
Expand Down Expand Up @@ -1923,6 +1924,71 @@ function compile_mlir!(
),
"no_enzyme",
)
elseif compile_options.optimization_passes === :probprog
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
probprog_pass,
"lower-probprog-to-stablehlo{backend=$backend}",
"outline-enzyme-regions",
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
"lower-probprog-trace-ops{backend=$backend}",
jit,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
probprog_pass,
"lower-probprog-to-stablehlo{backend=$backend}",
"outline-enzyme-regions",
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
lower_enzymexla_linalg_pass,
"lower-probprog-trace-ops{backend=$backend}",
jit,
]
end,
",",
),
"probprog",
)
elseif compile_options.optimization_passes === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
3 changes: 3 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ include("Overlay.jl")
# Serialization
include("serialization/Serialization.jl")

# ProbProg
include("probprog/ProbProg.jl")

using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
export ConcreteRArray,
ConcreteRNumber,
Expand Down
2 changes: 2 additions & 0 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ function ConcretePJRTArray(
end

Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data)
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data))
Expand Down Expand Up @@ -420,6 +421,7 @@ function ConcreteIFRTArray(
end

Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data)
Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data)
XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber})
return XLA.device(x.data)
Expand Down
87 changes: 87 additions & 0 deletions src/probprog/Display.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
VERT = '\u2502'
PLUS = '\u251C'
HORZ = '\u2500'
LAST = '\u2514'

indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])

for i in vert_bars
indent_vert[i] = VERT
indent[i] = VERT
indent_last[i] = VERT
end

indent_vert_str = join(indent_vert)
indent_str = join(indent)
indent_last_str = join(indent_last)

sorted_choices = sort(collect(trace.choices); by=x -> x[1])
n = length(sorted_choices)

if trace.retval !== nothing
n += 1
end

if trace.weight !== nothing
n += 1
end

cur = 1

if trace.retval !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
cur += 1
end

if trace.weight !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
cur += 1
end

for (key, value) in sorted_choices
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
cur += 1
end

sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
n += length(sorted_subtraces)

for (key, subtrace) in sorted_subtraces
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
_show_pretty(
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
)
cur += 1
end
end

function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
println(io, "ProbProgTrace:")
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
println(io, " (empty)")
else
_show_pretty(io, trace, 0, ())
end
end

function Base.show(io::IO, trace::ProbProgTrace)
if get(io, :compact, false)
choices_count = length(trace.choices)
has_retval = trace.retval !== nothing
print(io, "ProbProgTrace($(choices_count) choices")
if has_retval
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
end
print(io, ")")
else
show(io, MIME"text/plain"(), trace)
end
end
Loading
Loading