diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 0806368e03..559b875cf6 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -10,7 +10,6 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/cpu/compiled_preamble.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/jit_compiler.h" #include "mlx/device.h" @@ -316,7 +315,9 @@ void Compiled::eval_cpu( // Get the function auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() { std::ostringstream kernel; - kernel << get_kernel_preamble() << std::endl; + kernel << std::get<2>(JitCompiler::get_preamble()) << std::endl; + kernel << "using namespace mlx::core;" << std::endl; + kernel << "using namespace mlx::core::detail;" << std::endl; kernel << "extern \"C\" {" << std::endl; build_kernel( kernel, diff --git a/mlx/backend/cpu/compiled_preamble.h b/mlx/backend/cpu/compiled_preamble.h index 31ca1b4685..02e3db8e78 100644 --- a/mlx/backend/cpu/compiled_preamble.h +++ b/mlx/backend/cpu/compiled_preamble.h @@ -9,4 +9,4 @@ #include "mlx/backend/cpu/binary_ops.h" // clang-format on -const char* get_kernel_preamble(); +const char* get_prebuilt_preamble(); diff --git a/mlx/backend/cpu/jit_compiler.cpp b/mlx/backend/cpu/jit_compiler.cpp index 267e8fd867..4099ecf97d 100644 --- a/mlx/backend/cpu/jit_compiler.cpp +++ b/mlx/backend/cpu/jit_compiler.cpp @@ -1,6 +1,8 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/cpu/jit_compiler.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/compiled_preamble.h" #include #include @@ -86,30 +88,61 @@ const VisualStudioInfo& GetVisualStudioInfo() { #endif // _MSC_VER +const std::tuple& JitCompiler::get_preamble() { + static auto preamble = []() -> std::tuple { + // Check whether the headers are shipped with the binary, if so use the + // preamble from the headers, otherwise use the prebuilt one embeded in + // binary, which may not work with all compilers. + auto root_dir = current_binary_dir(); +#if !defined(_WIN32) + root_dir = root_dir.parent_path(); +#endif + auto include_dir = root_dir / "include"; + if (std::filesystem::exists(include_dir / "mlx")) { + return std::make_tuple( + true, + include_dir.string(), + "#include \"mlx/backend/cpu/compiled_preamble.h\"\n"); + } else { + return std::make_tuple(false, "", get_prebuilt_preamble()); + } + }(); + return preamble; +} + std::string JitCompiler::build_command( const std::filesystem::path& dir, const std::string& source_file_name, const std::string& shared_lib_name) { + auto& [use_include, include_dir, preamble] = get_preamble(); #ifdef _MSC_VER + std::string extra_flags; + if (use_include) { + extra_flags += fmt::format("/I \"{}\"", include_dir); + } const VisualStudioInfo& info = GetVisualStudioInfo(); - std::string libpaths; for (const std::string& lib : info.libpaths) { - libpaths += fmt::format(" /libpath:\"{0}\"", lib); + extra_flags += fmt::format(" /libpath:\"{}\"", lib); } return fmt::format( "\"" - "cd /D \"{0}\" && " - "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " - "/link /out:\"{3}\" {4} 2>&1" + "cd /D \"{}\" && " + "\"{}\" /LD /EHsc /MD /Ox /nologo /std:c++17 {} \"{}\" " + "/link /out:\"{}\" 2>&1" "\"", dir.string(), info.cl_exe, + extra_flags, source_file_name, - shared_lib_name, - libpaths); + shared_lib_name); #else + std::string extra_flags; + if (use_include) { + extra_flags = fmt::format("-I \"{}\"", include_dir); + } return fmt::format( - "g++ -std=c++17 -O3 -Wall -fPIC -shared \"{0}\" -o \"{1}\" 2>&1", + "g++ -std=c++17 -O3 -Wall -fPIC -shared {} \"{}\" -o \"{}\" 2>&1", + extra_flags, (dir / source_file_name).string(), (dir / shared_lib_name).string()); #endif diff --git a/mlx/backend/cpu/jit_compiler.h b/mlx/backend/cpu/jit_compiler.h index 3a9e988da7..af49a5efeb 100644 --- a/mlx/backend/cpu/jit_compiler.h +++ b/mlx/backend/cpu/jit_compiler.h @@ -7,6 +7,9 @@ namespace mlx::core { class JitCompiler { public: + // Return the includes that should be prepended to the source code. + static const std::tuple& get_preamble(); + // Build a shell command that compiles a source code file to a shared library. static std::string build_command( const std::filesystem::path& dir, diff --git a/mlx/backend/cpu/make_compiled_preamble.ps1 b/mlx/backend/cpu/make_compiled_preamble.ps1 index 0cd2d1f170..d231f3ede1 100644 --- a/mlx/backend/cpu/make_compiled_preamble.ps1 +++ b/mlx/backend/cpu/make_compiled_preamble.ps1 @@ -15,13 +15,6 @@ $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } # Concatenate to string. $CONTENT = $CONTENT -join "`n" -# Append extra content. -$CONTENT = @" -$($CONTENT) -using namespace mlx::core; -using namespace mlx::core::detail; -"@ - # Convert each char to ASCII code. # Unlike the unix script that outputs string literal directly, the output from # MSVC is way too large to be embedded as string and compilation will fail, so @@ -29,7 +22,7 @@ using namespace mlx::core::detail; $CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0' $OUTPUT = @" -const char* get_kernel_preamble() { +const char* get_prebuilt_preamble() { static char preamble[] = { $CHARCODES }; return preamble; } diff --git a/mlx/backend/cpu/make_compiled_preamble.sh b/mlx/backend/cpu/make_compiled_preamble.sh index 88b4c4615e..3ae91ac734 100644 --- a/mlx/backend/cpu/make_compiled_preamble.sh +++ b/mlx/backend/cpu/make_compiled_preamble.sh @@ -30,12 +30,10 @@ fi CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null) cat << EOF > "$OUTPUT_FILE" -const char* get_kernel_preamble() { +const char* get_prebuilt_preamble() { return R"preamble( $INCLUDES $CONTENT -using namespace mlx::core; -using namespace mlx::core::detail; )preamble"; } EOF