Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlx/backend/cpu/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <fmt/format.h>

#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
Expand Down Expand Up @@ -316,7 +315,9 @@ void Compiled::eval_cpu(
// Get the function
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << std::get<2>(JitCompiler::get_preamble()) << std::endl;
kernel << "using namespace mlx::core;" << std::endl;
kernel << "using namespace mlx::core::detail;" << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cpu/compiled_preamble.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
#include "mlx/backend/cpu/binary_ops.h"
// clang-format on

const char* get_kernel_preamble();
const char* get_prebuilt_preamble();
49 changes: 41 additions & 8 deletions mlx/backend/cpu/jit_compiler.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright © 2024 Apple Inc.

#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/compiled_preamble.h"

#include <algorithm>
#include <sstream>
Expand Down Expand Up @@ -86,30 +88,61 @@ const VisualStudioInfo& GetVisualStudioInfo() {

#endif // _MSC_VER

const std::tuple<bool, std::string, std::string>& JitCompiler::get_preamble() {
static auto preamble = []() -> std::tuple<bool, std::string, std::string> {
// Check whether the headers are shipped with the binary, if so use the
// preamble from the headers, otherwise use the prebuilt one embeded in
// binary, which may not work with all compilers.
auto root_dir = current_binary_dir();
#if !defined(_WIN32)
root_dir = root_dir.parent_path();
#endif
auto include_dir = root_dir / "include";
if (std::filesystem::exists(include_dir / "mlx")) {
return std::make_tuple(
true,
include_dir.string(),
"#include \"mlx/backend/cpu/compiled_preamble.h\"\n");
} else {
return std::make_tuple(false, "", get_prebuilt_preamble());
}
}();
return preamble;
}

std::string JitCompiler::build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name) {
auto& [use_include, include_dir, preamble] = get_preamble();
#ifdef _MSC_VER
std::string extra_flags;
if (use_include) {
extra_flags += fmt::format("/I \"{}\"", include_dir);
}
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
extra_flags += fmt::format(" /libpath:\"{}\"", lib);
}
return fmt::format(
"\""
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
"/link /out:\"{3}\" {4} 2>&1"
"cd /D \"{}\" && "
"\"{}\" /LD /EHsc /MD /Ox /nologo /std:c++17 {} \"{}\" "
"/link /out:\"{}\" 2>&1"
"\"",
dir.string(),
info.cl_exe,
extra_flags,
source_file_name,
shared_lib_name,
libpaths);
shared_lib_name);
#else
std::string extra_flags;
if (use_include) {
extra_flags = fmt::format("-I \"{}\"", include_dir);
}
return fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared \"{0}\" -o \"{1}\" 2>&1",
"g++ -std=c++17 -O3 -Wall -fPIC -shared {} \"{}\" -o \"{}\" 2>&1",
extra_flags,
(dir / source_file_name).string(),
(dir / shared_lib_name).string());
#endif
Expand Down
3 changes: 3 additions & 0 deletions mlx/backend/cpu/jit_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ namespace mlx::core {

class JitCompiler {
public:
// Return the includes that should be prepended to the source code.
static const std::tuple<bool, std::string, std::string>& get_preamble();

// Build a shell command that compiles a source code file to a shared library.
static std::string build_command(
const std::filesystem::path& dir,
Expand Down
9 changes: 1 addition & 8 deletions mlx/backend/cpu/make_compiled_preamble.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,14 @@ $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join "`n"

# Append extra content.
$CONTENT = @"
$($CONTENT)
using namespace mlx::core;
using namespace mlx::core::detail;
"@

# Convert each char to ASCII code.
# Unlike the unix script that outputs string literal directly, the output from
# MSVC is way too large to be embedded as string and compilation will fail, so
# we store it as static array instead.
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'

$OUTPUT = @"
const char* get_kernel_preamble() {
const char* get_prebuilt_preamble() {
static char preamble[] = { $CHARCODES };
return preamble;
}
Expand Down
4 changes: 1 addition & 3 deletions mlx/backend/cpu/make_compiled_preamble.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ fi
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null)

cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {
const char* get_prebuilt_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}
EOF
Loading