diff --git a/docs/src/api/serialization.md b/docs/src/api/serialization.md index bfe77a6c98..56ee09a21e 100644 --- a/docs/src/api/serialization.md +++ b/docs/src/api/serialization.md @@ -46,3 +46,18 @@ additional Julia dependencies. ```@docs Reactant.Serialization.EnzymeJAX.export_to_enzymejax ``` + +## Exporting to Standalone Reactant Script + +This export functionality generates: + +1. A `.mlir` file containing the StableHLO representation of your Julia function +2. Input `.jls` files containing the input arrays (using Julia's Serialization) +3. A Julia script that can load and execute the function using only Reactant + +The generated Julia script serves as a minimal reproducer that can be shared when reporting +issues or debugging. It only depends on Reactant (no external packages required). + +```@docs +Reactant.Serialization.ReactantExport.export_to_reactant_script +``` diff --git a/src/Compiler.jl b/src/Compiler.jl index 6b08d0b5af..cfd16082c0 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -4,6 +4,7 @@ using Reactant_jll using Libdl: dlsym using LinearAlgebra: BlasInt using Functors: Functors +import p7zip_jll: p7zip import ..Reactant: Reactant, @@ -1355,6 +1356,102 @@ end const context_gc_vector = Dict{MLIR.IR.Context,Vector{Union{TracedRArray,TracedRNumber}}}() +""" + create_pass_failure_zip(mod, f, args, pass_pipeline_key, error_msg) + +Create a zip file containing the unoptimized IR and a Julia script for reproducing the issue. +This is automatically called when a pass pipeline fails during compilation. + +Returns the path to the created zip file. +""" +function create_pass_failure_zip( + mod::MLIR.IR.Module, f, args, pass_pipeline_key::String, error_msg::String +) + try + # Create a temporary directory for the files + temp_dir = mktempdir(; prefix="reactant_failure_", cleanup=false) + + # Save the unoptimized IR + mlir_path = joinpath(temp_dir, "unoptimized_ir.mlir") + open(mlir_path, "w") do io + println(io, "// Pass pipeline that failed: ", pass_pipeline_key) + println(io, "// Error message: ", error_msg) + println(io) + show(IOContext(io, :debug => true), mod) + end + + # Try to export inputs and create a Julia script using Serialization + function_name = string(f) + script_path = nothing + try + # Check if NPZ is available for serialization + if Reactant.Serialization.serialization_supported(Val(:NPZ)) + script_path = Reactant.Serialization.export_to_reactant_script( + f, args...; output_dir=temp_dir, function_name=function_name + ) + end + catch e + @debug "Could not create Julia script for reproduction" exception = e + end + + # Create README with instructions + readme_path = joinpath(temp_dir, "README.md") + open(readme_path, "w") do io + println(io, "# Reactant Compilation Failure Report") + println(io) + println(io, "This archive contains information about a failed Reactant compilation.") + println(io) + println(io, "## Contents") + println(io) + println(io, "- `unoptimized_ir.mlir`: The MLIR module before optimization passes") + println(io, "- `README.md`: This file") + if script_path !== nothing + println(io, "- `$(function_name).jl`: Julia script for reproduction") + println(io, "- `$(function_name)*.mlir`: Exported HLO code") + println(io, "- `$(function_name)*_inputs.npz`: Input data") + end + println(io) + println(io, "## Error Information") + println(io) + println(io, "**Pass Pipeline Key**: `$(pass_pipeline_key)`") + println(io) + println(io, "**Error Message**:") + println(io, "```") + println(io, error_msg) + println(io, "```") + println(io) + println(io, "## How to Report") + println(io) + println(io, "1. Upload this zip file to a file sharing service") + println(io, "2. Open an issue at https://github.com/EnzymeAD/Reactant.jl/issues") + println(io, "3. Include the link to the uploaded zip file in your issue") + println(io, "4. Describe what you were trying to do when the error occurred") + println(io) + println(io, "## Debugging") + println(io) + println(io, "You can inspect the `unoptimized_ir.mlir` file to see the IR before it failed.") + if script_path !== nothing + println(io, "You can also try running the `$(function_name).jl` script to reproduce the issue.") + end + end + + # Create the zip file + zip_path = temp_dir * ".zip" + # Note: temp_files are passed as command arguments (not via shell expansion) + # which prevents shell injection even if paths contain special characters + temp_files = readdir(temp_dir; join=true) + run(pipeline(`$(p7zip()) a -tzip $(zip_path) $(temp_files...)`, devnull)) + + # Clean up the temp directory (but keep the zip) + rm(temp_dir; recursive=true, force=true) + + return zip_path + catch e + @error "Failed to create debug zip file" exception = e + return nothing + end +end + # helper for debug purposes: String -> Text function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true) return MLIR.IR.with_context() do ctx @@ -1425,16 +1522,42 @@ function compile_mlir(f, args; client=nothing, drop_unsupported_attributes=false mod = MLIR.IR.Module(MLIR.IR.Location()) compile_options, kwargs = __get_compile_options_and_kwargs(; kwargs...) - mlir_fn_res = compile_mlir!( - mod, - f, - args, - compile_options; - backend, - runtime=XLA.runtime(client), - client, - kwargs..., - ) + + # Wrap compile_mlir! to catch pass pipeline failures + mlir_fn_res = try + compile_mlir!( + mod, + f, + args, + compile_options; + backend, + runtime=XLA.runtime(client), + client, + kwargs..., + ) + catch e + # Check if this is a pass pipeline failure + error_msg = string(e) + if contains(error_msg, "failed to run pass manager") + # Create a debug zip file with the unoptimized IR + zip_path = create_pass_failure_zip(mod, f, args, "compilation", error_msg) + if zip_path !== nothing + error( + "Compilation failed during pass pipeline execution. " * + "A debug zip file has been created at: $(zip_path)\n" * + "Please upload this file when reporting the issue at: " * + "https://github.com/EnzymeAD/Reactant.jl/issues\n" * + "Original error: $(error_msg)" + ) + else + # If we couldn't create the zip, just rethrow the original error + rethrow() + end + else + # Not a pass pipeline failure, rethrow + rethrow() + end + end # Attach a name, and partitioning attributes to the module __add_mhlo_attributes_and_name!( diff --git a/src/serialization/ReactantExport.jl b/src/serialization/ReactantExport.jl new file mode 100644 index 0000000000..18725ddc43 --- /dev/null +++ b/src/serialization/ReactantExport.jl @@ -0,0 +1,249 @@ +module ReactantExport + +using ..Reactant: Reactant, Compiler, Serialization +using Serialization: serialize, deserialize + +""" + export_to_reactant_script( + f, + args...; + output_dir::Union{String,Nothing}=nothing, + function_name::String=string(f) + ) + +Export a Julia function to a standalone Reactant script. + +This function: +1. Compiles the function to StableHLO via Reactant's compile_mlir +2. Saves the MLIR/StableHLO code to a `.mlir` file +3. Saves all input arrays to a serialized `.jls` file using Julia's Serialization +4. Generates a Julia script that only depends on Reactant for loading and executing + +## Requirements + +No external dependencies required - uses Julia's standard library Serialization + +## Arguments + + - `f`: The Julia function to export + - `args...`: The arguments to the function (used to infer types and shapes) + +## Keyword Arguments + + - `output_dir::Union{String,Nothing}`: Directory where output files will be saved. If + `nothing`, uses a temporary directory and prints the path. + - `function_name::String`: Base name for generated files + +## Returns + +The path to the generated Julia script as a `String`. + +## Files Generated + + - `{function_name}_{id}.mlir`: The StableHLO/MLIR module (where `{id}` is a numeric counter) + - `{function_name}_{id}_inputs.jls`: Serialized file containing all input arrays + - `{function_name}.jl`: Julia script that loads and executes the exported function + +## Example + +```julia +using Reactant + +# Define a simple function +function my_function(x::AbstractArray, y::AbstractArray) + return x .+ y +end + +# Create some example inputs +x = Reactant.to_rarray(rand(Float32, 2, 3)) +y = Reactant.to_rarray(rand(Float32, 2, 3)) + +# Export to Reactant script +julia_file_path = Reactant.Serialization.export_to_reactant_script(my_function, x, y) +``` + +Then in Julia: +```julia +# Run the generated Julia script +include(julia_file_path) +``` +""" +function export_to_reactant_script( + f, + args...; + output_dir::Union{String,Nothing}=nothing, + function_name::String=string(f), +) + if output_dir === nothing + output_dir = mktempdir(; cleanup=false) + @info "Output directory is $(output_dir)" + else + mkpath(output_dir) + end + + # Generate the StableHLO/MLIR code using compile_mlir + # This returns compilation result with traced argument information + argprefix = gensym("exportarg") + mod, mlir_fn_res = Compiler.compile_mlir( + f, + args; + argprefix, + drop_unsupported_attributes=true, + shardy_passes=:none, + ) + hlo_code = string(mod) + + # Save MLIR code + fnid = 0 + while isfile(joinpath(output_dir, "$(function_name)_$(fnid).mlir")) + fnid += 1 + end + mlir_path = joinpath(output_dir, "$(function_name)_$(fnid).mlir") + write(mlir_path, hlo_code) + + # Process and save inputs based on the linearized arguments + input_data = Dict{String,Union{AbstractArray,Number}}() + input_info = [] + input_idx = 1 + for (concrete_arg, traced_arg) in mlir_fn_res.seen_args + path = Reactant.TracedUtils.get_idx(traced_arg, argprefix)[2:end] + + # Store input data for the single NPZ file + arr_key = "arr_$input_idx" + input_data[arr_key] = _to_array(concrete_arg) + + push!( + input_info, + ( + shape=size(concrete_arg), + dtype=Reactant.unwrapped_eltype(concrete_arg), + path="arg." * join(string.(path), "."), + key=arr_key, + ), + ) + input_idx += 1 + end + + # Save all inputs to a serialized file + input_path = joinpath(output_dir, "$(function_name)_$(fnid)_inputs.jls") + save_inputs_jls(input_path, input_data) + + # Generate Julia script + julia_path = joinpath(output_dir, "$(function_name).jl") + _generate_julia_script(julia_path, function_name, mlir_path, input_path, input_info) + return julia_path +end + +_to_array(x::Reactant.ConcreteRArray) = Array(x) +_to_array(x::Reactant.ConcreteRNumber{T}) where {T} = T(x) + +function save_inputs_jls(output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}}) + open(output_path, "w") do io + serialize(io, inputs) + end + return output_path +end + +function _generate_julia_script( + julia_path::String, + function_name::String, + mlir_path::String, + input_path::String, + input_info::Vector, +) + # Get relative paths for the Julia script + output_dir = dirname(julia_path) + mlir_rel = relpath(mlir_path, output_dir) + input_rel = relpath(input_path, output_dir) + + # Generate argument list and documentation + arg_names = ["arg$i" for i in 1:length(input_info)] + arg_list = join(arg_names, ", ") + + # Generate docstring for arguments + arg_docs = join( + [ + if length(info.shape) == 0 + " $(arg_names[i]): Scalar of type $(info.dtype). Path: $(info.path)" + else + " $(arg_names[i]): Array of shape $(info.shape) and type $(info.dtype). Path: $(info.path)" + end + for (i, info) in enumerate(input_info) + ], + "\n", + ) + + load_inputs = ["inputs_data[\"$(info.key)\"]" for info in input_info] + + # Build a cleaner representation of the load_inputs code - no transpose needed for Julia Serialization + load_input_lines = String[] + for load in load_inputs + push!(load_input_lines, load) + end + load_inputs_code = join(load_input_lines, ",\n ") + + # Build the complete Julia script + script = """ + \"\"\" + Auto-generated Julia script for calling exported Reactant function. + + This script was generated by Reactant.Serialization.export_to_reactant_script(). + \"\"\" + + using Reactant + using Serialization + + # Get the directory of this script + const SCRIPT_DIR = @__DIR__ + + # Load the MLIR/StableHLO code + const HLO_CODE = read(joinpath(SCRIPT_DIR, "$(mlir_rel)"), String) + + function load_inputs() + \"\"\"Load the example inputs that were exported from Julia.\"\"\" + inputs_data = open(joinpath(SCRIPT_DIR, "$(input_rel)"), "r") do io + deserialize(io) + end + inputs = [ + $(load_inputs_code) + ] + return tuple(inputs...) + end + + function run_$(function_name)($(arg_list)) + \"\"\" + Execute the exported Julia function using Reactant. + + Args: + $arg_docs + + Returns: + The result of calling the exported function. + \"\"\" + # Load HLO module from string + # TODO: This will use Reactant's HLO execution API once available + # For now, we document that this is a placeholder that will be implemented + error("Direct HLO execution from loaded IR is not yet implemented in Reactant. " * + "This script serves as a template for future functionality.") + end + + # Main execution when script is run directly + if abspath(PROGRAM_FILE) == @__FILE__ + # Load the example inputs + ($(arg_list),) = load_inputs() + + # Convert to RArrays + $(join(["$arg = Reactant.to_rarray($arg)" for arg in arg_names], "\n ")) + + # Run the function + println("Running $(function_name)...") + result = run_$(function_name)($(arg_list)) + println("Result: ", result) + end + """ + + write(julia_path, strip(script) * "\n") + return nothing +end + +end # module diff --git a/src/serialization/Serialization.jl b/src/serialization/Serialization.jl index 5fdc28c297..f22238339b 100644 --- a/src/serialization/Serialization.jl +++ b/src/serialization/Serialization.jl @@ -30,6 +30,7 @@ const NUMPY_SIMPLE_TYPES = Dict( include("TFSavedModel.jl") include("EnzymeJAX.jl") +include("ReactantExport.jl") """ export_as_tf_saved_model( @@ -128,5 +129,6 @@ function export_as_tf_saved_model( end const export_to_enzymejax = EnzymeJAX.export_to_enzymejax +const export_to_reactant_script = ReactantExport.export_to_reactant_script end diff --git a/test/integration/reactant_export.jl b/test/integration/reactant_export.jl new file mode 100644 index 0000000000..30c00dab2b --- /dev/null +++ b/test/integration/reactant_export.jl @@ -0,0 +1,135 @@ +using Reactant, Test + +@testset "ReactantExport" begin + @testset "Simple function export" begin + f_simple(x) = sin.(x) .+ cos.(x) + + x_data = Reactant.TestUtils.construct_test_array(Float32, 4, 5) + x = Reactant.to_rarray(x_data) + + # Export the function + julia_file_path = Reactant.Serialization.export_to_reactant_script( + f_simple, x; output_dir=mktempdir(; cleanup=true) + ) + + @test isfile(julia_file_path) + @test endswith(julia_file_path, ".jl") + + # Check that generated files exist + output_dir = dirname(julia_file_path) + mlir_files = filter(f -> endswith(f, ".mlir"), readdir(output_dir)) + jls_files = filter(f -> endswith(f, ".jls"), readdir(output_dir)) + + @test length(mlir_files) > 0 + @test length(jls_files) > 0 + + # Verify Julia script contains key components + julia_content = read(julia_file_path, String) + @test contains(julia_content, "using Reactant") + @test contains(julia_content, "using Serialization") + @test contains(julia_content, "f_simple") + @test contains(julia_content, "load_inputs") + @test contains(julia_content, "run_f_simple") + + # We can't execute the full script since HLO execution isn't implemented yet, + # but we can verify the structure is correct + end + + @testset "Matrix multiplication export" begin + f_matmul(x, y) = x * y + + x_data = Reactant.TestUtils.construct_test_array(Float32, 3, 4) + y_data = Reactant.TestUtils.construct_test_array(Float32, 4, 5) + x = Reactant.to_rarray(x_data) + y = Reactant.to_rarray(y_data) + + # Export the function + julia_file_path = Reactant.Serialization.export_to_reactant_script( + f_matmul, x, y; output_dir=mktempdir(; cleanup=true), function_name="matmul" + ) + + @test isfile(julia_file_path) + + output_dir = dirname(julia_file_path) + jls_files = filter(f -> endswith(f, ".jls"), readdir(output_dir)) + @test length(jls_files) > 0 + + # Verify the JLS file contains both inputs + using Serialization + inputs_data = open(first(filter(f -> endswith(f, ".jls"), readdir(output_dir; join=true))), "r") do io + deserialize(io) + end + @test haskey(inputs_data, "arr_1") || haskey(inputs_data, "arr_2") + + # Verify Julia script structure + julia_content = read(julia_file_path, String) + @test contains(julia_content, "matmul") + @test contains(julia_content, "arg1") + @test contains(julia_content, "arg2") + end + + @testset "Complex function with multiple arguments" begin + f_complex(x, y, z) = sum(x .* y .+ sin.(z); dims=2) + + x_data = Reactant.TestUtils.construct_test_array(Float32, 5, 4) + y_data = Reactant.TestUtils.construct_test_array(Float32, 5, 4) + z_data = Reactant.TestUtils.construct_test_array(Float32, 5, 4) + x = Reactant.to_rarray(x_data) + y = Reactant.to_rarray(y_data) + z = Reactant.to_rarray(z_data) + + # Export the function + julia_file_path = Reactant.Serialization.export_to_reactant_script( + f_complex, + x, + y, + z; + output_dir=mktempdir(; cleanup=true), + function_name="complex_fn", + ) + + @test isfile(julia_file_path) + + output_dir = dirname(julia_file_path) + mlir_files = filter(f -> endswith(f, ".mlir"), readdir(output_dir)) + jls_files = filter(f -> endswith(f, ".jls"), readdir(output_dir)) + + @test length(mlir_files) > 0 + @test length(jls_files) > 0 + + julia_content = read(julia_file_path, String) + @test contains(julia_content, "complex_fn") + @test contains(julia_content, "arg1") + @test contains(julia_content, "arg2") + @test contains(julia_content, "arg3") + end + + @testset "Test Serialization input/output consistency" begin + # Test that inputs are saved and can be loaded correctly + f_test(x) = x .+ 1.0f0 + + x_data = Float32[1.0 2.0; 3.0 4.0] + x = Reactant.to_rarray(x_data) + + julia_file_path = Reactant.Serialization.export_to_reactant_script( + f_test, x; output_dir=mktempdir(; cleanup=true), function_name="test_jls" + ) + + output_dir = dirname(julia_file_path) + jls_files = filter(f -> endswith(f, ".jls"), readdir(output_dir; join=true)) + @test !isempty(jls_files) + jls_path = first(jls_files) + + # Load the JLS file and verify the data + using Serialization + inputs_data = open(jls_path, "r") do io + deserialize(io) + end + @test haskey(inputs_data, "arr_1") + + # The data should be in Julia's native format (no transposition needed) + loaded_data = inputs_data["arr_1"] + @test size(loaded_data) == size(x_data) + @test isapprox(loaded_data, x_data; rtol=1e-5) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..a2694095c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,7 @@ end @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") + @safetestset "ReactantExport" include("integration/reactant_export.jl") if ENZYMEJAX_INSTALLED[] && !Sys.isapple() @safetestset "EnzymeJAX Export" include("integration/enzymejax.jl") end