From e175e6390192ec3e24692fcbd2f1e0203b5db59e Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Sun, 19 Apr 2026 13:56:49 +0300 Subject: [PATCH 1/6] add metal 4 backend --- build.sh | 463 +++++--- config/default.ini | 5 + src/cpu_inference.h | 328 ++++++ src/metal_bindings.mm | 566 +++++++++ src/metal_kernels.mm | 1561 +++++++++++++++++++++++++ src/metal_platform.h | 208 ++++ src/metal_platform.mm | 1272 ++++++++++++++++++++ src/metal_pufferlib.mm | 1265 ++++++++++++++++++++ src/metal_shader_src.h | 2523 ++++++++++++++++++++++++++++++++++++++++ src/puf_types.h | 596 ++++++++++ src/tensor.h | 8 + src/vecenv.h | 113 +- 12 files changed, 8739 insertions(+), 169 deletions(-) create mode 100644 src/cpu_inference.h create mode 100644 src/metal_bindings.mm create mode 100644 src/metal_kernels.mm create mode 100644 src/metal_platform.h create mode 100644 src/metal_platform.mm create mode 100644 src/metal_pufferlib.mm create mode 100644 src/metal_shader_src.h create mode 100644 src/puf_types.h diff --git a/build.sh b/build.sh index 8980552f94..68740090ed 100755 --- a/build.sh +++ b/build.sh @@ -1,33 +1,26 @@ #!/bin/bash -set -e - -# Usage: -# ./build.sh breakout # Build _C.so with breakout statically linked -# ./build.sh breakout --float # float32 precision (required for --slowly) -# ./build.sh breakout --cpu # CPU fallback, torch only -# ./build.sh breakout --debug # Debug build -# ./build.sh breakout --local # Standalone executable (debug, sanitizers) -# ./build.sh breakout --fast # Standalone executable (optimized) -# ./build.sh breakout --web # Emscripten web build -# ./build.sh breakout --profile # Kernel profiling binary -# ./build.sh all # Build all envs with default and --float - -if [ -z "$1" ]; then +set -euo pipefail + +if [ -z "${1:-}" ]; then echo "Usage: ./build.sh ENV_NAME [--float] [--debug] [--local|--fast|--web|--profile|--cpu|--all]" exit 1 fi + ENV=$1 shift +MODE="" +PRECISION="" +DEBUG="" for arg in "$@"; do case $arg in --float) PRECISION="-DPRECISION_FLOAT" ;; --debug) DEBUG=1 ;; --local) MODE=local ;; - --fast) MODE=fast ;; - --web) MODE=web ;; + --fast) MODE=fast ;; + --web) MODE=web ;; --profile) MODE=profile ;; - --cpu) MODE=cpu; PRECISION="-DPRECISION_FLOAT" ;; + --cpu) MODE=cpu; PRECISION="-DPRECISION_FLOAT" ;; *) echo "Error: unknown argument '$arg'" && exit 1 ;; esac done @@ -35,36 +28,22 @@ done if [ "$ENV" = "all" ]; then FAILED="" for env_dir in ocean/*/; do - env=$(basename "$env_dir") - if bash "$0" "$env" && bash "$0" "$env" --float; then - echo "OK: $env" + env_name=$(basename "$env_dir") + if bash "$0" "$env_name" && bash "$0" "$env_name" --float; then + echo "OK: $env_name" else - echo "FAIL: $env" - FAILED="$FAILED\n $env" + echo "FAIL: $env_name" + FAILED="$FAILED\n $env_name" fi done - if [ -n "$FAILED" ]; then echo -e "\nFailed builds:$FAILED" + exit 1 fi exit 0 fi -# Linux/mac PLATFORM="$(uname -s)" -if [ "$PLATFORM" = "Linux" ]; then - RAYLIB_NAME='raylib-5.5_linux_amd64' - OMP_LIB=-lomp5 - SANITIZE_FLAGS=(-fsanitize=address,undefined,bounds,pointer-overflow,leak -fno-omit-frame-pointer) - STANDALONE_LDFLAGS=(-lGL) - SHARED_LDFLAGS=(-Bsymbolic-functions) -else - RAYLIB_NAME='raylib-5.5_macos' - OMP_LIB=-lomp - SANITIZE_FLAGS=() - STANDALONE_LDFLAGS=(-framework Cocoa -framework IOKit -framework CoreVideo -framework OpenGL) - SHARED_LDFLAGS=(-framework Cocoa -framework OpenGL -framework IOKit -undefined dynamic_lookup) -fi CLANG_WARN=( -Wall @@ -76,16 +55,84 @@ CLANG_WARN=( -Wno-error=array-parameter ) +if [ -n "$DEBUG" ] || [ "$MODE" = "local" ]; then + CLANG_OPT=(-g -O0 "${CLANG_WARN[@]}") + NVCC_OPT=(-O0 -g) + LINK_OPT=(-g) + if [ "$PLATFORM" = "Linux" ]; then + CLANG_OPT+=(-fsanitize=address,undefined,bounds,pointer-overflow,leak) + CLANG_OPT+=(-fno-omit-frame-pointer) + fi +else + CLANG_OPT=(-O2 -DNDEBUG "${CLANG_WARN[@]}") + NVCC_OPT=(-O2 --threads 0) + LINK_OPT=(-O2) +fi + download() { - local name=$1 url=$2 + local name=$1 + local url=$2 [ -d "$name" ] && return echo "Downloading $name..." case "$url" in *.zip) curl -sL "$url" -o "$name.zip" && unzip -q "$name.zip" && rm "$name.zip" ;; - *) curl -sL "$url" -o "$name.tar.gz" && tar xf "$name.tar.gz" && rm "$name.tar.gz" ;; + *) curl -sL "$url" -o "$name.tar.gz" && tar xf "$name.tar.gz" && rm "$name.tar.gz" ;; esac } +find_omp_include() { + if ! command -v brew >/dev/null 2>&1; then + return 0 + fi + local omp_prefix + omp_prefix=$(brew --prefix libomp 2>/dev/null || true) + if [ -n "$omp_prefix" ] && [ -d "$omp_prefix/include" ]; then + echo "$omp_prefix/include" + fi +} + +DARWIN_OMP_SOURCE="system" +DARWIN_OMP_LINK=(-lomp) +resolve_darwin_omp_link() { + local torch_omp="" + torch_omp=$(python -c "import torch; import os; print(os.path.join(torch.__path__[0], 'lib', 'libomp.dylib'))" 2>/dev/null || true) + if [ -n "$torch_omp" ] && [ -f "$torch_omp" ]; then + local omp_dir + omp_dir=$(dirname "$torch_omp") + DARWIN_OMP_SOURCE="torch" + DARWIN_OMP_LINK=(-L"$omp_dir" -Wl,-rpath,"$omp_dir" -lomp) + return + fi + + if command -v brew >/dev/null 2>&1; then + local omp_prefix + omp_prefix=$(brew --prefix libomp 2>/dev/null || true) + if [ -n "$omp_prefix" ] && [ -d "$omp_prefix/lib" ]; then + DARWIN_OMP_SOURCE="homebrew" + DARWIN_OMP_LINK=(-L"$omp_prefix/lib" -Wl,-rpath,"$omp_prefix/lib" -lomp) + return + fi + fi + + DARWIN_OMP_SOURCE="system" + DARWIN_OMP_LINK=(-lomp) +} + +if [ "$PLATFORM" = "Linux" ]; then + RAYLIB_NAME='raylib-5.5_linux_amd64' + OMP_LIB=-lomp5 + STANDALONE_LDFLAGS=(-lGL) + SHARED_LDFLAGS=(-Bsymbolic-functions) + DEFAULT_CC=${CC:-clang} + DEFAULT_CXX=${CXX:-g++} +else + RAYLIB_NAME='raylib-5.5_macos' + STANDALONE_LDFLAGS=(-framework Cocoa -framework IOKit -framework CoreVideo -framework OpenGL) + SHARED_LDFLAGS=(-undefined dynamic_lookup) + DEFAULT_CC=${CC:-clang} + DEFAULT_CXX=${CXX:-clang++} +fi + RAYLIB_URL="https://github.com/raysan5/raylib/releases/download/5.5" if [ "$MODE" = "web" ]; then RAYLIB_NAME='raylib-5.5_webassembly' @@ -108,9 +155,12 @@ elif [ "$ENV" = "trailer" ]; then OUTPUT_NAME="trailer/trailer" elif [ "$ENV" = "impulse_wars" ]; then SRC_DIR="ocean/$ENV" - if [ "$MODE" = "web" ]; then BOX2D_NAME='box2d-web' - elif [ "$PLATFORM" = "Linux" ]; then BOX2D_NAME='box2d-linux-amd64' - else BOX2D_NAME='box2d-macos-arm64' + if [ "$MODE" = "web" ]; then + BOX2D_NAME='box2d-web' + elif [ "$PLATFORM" = "Linux" ]; then + BOX2D_NAME='box2d-linux-amd64' + else + BOX2D_NAME='box2d-macos-arm64' fi BOX2D_URL="https://github.com/capnspacehook/box2d/releases/latest/download" download "$BOX2D_NAME" "$BOX2D_URL/$BOX2D_NAME.tar.gz" @@ -124,30 +174,32 @@ fi OUTPUT_NAME=${OUTPUT_NAME:-$ENV} -# Standalone environment build -if [ -n "$DEBUG" ] || [ "$MODE" = "local" ]; then - CLANG_OPT=(-g -O0 "${CLANG_WARN[@]}" "${SANITIZE_FLAGS[@]}") - NVCC_OPT="-O0 -g" - LINK_OPT="-g" -else - CLANG_OPT=(-O2 -DNDEBUG "${CLANG_WARN[@]}") - NVCC_OPT="-O2 --threads 0" - LINK_OPT="-O2" -fi if [ "$MODE" = "local" ] || [ "$MODE" = "fast" ]; then FLAGS=( "${INCLUDES[@]}" - "$SRC_DIR/$ENV.c" $EXTRA_SRC -o "$OUTPUT_NAME" + "$SRC_DIR/$ENV.c" $EXTRA_SRC + -o "$OUTPUT_NAME" "${LINK_ARCHIVES[@]}" "${STANDALONE_LDFLAGS[@]}" - -lm -lpthread -fopenmp -DPLATFORM_DESKTOP + -lm ) + if [ "$PLATFORM" = "Darwin" ]; then + OMP_INC=$(find_omp_include || true) + [ -n "${OMP_INC:-}" ] && FLAGS+=(-I"$OMP_INC") + FLAGS+=(-Xclang -fopenmp) + resolve_darwin_omp_link + FLAGS+=("${DARWIN_OMP_LINK[@]}") + else + FLAGS+=(-fopenmp -lpthread) + fi echo "Compiling $ENV..." - ${CC:-clang} "${CLANG_OPT[@]}" "${FLAGS[@]}" + "$DEFAULT_CC" "${CLANG_OPT[@]}" "${FLAGS[@]}" echo "Built: ./$OUTPUT_NAME" exit 0 -elif [ "$MODE" = "web" ]; then +fi + +if [ "$MODE" = "web" ]; then mkdir -p "build/web/$ENV" echo "Compiling $ENV for web..." emcc \ @@ -168,8 +220,147 @@ elif [ "$MODE" = "web" ]; then exit 0 fi -# Find cuDNN path -CUDA_HOME=${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(which nvcc)")")}} +PYTHON_INCLUDE=$(python -c "import sysconfig; print(sysconfig.get_path('include'))") +PYBIND_INCLUDE=$(python -c "import pybind11; print(pybind11.get_include())") +NUMPY_INCLUDE=$(python -c "import numpy; print(numpy.get_include())") +EXT_SUFFIX=$(python -c "import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX'))") +OUTPUT="pufferlib/_C${EXT_SUFFIX}" + +BINDING_SRC="$SRC_DIR/binding.c" +mkdir -p build +STATIC_OBJ="build/libstatic_${ENV}.o" +STATIC_LIB="build/libstatic_${ENV}.a" + +if [ ! -f "$BINDING_SRC" ]; then + echo "Error: $BINDING_SRC not found" + exit 1 +fi + +STATIC_CFLAGS=( + "${CLANG_OPT[@]}" + -I. -Isrc -I"$SRC_DIR" -Ivendor + -I./"$RAYLIB_NAME"/include + -DPLATFORM_DESKTOP + -fvisibility=hidden + -fPIC +) + +if [ "$PLATFORM" = "Darwin" ]; then + OMP_INC=$(find_omp_include || true) + [ -n "${OMP_INC:-}" ] && STATIC_CFLAGS+=(-I"$OMP_INC") + STATIC_CFLAGS+=(-Xclang -fopenmp) +else + STATIC_CFLAGS+=(-fno-semantic-interposition) + STATIC_CFLAGS+=(-fopenmp) +fi + +echo "Compiling static library for $ENV..." +"$DEFAULT_CC" -c "${STATIC_CFLAGS[@]}" "$BINDING_SRC" -o "$STATIC_OBJ" +ar rcs "$STATIC_LIB" "$STATIC_OBJ" + +if [ "$MODE" = "cpu" ]; then + CPU_CFLAGS=( + -c -fPIC + -D_GLIBCXX_USE_CXX11_ABI=1 + -DPLATFORM_DESKTOP + -DENV_NAME="$ENV" + -std=c++17 + -I. -Isrc + -I"$PYTHON_INCLUDE" -I"$PYBIND_INCLUDE" + ${PRECISION:+$PRECISION} + "${LINK_OPT[@]}" + ) + if [ "$PLATFORM" = "Darwin" ]; then + [ -n "${OMP_INC:-}" ] && CPU_CFLAGS+=(-I"$OMP_INC") + CPU_CFLAGS+=(-Xclang -fopenmp) + resolve_darwin_omp_link + else + CPU_CFLAGS+=(-fopenmp) + fi + + echo "Compiling CPU training backend..." + "$DEFAULT_CXX" "${CPU_CFLAGS[@]}" src/bindings_cpu.cpp -o build/bindings_cpu.o + + if [ "$PLATFORM" = "Darwin" ]; then + LINK_CMD=( + "$DEFAULT_CXX" -shared -fPIC + build/bindings_cpu.o "$STATIC_LIB" "$RAYLIB_A" + -lm -lpthread + "${DARWIN_OMP_LINK[@]}" + -framework Cocoa -framework OpenGL -framework IOKit -framework CoreVideo + "${LINK_OPT[@]}" + "${SHARED_LDFLAGS[@]}" + -o "$OUTPUT" + ) + else + LINK_CMD=( + "$DEFAULT_CXX" -shared -fPIC -fopenmp + build/bindings_cpu.o "$STATIC_LIB" "$RAYLIB_A" + -lm -lpthread "$OMP_LIB" + "${LINK_OPT[@]}" + "${SHARED_LDFLAGS[@]}" + -o "$OUTPUT" + ) + fi + "${LINK_CMD[@]}" + echo "Built: $OUTPUT" + exit 0 +fi + +if [ "$PLATFORM" = "Darwin" ]; then + if [ "$MODE" = "profile" ]; then + echo "Error: --profile is only supported on the CUDA path" + exit 1 + fi + + resolve_darwin_omp_link + METAL_CFLAGS=( + -c -fPIC -std=c++17 -ObjC++ -fobjc-arc + -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION + -DPLATFORM_DESKTOP + -DENV_NAME="$ENV" + -DWITH_METAL + -I. -Isrc + -I"$PYTHON_INCLUDE" -I"$PYBIND_INCLUDE" -I"$NUMPY_INCLUDE" + -I./"$RAYLIB_NAME"/include + "${CLANG_OPT[@]}" + ) + [ -n "${PRECISION:-}" ] && METAL_CFLAGS+=("$PRECISION") + [ -n "${OMP_INC:-}" ] && METAL_CFLAGS+=(-I"$OMP_INC") + METAL_CFLAGS+=(-Xclang -fopenmp) + + echo "Compiling Metal training backend..." + "$DEFAULT_CXX" "${METAL_CFLAGS[@]}" src/metal_bindings.mm -o build/metal_bindings.o + "$DEFAULT_CXX" "${METAL_CFLAGS[@]}" src/metal_platform.mm -o build/metal_platform.o + + LINK_CMD=( + "$DEFAULT_CXX" -shared -fPIC + build/metal_bindings.o build/metal_platform.o + "$STATIC_LIB" "$RAYLIB_A" + -framework Metal -framework Accelerate -framework Foundation + -framework Cocoa -framework OpenGL -framework IOKit + -framework CoreGraphics -framework CoreFoundation + -framework CoreVideo -framework CoreAudio + -framework AudioToolbox -framework UniformTypeIdentifiers + "${DARWIN_OMP_LINK[@]}" + "${LINK_OPT[@]}" + "${SHARED_LDFLAGS[@]}" + -o "$OUTPUT" + ) + "${LINK_CMD[@]}" + + for install_name in \ + /opt/llvm-openmp/lib/libomp.dylib \ + /opt/homebrew/opt/libomp/lib/libomp.dylib \ + /usr/local/opt/libomp/lib/libomp.dylib; do + install_name_tool -change "$install_name" "@rpath/libomp.dylib" "$OUTPUT" 2>/dev/null || true + done + + echo "Built: $OUTPUT" + exit 0 +fi + +CUDA_HOME=${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(command -v nvcc)")")}} CUDNN_IFLAG="" CUDNN_LFLAG="" for dir in /usr/local/cuda/include /usr/include; do @@ -191,15 +382,19 @@ if [ -z "$CUDNN_LFLAG" ]; then CUDNN_LFLAG=$(python -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || echo "") fi -# NCCL include/lib fallback (mirrors the cuDNN fallback above). -# Needed when NCCL is provided by the nvidia-nccl-cu12 wheel in the active venv. NCCL_IFLAG="" NCCL_LFLAG="" for dir in /usr/include /usr/local/cuda/include; do - if [ -f "$dir/nccl.h" ]; then NCCL_IFLAG="-I$dir"; break; fi + if [ -f "$dir/nccl.h" ]; then + NCCL_IFLAG="-I$dir" + break + fi done for dir in /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64; do - if [ -f "$dir/libnccl.so" ] || [ -f "$dir/libnccl.so.2" ]; then NCCL_LFLAG="-L$dir"; break; fi + if [ -f "$dir/libnccl.so" ] || [ -f "$dir/libnccl.so.2" ]; then + NCCL_LFLAG="-L$dir" + break + fi done if [ -z "$NCCL_IFLAG" ]; then NCCL_IFLAG=$(python -c "import nvidia.nccl, os; print('-I' + os.path.join(nvidia.nccl.__path__[0], 'include'))" 2>/dev/null || echo "") @@ -208,110 +403,70 @@ if [ -z "$NCCL_LFLAG" ]; then NCCL_LFLAG=$(python -c "import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2>/dev/null || echo "") fi -export CCACHE_DIR="${CCACHE_DIR:-$HOME/.ccache}" -export CCACHE_BASEDIR="$(pwd)" -export CCACHE_COMPILERCHECK=content -NVCC="ccache $CUDA_HOME/bin/nvcc" -CC="${CC:-$(command -v ccache >/dev/null && echo 'ccache clang' || echo 'clang')}" -ARCH=${NVCC_ARCH:-native} - -PYTHON_INCLUDE=$(python -c "import sysconfig; print(sysconfig.get_path('include'))") -PYBIND_INCLUDE=$(python -c "import pybind11; print(pybind11.get_include())") -NUMPY_INCLUDE=$(python -c "import numpy; print(numpy.get_include())") -EXT_SUFFIX=$(python -c "import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX'))") -OUTPUT="pufferlib/_C${EXT_SUFFIX}" - -BINDING_SRC="$SRC_DIR/binding.c" -mkdir -p build -STATIC_OBJ="build/libstatic_${ENV}.o" -STATIC_LIB="build/libstatic_${ENV}.a" - -if [ ! -f "$BINDING_SRC" ]; then - echo "Error: $BINDING_SRC not found" - exit 1 -fi - -echo "Compiling static library for $ENV..." -${CC:-clang} -c "${CLANG_OPT[@]}" \ - -I. -Isrc -I$SRC_DIR -Ivendor \ - -I./$RAYLIB_NAME/include -I$CUDA_HOME/include \ - -DPLATFORM_DESKTOP \ - -fno-semantic-interposition -fvisibility=hidden \ - -fPIC -fopenmp \ - "$BINDING_SRC" -o "$STATIC_OBJ" -ar rcs "$STATIC_LIB" "$STATIC_OBJ" - -# Brittle hack: have to extract the tensor type from the static lib to build trainer -OBS_TENSOR_T=$(awk '/^#define OBS_TENSOR_T/{print $3}' "$BINDING_SRC") +OBS_TENSOR_T=$(awk ' + /^#define OBS_TENSOR_T/ { print $3; exit } + /^#define OBS_TYPE/ { + if ($3 == "FLOAT") print "FloatTensor"; + else if ($3 == "UNSIGNED_CHAR") print "ByteTensor"; + else if ($3 == "INT") print "IntTensor"; + else if ($3 == "CHAR") print "ByteTensor"; + exit + } +' "$BINDING_SRC") if [ -z "$OBS_TENSOR_T" ]; then echo "Error: Could not find OBS_TENSOR_T in $BINDING_SRC" exit 1 fi -if [ -z "$MODE" ]; then - echo "Compiling CUDA ($ARCH) training backend..." - $NVCC -c -arch=$ARCH -Xcompiler -fPIC \ - -Xcompiler=-D_GLIBCXX_USE_CXX11_ABI=1 \ - -Xcompiler=-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION \ - -Xcompiler=-DPLATFORM_DESKTOP \ - -std=c++17 \ - -I. -Isrc \ - -I$PYTHON_INCLUDE -I$PYBIND_INCLUDE -I$NUMPY_INCLUDE \ - -I$CUDA_HOME/include $CUDNN_IFLAG $NCCL_IFLAG -I$RAYLIB_NAME/include \ - -Xcompiler=-fopenmp \ - -DOBS_TENSOR_T=$OBS_TENSOR_T \ - -DENV_NAME=$ENV \ - $PRECISION $NVCC_OPT \ - src/bindings.cu -o build/bindings.o - - LINK_CMD=( - ${CXX:-g++} -shared -fPIC -fopenmp - build/bindings.o "$STATIC_LIB" "$RAYLIB_A" - -L$CUDA_HOME/lib64 $CUDNN_LFLAG $NCCL_LFLAG - -lcudart -lnccl -lnvidia-ml -lcublas -lcusolver -lcurand -lcudnn - $OMP_LIB $LINK_OPT - "${SHARED_LDFLAGS[@]}" - -o "$OUTPUT" - ) - "${LINK_CMD[@]}" - echo "Built: $OUTPUT" - -elif [ "$MODE" = "cpu" ]; then - echo "Compiling CPU training backend..." - ${CXX:-g++} -c -fPIC -fopenmp \ - -D_GLIBCXX_USE_CXX11_ABI=1 \ - -DPLATFORM_DESKTOP \ - -std=c++17 \ - -I. -Isrc \ - -I$PYTHON_INCLUDE -I$PYBIND_INCLUDE \ - -DOBS_TENSOR_T=$OBS_TENSOR_T \ - -DENV_NAME=$ENV \ - $PRECISION $LINK_OPT \ - src/bindings_cpu.cpp -o build/bindings_cpu.o - LINK_CMD=( - ${CXX:-g++} -shared -fPIC -fopenmp - build/bindings_cpu.o "$STATIC_LIB" "$RAYLIB_A" - -lm -lpthread $OMP_LIB $LINK_OPT - "${SHARED_LDFLAGS[@]}" - -o "$OUTPUT" - ) - "${LINK_CMD[@]}" - echo "Built: $OUTPUT" +export CCACHE_DIR="${CCACHE_DIR:-$HOME/.ccache}" +export CCACHE_BASEDIR="$(pwd)" +export CCACHE_COMPILERCHECK=content +NVCC="ccache $CUDA_HOME/bin/nvcc" +ARCH=${NVCC_ARCH:-native} -elif [ "$MODE" = "profile" ]; then +if [ "$MODE" = "profile" ]; then echo "Compiling profile binary ($ARCH)..." - $NVCC $NVCC_OPT -arch=$ARCH -std=c++17 \ - -I. -Isrc -I$SRC_DIR -Ivendor \ - -I$CUDA_HOME/include $CUDNN_IFLAG $NCCL_IFLAG -I$RAYLIB_NAME/include \ - -DOBS_TENSOR_T=$OBS_TENSOR_T \ - -DENV_NAME=$ENV \ + $NVCC "${NVCC_OPT[@]}" -arch=$ARCH -std=c++17 \ + -I. -Isrc -I"$SRC_DIR" -Ivendor \ + -I"$CUDA_HOME/include" $CUDNN_IFLAG $NCCL_IFLAG -I"$RAYLIB_NAME/include" \ + -DOBS_TENSOR_T="$OBS_TENSOR_T" \ + -DENV_NAME="$ENV" \ -Xcompiler=-DPLATFORM_DESKTOP \ - $PRECISION \ + ${PRECISION:+$PRECISION} \ -Xcompiler=-fopenmp \ tests/profile_kernels.cu vendor/ini.c \ "$STATIC_LIB" "$RAYLIB_A" \ -lnccl -lnvidia-ml -lcublas -lcurand -lcudnn \ - -lGL -lm -lpthread $OMP_LIB \ + -lGL -lm -lpthread "$OMP_LIB" \ -o profile echo "Built: ./profile" + exit 0 fi + +echo "Compiling CUDA ($ARCH) training backend..." +$NVCC -c -arch=$ARCH -Xcompiler -fPIC \ + -Xcompiler=-D_GLIBCXX_USE_CXX11_ABI=1 \ + -Xcompiler=-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION \ + -Xcompiler=-DPLATFORM_DESKTOP \ + -std=c++17 \ + -I. -Isrc \ + -I"$PYTHON_INCLUDE" -I"$PYBIND_INCLUDE" -I"$NUMPY_INCLUDE" \ + -I"$CUDA_HOME/include" $CUDNN_IFLAG $NCCL_IFLAG -I"$RAYLIB_NAME/include" \ + -Xcompiler=-fopenmp \ + -DOBS_TENSOR_T="$OBS_TENSOR_T" \ + -DENV_NAME="$ENV" \ + ${PRECISION:+$PRECISION} "${NVCC_OPT[@]}" \ + src/bindings.cu -o build/bindings.o + +LINK_CMD=( + "$DEFAULT_CXX" -shared -fPIC -fopenmp + build/bindings.o "$STATIC_LIB" "$RAYLIB_A" + -L"$CUDA_HOME/lib64" $CUDNN_LFLAG $NCCL_LFLAG + -lcudart -lnccl -lnvidia-ml -lcublas -lcusolver -lcurand -lcudnn + "$OMP_LIB" + "${LINK_OPT[@]}" + "${SHARED_LDFLAGS[@]}" + -o "$OUTPUT" +) +"${LINK_CMD[@]}" +echo "Built: $OUTPUT" diff --git a/config/default.ini b/config/default.ini index fd53bf7d95..8fc61c21dc 100644 --- a/config/default.ini +++ b/config/default.ini @@ -54,8 +54,13 @@ vf_clip_coef = 0.2 max_grad_norm = 1.5 ent_coef = 0.001 beta1 = 0.95 +weight_decay = 0.0 beta2 = 0.999 eps = 1e-12 +overlap = 0 +cpu_inference = 0 +train_fp16 = 0 +ns_iters = 5 minibatch_size = 8192 horizon = 64 vtrace_rho_clip = 1.0 diff --git a/src/cpu_inference.h b/src/cpu_inference.h new file mode 100644 index 0000000000..45ac0be97e --- /dev/null +++ b/src/cpu_inference.h @@ -0,0 +1,328 @@ +// CPU rollout inference path for Metal. + +#ifndef PUFFERLIB_CPU_INFERENCE_H +#define PUFFERLIB_CPU_INFERENCE_H + +#include +#include +#include +#include +#include + +// ============================================================================ +// CPU Philox 4x32-10 RNG — matches MSL philox4x32_10 exactly +// ============================================================================ + +// Single Philox round: mulhi-based bijection. +// MSL uses mulhi(M, x) = (uint64_t(M) * x) >> 32. +static inline void cpu_philox_round(uint32_t ctr[4], const uint32_t key[2]) { + constexpr uint32_t M0 = 0xD2511F53u; + constexpr uint32_t M1 = 0xCD9E8D57u; + uint32_t hi0 = (uint32_t)(((uint64_t)M0 * ctr[0]) >> 32); + uint32_t lo0 = M0 * ctr[0]; + uint32_t hi1 = (uint32_t)(((uint64_t)M1 * ctr[2]) >> 32); + uint32_t lo1 = M1 * ctr[2]; + // MSL: uint4(hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0) + uint32_t new0 = hi1 ^ ctr[1] ^ key[0]; + uint32_t new1 = lo1; + uint32_t new2 = hi0 ^ ctr[3] ^ key[1]; + uint32_t new3 = lo0; + ctr[0] = new0; ctr[1] = new1; ctr[2] = new2; ctr[3] = new3; +} + +// 10-round Philox: matches MSL philox4x32_10. +// counter = {idx, offset, 0, 0}, key = {seed_lo, seed_hi}. +static inline void cpu_philox4x32_10(uint32_t counter[4], uint32_t key[2]) { + constexpr uint32_t W0 = 0x9E3779B9u; + constexpr uint32_t W1 = 0xBB67AE85u; + for (int i = 0; i < 10; i++) { + cpu_philox_round(counter, key); + key[0] += W0; + key[1] += W1; + } +} + +// Per-agent RNG state for CPU sampling. +struct CpuPhiloxState { + uint32_t counter[4]; + uint32_t key[2]; + uint32_t output[4]; + int consumed; // how many of output[0..3] have been used +}; + +static inline void cpu_philox_init(CpuPhiloxState &s, uint32_t idx, + uint32_t offset, uint64_t seed) { + s.counter[0] = idx; + s.counter[1] = offset; + s.counter[2] = 0; + s.counter[3] = 0; + s.key[0] = (uint32_t)(seed & 0xFFFFFFFF); + s.key[1] = (uint32_t)(seed >> 32); + // Generate first 4 random values + uint32_t ctr[4] = {s.counter[0], s.counter[1], s.counter[2], s.counter[3]}; + uint32_t key[2] = {s.key[0], s.key[1]}; + cpu_philox4x32_10(ctr, key); + memcpy(s.output, ctr, 16); + s.consumed = 0; +} + +// MSL philox_uniform: reads output[state_idx & 3], increments state_idx. +// When all 4 consumed, bumps counter.z and regenerates. +static inline float cpu_philox_uniform(CpuPhiloxState &s) { + if (s.consumed >= 4) { + s.counter[2]++; + uint32_t ctr[4] = {s.counter[0], s.counter[1], s.counter[2], s.counter[3]}; + uint32_t key[2] = {s.key[0], s.key[1]}; + cpu_philox4x32_10(ctr, key); + memcpy(s.output, ctr, 16); + s.consumed = 0; + } + uint32_t val = s.output[s.consumed++]; + return ((float)(val >> 8) + 0.5f) / 16777216.0f; +} + +// ============================================================================ +// CPU activation functions — matching MSL implementations exactly +// ============================================================================ + +// Stable sigmoid: matches MSL sigmoid_f. Used for gate and proj. +static inline float cpu_sigmoid(float x) { + float z = expf(-fabsf(x)); + return x >= 0.0f ? 1.0f / (1.0f + z) : z / (1.0f + z); +} + +// Horner polynomial tanh: matches MSL fast_tanh_f. +static inline float cpu_fast_tanh(float x) { + float v1 = x < -9.0f ? -9.0f : (x > 9.0f ? 9.0f : x); + float v2 = v1 * v1; + float p = v2 * (-2.76076847742355e-16f) + 2.00018790482477e-13f; + p = v2 * p + (-8.60467152213735e-11f); + p = v2 * p + 5.12229709037114e-08f; + p = v2 * p + 1.48572235717979e-05f; + p = v2 * p + 6.37261928875436e-04f; + p = v2 * p + 4.89352455891786e-03f; + p = v1 * p; + float q = v2 * 1.19825839466702e-06f + 1.18534705686654e-04f; + q = v2 * q + 2.26843463243900e-03f; + q = v2 * q + 4.89352518554385e-03f; + return p / q; +} + +// Polynomial sigmoid: matches MSL fast_sigmoid_f. Used inside tilde_relu only. +static inline float cpu_fast_sigmoid(float x) { + float y = cpu_fast_tanh(x * 0.5f); + float result = (y + 1.0f) * 0.5f; + return result < 0.0f ? 0.0f : (result > 1.0f ? 1.0f : result); +} + +// tilde_relu: x >= 0 ? x + 0.5 : fast_sigmoid(x). Matches MSL tilde_relu_fwd. +static inline float cpu_tilde_relu(float x) { + return x >= 0.0f ? x + 0.5f : cpu_fast_sigmoid(x); +} + +// Careful lerp avoiding catastrophic cancellation: matches MSL lerp_f. +static inline float cpu_lerp(float a, float b, float w) { + float diff = b - a; + return fabsf(w) < 0.5f ? a + w * diff : b - diff * (1.0f - w); +} + +// ============================================================================ +// CPU GEMM — cblas_sgemm via Accelerate (AMX/SME on Apple Silicon) +// ============================================================================ + +// out(M,N) = a(M,K) @ b(N,K)^T — matches puf_mm convention. +// Weight b is stored as (N,K) row-major (N rows of K elements). +static inline void cpu_mm_nt(const float *a, const float *b, float *out, + int M, int K, int N) { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + M, N, K, 1.0f, a, K, b, K, 0.0f, out, N); +} + +// ============================================================================ +// CPU MinGRU gate — element-wise matching MSL mingru_gate_inference +// ============================================================================ + +// combined: (B, 3*H) = [hidden | gate | proj] per row +// state_in, x_in, out, next_state: all (B, H) +static void cpu_mingru_gate(float *out, float *next_state, + const float *combined, const float *state_in, + const float *x_in, int H, int B) { + int BH = B * H; + for (int idx = 0; idx < BH; idx++) { + int b = idx / H; + int h = idx % H; + int base = b * 3 * H; + + float hidden = combined[base + h]; + float gate = combined[base + H + h]; + float proj = combined[base + 2 * H + h]; + float state = state_in[idx]; + float x = x_in[idx]; + + float gate_sig = cpu_sigmoid(gate); + float hidden_tilde = cpu_tilde_relu(hidden); + float mingru_out = cpu_lerp(state, hidden_tilde, gate_sig); + float proj_sig = cpu_sigmoid(proj); + + next_state[idx] = fmaxf(mingru_out, 1e-30f); + out[idx] = proj_sig * mingru_out + (1.0f - proj_sig) * x; + } +} + +// ============================================================================ +// CPU discrete action sampling — matches MSL sample_logits_kernel +// ============================================================================ + +// Apply action mask: invalid actions get -1e9. +static inline float cpu_mask_logit(float logit, float mask) { + if (mask < 0.5f) logit = -1e9f; + return logit; +} + +// Sample all action heads for B agents. +// dec_out: (B, fused_cols) — logits + value in last column +// act_sizes: (num_atns,) — number of discrete actions per head +// action_out_f32: (B, num_atns) — sampled action indices +// logprobs: (B,) — scalar joint log-probability (sum across heads, matches CUDA) +// value_out: (B,) — value head output +static void cpu_sample_logits( + const float *dec_out, int fused_cols, + const int *act_sizes, int num_atns, + float *action_out_f32, float *logprobs, float *value_out, + const float *action_mask, int mask_stride, + uint64_t seed, uint32_t *offset_ptr, int B) { + + uint32_t offset_snapshot = *offset_ptr; + *offset_ptr = offset_snapshot + 1u; + + for (int idx = 0; idx < B; idx++) { + CpuPhiloxState rng; + cpu_philox_init(rng, (uint32_t)idx, offset_snapshot, seed); + + const float *logits = dec_out + idx * fused_cols; + // mask_stride=0 means all agents read the same mask (all-ones fallback) + const float *mask = (mask_stride == 0) + ? action_mask + : action_mask + idx * mask_stride; + + int logits_offset = 0; + // CUDA joint-ratio: accumulate scalar total_log_prob across heads + float total_log_prob = 0.0f; + + for (int h = 0; h < num_atns; h++) { + int A = act_sizes[h]; + + // max for numerical stability + float max_val = -INFINITY; + for (int a = 0; a < A; a++) { + float l = cpu_mask_logit(logits[logits_offset + a], + mask[logits_offset + a]); + if (l > max_val) max_val = l; + } + + // logsumexp + float sum_exp = 0.0f; + for (int a = 0; a < A; a++) { + float l = cpu_mask_logit(logits[logits_offset + a], + mask[logits_offset + a]); + sum_exp += expf(l - max_val); + } + float logsumexp_val = max_val + logf(sum_exp); + + // Philox uniform sample + float rand_val = cpu_philox_uniform(rng); + + // Inverse CDF sampling + float cumsum = 0.0f; + int sampled = A - 1; + for (int a = 0; a < A; a++) { + float l = cpu_mask_logit(logits[logits_offset + a], + mask[logits_offset + a]); + cumsum += expf(l - logsumexp_val); + if (rand_val < cumsum) { + sampled = a; + break; + } + } + + float sl = cpu_mask_logit(logits[logits_offset + sampled], + mask[logits_offset + sampled]); + total_log_prob += sl - logsumexp_val; + + action_out_f32[idx * num_atns + h] = (float)sampled; + logits_offset += A; + } + // Scalar joint logprob matching CUDA: logprobs[idx] = sum of per-head log probs + logprobs[idx] = total_log_prob; + value_out[idx] = dec_out[idx * fused_cols + (fused_cols - 1)]; + } +} + +// ============================================================================ +// CPU forward pass — complete rollout step for one buffer +// ============================================================================ + +// Full CPU forward: encoder → MinGRU layers → decoder → sampling. +// Reads weights (unified memory, read-only). Writes to activation buffers +// (unified memory, per-buffer isolation). No GPU dispatch, no sync. +static void cpu_forward_and_sample( + PufTensor &obs, // (B, obs_dim) + PufTensor &state, // (num_layers, B, H) + PolicyWeights &weights, + int hidden_dim, + PolicyActivations &acts, + IntTensor &act_sizes_puf, + FloatTensor &act_f32_buf, // (B, num_atns) scratch + float *logprobs_out, // (B,) + float *values_out, // (B,) + const float *action_mask, int mask_stride, + uint64_t rng_seed, uint32_t *rng_offset_ptr) { + + int B = (int)obs.shape[0]; + int obs_dim = (int)obs.shape[1]; + int H = hidden_dim; + + EncoderWeights *ew = (EncoderWeights *)weights.encoder; + MinGRUWeights *mw = (MinGRUWeights *)weights.network; + MinGRUActivations *ma = (MinGRUActivations *)acts.network; + DecoderWeights *dw = (DecoderWeights *)weights.decoder; + DecoderActivations *da = (DecoderActivations *)acts.decoder; + + // --- Encoder --- + EncoderActivations *ea = (EncoderActivations *)acts.encoder; + cpu_mm_nt((const float *)obs.bytes, ew->weight.data, + ea->out.data, B, obs_dim, ew->out_dim); + float *layer_input = ea->out.data; + + // --- MinGRU layers --- + for (int i = 0; i < mw->num_layers; i++) { + float *state_i = (float *)state.bytes + i * B * H; + int input_K = H; + cpu_mm_nt(layer_input, mw->weights[i].data, + ma->combined[i].data, B, input_K, 3 * H); + + cpu_mingru_gate(ma->out.data, ma->next_state.data, + ma->combined[i].data, + state_i, layer_input, H, B); + + // Update RNN state + memcpy(state_i, ma->next_state.data, B * H * sizeof(float)); + + layer_input = ma->out.data; + } + + // --- Decoder --- + int fused_cols = dw->output_dim + 1; + cpu_mm_nt(layer_input, dw->weight.data, + da->out.data, B, H, fused_cols); + + // --- Sampling --- + cpu_sample_logits( + da->out.data, fused_cols, + act_sizes_puf.data, (int)puf_numel(act_sizes_puf.shape), + act_f32_buf.data, logprobs_out, values_out, + action_mask, mask_stride, + rng_seed, rng_offset_ptr, B); +} + +#endif // PUFFERLIB_CPU_INFERENCE_H diff --git a/src/metal_bindings.mm b/src/metal_bindings.mm new file mode 100644 index 0000000000..94f867b8c6 --- /dev/null +++ b/src/metal_bindings.mm @@ -0,0 +1,566 @@ +#include "metal_pufferlib.mm" + +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#define _PUFFER_STRINGIFY(x) #x +#define PUFFER_STRINGIFY(x) _PUFFER_STRINGIFY(x) + +static double wall_clock() { + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + +static std::string tensor_repr(const FloatTensor& tensor) { + char buffer[256]; + int position = snprintf(buffer, sizeof(buffer), "FloatTensor(["); + int dims = puf_ndim(tensor.shape); + for (int i = 0; i < dims && position < (int)sizeof(buffer) - 32; i++) { + position += snprintf( + buffer + position, + sizeof(buffer) - position, + "%s%lld", + i ? ", " : "", + (long long)tensor.shape[i]); + } + snprintf( + buffer + position, + sizeof(buffer) - position, + "], %lld elems)", + (long long)puf_numel(tensor.shape)); + return std::string(buffer); +} + +static py::dict get_utilization(int gpu_id) { + (void)gpu_id; + py::dict result; + + MetalContext* ctx = mtl_ctx(); + if (ctx->device) { + uint64_t gpu_budget = [ctx->device recommendedMaxWorkingSetSize]; + uint64_t gpu_current = [ctx->device currentAllocatedSize]; + result["gpu_percent"] = 0.0f; + if (gpu_budget > 0) { + result["gpu_mem"] = 100.0f * (float)gpu_current / (float)gpu_budget; + } else { + result["gpu_mem"] = 0.0f; + } + result["vram_used_gb"] = (float)gpu_current / (1024.0f * 1024.0f * 1024.0f); + result["vram_total_gb"] = (float)gpu_budget / (1024.0f * 1024.0f * 1024.0f); + } + + struct mach_task_basic_info info; + mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT; + if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, + (task_info_t)&info, &count) == KERN_SUCCESS) { + result["cpu_mem_gb"] = (float)info.resident_size / (1024.0f * 1024.0f * 1024.0f); + } + + return result; +} + +static py::dict puf_log(py::object pufferl_obj) { + auto& pufferl = pufferl_obj.cast(); + py::dict result; + + if (pufferl.train_pending) { + auto wait_start = std::chrono::high_resolution_clock::now(); + sync_pending_train(pufferl); + float wait_ms = std::chrono::duration( + std::chrono::high_resolution_clock::now() - wait_start).count(); + pufferl.profile.accum[PROF_TRAIN_SYNC] += wait_ms; + } + mtl_ensure_stream_synced((cudaStream_t)mtl_stream()); + + long global_step = pufferl.global_step; + int epoch = pufferl.epoch; + double now = wall_clock(); + double dt = now - pufferl.last_log_time; + long sps = dt > 0 ? (long)((global_step - pufferl.last_log_step) / dt) : 0; + pufferl.last_log_time = now; + pufferl.last_log_step = global_step; + + result["SPS"] = sps; + result["agent_steps"] = global_step; + result["uptime"] = now - pufferl.start_time; + result["epoch"] = epoch; + + py::dict env_dict; + Dict* env_out = log_environments_impl(pufferl); + for (int i = 0; i < env_out->size; i++) { + env_dict[env_out->items[i].key] = env_out->items[i].value; + } + result["env"] = env_dict; + + py::dict loss_dict; + float* losses = pufferl.losses_puf.data; + float n = losses[LOSS_N]; + if (n > 0) { + float inv_n = 1.0f / n; + loss_dict["policy"] = losses[LOSS_PG] * inv_n; + loss_dict["value"] = losses[LOSS_VF] * inv_n; + loss_dict["entropy"] = losses[LOSS_ENT] * inv_n; + loss_dict["total"] = losses[LOSS_TOTAL] * inv_n; + loss_dict["old_kl"] = losses[LOSS_OLD_APPROX_KL] * inv_n; + loss_dict["kl"] = losses[LOSS_APPROX_KL] * inv_n; + loss_dict["clipfrac"] = losses[LOSS_CLIPFRAC] * inv_n; + } + cudaStream_t loss_stream = pufferl.overlap_enabled + ? (cudaStream_t)mtl_train_stream() + : (cudaStream_t)mtl_stream(); + mtl_fill_f32(losses, 0.0f, (int)puf_numel(pufferl.losses_puf.shape), loss_stream); + result["loss"] = loss_dict; + + py::dict perf_dict; + float train_ms = 0.0f; + for (int i = 0; i < NUM_PROF; i++) { + float sec = pufferl.profile.accum[i] / 1000.0f; + perf_dict[PROF_NAMES[i]] = sec; + if (i >= PROF_TRAIN_PRELOOP) { + train_ms += pufferl.profile.accum[i]; + } + } + perf_dict["train"] = train_ms / 1000.0f; + memset(pufferl.profile.accum, 0, sizeof(pufferl.profile.accum)); + pufferl.rollout_sync_count = 0; + pufferl.rollout_sync_ms = 0; + pufferl.train_sync_count = 0; + pufferl.train_sync_ms = 0; + result["perf"] = perf_dict; + + result["util"] = get_utilization(0); + return result; +} + +static py::dict puf_eval_log(py::object pufferl_obj) { + auto& pufferl = pufferl_obj.cast(); + py::dict result; + + double now = wall_clock(); + pufferl.last_log_time = now; + pufferl.last_log_step = pufferl.global_step; + + py::dict env_dict; + Dict* env_out = create_dict(32); + static_vec_eval_log(pufferl.vec, env_out); + for (int i = 0; i < env_out->size; i++) { + env_dict[env_out->items[i].key] = env_out->items[i].value; + } + result["env"] = env_dict; + return result; +} + +static void render(py::object pufferl_obj, int env_id) { + PuffeRL& pufferl = pufferl_obj.cast(); + static_vec_render(pufferl.vec, env_id); +} + +static void rollouts(py::object pufferl_obj) { + PuffeRL& pufferl = pufferl_obj.cast(); + + { int count; double ms; mtl_sync_stats(&count, &ms); } + + py::gil_scoped_release no_gil; + + if (pufferl.train_pending) { + auto wait_start = std::chrono::high_resolution_clock::now(); + sync_pending_train(pufferl); + float wait_ms = std::chrono::duration( + std::chrono::high_resolution_clock::now() - wait_start).count(); + pufferl.profile.accum[PROF_TRAIN_SYNC] += wait_ms; + } + + auto t0 = std::chrono::high_resolution_clock::now(); + if (!pufferl.cpu_inference) puf_set_gpu_training(true); + static_vec_omp_step(pufferl.vec); + if (!pufferl.cpu_inference) puf_set_gpu_training(false); + float sec = std::chrono::duration( + std::chrono::high_resolution_clock::now() - t0).count(); + pufferl.profile.accum[PROF_ROLLOUT] += sec * 1000.0f; + + float eval_prof[NUM_EVAL_PROF]; + static_vec_read_profile(pufferl.vec, eval_prof); + pufferl.profile.accum[PROF_EVAL_GPU] += eval_prof[EVAL_GPU]; + pufferl.profile.accum[PROF_EVAL_ENV] += eval_prof[EVAL_ENV_STEP]; + + mtl_sync_stats(&pufferl.rollout_sync_count, &pufferl.rollout_sync_ms); + pufferl.global_step += pufferl.hypers.horizon * pufferl.hypers.total_agents; +} + +static py::dict train(py::object pufferl_obj) { + PuffeRL& pufferl = pufferl_obj.cast(); + { int count; double ms; mtl_sync_stats(&count, &ms); } + + { + py::gil_scoped_release no_gil; + train_impl(pufferl); + } + + mtl_sync_stats(&pufferl.train_sync_count, &pufferl.train_sync_ms); + return py::dict(); +} + +static void puf_close(py::object pufferl_obj) { + PuffeRL& pufferl = pufferl_obj.cast(); + close_impl(pufferl); +} + +static void save_weights(py::object pufferl_obj, const std::string& path) { + PuffeRL& pufferl = pufferl_obj.cast(); + int64_t nbytes = pufferl.alloc_fp32.params.total_elems * sizeof(float); + + sync_pending_train(pufferl); + mtl_ensure_stream_synced((cudaStream_t)mtl_stream()); + + FILE* f = fopen(path.c_str(), "wb"); + if (!f) { + throw std::runtime_error("Failed to open " + path + " for writing"); + } + fwrite(pufferl.alloc_fp32.params.mem, 1, nbytes, f); + fclose(f); +} + +static void load_weights(py::object pufferl_obj, const std::string& path) { + PuffeRL& pufferl = pufferl_obj.cast(); + int64_t param_count = pufferl.alloc_fp32.params.total_elems; + int64_t nbytes = param_count * sizeof(float); + + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + throw std::runtime_error("Failed to open " + path + " for reading"); + } + fseek(f, 0, SEEK_END); + long file_size = ftell(f); + fseek(f, 0, SEEK_SET); + if (file_size != nbytes) { + fclose(f); + throw std::runtime_error( + "Weight file size mismatch: expected " + std::to_string(nbytes) + + " bytes, got " + std::to_string(file_size)); + } + + sync_pending_train(pufferl); + mtl_ensure_stream_synced((cudaStream_t)mtl_stream()); + + size_t nread = fread(pufferl.alloc_fp32.params.mem, 1, nbytes, f); + if ((int64_t)nread != nbytes) { + fclose(f); + throw std::runtime_error("Failed to read weight file"); + } + fclose(f); + + copy_weights_to_infer(pufferl); + if (pufferl.train_fp16) { + cudaStream_t stream = (cudaStream_t)mtl_stream(); + mtl_cast_f32_to_f16( + pufferl.param_fp16_puf.bytes, + (const float*)pufferl.alloc_fp32.params.mem, + (int)param_count, + stream); + mtl_ensure_stream_synced(stream); + } +} + +static void py_puff_advantage_cpu( + long long values_ptr, long long rewards_ptr, + long long dones_ptr, long long importance_ptr, + long long advantages_ptr, + int num_steps, int horizon, + float gamma, float lambda, float rho_clip, float c_clip) { + const float* values = (const float*)values_ptr; + const float* rewards = (const float*)rewards_ptr; + const float* dones = (const float*)dones_ptr; + const float* importance = (const float*)importance_ptr; + float* advantages = (float*)advantages_ptr; + + for (int row = 0; row < num_steps; row++) { + int offset = row * horizon; + float last = 0.0f; + for (int t = horizon - 2; t >= 0; t--) { + int next_t = t + 1; + float next_nonterminal = 1.0f - dones[offset + next_t]; + float imp = importance[offset + t]; + float rho_t = imp < rho_clip ? imp : rho_clip; + float c_t = imp < c_clip ? imp : c_clip; + float delta = rho_t * rewards[offset + next_t] + + gamma * values[offset + next_t] * next_nonterminal + - values[offset + t]; + last = delta + gamma * lambda * c_t * last * next_nonterminal; + advantages[offset + t] = last; + } + } +} + +static double get_config(py::dict& kwargs, const char* key) { + assert(kwargs.contains(key) && "Missing config key"); + return kwargs[key].cast(); +} + +static Dict* py_dict_to_c_dict(py::dict py_dict) { + Dict* c_dict = create_dict(py_dict.size()); + for (auto item : py_dict) { + const char* key = PyUnicode_AsUTF8(item.first.ptr()); + try { + dict_set(c_dict, key, item.second.cast()); + } catch (const py::cast_error&) { + } + } + return c_dict; +} + +struct VecEnv { + StaticVec* vec; + int total_agents; + int obs_size; + int num_atns; + std::vector act_sizes; + std::string obs_dtype; + size_t obs_elem_size; +}; + +static std::unique_ptr create_vec(py::dict args, int gpu = 0) { + (void)gpu; + py::dict vec_kwargs = args["vec"].cast(); + py::dict env_kwargs = args["env"].cast(); + + int total_agents = (int)get_config(vec_kwargs, "total_agents"); + int num_buffers = (int)get_config(vec_kwargs, "num_buffers"); + Dict* vec_dict = py_dict_to_c_dict(vec_kwargs); + Dict* env_dict = py_dict_to_c_dict(env_kwargs); + + auto ve = std::make_unique(); + { + py::gil_scoped_release no_gil; + ve->vec = create_static_vec(total_agents, num_buffers, 0, vec_dict, env_dict); + } + + ve->total_agents = total_agents; + ve->obs_size = get_obs_size(); + ve->num_atns = get_num_atns(); + { + int* raw = get_act_sizes(); + int n = get_num_act_sizes(); + ve->act_sizes = std::vector(raw, raw + n); + } + ve->obs_dtype = std::string(get_obs_dtype()); + ve->obs_elem_size = get_obs_elem_size(); + return ve; +} + +static void vec_reset(VecEnv& ve) { + py::gil_scoped_release no_gil; + static_vec_reset(ve.vec); +} + +static void cpu_vec_step_py(VecEnv& ve, long long actions_ptr) { + memcpy( + ve.vec->actions, + (void*)actions_ptr, + (size_t)ve.total_agents * ve.num_atns * sizeof(float)); + { + py::gil_scoped_release no_gil; + cpu_vec_step(ve.vec); + } +} + +static py::dict vec_log(VecEnv& ve) { + Dict* out = create_dict(32); + static_vec_log(ve.vec, out); + py::dict result; + for (int i = 0; i < out->size; i++) { + result[out->items[i].key] = out->items[i].value; + } + free(out->items); + free(out); + return result; +} + +static void vec_close(VecEnv& ve) { + static_vec_close(ve.vec); + ve.vec = nullptr; +} + +static std::unique_ptr create_pufferl(py::dict args) { + py::dict train_kwargs = args["train"].cast(); + py::dict vec_kwargs = args["vec"].cast(); + py::dict env_kwargs = args["env"].cast(); + py::dict policy_kwargs = args["policy"].cast(); + + HypersT hypers; + hypers.total_agents = get_config(vec_kwargs, "total_agents"); + hypers.num_buffers = get_config(vec_kwargs, "num_buffers"); + hypers.num_threads = get_config(vec_kwargs, "num_threads"); + hypers.horizon = get_config(train_kwargs, "horizon"); + hypers.hidden_size = get_config(policy_kwargs, "hidden_size"); + hypers.num_layers = get_config(policy_kwargs, "num_layers"); + hypers.seed = args.contains("seed") ? (uint64_t)get_config(args, "seed") + : train_kwargs.contains("seed") ? (uint64_t)get_config(train_kwargs, "seed") : 42; + hypers.lr = get_config(train_kwargs, "learning_rate"); + hypers.min_lr_ratio = get_config(train_kwargs, "min_lr_ratio"); + hypers.anneal_lr = get_config(train_kwargs, "anneal_lr"); + hypers.beta1 = get_config(train_kwargs, "beta1"); + hypers.weight_decay = get_config(train_kwargs, "weight_decay"); + hypers.minibatch_size = get_config(train_kwargs, "minibatch_size"); + hypers.replay_ratio = get_config(train_kwargs, "replay_ratio"); + hypers.total_timesteps = get_config(train_kwargs, "total_timesteps"); + hypers.max_grad_norm = get_config(train_kwargs, "max_grad_norm"); + hypers.clip_coef = get_config(train_kwargs, "clip_coef"); + hypers.vf_clip_coef = get_config(train_kwargs, "vf_clip_coef"); + hypers.vf_coef = get_config(train_kwargs, "vf_coef"); + hypers.ent_coef = get_config(train_kwargs, "ent_coef"); + hypers.gamma = get_config(train_kwargs, "gamma"); + hypers.gae_lambda = get_config(train_kwargs, "gae_lambda"); + hypers.vtrace_rho_clip = get_config(train_kwargs, "vtrace_rho_clip"); + hypers.vtrace_c_clip = get_config(train_kwargs, "vtrace_c_clip"); + hypers.prio_alpha = get_config(train_kwargs, "prio_alpha"); + hypers.prio_beta0 = get_config(train_kwargs, "prio_beta0"); + hypers.reset_state = + (args.contains("reset_state") && get_config(args, "reset_state") > 0) || + (train_kwargs.contains("reset_state") && get_config(train_kwargs, "reset_state") > 0); + hypers.profile = train_kwargs.contains("profile") ? get_config(train_kwargs, "profile") + : args.contains("profile") ? get_config(args, "profile") : 0; + hypers.overlap = + (train_kwargs.contains("overlap") && get_config(train_kwargs, "overlap") > 0) || + (args.contains("overlap") && get_config(args, "overlap") > 0); + hypers.cpu_inference = + (train_kwargs.contains("cpu_inference") && get_config(train_kwargs, "cpu_inference") > 0) || + (args.contains("cpu_inference") && get_config(args, "cpu_inference") > 0); + hypers.train_fp16 = + (train_kwargs.contains("train_fp16") && get_config(train_kwargs, "train_fp16") > 0) || + (args.contains("train_fp16") && get_config(args, "train_fp16") > 0); + hypers.ns_iters = train_kwargs.contains("ns_iters") ? (int)get_config(train_kwargs, "ns_iters") + : args.contains("ns_iters") ? (int)get_config(args, "ns_iters") : 5; + hypers.gpu_id = args.contains("gpu_id") ? (int)get_config(args, "gpu_id") : 0; + + mtl_enable_gpu_timing(hypers.profile); + + std::string env_name = args["env_name"].cast(); + Dict* vec_dict = py_dict_to_c_dict(vec_kwargs); + Dict* env_dict = py_dict_to_c_dict(env_kwargs); + + std::unique_ptr pufferl; + { + py::gil_scoped_release no_gil; + pufferl = create_pufferl_impl(hypers, env_name, vec_dict, env_dict); + } + + return pufferl; +} + +PYBIND11_MODULE(_C, m) { + m.def("get_nccl_id", []() -> py::bytes { + throw std::runtime_error("Metal backend does not support multi-GPU"); + }); + m.def("get_utilization", &get_utilization); + + m.attr("precision_bytes") = 4; + m.attr("env_name") = PUFFER_STRINGIFY(ENV_NAME); + m.attr("gpu") = 0; + + m.def("log", &puf_log); + m.def("eval_log", &puf_eval_log); + m.def("render", &render); + m.def("rollouts", &rollouts); + m.def("train", &train); + m.def("close", &puf_close); + m.def("save_weights", &save_weights); + m.def("load_weights", &load_weights); + m.def("uptime", [](py::object pufferl_obj) -> double { + PuffeRL& pufferl = pufferl_obj.cast(); + return wall_clock() - pufferl.start_time; + }); + + m.def("puff_advantage_cpu", &py_puff_advantage_cpu); + m.def("create_vec", &create_vec, py::arg("args"), py::arg("gpu") = 0); + + py::class_(m, "Policy"); + py::class_(m, "Muon"); + py::class_(m, "Allocator").def(py::init<>()); + + py::class_(m, "HypersT") + .def_readwrite("horizon", &HypersT::horizon) + .def_readwrite("total_agents", &HypersT::total_agents) + .def_readwrite("num_buffers", &HypersT::num_buffers) + .def_readwrite("num_atns", &HypersT::num_atns) + .def_readwrite("hidden_size", &HypersT::hidden_size) + .def_readwrite("replay_ratio", &HypersT::replay_ratio) + .def_readwrite("num_layers", &HypersT::num_layers) + .def_readwrite("seed", &HypersT::seed) + .def_readwrite("lr", &HypersT::lr) + .def_readwrite("min_lr_ratio", &HypersT::min_lr_ratio) + .def_readwrite("anneal_lr", &HypersT::anneal_lr) + .def_readwrite("beta1", &HypersT::beta1) + .def_readwrite("weight_decay", &HypersT::weight_decay) + .def_readwrite("total_timesteps", &HypersT::total_timesteps) + .def_readwrite("max_grad_norm", &HypersT::max_grad_norm) + .def_readwrite("clip_coef", &HypersT::clip_coef) + .def_readwrite("vf_clip_coef", &HypersT::vf_clip_coef) + .def_readwrite("vf_coef", &HypersT::vf_coef) + .def_readwrite("ent_coef", &HypersT::ent_coef) + .def_readwrite("gamma", &HypersT::gamma) + .def_readwrite("gae_lambda", &HypersT::gae_lambda) + .def_readwrite("vtrace_rho_clip", &HypersT::vtrace_rho_clip) + .def_readwrite("vtrace_c_clip", &HypersT::vtrace_c_clip) + .def_readwrite("prio_alpha", &HypersT::prio_alpha) + .def_readwrite("prio_beta0", &HypersT::prio_beta0) + .def_readwrite("reset_state", &HypersT::reset_state) + .def_readwrite("profile", &HypersT::profile) + .def_readwrite("overlap", &HypersT::overlap) + .def_readwrite("cpu_inference", &HypersT::cpu_inference) + .def_readwrite("train_fp16", &HypersT::train_fp16) + .def_readwrite("ns_iters", &HypersT::ns_iters) + .def_readwrite("gpu_id", &HypersT::gpu_id); + + py::class_(m, "FloatTensor") + .def("__repr__", [](const FloatTensor& t) { return tensor_repr(t); }) + .def("ndim", [](const FloatTensor& t) { return puf_ndim(t.shape); }) + .def("numel", [](const FloatTensor& t) { return puf_numel(t.shape); }); + m.attr("PrecisionTensor") = m.attr("FloatTensor"); + + py::class_(m, "RolloutBuf") + .def_readwrite("observations", &RolloutBuf::observations) + .def_readwrite("actions", &RolloutBuf::actions) + .def_readwrite("values", &RolloutBuf::values) + .def_readwrite("logprobs", &RolloutBuf::logprobs) + .def_readwrite("rewards", &RolloutBuf::rewards) + .def_readwrite("terminals", &RolloutBuf::terminals) + .def_readwrite("ratio", &RolloutBuf::ratio) + .def_readwrite("importance", &RolloutBuf::importance); + + py::class_>(m, "VecEnv") + .def_readonly("total_agents", &VecEnv::total_agents) + .def_readonly("obs_size", &VecEnv::obs_size) + .def_readonly("num_atns", &VecEnv::num_atns) + .def_readonly("act_sizes", &VecEnv::act_sizes) + .def_readonly("obs_dtype", &VecEnv::obs_dtype) + .def_readonly("obs_elem_size", &VecEnv::obs_elem_size) + .def_property_readonly("gpu", [](VecEnv&) { return 0; }) + .def_property_readonly("obs_ptr", [](VecEnv& ve) { return (long long)ve.vec->observations; }) + .def_property_readonly("rewards_ptr", [](VecEnv& ve) { return (long long)ve.vec->rewards; }) + .def_property_readonly("terminals_ptr", [](VecEnv& ve) { return (long long)ve.vec->terminals; }) + .def("reset", &vec_reset) + .def("cpu_step", &cpu_vec_step_py) + .def("render", [](VecEnv& ve, int env_id) { static_vec_render(ve.vec, env_id); }) + .def("log", &vec_log) + .def("close", &vec_close); + + m.def("create_pufferl", &create_pufferl); + py::class_>(m, "PuffeRL") + .def_readwrite("policy", &PuffeRL::policy) + .def_readwrite("muon", &PuffeRL::muon) + .def_readwrite("hypers", &PuffeRL::hypers) + .def_readwrite("rollouts", &PuffeRL::rollouts) + .def_readonly("epoch", &PuffeRL::epoch) + .def_readonly("global_step", &PuffeRL::global_step) + .def_readonly("last_log_time", &PuffeRL::last_log_time) + .def("num_params", [](PuffeRL& self) -> int64_t { + return self.alloc_fp32.params.total_elems; + }); +} diff --git a/src/metal_kernels.mm b/src/metal_kernels.mm new file mode 100644 index 0000000000..47ae1d6598 --- /dev/null +++ b/src/metal_kernels.mm @@ -0,0 +1,1561 @@ +#import "metal_platform.h" + +#include +#include +#include +#include + +static inline void mtl_set_ptr(MetalStream *ms, const void *ptr, + uint32_t index) { + MetalContext *ctx = mtl_ctx(); + for (auto &wb : ctx->buffers) { + if ((const char *)ptr >= wb.base && + (const char *)ptr < wb.base + wb.size) { + NSUInteger offset = (NSUInteger)((const char *)ptr - wb.base); + uint64_t addr = wb.buffer.gpuAddress + offset; + if (ms->bound_addresses[index] != addr) { + [ms->arg_table setAddress:addr atIndex:index]; + ms->bound_addresses[index] = addr; + } + return; + } + } + assert(false && "Pointer not in any wrapped allocator buffer"); +} + +static inline void mtl_unwrap_ptr(const void *ptr_base) { + auto &bufs = mtl_ctx()->buffers; + bufs.erase( + std::remove_if(bufs.begin(), bufs.end(), + [ptr_base](const WrappedBuffer &wb) { + return wb.base == (const char *)ptr_base; + }), + bufs.end()); +} + +void mtl_fill_f32(float *ptr, float value, int count, cudaStream_t stream); +void mtl_copy_f32(float *dst, const float *src, int count, cudaStream_t stream); +void mtl_fill_f16(void *ptr, int count, cudaStream_t stream); +void mtl_copy_f16(void *dst, const void *src, int count, cudaStream_t stream); +void mtl_mingru_scan_forward_fp16(PrefixScan &scan, cudaStream_t stream); +void mtl_mingru_scan_backward_fp16(PrefixScan &scan, const void *grad, + const void *grad_next_state, + cudaStream_t stream); +void mtl_assemble_decoder_grad_f32_to_f16(void *grad_out, + const float *grad_logits, + const float *grad_value, int B_TT, + int od, int od1, + cudaStream_t stream); + +void puf_copy(PufTensor &dst, const PufTensor &src, cudaStream_t stream) { + assert(dst.numel() == src.numel() && "puf_copy: size mismatch"); + assert(dst.dtype_size == src.dtype_size && "puf_copy: dtype mismatch"); + bool gpu = puf_is_gpu_training() || puf_stream_has_encoder(stream); + if (gpu && dst.dtype_size == 4) { + mtl_copy_f32((float *)dst.bytes, (const float *)src.bytes, + (int)dst.numel(), stream); + } else if (gpu && dst.dtype_size == 2) { + mtl_copy_f16(dst.bytes, src.bytes, (int)dst.numel(), stream); + } else { + mtl_ensure_stream_synced(stream); + memcpy(dst.bytes, src.bytes, dst.numel() * dst.dtype_size); + } +} + +void puf_zero(PufTensor *dst, cudaStream_t stream) { + if (puf_is_gpu_training() && dst->dtype_size == 4) { + mtl_fill_f32((float *)dst->bytes, 0.0f, (int)dst->numel(), stream); + } else if (puf_is_gpu_training() && dst->dtype_size == 2) { + mtl_fill_f16(dst->bytes, (int)dst->numel(), stream); + } else { + mtl_ensure_stream_synced(stream); + memset(dst->bytes, 0, dst->numel() * dst->dtype_size); + } +} + +void puf_copy(FloatTensor &dst, const FloatTensor &src, cudaStream_t stream) { + assert(puf_numel(dst.shape) == puf_numel(src.shape) && "puf_copy: size mismatch"); + bool gpu = puf_is_gpu_training() || puf_stream_has_encoder(stream); + if (gpu) { + mtl_copy_f32(dst.data, src.data, (int)puf_numel(dst.shape), stream); + } else { + mtl_ensure_stream_synced(stream); + memcpy(dst.data, src.data, puf_numel(dst.shape) * sizeof(float)); + } +} + +void puf_zero(FloatTensor *dst, cudaStream_t stream) { + if (puf_is_gpu_training()) { + mtl_fill_f32(dst->data, 0.0f, (int)puf_numel(dst->shape), stream); + } else { + mtl_ensure_stream_synced(stream); + memset(dst->data, 0, puf_numel(dst->shape) * sizeof(float)); + } +} + +void puf_add(PufTensor &dst, const PufTensor &src, cudaStream_t stream) { + assert(dst.numel() == src.numel() && "puf_add: size mismatch"); + assert(dst.dtype_size == src.dtype_size && "puf_add: dtype mismatch"); + if (puf_is_gpu_training()) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + const char *name = (dst.dtype_size == 2) ? "add_f16" : "add_f32"; + auto pso = mtl_pipeline(name); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst.bytes, 0); + mtl_set_ptr(ms, src.bytes, 1); + int count = (int)dst.numel(); + mtl_set_params(ms, count, 2); + mtl_dispatch_1d(ms, pso, count); + } else { + assert(dst.dtype_size == 4 && "puf_add: CPU path supports f32 only"); + mtl_ensure_stream_synced(stream); + float *d = (float *)dst.bytes; + const float *s = (const float *)src.bytes; + int64_t n = dst.numel(); + for (int64_t i = 0; i < n; i++) + d[i] += s[i]; + } +} + +void puf_transpose_01(PufTensor &dst, const PufTensor &src, + cudaStream_t stream) { + int A = (int)src.shape[0], B = (int)src.shape[1]; + int C = (src.ndim() >= 3) ? (int)src.shape[2] : 1; + assert(dst.shape[0] == B && dst.shape[1] == A); + assert(dst.dtype_size == src.dtype_size); + + if (src.dtype_size == 8) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("transpose_01_u64"); + mtl_set_pso(ms, pso); + mtl_set_tensor(ms, dst, 0); + mtl_set_tensor(ms, src, 1); + struct { + int A, B, C; + } params = {A, B, C}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, A * B * C); + return; + } + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("transpose_01"); + mtl_set_pso(ms, pso); + mtl_set_tensor(ms, dst, 0); + mtl_set_tensor(ms, src, 1); + struct { + int A, B, C; + } params = {A, B, C}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, A * B * C); +} + +void puf_transpose_01(FloatTensor &dst, const FloatTensor &src, + cudaStream_t stream) { + int A = (int)src.shape[0], B = (int)src.shape[1]; + int C = (puf_ndim(src.shape) >= 3) ? (int)src.shape[2] : 1; + assert(dst.shape[0] == B && dst.shape[1] == A); + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("transpose_01"); + mtl_set_pso(ms, pso); + mtl_set_tensor(ms, dst, 0); + mtl_set_tensor(ms, src, 1); + struct { + int A, B, C; + } params = {A, B, C}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, A * B * C); +} + +void cpu_cast_u8_to_f32(float *dst, const uint8_t *src, int count) { + int i = 0; + for (; i + 16 <= count; i += 16) { + uint8x16_t v = vld1q_u8(src + i); + uint16x8_t lo16 = vmovl_u8(vget_low_u8(v)); + uint16x8_t hi16 = vmovl_u8(vget_high_u8(v)); + vst1q_f32(dst + i, vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16)))); + vst1q_f32(dst + i + 4, vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16)))); + vst1q_f32(dst + i + 8, vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16)))); + vst1q_f32(dst + i + 12, vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16)))); + } + for (; i < count; i++) dst[i] = (float)src[i]; +} + +void puf_cast_u8_to_f32(PufTensor &dst, const PufTensor &src, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("cast_u8_to_f32"); + mtl_set_pso(ms, pso); + mtl_set_tensor(ms, dst, 0); + mtl_set_tensor(ms, src, 1); + struct { + int count; + } params = {(int)src.numel()}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, (int)src.numel()); +} + +void mtl_cast_f32_to_f16(void *dst, const float *src, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("cast_f32_to_f16"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + mtl_set_params(ms, count, 2); + mtl_dispatch_1d(ms, pso, count); +} + +void mtl_cast_f16_to_f32(float *dst, const void *src, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("cast_f16_to_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + mtl_set_params(ms, count, 2); + mtl_dispatch_1d(ms, pso, count); +} + +void mtl_fill_f16(void *ptr, int count, cudaStream_t stream) { + assert(count % 2 == 0 && "mtl_fill_f16: odd count would overwrite adjacent memory"); + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("fill_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, ptr, 0); + int f32_count = count / 2; + struct { + float value; + int count; + } params = {0.0f, f32_count}; + mtl_set_params(ms, params, 1); + mtl_dispatch_1d(ms, pso, f32_count); +} + +void mtl_copy_f16(void *dst, const void *src, int count, + cudaStream_t stream) { + assert(count % 2 == 0 && "mtl_copy_f16: odd count would over-read/write adjacent memory"); + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("copy_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + int f32_count = count / 2; + mtl_set_params(ms, f32_count, 2); + mtl_dispatch_1d(ms, pso, f32_count); +} + +void mtl_fill_f32(float *ptr, float value, int count, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("fill_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, ptr, 0); + struct { + float value; + int count; + } params = {value, count}; + mtl_set_params(ms, params, 1); + mtl_dispatch_1d(ms, pso, count); +} + +void mtl_copy_f32(float *dst, const float *src, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("copy_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + mtl_set_params(ms, count, 2); + mtl_dispatch_1d(ms, pso, count); +} + +void mtl_clamp_f32(float *ptr, float lo, float hi, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("clamp_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, ptr, 0); + struct { + float lo, hi; + int count; + } params = {lo, hi, count}; + mtl_set_params(ms, params, 1); + mtl_dispatch_1d(ms, pso, count); +} + +void mtl_scale_f32(float *ptr, float scale, int count, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("scale_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, ptr, 0); + struct { + float scale; + int count; + } params = {scale, count}; + mtl_set_params(ms, params, 1); + mtl_dispatch_1d(ms, pso, count); +} + +// dst += alpha * src +void mtl_axpy_f32(float *dst, const float *src, float alpha, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("axpy_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + struct { + float alpha; + int count; + } params = {alpha, count}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, count); +} + +void mtl_nesterov_f32(float *momentum, const float *grad, float mu, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("nesterov_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, momentum, 0); + mtl_set_ptr(ms, grad, 1); + struct { + float mu; + int count; + } params = {mu, count}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, count); +} + +static float *norm_partials_buf = nullptr; +static void ensure_norm_partials() { + if (!norm_partials_buf) + norm_partials_buf = (float *)mtl_alloc_scratch(256 * sizeof(float)); +} + +void mtl_norm_f32(float *partials, const float *data, int count, + int num_blocks, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("norm_f32_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, partials, 0); + mtl_set_ptr(ms, data, 1); + struct { + int count; + } params = {count}; + mtl_set_params(ms, params, 2); + mtl_dispatch_groups(ms, pso, num_blocks, 256); +} + +void mtl_norm_reduce(float *result, const float *partials, int num_blocks, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("norm_reduce_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, result, 0); + mtl_set_ptr(ms, partials, 1); + struct { + int num_blocks; + } params = {num_blocks}; + mtl_set_params(ms, params, 2); + mtl_dispatch_groups(ms, pso, 1, 256); +} + +void mtl_clip_by_norm_f32(float *data, const float *norm_ptr, + float max_norm, float eps, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("clip_by_norm_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, data, 0); + mtl_set_ptr(ms, norm_ptr, 1); + struct { + float max_norm, eps; + int count; + } params = {max_norm, eps, count}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, count); +} + +// Matches CUDA normalize_f32_kernel: inv_norm = 1/max(sqrt(norm), eps), no cap. +void mtl_normalize_f32(float *data, const float *norm_ptr, float eps, + int count, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("normalize_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, data, 0); + mtl_set_ptr(ms, norm_ptr, 1); + struct { + float eps; + int count; + } params = {eps, count}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, count); +} + +// Convenience: compute L2 norm of grad, clip in-place if > max_norm. +// scratch must point to a float in wrapped MTLBuffer memory. +void clip_grad_norm_f32(FloatTensor &grad, float *scratch, float max_norm, + float eps, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ensure_norm_partials(); + int count = (int)puf_numel(grad.shape); + int num_blocks = (count + 255) / 256; + if (num_blocks > 256) num_blocks = 256; + mtl_norm_f32(norm_partials_buf, grad.data, count, num_blocks, stream); + mtl_barrier(ms); + mtl_norm_reduce(scratch, norm_partials_buf, num_blocks, stream); + mtl_barrier(ms); + mtl_clip_by_norm_f32(grad.data, scratch, max_norm, eps, count, stream); +} + +void mtl_transpose_f32(float *dst, const float *src, int rows, int cols, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("transpose_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + struct { + int rows, cols; + } params = {rows, cols}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, rows * cols); +} + +void mtl_assemble_decoder_grad_f32(float *grad_out, const float *grad_logits, + const float *grad_value, int B_TT, int od, + int od1, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("assemble_decoder_grad_f32"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, grad_out, 0); + mtl_set_ptr(ms, grad_logits, 1); + mtl_set_ptr(ms, grad_value, 2); + struct { + int B_TT, od, od1; + } params = {B_TT, od, od1}; + mtl_set_params(ms, params, 3); + mtl_dispatch_1d(ms, pso, B_TT * od1); +} + +// Assemble fp32 PPO gradients into fp16 decoder gradient output. +void mtl_assemble_decoder_grad_f32_to_f16(void *grad_out, + const float *grad_logits, + const float *grad_value, int B_TT, + int od, int od1, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("assemble_decoder_grad_f32_to_f16"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, grad_out, 0); + mtl_set_ptr(ms, grad_logits, 1); + mtl_set_ptr(ms, grad_value, 2); + struct { int B_TT, od, od_plus_1; } params = {B_TT, od, od1}; + mtl_set_params(ms, params, 3); + mtl_dispatch_1d(ms, pso, B_TT * od1); +} + +void mtl_sum_rows_to_f32(float *dst, const float *src, int rows, int cols, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("sum_rows_to_f32_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, dst, 0); + mtl_set_ptr(ms, src, 1); + struct { + int rows, cols; + } params = {rows, cols}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, cols); +} + +void mtl_mingru_gate(float *out, float *next_state, const float *combined, + const float *state_in, const float *x_in, int H, int B, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("mingru_gate_inference"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, out, 0); + mtl_set_ptr(ms, next_state, 1); + mtl_set_ptr(ms, combined, 2); + mtl_set_ptr(ms, state_in, 3); + mtl_set_ptr(ms, x_in, 4); + struct { + int H, B; + } params = {H, B}; + mtl_set_params(ms, params, 5); + mtl_dispatch_1d(ms, pso, B * H); +} + +// Shared scan dispatch: binds all buffers and dispatches the named kernel. +static void dispatch_scan_forward(const char *kernel_name, PrefixScan &scan, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline(kernel_name); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, scan.out.data, 0); + mtl_set_ptr(ms, scan.next_state.data, 1); + mtl_set_ptr(ms, scan.a_star.data, 2); + mtl_set_ptr(ms, scan.s_vals.data, 3); + mtl_set_ptr(ms, scan.log_values_buf.data, 4); + mtl_set_ptr(ms, scan.combined_ptr, 5); + mtl_set_ptr(ms, scan.state_ptr, 6); + mtl_set_ptr(ms, scan.input_ptr, 7); + struct { int T_seq, H, B; } params = {scan.T, scan.H, scan.B}; + mtl_set_params(ms, params, 8); + mtl_dispatch_1d(ms, pso, scan.B * scan.H); +} + +static void dispatch_scan_backward(const char *kernel_name, PrefixScan &scan, + const void *grad, const void *grad_next_state, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline(kernel_name); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, scan.grad_combined.data, 0); + mtl_set_ptr(ms, scan.grad_state.data, 1); + mtl_set_ptr(ms, scan.grad_input.data, 2); + mtl_set_ptr(ms, grad, 3); + mtl_set_ptr(ms, grad_next_state, 4); + mtl_set_ptr(ms, scan.combined_ptr, 5); + mtl_set_ptr(ms, scan.state_ptr, 6); + mtl_set_ptr(ms, scan.input_ptr, 7); + mtl_set_ptr(ms, scan.a_star.data, 8); + mtl_set_ptr(ms, scan.s_vals.data, 9); + mtl_set_ptr(ms, scan.log_values_buf.data, 10); + struct { int T_seq, H, B; } params = {scan.T, scan.H, scan.B}; + mtl_set_params(ms, params, 11); + mtl_dispatch_1d(ms, pso, scan.B * scan.H); +} + +void mtl_mingru_scan_forward(PrefixScan &scan, cudaStream_t stream) { + dispatch_scan_forward("mingru_scan_forward_checkpointed", scan, stream); +} +void mtl_mingru_scan_backward(PrefixScan &scan, const float *grad, + const float *grad_next_state, cudaStream_t stream) { + dispatch_scan_backward("mingru_scan_backward_checkpointed", scan, grad, grad_next_state, stream); +} +void mtl_mingru_scan_forward_fp16(PrefixScan &scan, cudaStream_t stream) { + dispatch_scan_forward("mingru_scan_forward_checkpointed_fp16", scan, stream); +} +void mtl_mingru_scan_backward_fp16(PrefixScan &scan, const void *grad, + const void *grad_next_state, cudaStream_t stream) { + dispatch_scan_backward("mingru_scan_backward_checkpointed_fp16", scan, grad, grad_next_state, stream); +} + +// Dispatch GPU sampling kernel on the current command buffer (no sync). +// Call BEFORE ensure_gpu_synced so sampling runs in the same command buffer +// as the forward pass. +void mtl_sample_logits_dispatch_to( + PrecisionTensor &dec_out, IntTensor &act_sizes_puf, + float *action_out_f32, float *logprobs, float *value_out, + const float *action_mask, int mask_stride, + uint64_t seed, uint32_t *offset_ptr, cudaStream_t stream) { + + int B = (int)dec_out.shape[0]; + int fused_cols = (int)dec_out.shape[1]; + int num_atns = (int)puf_numel(act_sizes_puf.shape); + int A_total = fused_cols - 1; + + assert(action_out_f32 && "sampling destination buffer must be allocated"); + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("sample_logits_kernel"); + mtl_set_pso(ms, pso); + + mtl_set_ptr(ms, action_out_f32, 0); + mtl_set_ptr(ms, logprobs, 1); + mtl_set_ptr(ms, value_out, 2); + mtl_set_ptr(ms, dec_out.data, 3); + mtl_set_ptr(ms, dec_out.data, 4); // dummy logstd (discrete only) + // value column is the last fused decoder column. + mtl_set_ptr(ms, dec_out.data + (fused_cols - 1), 5); + mtl_set_ptr(ms, act_sizes_puf.data, 6); + uint32_t offset_snapshot = *offset_ptr; + *offset_ptr = offset_snapshot + 1u; + + struct { + uint64_t seed; + uint32_t offset; + int num_atns; + int num_atns_total; + int B; + int logits_stride; + int logstd_stride; + int value_stride; + int is_continuous; + int mask_stride; + } params = {seed, offset_snapshot, num_atns, A_total, B, + fused_cols, 0, fused_cols, 0, mask_stride}; + mtl_set_params(ms, params, 7); + + mtl_set_ptr(ms, (void *)action_mask, 8); + + mtl_dispatch_1d(ms, pso, B); +} + +// Expand f32 GPU actions to f64 (call after ensure_gpu_synced). +void mtl_sample_logits_expand(const float *f32, double *f64, int count) { + for (int i = 0; i < count; i++) f64[i] = (double)f32[i]; +} + +// Recompute logprobs from CPU-produced logits using GPU fast::exp. +// Dispatches on the given stream, no sync. The tiny kernel (B threads) +// completes in ~1us and ensures old_logp matches PPO training precision. +void mtl_recompute_logprobs( + float *logprobs, const float *logits, const float *actions_f32, + const int *act_sizes, const float *action_mask, int mask_stride, + int B, int num_atns, int fused_cols, cudaStream_t stream) { + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("recompute_logprobs_kernel"); + mtl_set_pso(ms, pso); + + mtl_set_ptr(ms, logprobs, 0); + mtl_set_ptr(ms, (void *)logits, 1); + mtl_set_ptr(ms, (void *)actions_f32, 2); + mtl_set_ptr(ms, (void *)act_sizes, 3); + mtl_set_ptr(ms, (void *)action_mask, 4); + + struct { + int B, num_atns, logits_stride, mask_stride; + } params = {B, num_atns, fused_cols, mask_stride}; + mtl_set_params(ms, params, 5); + + mtl_dispatch_1d(ms, pso, B); +} + +static float *ppo_partials_buf = nullptr; +static int ppo_partials_capacity = 0; + +void ppo_loss_fwd_bwd(PufTensor &dec_out, PufTensor &logstd, TrainGraph &graph, + IntTensor &act_sizes, FloatTensor &losses_acc, + float clip_coef, float vf_clip_coef, float vf_coef, + float ent_coef, PPOBuffersPuf &bufs, bool is_continuous, + const float *ext_mask_ptr, int ext_mask_stride, + const FloatTensor *full_batch_adv, + cudaStream_t stream) { + int N = (int)dec_out.shape[0], T = (int)dec_out.shape[1]; + int fused_cols = (int)dec_out.shape[2]; + int num_atns = (int)puf_numel(act_sizes.shape); + int A_total = fused_cols - 1; + int total = N * T; + + int logits_stride_n = T * fused_cols; + int logits_stride_t = fused_cols; + int logits_stride_a = 1; + int values_stride_n = T * fused_cols; + int values_stride_t = fused_cols; + + MetalStream *ms = mtl_resolve_stream(stream); + + { + const float *adv_data = full_batch_adv ? full_batch_adv->data : graph.mb_advantages.data; + int64_t adv_count = full_batch_adv ? puf_numel(full_batch_adv->shape) : puf_numel(graph.mb_advantages.shape); + ms->compute_encoder(); + auto pso = mtl_pipeline("var_mean_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, adv_data, 0); + mtl_set_ptr(ms, bufs.adv_scratch.data, 1); + mtl_set_ptr(ms, bufs.adv_scratch.data + 1, 2); + struct { int count; } params = {(int)adv_count}; + mtl_set_params(ms, params, 3); + mtl_dispatch_groups(ms, pso, 1, 256); + } + + int ppo_threads = 256; + int ppo_grid = (total + ppo_threads - 1) / ppo_threads; + int ppo_partials_needed = ppo_grid * (LOSS_N + 1); + if (!ppo_partials_buf || ppo_partials_needed > ppo_partials_capacity) { + if (ppo_partials_buf) { + mtl_unwrap_ptr(ppo_partials_buf); + free(ppo_partials_buf); + } + ppo_partials_capacity = ppo_partials_needed; + ppo_partials_buf = (float *)mtl_alloc_scratch(ppo_partials_capacity * sizeof(float)); + } + + puf_zero(&bufs.loss_output, stream); + + float *ppo_act_f32 = graph.mb_actions.data; + + int input_size = (int)graph.mb_obs.shape[2]; + const float *mask_ptr; + int mask_stride; + if (ext_mask_ptr) { + mask_ptr = ext_mask_ptr; + mask_stride = ext_mask_stride; + } else { + int mask_offset = input_size - A_total; + mask_ptr = graph.mb_obs.data + mask_offset; + mask_stride = input_size; + } + + mtl_barrier(ms); + + { + ms->compute_encoder(); + auto pso = mtl_pipeline("ppo_loss_fwd_bwd_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, ppo_partials_buf, 0); + mtl_set_ptr(ms, bufs.grad_logits.data, 1); + mtl_set_ptr(ms, is_continuous ? bufs.grad_logstd.data + : bufs.grad_logits.data, + 2); + mtl_set_ptr(ms, bufs.grad_values.data, 3); + mtl_set_ptr(ms, dec_out.bytes, 4); + mtl_set_ptr(ms, is_continuous ? logstd.bytes : dec_out.bytes, 5); + mtl_set_ptr(ms, (float *)dec_out.bytes + A_total, 6); + mtl_set_ptr(ms, ppo_act_f32, 7); + mtl_set_ptr(ms, graph.mb_logprobs.data, 8); + mtl_set_ptr(ms, graph.mb_advantages.data, 9); + mtl_set_ptr(ms, graph.mb_prio.data, 10); + mtl_set_ptr(ms, graph.mb_values.data, 11); + mtl_set_ptr(ms, graph.mb_returns.data, 12); + mtl_set_ptr(ms, bufs.adv_scratch.data + 1, 13); + mtl_set_ptr(ms, bufs.adv_scratch.data, 14); + mtl_set_ptr(ms, act_sizes.data, 15); + + struct { + int num_atns; + float clip_coef, vf_clip_coef, vf_coef, ent_coef; + int T_seq, A_total, N; + int logits_stride_n, logits_stride_t, logits_stride_a; + int values_stride_n, values_stride_t; + int is_continuous; + int num_atns_total; + int mask_stride_val; + } params = {num_atns, + clip_coef, + vf_clip_coef, + vf_coef, + ent_coef, + T, + A_total, + N, + logits_stride_n, + logits_stride_t, + logits_stride_a, + values_stride_n, + values_stride_t, + is_continuous ? 1 : 0, + A_total, + mask_stride}; + mtl_set_params(ms, params, 16); + mtl_set_ptr(ms, (void *)mask_ptr, 17); + mtl_set_ptr(ms, graph.mb_ratio.data, 18); + mtl_set_ptr(ms, graph.mb_newvalue.data, 19); + mtl_dispatch_groups(ms, pso, ppo_grid, ppo_threads); + } + + mtl_barrier(ms); + + { + ms->compute_encoder(); + auto pso = mtl_pipeline("ppo_loss_reduce_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, bufs.loss_output.data, 0); + mtl_set_ptr(ms, losses_acc.data, 1); + mtl_set_ptr(ms, ppo_partials_buf, 2); + struct { + int num_blocks; + } params = {ppo_grid}; + mtl_set_params(ms, params, 3); + + mtl_dispatch_groups(ms, pso, 1, LOSS_N + 1); + } +} + +// Scatter mb_ratio and mb_newvalue from minibatch back into rollout buffers. +// Called after ppo_loss_fwd_bwd so subsequent minibatches see updated values. +void mtl_scatter_ppo_outputs(TrainGraph& graph, RolloutBuf& rollouts, + const int64_t* idx, cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + int num_idx = (int)graph.mb_ratio.shape[0]; + + auto scatter = [&](FloatTensor& dst, FloatTensor& src) { + ms->compute_encoder(); + auto pso = mtl_pipeline("index_copy_kernel"); + mtl_set_pso(ms, pso); + int row_bytes = (int)(src.shape[1] * sizeof(float)); + mtl_set_ptr(ms, dst.data, 0); + mtl_set_ptr(ms, (void*)idx, 1); + mtl_set_ptr(ms, src.data, 2); + struct { int num_idx; int row_bytes; } p = {num_idx, row_bytes}; + mtl_set_params(ms, p, 3); + mtl_dispatch_groups(ms, pso, (num_idx + 255) / 256, 256); + }; + + scatter(rollouts.ratio, graph.mb_ratio); + scatter(rollouts.values, graph.mb_newvalue); +} + +void puff_advantage(FloatTensor &values, FloatTensor &rewards, + FloatTensor &dones, FloatTensor &importance, + FloatTensor &advantages, float gamma, float lambda, + float rho_clip, float c_clip, cudaStream_t stream) { + int num_steps = (int)values.shape[0], horizon = (int)values.shape[1]; + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("puff_advantage_kernel"); + mtl_set_pso(ms, pso); + mtl_set_tensor(ms, values, 0); + mtl_set_tensor(ms, rewards, 1); + mtl_set_tensor(ms, dones, 2); + mtl_set_tensor(ms, importance, 3); + mtl_set_tensor(ms, advantages, 4); + struct { + float gamma, lambda, rho_clip, c_clip; + int num_steps, horizon; + } params = {gamma, lambda, rho_clip, c_clip, num_steps, horizon}; + mtl_set_params(ms, params, 5); + int blocks = (num_steps + 255) / 256; + mtl_dispatch_groups(ms, pso, blocks, 256); +} + +// Phase 1: compute normalized per-segment probabilities on GPU. +void prio_precompute(FloatTensor &advantages, float prio_alpha, + PrioBuffers &bufs, cudaStream_t stream) { + int S = (int)advantages.shape[0], T = (int)advantages.shape[1]; + MetalStream *ms = mtl_resolve_stream(stream); + + // Prio adv reduction + { + ms->compute_encoder(); + auto pso = mtl_pipeline("prio_adv_reduction_kernel"); + mtl_set_pso(ms, pso); + mtl_set_tensor(ms, advantages, 0); + mtl_set_ptr(ms, bufs.prio_probs.data, 1); + struct { + float prio_alpha; + int stride; + } params = {prio_alpha, T}; + mtl_set_params(ms, params, 2); + mtl_dispatch_groups(ms, pso, S, 32); + } + mtl_barrier(ms); // reduction -> normalize + + // Normalize + { + ms->compute_encoder(); + auto pso = mtl_pipeline("prio_normalize_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, bufs.prio_probs.data, 0); + struct { + int S; + } params = {S}; + mtl_set_params(ms, params, 1); + mtl_dispatch_groups(ms, pso, 1, 256); + } + mtl_barrier(ms); // normalize -> sample +} + +// Phase 2 (per-minibatch): GPU sampling from prio_probs + GPU importance weights. +void prio_sample(int minibatch_segments, int total_agents, + float anneal_beta, PrioBuffers &bufs, uint64_t seed, + uint32_t *offset_ptr, cudaStream_t stream) { + int S = (int)bufs.prio_probs.shape[0]; + MetalStream *ms = mtl_resolve_stream(stream); + + // prio_probs -> sampled indices + ms->compute_encoder(); + { + auto pso = mtl_pipeline("prio_sample_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, bufs.idx.data, 0); + mtl_set_ptr(ms, bufs.prio_probs.data, 1); + uint32_t base_offset = *offset_ptr; + *offset_ptr = base_offset + (uint32_t)minibatch_segments; + struct { + uint64_t seed; + uint32_t base_offset; + int total_segments; + int minibatch_segments; + } params = {seed, base_offset, S, minibatch_segments}; + mtl_set_params(ms, params, 2); + mtl_dispatch_1d(ms, pso, minibatch_segments); + } + + mtl_barrier(ms); // sampled idx -> imp-weights + + // sampled indices + prio_probs -> importance weights + ms->compute_encoder(); + { + auto pso = mtl_pipeline("prio_imp_weights_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, bufs.idx.data, 0); + mtl_set_ptr(ms, bufs.prio_probs.data, 1); + mtl_set_ptr(ms, bufs.mb_prio.data, 2); + struct { + int total_agents; + float anneal_beta; + int minibatch_segments; + } params = {total_agents, anneal_beta, minibatch_segments}; + mtl_set_params(ms, params, 3); + // single threadgroup — max-reduction assumes all threads share threadgroup memory + mtl_dispatch_groups(ms, pso, 1, 256); + } +} + +void mtl_select_copy(RolloutBuf &rollouts, TrainGraph &graph, + const int64_t *idx, const float *advantages, + const float *mb_prio, int mb_segs, + void *fp16_obs_out, cudaStream_t stream) { + int obs_row_bytes = (int)(puf_numel(rollouts.observations.shape) / + rollouts.observations.shape[0]) * + (int)sizeof(float); + int act_row_bytes = (int)(puf_numel(rollouts.actions.shape) / + rollouts.actions.shape[0]) * + (int)sizeof(float); + int lp_row_bytes = (int)(puf_numel(rollouts.logprobs.shape) / + rollouts.logprobs.shape[0]) * + (int)sizeof(float); + int horizon = (int)rollouts.values.shape[1]; + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("select_copy_kernel"); + mtl_set_pso(ms, pso); + + mtl_set_ptr(ms, graph.mb_obs.data, 0); + mtl_set_ptr(ms, graph.mb_actions.data, 1); + mtl_set_ptr(ms, graph.mb_logprobs.data, 2); + mtl_set_ptr(ms, graph.mb_values.data, 3); + mtl_set_ptr(ms, graph.mb_advantages.data, 4); + mtl_set_ptr(ms, graph.mb_returns.data, 5); + mtl_set_ptr(ms, graph.mb_prio.data, 6); + mtl_set_ptr(ms, rollouts.observations.data, 7); + mtl_set_ptr(ms, rollouts.actions.data, 8); + mtl_set_ptr(ms, rollouts.logprobs.data, 9); + mtl_set_ptr(ms, rollouts.values.data, 10); + mtl_set_ptr(ms, advantages, 11); + mtl_set_ptr(ms, idx, 12); + mtl_set_ptr(ms, mb_prio, 13); + + struct { + int obs_row_bytes, act_row_bytes, lp_row_bytes, horizon; + } params = {obs_row_bytes, act_row_bytes, lp_row_bytes, horizon}; + mtl_set_params(ms, params, 14); + + mtl_set_ptr(ms, fp16_obs_out, 15); + + // 2D dispatch: (mb_segs, 5) threadgroups, 256 threads each + [ms->enc dispatchThreadgroups:MTLSizeMake(mb_segs, 5, 1) + threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; +} + +void mtl_muon_weight_update(float *weights, const float *updates, + const float *lr_ptr, float weight_decay, + float scale, int count, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + auto pso = mtl_pipeline("muon_weight_update_kernel"); + mtl_set_pso(ms, pso); + mtl_set_ptr(ms, weights, 0); + mtl_set_ptr(ms, updates, 1); + mtl_set_ptr(ms, lr_ptr, 2); + struct { + int count; + float weight_decay; + float scale; + } params = {count, weight_decay, scale}; + mtl_set_params(ms, params, 3); + mtl_dispatch_1d(ms, pso, count); +} + +// ============================================================================ +// Kaiming uniform init (CPU-side, matches CUDA puf_kaiming_init) +// +// U(-bound, bound) where bound = gain / sqrt(fan_in). +// For 2D weight [rows, cols], fan_in = cols. +// Runs once at model init — not perf-critical. +// ============================================================================ + +void puf_kaiming_init(PufTensor &dst, float gain, uint64_t seed, + cudaStream_t stream) { + mtl_ensure_stream_synced(stream); + + assert(dst.ndim() == 2); + int64_t rows = dst.shape[0], cols = dst.shape[1]; + assert(rows > 0 && cols > 0); + + float bound = gain / std::sqrt((float)cols); + int64_t n = rows * cols; + float *dst_f = (float *)dst.bytes; + + std::mt19937_64 rng(seed); + std::uniform_real_distribution uniform(-bound, bound); + for (int64_t i = 0; i < n; i++) + dst_f[i] = uniform(rng); +} + +void muon_init(Muon *m, Allocator *param_alloc, FloatTensor weight_buffer, + double lr_val, double momentum, double weight_decay, + int ns_iters, Allocator &alloc) { + m->momentum = momentum; + m->weight_decay = weight_decay; + m->ns_iters = (ns_iters > 0 && ns_iters <= 5) ? ns_iters : 5; + m->lr_val_init = (float)lr_val; + m->lr_ptr = nullptr; + m->lr_derived_ptr = nullptr; + m->wb_puf = weight_buffer; + m->param_alloc = param_alloc; + m->ns = {}; + int64_t n = puf_numel(m->wb_puf.shape); + m->lr_puf = {.shape = {1}}; + m->lr_derived_puf = {.shape = {2}}; + m->mb_puf = {.shape = {n}}; + m->gc_puf = {.shape = {n}}; + m->up_puf = {.shape = {n}}; + alloc_register(&alloc, &m->lr_puf); + alloc_register(&alloc, &m->lr_derived_puf); + alloc_register(&alloc, &m->mb_puf); + alloc_register(&alloc, &m->gc_puf); + alloc_register(&alloc, &m->up_puf); + + int64_t max_M = 0, max_N = 0; + for (auto &e : param_alloc->regs) { + int nd = puf_ndim(e.shape); + if (nd >= 2) { + int64_t R = e.shape[0], C = puf_numel(e.shape) / R; + max_M = std::max(max_M, std::min(R, C)); + max_N = std::max(max_N, std::max(R, C)); + } + } + if (max_M > 0) { + m->ns.max_M = max_M; + m->ns.max_N = max_N; + int ns_esz = PRECISION_SIZE; // always 4 on Metal + m->ns.x = {.shape = {max_M, max_N}, .dtype_size = ns_esz}; + m->ns.A = {.shape = {max_M, max_M}, .dtype_size = ns_esz}; + m->ns.gram = {.shape = {max_M, max_M}, .dtype_size = ns_esz}; + m->ns.tmp = {.shape = {max_M, max_N}, .dtype_size = ns_esz}; + m->ns.result_f32 = {.shape = {max_M, max_N}, .dtype_size = ns_esz}; + m->ns_norm_puf = {.shape = {1}}; + alloc_register(&alloc, &m->ns.x); + alloc_register(&alloc, &m->ns.A); + alloc_register(&alloc, &m->ns.gram); + alloc_register(&alloc, &m->ns.tmp); + alloc_register(&alloc, &m->ns.result_f32); + alloc_register(&alloc, &m->ns_norm_puf); + } +} + +void muon_step(Muon *m, cudaStream_t stream) { + assert(m->wb_puf.data != nullptr && "muon_step: weights not initialized"); + MetalStream *ms = mtl_resolve_stream(stream); + + // No NCCL on Metal (single GPU) + + // Nesterov momentum update: mb = mu * mb + gc + mtl_nesterov_f32(m->mb_puf.data, m->gc_puf.data, + (float)m->momentum, (int)puf_numel(m->mb_puf.shape), stream); + mtl_barrier(ms); + + // Zero update buffer + puf_zero(&m->up_puf, stream); + mtl_barrier(ms); + + int64_t offset = 0; + for (auto &e : m->param_alloc->regs) { + float *gc_ptr = m->gc_puf.data + offset; + float *up_ptr = m->up_puf.data + offset; + int64_t R = e.shape[0]; + int64_t C = puf_numel(e.shape) / std::max(1, R); + // NS orthogonalization for 2D+ params with M >= 2. + // 1-row tensors (e.g. value weight) use direct gradient update. + if (puf_ndim(e.shape) >= 2 && std::min(R, C) >= 2) { + bool transposed_flag = R > C; + int64_t M = transposed_flag ? C : R; + int64_t N = transposed_flag ? R : C; + + PufTensor G_f32 = {.bytes = (char *)gc_ptr, + .shape = {R, C}, + .dtype_size = 4}; + PufTensor x = ns_slice(m->ns.x, M, N); + PufTensor A = ns_slice(m->ns.A, M, M); + PufTensor gram = ns_slice(m->ns.gram, M, M); + PufTensor tmp = ns_slice(m->ns.tmp, M, N); + + // fp32 path: copy or transpose into x + if (transposed_flag) { + mtl_transpose_f32((float *)x.bytes, (const float *)G_f32.bytes, (int)R, + (int)C, stream); + } else { + puf_copy(x, G_f32, stream); + } + mtl_barrier(ms); + + // Normalize x to unit Frobenius norm (matches CUDA models.cu:1219, no cap) + ensure_norm_partials(); + { + int nblk = std::min((int)((x.numel() + 255) / 256), 256); + mtl_norm_f32(norm_partials_buf, (const float *)x.bytes, + (int)x.numel(), nblk, stream); + mtl_barrier(ms); + mtl_norm_reduce(m->ns.norm_ptr, norm_partials_buf, nblk, stream); + mtl_barrier(ms); + } + mtl_normalize_f32((float *)x.bytes, m->ns.norm_ptr, 1e-7f, + (int)x.numel(), stream); + mtl_barrier(ms); + + // Newton-Schulz iterations + for (int i = 0; i < m->ns_iters; ++i) { + int ci = i * 4 / (m->ns_iters - 1 + (m->ns_iters == 1)); + float a = (float)ns_coeffs[ci][0], b = (float)ns_coeffs[ci][1], + c = (float)ns_coeffs[ci][2]; + PufTensor &src = (i % 2 == 0) ? x : tmp; + PufTensor &dst = (i % 2 == 0) ? tmp : x; + puf_mm(src, src, A, stream); + mtl_barrier(ms); + puf_copy(gram, A, stream); + mtl_barrier(ms); + puf_addmm_nn(A, A, gram, c, b, stream); + mtl_barrier(ms); + puf_copy(dst, src, stream); + mtl_barrier(ms); + puf_addmm_nn(gram, src, dst, 1.0f, a, stream); + mtl_barrier(ms); + } + + PufTensor &result_precision = (m->ns_iters % 2 == 0) ? x : tmp; + + // Scale matches CUDA models.cu:1233: sqrt(max(1.0, R/C)). + // For tall matrices (R>C), scale up by sqrt(R/C) to compensate for + // the transposition used in NS iteration. + float scale = (float)std::sqrt(std::max(1.0, (double)R / (double)C)); + if (scale != 1.0f) { + mtl_scale_f32((float *)result_precision.bytes, scale, + (int)result_precision.numel(), stream); + mtl_barrier(ms); + } + + PufTensor out_f32 = {.bytes = (char *)up_ptr, + .shape = {R, C}, + .dtype_size = 4}; + if (transposed_flag) { + mtl_transpose_f32((float *)out_f32.bytes, + (const float *)result_precision.bytes, (int)M, + (int)N, stream); + } else { + puf_copy(out_f32, result_precision, stream); + } + mtl_barrier(ms); + } else { + // 1D and tiny matrix params: use direct gradient update. + int64_t n = puf_numel(e.shape); + PufTensor src_puf = {.bytes = (char *)gc_ptr, + .shape = {n}, + .dtype_size = 4}; + PufTensor dst_puf = {.bytes = (char *)up_ptr, + .shape = {n}, + .dtype_size = 4}; + puf_copy(dst_puf, src_puf, stream); + mtl_barrier(ms); + } + offset += puf_numel(e.shape); + } + + // Apply weight update: w = w * (1 - lr*wd) - lr * up + // Scale is already baked into up_puf during NS loop, so pass scale=1.0 here. + mtl_muon_weight_update(m->wb_puf.data, m->up_puf.data, + m->lr_ptr, (float)m->weight_decay, 1.0f, + (int)puf_numel(m->wb_puf.shape), stream); + mtl_barrier(ms); +} + +static PrecisionTensor encoder_forward(void *w, void *activations, + PrecisionTensor input, cudaStream_t stream) { + EncoderWeights *ew = (EncoderWeights *)w; + EncoderActivations *a = (EncoderActivations *)activations; + MetalStream *ms = mtl_resolve_stream(stream); + if (a->saved_input.data) { + PufTensor dst = to_puf(a->saved_input), src = to_puf(input); + puf_copy(dst, src, stream); + } + + PufTensor inp = to_puf(input), wt = to_puf(ew->weight), out = to_puf(a->out); + puf_mm(inp, wt, out, stream); + mtl_barrier(ms); + + return a->out; +} + +static void encoder_backward(void *w, void *activations, PrecisionTensor grad, + cudaStream_t stream) { + EncoderActivations *a = (EncoderActivations *)activations; + PufTensor g = to_puf(grad), si = to_puf(a->saved_input), wg = to_puf(a->wgrad); + puf_mm_tn(g, si, wg, stream); +} + +static void encoder_init_weights(void *w, uint64_t *seed, + cudaStream_t stream) { + EncoderWeights *ew = (EncoderWeights *)w; + PufTensor wt = {.bytes = (char *)ew->weight.data, + .shape = {ew->out_dim, ew->in_dim}, + .dtype_size = PRECISION_SIZE}; + puf_kaiming_init(wt, std::sqrt(2.0f), (*seed)++, stream); +} + +static void encoder_reg_params(void *w, Allocator *alloc, int esz) { + EncoderWeights *ew = (EncoderWeights *)w; + ew->weight = {.shape = {ew->out_dim, ew->in_dim}, .dtype_size = esz}; + alloc_register(alloc, &ew->weight); +} + +static void encoder_reg_train(void *w, void *activations, + Allocator *acts, Allocator *grads, + int B_TT, int precision) { + EncoderWeights *ew = (EncoderWeights *)w; + EncoderActivations *a = (EncoderActivations *)activations; + *a = (EncoderActivations){ + .out = {.shape = {B_TT, ew->out_dim}, .dtype_size = precision}, + .saved_input = {.shape = {B_TT, ew->in_dim}, .dtype_size = precision}, + .wgrad = {.shape = {ew->out_dim, ew->in_dim}, .dtype_size = precision}, + }; + alloc_register(acts, &a->out); + alloc_register(acts, &a->saved_input); + alloc_register(grads, &a->wgrad); +} + +static void encoder_reg_rollout(void *w, void *activations, + Allocator *alloc, int B) { + EncoderWeights *ew = (EncoderWeights *)w; + EncoderActivations *a = (EncoderActivations *)activations; + a->out = {.shape = {B, ew->out_dim}, .dtype_size = PRECISION_SIZE}; + alloc_register(alloc, &a->out); +} + +static PrecisionTensor decoder_forward(void *w, void *activations, + PrecisionTensor input, + cudaStream_t stream) { + DecoderWeights *dw = (DecoderWeights *)w; + DecoderActivations *a = (DecoderActivations *)activations; + MetalStream *ms = mtl_resolve_stream(stream); + if (a->saved_input.data) { + PufTensor dst = to_puf(a->saved_input), src = to_puf(input); + puf_copy(dst, src, stream); + } + PufTensor inp = to_puf(input), wt = to_puf(dw->weight), out = to_puf(a->out); + puf_mm(inp, wt, out, stream); + mtl_barrier(ms); + return a->out; +} + +static PrecisionTensor decoder_backward(void *w, void *activations, + FloatTensor grad_logits, + FloatTensor grad_logstd, + FloatTensor grad_value, + cudaStream_t stream) { + DecoderWeights *dw = (DecoderWeights *)w; + DecoderActivations *a = (DecoderActivations *)activations; + int B_TT = (int)a->saved_input.shape[0]; + int od = dw->output_dim, od1 = od + 1; + + MetalStream *ms = mtl_resolve_stream(stream); + + if (a->grad_out.dtype_size == 2) { + mtl_assemble_decoder_grad_f32_to_f16(a->grad_out.data, + grad_logits.data, + grad_value.data, B_TT, od, + od1, stream); + } else { + mtl_assemble_decoder_grad_f32((float *)a->grad_out.data, + grad_logits.data, + grad_value.data, B_TT, od, + od1, stream); + } + mtl_barrier(ms); // assemble writes grad_out, GEMMs read it + + // weight grad: grad_out^T @ saved_input + PufTensor go = to_puf(a->grad_out), si = to_puf(a->saved_input), wg = to_puf(a->wgrad); + puf_mm_tn(go, si, wg, stream); + + if (dw->continuous && grad_logstd.data != nullptr) { + mtl_sum_rows_to_f32((float *)a->logstd_scratch.data, + grad_logstd.data, B_TT, + dw->output_dim, stream); + } + + // grad -> hidden: grad_out @ weight + PufTensor wt = to_puf(dw->weight), gi = to_puf(a->grad_input); + puf_mm_nn(go, wt, gi, stream); + mtl_barrier(ms); // grad_input consumed by mingru_backward + return a->grad_input; +} + +static void decoder_init_weights(void *w, uint64_t *seed, + cudaStream_t stream) { + DecoderWeights *dw = (DecoderWeights *)w; + int od1 = dw->output_dim + 1; + PufTensor wt = {.bytes = (char *)dw->weight.data, + .shape = {od1, dw->hidden_dim}, + .dtype_size = PRECISION_SIZE}; + puf_kaiming_init(wt, 1.0f, (*seed)++, stream); +} + +static void decoder_reg_params(void *w, Allocator *alloc, int esz) { + DecoderWeights *dw = (DecoderWeights *)w; + int od = dw->output_dim; + int H = dw->hidden_dim; + dw->weight = {.shape = {od + 1, H}, .dtype_size = esz}; + alloc_register(alloc, &dw->weight); + if (dw->continuous) { + dw->logstd = {.shape = {1, od}, .dtype_size = esz}; + alloc_register(alloc, &dw->logstd); + } +} + +static void decoder_reg_train(void *w, void *activations, + Allocator *acts, Allocator *grads, + int B_TT, int precision) { + DecoderWeights *dw = (DecoderWeights *)w; + DecoderActivations *a = (DecoderActivations *)activations; + int od1 = dw->output_dim + 1; + *a = (DecoderActivations){ + .out = {.shape = {B_TT, od1}, .dtype_size = precision}, + .grad_out = {.shape = {B_TT, od1}, .dtype_size = precision}, + .saved_input = {.shape = {B_TT, dw->hidden_dim}, .dtype_size = precision}, + .grad_input = {.shape = {B_TT, dw->hidden_dim}, .dtype_size = precision}, + .wgrad = {.shape = {od1, dw->hidden_dim}, .dtype_size = precision}, + .logstd_scratch = {.shape = {1, dw->output_dim}, .dtype_size = precision}, + }; + alloc_register(acts, &a->out); + alloc_register(acts, &a->saved_input); + // grad registration order MUST match param registration order in reg_params + alloc_register(acts, &a->grad_out); + alloc_register(acts, &a->grad_input); + alloc_register(grads, &a->wgrad); + if (dw->continuous) + alloc_register(grads, &a->logstd_scratch); +} + +static void decoder_reg_rollout(void *w, void *activations, + Allocator *alloc, int B) { + DecoderWeights *dw = (DecoderWeights *)w; + DecoderActivations *a = (DecoderActivations *)activations; + int od1 = dw->output_dim + 1; + *a = {}; + a->out = {.shape = {B, od1}, .dtype_size = PRECISION_SIZE}; + alloc_register(alloc, &a->out); +} + +static void mingru_init_weights(void *w, uint64_t *seed, cudaStream_t stream) { + MinGRUWeights *m = (MinGRUWeights *)w; + for (int i = 0; i < m->num_layers; i++) { + PufTensor w2d = {.bytes = (char *)m->weights[i].data, + .shape = {3 * m->hidden, m->hidden}, + .dtype_size = PRECISION_SIZE}; + puf_kaiming_init(w2d, 1.0f, (*seed)++, stream); + } +} + +static void mingru_reg_params(void *w, Allocator *alloc, int esz) { + MinGRUWeights *m = (MinGRUWeights *)w; + for (int i = 0; i < m->num_layers; i++) { + m->weights[i] = { + .shape = {3 * m->hidden, m->hidden}, + .dtype_size = esz}; + alloc_register(alloc, &m->weights[i]); + } +} + +static void mingru_reg_train(void *w, void *activations, Allocator *acts, + Allocator *grads, int B_TT, int precision) { + MinGRUWeights *m = (MinGRUWeights *)w; + MinGRUActivations *a = (MinGRUActivations *)activations; + int H = m->hidden, TT = m->horizon, B = B_TT / TT; + a->num_layers = m->num_layers; + a->saved_inputs.resize(m->num_layers); + a->scan_bufs.resize(m->num_layers); + a->combined_bufs.resize(m->num_layers); + a->wgrad_scratch.resize(m->num_layers); + a->grad_input_buf = {.shape = {B_TT, H}, .dtype_size = precision}; + a->grad_next_state = {.shape = {B, 1, H}, .dtype_size = precision}; + alloc_register(acts, &a->grad_input_buf); + alloc_register(acts, &a->grad_next_state); + for (int i = 0; i < m->num_layers; i++) { + a->scan_bufs[i] = { + .B = B, + .T = TT, + .H = H, + .a_star = {.shape = {B, TT + 1, H}}, + .s_vals = {.shape = {B, TT + 1, H}}, + .log_values_buf = {.shape = {B, TT + 1, H}}, + .out = {.shape = {B, TT, H}, .dtype_size = precision}, + .next_state = {.shape = {B, 1, H}, .dtype_size = precision}, + .grad_combined = {.shape = {B, TT, 3 * H}, .dtype_size = precision}, + .grad_state = {.shape = {B, 1, H}, .dtype_size = precision}, + .grad_input = {.shape = {B, TT, H}, .dtype_size = precision}, + }; + a->saved_inputs[i] = {.shape = {B, TT, H}, .dtype_size = precision}; + a->combined_bufs[i] = {.shape = {B_TT, 3 * H}, .dtype_size = precision}; + a->wgrad_scratch[i] = {.shape = {3 * H, H}, .dtype_size = precision}; + alloc_register(acts, &a->saved_inputs[i]); + alloc_register(acts, &a->combined_bufs[i]); + alloc_register(acts, &a->scan_bufs[i].out); + alloc_register(acts, &a->scan_bufs[i].next_state); + alloc_register(acts, &a->scan_bufs[i].a_star); + alloc_register(acts, &a->scan_bufs[i].s_vals); + alloc_register(acts, &a->scan_bufs[i].log_values_buf); + alloc_register(acts, &a->scan_bufs[i].grad_combined); + alloc_register(acts, &a->scan_bufs[i].grad_state); + alloc_register(acts, &a->scan_bufs[i].grad_input); + alloc_register(grads, &a->wgrad_scratch[i]); + } +} + +static void mingru_reg_rollout(void *weights, void *activations, + Allocator *alloc, int B_inf) { + MinGRUWeights *w = (MinGRUWeights *)weights; + MinGRUActivations *a = (MinGRUActivations *)activations; + int H = w->hidden; + a->num_layers = w->num_layers; + a->combined.resize(w->num_layers); + for (int i = 0; i < w->num_layers; i++) { + a->combined[i] = {.shape = {B_inf, 3 * H}, .dtype_size = PRECISION_SIZE}; + alloc_register(alloc, &a->combined[i]); + } + a->out = {.shape = {B_inf, H}, .dtype_size = PRECISION_SIZE}; + a->next_state = {.shape = {B_inf, H}, .dtype_size = PRECISION_SIZE}; + alloc_register(alloc, &a->out); + alloc_register(alloc, &a->next_state); +} + +static PrecisionTensor mingru_forward(void *w, PrecisionTensor x, + PrecisionTensor state, + void *activations, cudaStream_t stream) { + MinGRUWeights *m = (MinGRUWeights *)w; + MinGRUActivations *a = (MinGRUActivations *)activations; + int B = (int)state.shape[1]; + int H = (int)state.shape[2]; + MetalStream *ms = mtl_resolve_stream(stream); + + for (int i = 0; i < m->num_layers; i++) { + PrecisionTensor state_i = mingru_state_layer(state, i); + PufTensor xp = to_puf(x), wi = to_puf(m->weights[i]), ci = to_puf(a->combined[i]); + puf_mm(xp, wi, ci, stream); + mtl_barrier(ms); + mtl_mingru_gate(a->out.data, a->next_state.data, + (const float *)a->combined[i].data, + state_i.data, x.data, H, B, stream); + mtl_barrier(ms); + PufTensor si = to_puf(state_i), ns = to_puf(a->next_state); + puf_copy(si, ns, stream); + if (i + 1 < m->num_layers) + mtl_barrier(ms); + x = a->out; + } + return x; +} + +static PrecisionTensor mingru_forward_train(void *w, PrecisionTensor x, + PrecisionTensor state, + void *activations, + cudaStream_t stream) { + MinGRUWeights *m = (MinGRUWeights *)w; + MinGRUActivations *a = (MinGRUActivations *)activations; + MetalStream *ms = mtl_resolve_stream(stream); + + for (int i = 0; i < m->num_layers; i++) { + PufTensor si_p = to_puf(a->saved_inputs[i]), xp = to_puf(x); + puf_copy(si_p, xp, stream); + PrecisionTensor state_i = mingru_state_layer(state, i); + PufTensor wi = to_puf(m->weights[i]), cb = to_puf(a->combined_bufs[i]); + puf_mm(xp, wi, cb, stream); + mtl_barrier(ms); + a->scan_bufs[i].combined_ptr = a->combined_bufs[i].data; + a->scan_bufs[i].state_ptr = state_i.data; + a->scan_bufs[i].input_ptr = a->saved_inputs[i].data; + if (x.dtype_size == 2) { + mtl_mingru_scan_forward_fp16(a->scan_bufs[i], stream); + } else { + mtl_mingru_scan_forward(a->scan_bufs[i], stream); + } + mtl_barrier(ms); + x = a->scan_bufs[i].out; + } + return x; +} + +static PrecisionTensor mingru_backward(void *w, PrecisionTensor grad, + void *activations, + cudaStream_t stream) { + MinGRUWeights *m = (MinGRUWeights *)w; + MinGRUActivations *a = (MinGRUActivations *)activations; + MetalStream *ms = mtl_resolve_stream(stream); + + PufTensor gns = to_puf(a->grad_next_state); + puf_zero(&gns, stream); + mtl_barrier(ms); + + for (int i = m->num_layers - 1; i >= 0; i--) { + PrefixScan &scan = a->scan_bufs[i]; + if (grad.dtype_size == 2) { + mtl_mingru_scan_backward_fp16(scan, grad.data, + a->grad_next_state.data, stream); + } else { + mtl_mingru_scan_backward(scan, grad.data, + a->grad_next_state.data, stream); + } + mtl_barrier(ms); + PufTensor gc = to_puf(scan.grad_combined), si = to_puf(a->saved_inputs[i]); + PufTensor wgs = to_puf(a->wgrad_scratch[i]); + puf_mm_tn(gc, si, wgs, stream); + PufTensor wi = to_puf(m->weights[i]), gib = to_puf(a->grad_input_buf); + puf_mm_nn(gc, wi, gib, stream); + mtl_barrier(ms); + PufTensor gi = to_puf(scan.grad_input); + puf_add(gib, gi, stream); + mtl_barrier(ms); + grad = a->grad_input_buf; + } + return grad; +} + +void mtl_kernels_reset() { + if (norm_partials_buf) { + mtl_unwrap_ptr(norm_partials_buf); + free(norm_partials_buf); + norm_partials_buf = nullptr; + } + if (ppo_partials_buf) { + mtl_unwrap_ptr(ppo_partials_buf); + free(ppo_partials_buf); + ppo_partials_buf = nullptr; + ppo_partials_capacity = 0; + } +} diff --git a/src/metal_platform.h b/src/metal_platform.h new file mode 100644 index 0000000000..09894b60d7 --- /dev/null +++ b/src/metal_platform.h @@ -0,0 +1,208 @@ +#ifndef PUFFERLIB_METAL_PLATFORM_H +#define PUFFERLIB_METAL_PLATFORM_H + +#import +#include + +#define ACCELERATE_NEW_LAPACK +#import + +#include "puf_types.h" +#include +#include +#include + +struct MetalStream { + id allocator; + id cmd; + id enc; + id arg_table; + id sync_event; + id const_ring; + NSUInteger const_ring_offset = 0; + uint64_t bound_addresses[32] = {}; + + bool enc_active = false; + bool pending_work = false; + bool flushed = false; + uint64_t flush_event_val = 0; + uint64_t sync_event_value = 0; + + void begin(); + void compute_encoder(); + void end_compute(); + void sync(); + void flush(); + void commit_chunk(); + void wait_completed(); +}; + +struct WrappedBuffer { + char *base; + int64_t size; + id buffer; +}; + +static const NSUInteger MTL_CONST_RING_SIZE = 2 * 1024 * 1024; + +struct MetalContext { + id device; + id library; + NSMutableDictionary> *pipelines; + + id tensor_ops_gemm_nt_f32; + id tensor_ops_gemm_nn_f32; + id tensor_ops_gemm_tn_f32; + id tensor_ops_gemm_nt_f16; + id tensor_ops_gemm_nn_f16; + id tensor_ops_gemm_tn_f16; + + id queue; + id train_queue; + id residency_set; + + MetalStream stream; + MetalStream train_stream; + std::vector buffers; +}; +void mtl_init(); +MetalContext *mtl_ctx(); +void *mtl_stream(); +void *mtl_train_stream(); +static inline MetalStream *mtl_resolve_stream(cudaStream_t s) { + return s ? (MetalStream *)s : (MetalStream *)mtl_stream(); +} + +static inline void mtl_ensure_stream_synced(cudaStream_t s) { + MetalStream *ms = mtl_resolve_stream(s); + if (ms->enc_active || ms->pending_work) ms->sync(); +} + +void *mtl_create_stream(); +void mtl_destroy_stream(void *stream); + +static inline void *mtl_alloc_scratch(int64_t bytes) { + int64_t page = 16384; + int64_t alloc = ((bytes + page - 1) / page) * page; + void *ptr = nullptr; + posix_memalign(&ptr, page, alloc); + assert(ptr && "mtl_alloc_scratch: posix_memalign failed"); + memset(ptr, 0, alloc); + MetalContext *ctx = mtl_ctx(); + id buf = [ctx->device newBufferWithBytesNoCopy:ptr + length:(NSUInteger)alloc + options:MTLResourceStorageModeShared + deallocator:nil]; + assert(buf); + ctx->buffers.push_back({(char *)ptr, alloc, buf}); + [ctx->residency_set addAllocation:buf]; + [ctx->residency_set commit]; + [ctx->residency_set requestResidency]; + return ptr; +} + +void mtl_destroy(); +void mtl_kernels_reset(); +id mtl_wrap_allocator(Allocator *alloc); +id mtl_buffer_for(const PufTensor &t, NSUInteger *out_offset); +id mtl_pipeline(const char *name); +inline void mtl_set_pso(MetalStream *ms, id pso) { + [ms->enc setComputePipelineState:pso]; +} + +inline void mtl_set_tensor(MetalStream *ms, const PufTensor &t, + uint32_t index) { + NSUInteger offset; + id buf = mtl_buffer_for(t, &offset); + uint64_t addr = buf.gpuAddress + offset; + if (ms->bound_addresses[index] != addr) { + [ms->arg_table setAddress:addr atIndex:index]; + ms->bound_addresses[index] = addr; + } +} + +id mtl_buffer_for_ptr(const void *ptr, NSUInteger *out_offset); + +inline void mtl_set_tensor(MetalStream *ms, const FloatTensor &t, + uint32_t index) { + NSUInteger offset; + id buf = mtl_buffer_for_ptr(t.data, &offset); + uint64_t addr = buf.gpuAddress + offset; + if (ms->bound_addresses[index] != addr) { + [ms->arg_table setAddress:addr atIndex:index]; + ms->bound_addresses[index] = addr; + } +} +template +inline void mtl_set_params(MetalStream *ms, const T ¶ms, uint32_t index) { + NSUInteger aligned = (sizeof(T) + 15) & ~15; + assert(ms->const_ring_offset + aligned <= MTL_CONST_RING_SIZE); + memcpy((char *)[ms->const_ring contents] + ms->const_ring_offset, + ¶ms, sizeof(T)); + uint64_t addr = ms->const_ring.gpuAddress + ms->const_ring_offset; + [ms->arg_table setAddress:addr atIndex:index]; + ms->bound_addresses[index] = addr; + ms->const_ring_offset += aligned; +} + +inline void mtl_dispatch_1d(MetalStream *ms, id pso, + int count) { + NSUInteger tg = MIN((NSUInteger)pso.maxTotalThreadsPerThreadgroup, 256); + [ms->enc dispatchThreads:MTLSizeMake(count, 1, 1) + threadsPerThreadgroup:MTLSizeMake(tg, 1, 1)]; +} + +inline void mtl_dispatch_groups(MetalStream *ms, + id pso, + int num_groups, int group_size) { + [ms->enc dispatchThreadgroups:MTLSizeMake(num_groups, 1, 1) + threadsPerThreadgroup:MTLSizeMake(group_size, 1, 1)]; +} + +inline void mtl_set_threadgroup_memory(MetalStream *ms, NSUInteger length, + uint32_t index) { + [ms->enc setThreadgroupMemoryLength:length atIndex:index]; +} + +inline void mtl_barrier(MetalStream *ms) { + if (ms->enc_active) { + [ms->enc barrierAfterEncoderStages:MTLStageDispatch + beforeEncoderStages:MTLStageDispatch + visibilityOptions:MTL4VisibilityOptionDevice]; + ms->pending_work = true; + } +} + +inline PufTensor to_puf(PrecisionTensor &t) { + return {.bytes = (char *)t.data, + .shape = {t.shape[0], t.shape[1], t.shape[2], t.shape[3]}, + .dtype_size = t.dtype_size}; +} + +inline PufTensor to_puf(const PrecisionTensor &t) { + return {.bytes = (char *)t.data, + .shape = {t.shape[0], t.shape[1], t.shape[2], t.shape[3]}, + .dtype_size = t.dtype_size}; +} + +void puf_mm(PufTensor &a, PufTensor &b, PufTensor &out, cudaStream_t stream); +void puf_mm_tn(PufTensor &a, PufTensor &b, PufTensor &out, + cudaStream_t stream); +void puf_mm_nn(PufTensor &a, PufTensor &b, PufTensor &out, + cudaStream_t stream); +void puf_addmm_nn(PufTensor &a, PufTensor &b, PufTensor &out, float alpha, + float beta, cudaStream_t stream); +void mtl_sync_stats(int *out_count, double *out_total_ms); +void mtl_enable_gpu_timing(bool enable); +void mtl_gpu_timing_stats(double *gpu_exec_ms, double *sched_wait_ms); +void mtl_gemm_stats(int *tensor_ops_count); +bool puf_stream_has_encoder(cudaStream_t stream); +void puf_set_gpu_training(bool val); +bool puf_is_gpu_training(); +void cpu_cast_u8_to_f32(float *dst, const uint8_t *src, int count); +void mtl_cast_f32_to_f16(void *dst, const float *src, int count, + cudaStream_t stream); +void mtl_cast_f16_to_f32(float *dst, const void *src, int count, + cudaStream_t stream); + +#endif // PUFFERLIB_METAL_PLATFORM_H diff --git a/src/metal_platform.mm b/src/metal_platform.mm new file mode 100644 index 0000000000..9b0ac710d6 --- /dev/null +++ b/src/metal_platform.mm @@ -0,0 +1,1272 @@ +#import "metal_platform.h" +#import // CACurrentMediaTime +#include "metal_shader_src.h" + +#include +#include +#include +#include +#include +#include + +static MetalContext g_ctx = {}; +static std::mutex g_pipeline_mutex; + +// ============================================================================ +// Metal 4 tensor_ops GEMM — MSL source for JIT compilation. +// Separate library (needs metal_tensor + MetalPerformancePrimitives includes). +// ============================================================================ + +static const char *get_tensor_ops_shader_source() { + return R"METAL( +#include +#include +#include +using namespace metal; +using namespace mpp::tensor_ops; + +// C(M,N) = A(M,K) @ B(N,K)^T — float32, tensor_inline with device memory. +// Tile: 64 rows (M) x 32 cols (N), dynamic K. +// N MUST be a multiple of 32, M MUST be a multiple of 64. +kernel void tensor_ops_gemm_nt_f32( + device float* A_buf [[buffer(0)]], + device float* B_buf [[buffer(1)]], + device float* C_buf [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + auto A = tensor, tensor_inline>( + A_buf, dextents(K, M)); + auto B = tensor, tensor_inline>( + B_buf, dextents(K, N)); + auto C = tensor, tensor_inline>( + C_buf, dextents(N, M)); + + constexpr auto desc = matmul2d_descriptor( + 64, 32, + static_cast(dynamic_extent), + false, true, false + ); + matmul2d> op; + + auto mA = A.slice(0, tgid.y * 64); + auto mB = B.slice(0, tgid.x * 32); + auto mC = C.slice(tgid.x * 32, tgid.y * 64); + + op.run(mA, mB, mC); +} + +// C(M,N) = A(M,K) @ B(K,N) — float32, tensor_inline with device memory. +// Row-major NN maps to col-major: C_cm(N,M) = B_cm(N,K) @ A_cm(K,M). +// Tiling follows NT convention: tgid.y tiles M (stride 64), tgid.x tiles N (stride 32). +// M % 64 == 0 and N % 32 == 0 required (caller falls back to steel_gemm). +kernel void tensor_ops_gemm_nn_f32( + device float* A_buf [[buffer(0)]], + device float* B_buf [[buffer(1)]], + device float* C_buf [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + // Row-major A(M,K) in memory == col-major A_cm(K,M) + auto A_cm = tensor, tensor_inline>( + A_buf, dextents(K, M)); + // Row-major B(K,N) in memory == col-major B_cm(N,K) + auto B_cm = tensor, tensor_inline>( + B_buf, dextents(N, K)); + // Row-major C(M,N) in memory == col-major C_cm(N,M) + auto C_cm = tensor, tensor_inline>( + C_buf, dextents(N, M)); + + // Col-major NN: C_cm(N,M) = B_cm(N,K) @ A_cm(K,M) + // op.run convention: result = second @ first (same as NT kernel) + // first=A_cm, second=B_cm, no transposes + constexpr auto desc = matmul2d_descriptor( + 64, 32, + static_cast(dynamic_extent), + false, false, false + ); + matmul2d> op; + + // tgid.y tiles M at stride 64 (tile_M), tgid.x tiles N at stride 32 (tile_N) + // Same convention as NT kernel + auto mFirst = A_cm.slice(0, tgid.y * 64); + auto mSecond = B_cm.slice(tgid.x * 32, 0); + auto mResult = C_cm.slice(tgid.x * 32, tgid.y * 64); + + op.run(mFirst, mSecond, mResult); +} + +// ---- fp16 variants: half inputs/outputs, float accumulation inside matmul2d ---- + +// C(M,N) = A(M,K) @ B(N,K)^T — half precision, tensor_inline with device memory. +kernel void tensor_ops_gemm_nt_f16( + device half* A_buf [[buffer(0)]], + device half* B_buf [[buffer(1)]], + device half* C_buf [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + auto A = tensor, tensor_inline>( + A_buf, dextents(K, M)); + auto B = tensor, tensor_inline>( + B_buf, dextents(K, N)); + auto C = tensor, tensor_inline>( + C_buf, dextents(N, M)); + + constexpr auto desc = matmul2d_descriptor( + 64, 32, + static_cast(dynamic_extent), + false, true, false + ); + matmul2d> op; + + auto mA = A.slice(0, tgid.y * 64); + auto mB = B.slice(0, tgid.x * 32); + auto mC = C.slice(tgid.x * 32, tgid.y * 64); + + op.run(mA, mB, mC); +} + +// C(M,N) = A(M,K) @ B(K,N) — half precision NN variant. +kernel void tensor_ops_gemm_nn_f16( + device half* A_buf [[buffer(0)]], + device half* B_buf [[buffer(1)]], + device half* C_buf [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + auto A_cm = tensor, tensor_inline>( + A_buf, dextents(K, M)); + auto B_cm = tensor, tensor_inline>( + B_buf, dextents(N, K)); + auto C_cm = tensor, tensor_inline>( + C_buf, dextents(N, M)); + + constexpr auto desc = matmul2d_descriptor( + 64, 32, + static_cast(dynamic_extent), + false, false, false + ); + matmul2d> op; + + auto mFirst = A_cm.slice(0, tgid.y * 64); + auto mSecond = B_cm.slice(tgid.x * 32, 0); + auto mResult = C_cm.slice(tgid.x * 32, tgid.y * 64); + + op.run(mFirst, mSecond, mResult); +} + +// C(M,N) = A(K,M)^T @ B(K,N) — float32, tensor_inline with device memory. +// TN: backward weight gradient. Row-major A(K,M) = col-major (M,K). +// matmul2d: result(N,M) = second(N,K) @ transpose(first(M,K)) +// M % 64 == 0 and N % 32 == 0 required (caller falls back to steel_gemm). +kernel void tensor_ops_gemm_tn_f32( + device float* A_buf [[buffer(0)]], + device float* B_buf [[buffer(1)]], + device float* C_buf [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + // Row-major A(K,M) in memory == col-major tensor(M, K) + auto A_cm = tensor, tensor_inline>( + A_buf, dextents(M, K)); + // Row-major B(K,N) in memory == col-major tensor(N, K) + auto B_cm = tensor, tensor_inline>( + B_buf, dextents(N, K)); + // Row-major C(M,N) in memory == col-major tensor(N, M) + auto C_cm = tensor, tensor_inline>( + C_buf, dextents(N, M)); + + // result(N,M) = second(N,K) @ transpose(first(M,K)) + // transpose_first=true: matmul sees first as (K,M) + // (N,K) @ (K,M) = (N,M) = result + constexpr auto desc = matmul2d_descriptor( + 64, 32, + static_cast(dynamic_extent), + true, false, false + ); + matmul2d> op; + + // tgid.y tiles M at stride 64, tgid.x tiles N at stride 32 + auto mFirst = A_cm.slice(tgid.y * 64, 0); + auto mSecond = B_cm.slice(tgid.x * 32, 0); + auto mResult = C_cm.slice(tgid.x * 32, tgid.y * 64); + + op.run(mFirst, mSecond, mResult); +} + +// C(M,N) = A(K,M)^T @ B(K,N) — half precision TN variant. +kernel void tensor_ops_gemm_tn_f16( + device half* A_buf [[buffer(0)]], + device half* B_buf [[buffer(1)]], + device half* C_buf [[buffer(2)]], + constant uint& M [[buffer(3)]], + constant uint& N [[buffer(4)]], + constant uint& K [[buffer(5)]], + uint2 tgid [[threadgroup_position_in_grid]]) +{ + auto A_cm = tensor, tensor_inline>( + A_buf, dextents(M, K)); + auto B_cm = tensor, tensor_inline>( + B_buf, dextents(N, K)); + auto C_cm = tensor, tensor_inline>( + C_buf, dextents(N, M)); + + constexpr auto desc = matmul2d_descriptor( + 64, 32, + static_cast(dynamic_extent), + true, false, false + ); + matmul2d> op; + + auto mFirst = A_cm.slice(tgid.y * 64, 0); + auto mSecond = B_cm.slice(tgid.x * 32, 0); + auto mResult = C_cm.slice(tgid.x * 32, tgid.y * 64); + + op.run(mFirst, mSecond, mResult); +} +)METAL"; +} + +void MetalStream::begin() { + enc = nil; + enc_active = false; + pending_work = false; + flushed = false; + memset(bound_addresses, 0, sizeof(bound_addresses)); + [allocator reset]; + [cmd beginCommandBufferWithAllocator:allocator]; + [cmd useResidencySet:mtl_ctx()->residency_set]; + const_ring_offset = 0; +} + +void MetalStream::compute_encoder() { + if (!enc_active) { + enc = [cmd computeCommandEncoder]; + [enc setArgumentTable:arg_table]; + enc_active = true; + pending_work = true; + } +} + +void MetalStream::end_compute() { + if (enc_active) { + [enc endEncoding]; + enc = nil; + enc_active = false; + } +} + +// Sync profiling globals +static int g_sync_count = 0; +static double g_sync_total_ns = 0.0; +static mach_timebase_info_data_t g_timebase = {0, 0}; + +// GPU timing diagnostic — actual kernel execution vs scheduling delay. +// Only active when g_gpu_timing_enabled is true (set via mtl_enable_gpu_timing). +// The MTL4CommitOptions + feedback handler allocation adds ObjC overhead that +// causes measurable scheduling jitter when sampled unconditionally. +static bool g_gpu_timing_enabled = false; +static double g_gpu_exec_ns = 0.0; +static double g_sched_wait_ns = 0.0; +static constexpr NSUInteger kMetalSyncTimeoutMs = 300000; // 5 min — worst-case sweep configs push 3K+ minibatches per epoch + +static double mach_to_ns(uint64_t ticks) { + if (g_timebase.denom == 0) mach_timebase_info(&g_timebase); + return (double)ticks * g_timebase.numer / g_timebase.denom; +} + +void MetalStream::sync() { + end_compute(); + uint64_t t0 = mach_absolute_time(); + [cmd endCommandBuffer]; + uint64_t val = ++sync_event_value; + id bufs[] = { cmd }; + MetalContext *ctx = mtl_ctx(); + id q = + (this == &ctx->train_stream) ? ctx->train_queue : ctx->queue; + // Sample GPU timing every 32nd sync, only when profiling is enabled. + // The MTL4CommitOptions + ObjC feedback handler block allocation adds + // measurable jitter (~50-200us) that contributes to SPS variance. + bool sample_timing = g_gpu_timing_enabled && (g_sync_count % 32 == 0); + if (sample_timing) { + CFTimeInterval cpu_commit = CACurrentMediaTime(); + MTL4CommitOptions *opts = [MTL4CommitOptions new]; + __block CFTimeInterval gpu_start = 0, gpu_end = 0; + [opts addFeedbackHandler:^(id fb) { + gpu_start = fb.GPUStartTime; + gpu_end = fb.GPUEndTime; + }]; + [q commit:bufs count:1 options:opts]; + [q signalEvent:sync_event value:val]; + BOOL signaled = [sync_event waitUntilSignaledValue:val timeoutMS:kMetalSyncTimeoutMs]; + if (!signaled) { + assert(false && "Metal sync timeout in MetalStream::sync"); + } + if (gpu_start > 0 && gpu_end > 0) { + g_gpu_exec_ns += (gpu_end - gpu_start) * 1e9; + g_sched_wait_ns += (gpu_start - cpu_commit) * 1e9; + } + } else { + [q commit:bufs count:1]; + [q signalEvent:sync_event value:val]; + BOOL signaled = [sync_event waitUntilSignaledValue:val timeoutMS:kMetalSyncTimeoutMs]; + if (!signaled) { + assert(false && "Metal sync timeout in MetalStream::sync"); + } + } + uint64_t t1 = mach_absolute_time(); + g_sync_count++; + g_sync_total_ns += mach_to_ns(t1 - t0); + begin(); +} + +void MetalStream::flush() { + end_compute(); + if (pending_work) { + [cmd endCommandBuffer]; + id bufs[] = { cmd }; + MetalContext *ctx = mtl_ctx(); + id q = + (this == &ctx->train_stream) ? ctx->train_queue : ctx->queue; + // Signal event value for wait_completed(). + flush_event_val = ++sync_event_value; + [q commit:bufs count:1]; + [q signalEvent:sync_event value:flush_event_val]; + flushed = true; + pending_work = false; + } +} + +void MetalStream::commit_chunk() { + end_compute(); + if (!pending_work) return; + + MetalContext *ctx = mtl_ctx(); + [cmd endCommandBuffer]; + id bufs[] = { cmd }; + id q = + (this == &ctx->train_stream) ? ctx->train_queue : ctx->queue; + [q commit:bufs count:1]; + + cmd = [ctx->device newCommandBuffer]; + assert(cmd && "Failed to allocate Metal command buffer for chunked training"); + begin(); +} + +void MetalStream::wait_completed() { + if (flushed) { + uint64_t t0 = mach_absolute_time(); + BOOL signaled = [sync_event waitUntilSignaledValue:flush_event_val timeoutMS:kMetalSyncTimeoutMs]; + if (!signaled) { + assert(false && "Metal sync timeout in MetalStream::wait_completed"); + } + uint64_t t1 = mach_absolute_time(); + g_sync_count++; + g_sync_total_ns += mach_to_ns(t1 - t0); + flushed = false; + begin(); + } +} + +void mtl_sync_stats(int *out_count, double *out_total_ms) { + *out_count = g_sync_count; + *out_total_ms = g_sync_total_ns / 1e6; + g_sync_count = 0; + g_sync_total_ns = 0.0; +} + +void mtl_enable_gpu_timing(bool enable) { + g_gpu_timing_enabled = enable; +} + +void mtl_gpu_timing_stats(double *gpu_exec_ms, double *sched_wait_ms) { + *gpu_exec_ms = g_gpu_exec_ns / 1e6; + *sched_wait_ms = g_sched_wait_ns / 1e6; + g_gpu_exec_ns = 0.0; + g_sched_wait_ns = 0.0; +} + +static int g_gemm_dispatch_count = 0; + +void mtl_gemm_stats(int *tensor_ops_count) { + *tensor_ops_count = g_gemm_dispatch_count; + g_gemm_dispatch_count = 0; +} + +void mtl_init() { + @autoreleasepool { + g_ctx.device = MTLCreateSystemDefaultDevice(); + assert(g_ctx.device && "No Metal device found"); + + // JIT-compile all MSL shaders from the embedded source string + NSError *error = nil; + NSString *src = + [NSString stringWithUTF8String:get_all_metal_shader_source()]; + MTLCompileOptions *opts = [[MTLCompileOptions alloc] init]; + opts.mathMode = MTLMathModeFast; + opts.languageVersion = MTLLanguageVersion4_0; + g_ctx.library = + [g_ctx.device newLibraryWithSource:src options:opts error:&error]; + if (!g_ctx.library) + throw std::runtime_error(error ? error.localizedDescription.UTF8String + : "MSL compilation failed"); + + g_ctx.pipelines = [NSMutableDictionary new]; + + // Compile Metal 4 tensor_ops GEMM pipelines. All variants must succeed — + // steel_gemm fallback is 2-3x slower and we target M4 Pro exclusively. + { + NSString *tensor_src = [NSString stringWithUTF8String:get_tensor_ops_shader_source()]; + MTLCompileOptions *tensor_opts = [[MTLCompileOptions alloc] init]; + tensor_opts.mathMode = MTLMathModeFast; + tensor_opts.languageVersion = MTLLanguageVersion4_0; + NSError *tensor_err = nil; + id tensor_lib = [g_ctx.device newLibraryWithSource:tensor_src + options:tensor_opts + error:&tensor_err]; + if (!tensor_lib) + throw std::runtime_error(tensor_err ? tensor_err.localizedDescription.UTF8String + : "tensor_ops library compilation failed"); + + // Helper: compile one PSO from the tensor_ops library, assert on failure. + auto compile_pso = [&](const char *name) -> id { + id fn = [tensor_lib newFunctionWithName: + [NSString stringWithUTF8String:name]]; + assert(fn && "tensor_ops function not found"); + MTLComputePipelineDescriptor *pd = [[MTLComputePipelineDescriptor alloc] init]; + pd.computeFunction = fn; + pd.maxTotalThreadsPerThreadgroup = 128; + NSError *err = nil; + id pso = + [g_ctx.device newComputePipelineStateWithDescriptor:pd + options:0 + reflection:nil + error:&err]; + if (!pso) + throw std::runtime_error(err ? err.localizedDescription.UTF8String + : "tensor_ops PSO compilation failed"); + return pso; + }; + + g_ctx.tensor_ops_gemm_nt_f32 = compile_pso("tensor_ops_gemm_nt_f32"); + g_ctx.tensor_ops_gemm_nn_f32 = compile_pso("tensor_ops_gemm_nn_f32"); + g_ctx.tensor_ops_gemm_tn_f32 = compile_pso("tensor_ops_gemm_tn_f32"); + g_ctx.tensor_ops_gemm_nt_f16 = compile_pso("tensor_ops_gemm_nt_f16"); + g_ctx.tensor_ops_gemm_nn_f16 = compile_pso("tensor_ops_gemm_nn_f16"); + g_ctx.tensor_ops_gemm_tn_f16 = compile_pso("tensor_ops_gemm_tn_f16"); + + } + + // Metal 4 reusable command buffer infrastructure + g_ctx.queue = [g_ctx.device newMTL4CommandQueue]; + assert(g_ctx.queue && "Metal 4 required — device must support newMTL4CommandQueue"); + g_ctx.train_queue = [g_ctx.device newMTL4CommandQueue]; + + // Command allocators + reusable command buffers + g_ctx.stream.allocator = [g_ctx.device newCommandAllocator]; + g_ctx.stream.cmd = [g_ctx.device newCommandBuffer]; + g_ctx.stream.sync_event = [g_ctx.device newSharedEvent]; + g_ctx.stream.sync_event_value = 0; + g_ctx.train_stream.allocator = [g_ctx.device newCommandAllocator]; + g_ctx.train_stream.cmd = [g_ctx.device newCommandBuffer]; + g_ctx.train_stream.sync_event = [g_ctx.device newSharedEvent]; + g_ctx.train_stream.sync_event_value = 0; + + // Argument tables — PPO kernel uses slots 0-19, Metal 4 max is 31 + MTL4ArgumentTableDescriptor *atd = [MTL4ArgumentTableDescriptor new]; + atd.maxBufferBindCount = 31; + NSError *at_err = nil; + g_ctx.stream.arg_table = + [g_ctx.device newArgumentTableWithDescriptor:atd error:&at_err]; + assert(g_ctx.stream.arg_table && "Failed to create argument table"); + g_ctx.train_stream.arg_table = + [g_ctx.device newArgumentTableWithDescriptor:atd error:&at_err]; + assert(g_ctx.train_stream.arg_table && + "Failed to create train argument table"); + + // Per-stream constants ring buffers (64KB each, replaces setBytes) + g_ctx.stream.const_ring = + [g_ctx.device newBufferWithLength:MTL_CONST_RING_SIZE + options:MTLResourceStorageModeShared]; + g_ctx.train_stream.const_ring = + [g_ctx.device newBufferWithLength:MTL_CONST_RING_SIZE + options:MTLResourceStorageModeShared]; + + // Residency set — populated by mtl_wrap_allocator as buffers arrive + MTLResidencySetDescriptor *rsd = [MTLResidencySetDescriptor new]; + rsd.initialCapacity = 16; + NSError *rs_err = nil; + g_ctx.residency_set = + [g_ctx.device newResidencySetWithDescriptor:rsd error:&rs_err]; + assert(g_ctx.residency_set && "Failed to create residency set"); + [g_ctx.residency_set addAllocation:g_ctx.stream.const_ring]; + [g_ctx.residency_set addAllocation:g_ctx.train_stream.const_ring]; + [g_ctx.residency_set commit]; + [g_ctx.residency_set requestResidency]; + [g_ctx.queue addResidencySet:g_ctx.residency_set]; + [g_ctx.train_queue addResidencySet:g_ctx.residency_set]; + + // Start the default stream (rollout) and training stream + g_ctx.stream.begin(); + g_ctx.train_stream.begin(); + + } +} + +MetalContext *mtl_ctx() { return &g_ctx; } + +// Lazy-init scratch buffer for addmm temp workspace +static char *g_addmm_temp_base; +static int64_t g_addmm_temp_size; + +void *mtl_stream() { return &g_ctx.stream; } + +void *mtl_train_stream() { return &g_ctx.train_stream; } + +void *mtl_create_stream() { + MetalStream *ms = new MetalStream{}; + ms->allocator = [g_ctx.device newCommandAllocator]; + ms->cmd = [g_ctx.device newCommandBuffer]; + ms->sync_event = [g_ctx.device newSharedEvent]; + ms->sync_event_value = 0; + assert(ms->allocator && ms->cmd && "Failed to create Metal stream allocator/cmd"); + + MTL4ArgumentTableDescriptor *atd = [MTL4ArgumentTableDescriptor new]; + atd.maxBufferBindCount = 31; + NSError *at_err = nil; + ms->arg_table = [g_ctx.device newArgumentTableWithDescriptor:atd error:&at_err]; + assert(ms->arg_table && "Failed to create Metal stream argument table"); + + ms->const_ring = [g_ctx.device newBufferWithLength:MTL_CONST_RING_SIZE + options:MTLResourceStorageModeShared]; + assert(ms->const_ring && "Failed to create Metal stream constants ring"); + [g_ctx.residency_set addAllocation:ms->const_ring]; + [g_ctx.residency_set commit]; + [g_ctx.residency_set requestResidency]; + + ms->begin(); + return ms; +} + +void mtl_destroy_stream(void *stream) { + assert(stream && "mtl_destroy_stream: null stream"); + MetalStream *ms = (MetalStream *)stream; + if (ms->flushed) { + ms->wait_completed(); + } else if (ms->enc_active || ms->pending_work) { + ms->sync(); + } + ms->arg_table = nil; + ms->cmd = nil; + ms->allocator = nil; + ms->enc = nil; + ms->sync_event = nil; + ms->const_ring = nil; + delete ms; +} + +static void ksplit_reset(); // forward decl — defined near K-split GEMM + +void mtl_destroy() { + // 1. Drain both command queues — no GPU work in flight. + g_ctx.stream.end_compute(); + g_ctx.train_stream.end_compute(); + + // 2. Free lazy-init scratch buffers BEFORE clearing the buffer registry, + // while MTLBuffer refs still exist (backing memory released after). + mtl_kernels_reset(); + ksplit_reset(); + if (g_addmm_temp_base) { + free(g_addmm_temp_base); + g_addmm_temp_base = nullptr; + g_addmm_temp_size = 0; + } + + // 3. Release all Metal objects inside @autoreleasepool to force immediate + // deallocation. Device released LAST — MTLBuffers/pipelines reference it. + @autoreleasepool { + g_ctx.stream.arg_table = nil; + g_ctx.stream.cmd = nil; + g_ctx.stream.allocator = nil; + g_ctx.stream.enc = nil; + g_ctx.stream.sync_event = nil; + g_ctx.train_stream.arg_table = nil; + g_ctx.train_stream.cmd = nil; + g_ctx.train_stream.allocator = nil; + g_ctx.train_stream.enc = nil; + g_ctx.train_stream.sync_event = nil; + g_ctx.stream.const_ring = nil; + g_ctx.train_stream.const_ring = nil; + g_ctx.residency_set = nil; + g_ctx.queue = nil; + g_ctx.train_queue = nil; + g_ctx.buffers.clear(); + g_ctx.pipelines = nil; + g_ctx.library = nil; + } + g_ctx.device = nil; +} + +id mtl_wrap_allocator(Allocator *alloc) { + assert(alloc->mem && "Allocator::create() not called"); + + // Find the highest byte offset used by any registered tensor + int64_t max_end = 0; + for (auto &e : alloc->regs) { + int64_t end = + ((char *)*e.data_ptr - (char *)alloc->mem) + puf_numel(e.shape) * e.elem_size; + if (end > max_end) + max_end = end; + } + + // Wrap exactly the allocator's used byte range. + // Rounding up past allocated memory can cause pointer-range overlap between + // allocators and incorrect buffer resolution in mtl_buffer_for(). + int64_t size = max_end; + + // Zero-copy wrap: StorageModeShared on Apple Silicon means CPU and GPU + // access the same physical memory pages. No deallocator — Allocator + // owns the memory and frees it in destroy(). + id buf = + [g_ctx.device newBufferWithBytesNoCopy:alloc->mem + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + assert(buf && "newBufferWithBytesNoCopy failed — is Allocator::mem " + "page-aligned? (requires WITH_METAL)"); + + g_ctx.buffers.push_back({(char *)alloc->mem, size, buf}); + + // Add to residency set so GPU addresses are valid + [g_ctx.residency_set addAllocation:buf]; + [g_ctx.residency_set commit]; + [g_ctx.residency_set requestResidency]; + + return buf; +} + +id mtl_buffer_for(const PufTensor &t, NSUInteger *out_offset) { + for (auto &wb : g_ctx.buffers) { + if (t.bytes >= wb.base && t.bytes < wb.base + wb.size) { + *out_offset = (NSUInteger)(t.bytes - wb.base); + return wb.buffer; + } + } + assert(false && "PufTensor not in any wrapped allocator buffer"); + __builtin_unreachable(); +} + +// Typed tensor variant: look up buffer by raw pointer. +id mtl_buffer_for_ptr(const void *ptr, NSUInteger *out_offset) { + const char *p = (const char *)ptr; + for (auto &wb : g_ctx.buffers) { + if (p >= wb.base && p < wb.base + wb.size) { + *out_offset = (NSUInteger)(p - wb.base); + return wb.buffer; + } + } + assert(false && "pointer not in any wrapped allocator buffer"); + __builtin_unreachable(); +} + +id mtl_pipeline(const char *name) { + std::lock_guard lock(g_pipeline_mutex); + NSString *key = [NSString stringWithUTF8String:name]; + id pso = g_ctx.pipelines[key]; + if (pso) + return pso; + + id fn = [g_ctx.library newFunctionWithName:key]; + assert(fn && "Kernel function not found in MSL library"); + + NSError *error = nil; + pso = [g_ctx.device newComputePipelineStateWithFunction:fn error:&error]; + if (!pso) + throw std::runtime_error(error ? error.localizedDescription.UTF8String + : "Pipeline creation failed"); + + g_ctx.pipelines[key] = pso; + return pso; +} + +// ============================================================================ +// GEMM dispatch — tensor_ops (aligned) or steel_gemm (unaligned) on GPU. +// All GEMM goes through Metal compute shaders (no CPU cblas path). +// +// Matches cuBLAS calling conventions in models.cu (row-major data, +// column-major API trick: swap A/B and transpose flags). +// ============================================================================ + +// GPU training mode — when true, puf_mm forces GPU GEMM to avoid ensure_gpu_synced. +// Set by train_impl to keep all training ops on the GPU encoder chain. +static std::atomic_bool g_gpu_training = false; +void puf_set_gpu_training(bool val) { g_gpu_training.store(val, std::memory_order_release); } +bool puf_is_gpu_training() { return g_gpu_training.load(std::memory_order_acquire); } + +bool puf_stream_has_encoder(cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + return ms->enc_active; +} + +// Find MTLBuffer containing a raw pointer and compute byte offset. +static id buffer_for_ptr(const void *ptr, NSUInteger *out_offset) { + for (auto &wb : g_ctx.buffers) { + if ((const char *)ptr >= wb.base && + (const char *)ptr < wb.base + wb.size) { + *out_offset = (NSUInteger)((const char *)ptr - wb.base); + return wb.buffer; + } + } + assert(false && "Pointer not in any wrapped allocator buffer"); + __builtin_unreachable(); +} + +// Bind a pre-resolved MTLBuffer+offset to a stream binding slot (GEMM helpers). +static inline void bind_buf(MetalStream *ms, id buf, + NSUInteger offset, uint32_t index) { + uint64_t addr = buf.gpuAddress + offset; + [ms->arg_table setAddress:addr atIndex:index]; + ms->bound_addresses[index] = addr; +} + +// ============================================================================ +// Metal compute GEMM: uses simdgroup_matrix hardware instructions (M3+). +// Stays on the compute encoder (no encoder transitions). +// Used for fp32 GEMMs (rollout inference + Muon optimizer). +// ============================================================================ + +// Must match MSL GemmParams layout exactly (10 x 4 bytes = 40 bytes). +struct HostGemmParams { + int M, N, K, lda, ldb, ldc; + float alpha, beta; + int trans_a, trans_b; +}; + +// steel_gemm dispatch: C(M,N) = alpha * op(A) @ op(B) + beta * C. +// 64x64 output tile per threadgroup, 128 threads (4 simdgroups). +static void steel_gemm_dispatch(const char *kernel_name, + const void *A, const void *B, void *C, + int M, int N, int K, + bool trans_a, bool trans_b, + int lda, int ldb, int ldc, + float alpha, float beta, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + id pso = mtl_pipeline(kernel_name); + mtl_set_pso(ms, pso); + + NSUInteger off_a, off_b, off_c; + bind_buf(ms, buffer_for_ptr(A, &off_a), off_a, 0); + bind_buf(ms, buffer_for_ptr(B, &off_b), off_b, 1); + bind_buf(ms, buffer_for_ptr(C, &off_c), off_c, 2); + + HostGemmParams params = {M, N, K, lda, ldb, ldc, alpha, beta, + trans_a ? 1 : 0, trans_b ? 1 : 0}; + mtl_set_params(ms, params, 3); + + [ms->enc dispatchThreadgroups:MTLSizeMake((N + 63) / 64, (M + 63) / 64, 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + ms->pending_work = true; + g_gemm_dispatch_count++; +} + +static void compute_gemm(const float *A, const float *B, float *C, + int M, int N, int K, bool trans_a, bool trans_b, + int lda, int ldb, int ldc, float alpha, float beta, + cudaStream_t stream) { + steel_gemm_dispatch("steel_gemm", A, B, C, M, N, K, trans_a, trans_b, + lda, ldb, ldc, alpha, beta, stream); +} + +static void compute_gemm_f16(const void *A, const void *B, void *C, + int M, int N, int K, bool trans_a, bool trans_b, + int lda, int ldb, int ldc, float alpha, float beta, + cudaStream_t stream) { + steel_gemm_dispatch("steel_gemm_f16", A, B, C, M, N, K, trans_a, trans_b, + lda, ldb, ldc, alpha, beta, stream); +} + +// ============================================================================ +// tensor_ops GEMM dispatch. All variants (NT/NN/TN x fp32/fp16) use identical +// dispatch: bind A/B/C buffers, set M/N/K params, dispatch 64x32 tile groups. +// Returns false if the PSO is nil (compilation failed). +// ============================================================================ + +static bool tensor_ops_dispatch(id pso, + const void *A, const void *B, void *C, + int M, int N, int K, cudaStream_t stream) { + if (!pso) return false; + g_gemm_dispatch_count++; + + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + mtl_set_pso(ms, pso); + + NSUInteger off_a, off_b, off_c; + id buf_a = buffer_for_ptr(A, &off_a); + id buf_b = buffer_for_ptr(B, &off_b); + id buf_c = buffer_for_ptr(C, &off_c); + + bind_buf(ms, buf_a, off_a, 0); + bind_buf(ms, buf_b, off_b, 1); + bind_buf(ms, buf_c, off_c, 2); + + uint32_t mM = (uint32_t)M, mN = (uint32_t)N, mK = (uint32_t)K; + mtl_set_params(ms, mM, 3); + mtl_set_params(ms, mN, 4); + mtl_set_params(ms, mK, 5); + + int groups_m = (M + 63) / 64; + int groups_n = (N + 31) / 32; + [ms->enc dispatchThreadgroups:MTLSizeMake(groups_n, groups_m, 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + ms->pending_work = true; + return true; +} + +// Typed wrappers for callers that pass float* (fp32 variants). +static bool tensor_ops_gemm_nt(const float *A, const float *B, float *C, + int M, int N, int K, cudaStream_t s) { + return tensor_ops_dispatch(g_ctx.tensor_ops_gemm_nt_f32, A, B, C, M, N, K, s); +} +static bool tensor_ops_gemm_nn(const float *A, const float *B, float *C, + int M, int N, int K, cudaStream_t s) { + return tensor_ops_dispatch(g_ctx.tensor_ops_gemm_nn_f32, A, B, C, M, N, K, s); +} +static bool tensor_ops_gemm_tn(const float *A, const float *B, float *C, + int M, int N, int K, cudaStream_t s) { + return tensor_ops_dispatch(g_ctx.tensor_ops_gemm_tn_f32, A, B, C, M, N, K, s); +} +static bool tensor_ops_gemm_nt_f16(const void *A, const void *B, void *C, + int M, int N, int K, cudaStream_t s) { + return tensor_ops_dispatch(g_ctx.tensor_ops_gemm_nt_f16, A, B, C, M, N, K, s); +} +static bool tensor_ops_gemm_nn_f16(const void *A, const void *B, void *C, + int M, int N, int K, cudaStream_t s) { + return tensor_ops_dispatch(g_ctx.tensor_ops_gemm_nn_f16, A, B, C, M, N, K, s); +} +static bool tensor_ops_gemm_tn_f16(const void *A, const void *B, void *C, + int M, int N, int K, cudaStream_t s) { + return tensor_ops_dispatch(g_ctx.tensor_ops_gemm_tn_f16, A, B, C, M, N, K, s); +} + +// K-split TN GEMM for small-M, small-N, large-K reductions (wgrad). +// Partitions K across Z threadgroups, each writing partial sums. Then reduces. +static id ksplit_buf = nil; +static float *ksplit_ptr = nullptr; +static int ksplit_capacity = 0; + +// Reset K-split state on Metal context teardown. +static void ksplit_reset() { + ksplit_buf = nil; + ksplit_ptr = nullptr; + ksplit_capacity = 0; +} + +static void compute_gemm_ksplit_tn(const float *A, const float *B, float *C, + int M, int N, int K, + int lda, int ldb, int ldc, + cudaStream_t stream) { + int tile_groups_m = (M + 31) / 32; + int tile_groups_n = (N + 31) / 32; + int tile_groups = tile_groups_m * tile_groups_n; + + // Target ~32 threadgroups per GPU core (16 cores on M4 Pro). + int target_tgs = 16 * 32; + int num_splits = (target_tgs + tile_groups - 1) / tile_groups; + num_splits = std::max(2, std::min(num_splits, K / 32)); + int k_per_split = (K + num_splits - 1) / num_splits; + + // Lazy-allocate partials buffer as a shared Metal buffer. + int partials_count = num_splits * M * N; + if (partials_count > ksplit_capacity) { + NSUInteger sz = (NSUInteger)partials_count * sizeof(float); + // Remove old buffer from residency and buffer list. + if (ksplit_buf) { + auto &bufs = g_ctx.buffers; + bufs.erase(std::remove_if(bufs.begin(), bufs.end(), + [](const WrappedBuffer &wb) { + return wb.base == (const char *)ksplit_ptr; + }), bufs.end()); + } + ksplit_buf = [g_ctx.device newBufferWithLength:sz + options:MTLResourceStorageModeShared]; + ksplit_ptr = (float *)ksplit_buf.contents; + g_ctx.buffers.push_back({(char *)ksplit_ptr, (int64_t)sz, ksplit_buf}); + [g_ctx.residency_set addAllocation:ksplit_buf]; + [g_ctx.residency_set commit]; + [g_ctx.residency_set requestResidency]; + ksplit_capacity = partials_count; + } + + MetalStream *ms = mtl_resolve_stream(stream); + + // Step 1: K-split GEMM — write partials. + ms->compute_encoder(); + auto pso_ksplit = mtl_pipeline("sgemm_ksplit"); + mtl_set_pso(ms, pso_ksplit); + + NSUInteger off_a, off_b, off_p; + id buf_a = buffer_for_ptr(A, &off_a); + id buf_b = buffer_for_ptr(B, &off_b); + id buf_p = buffer_for_ptr(ksplit_ptr, &off_p); + + bind_buf(ms, buf_a, off_a, 0); + bind_buf(ms, buf_b, off_b, 1); + bind_buf(ms, buf_p, off_p, 2); + + HostGemmParams params = {M, N, K, lda, ldb, ldc, 1.0f, 0.0f, 1, 0}; // trans_a=TN + mtl_set_params(ms, params, 3); + int kps = k_per_split; + mtl_set_params(ms, kps, 4); + + // sgemm_ksplit uses 2D threadgroups: (BN/TN, BM/TM) = (8, 8) = 64 threads. + [ms->enc dispatchThreadgroups:MTLSizeMake(tile_groups_n, tile_groups_m, num_splits) + threadsPerThreadgroup:MTLSizeMake(8, 8, 1)]; + ms->pending_work = true; + + // Barrier before reduce. + mtl_barrier(ms); + + // Step 2: Reduce partials → C. + ms->compute_encoder(); + auto pso_reduce = mtl_pipeline("reduce_ksplit"); + mtl_set_pso(ms, pso_reduce); + + NSUInteger off_c; + id buf_c = buffer_for_ptr(C, &off_c); + bind_buf(ms, buf_p, off_p, 0); + bind_buf(ms, buf_c, off_c, 1); + + struct { int MN, num_splits; float alpha, beta; } + rp = {M * N, num_splits, 1.0f, 0.0f}; + mtl_set_params(ms, rp, 2); + + mtl_dispatch_1d(ms, pso_reduce, M * N); +} + +// Small compute-encoder GEMM for unaligned N. +// C(M,N) = A(M,K) @ B(N,K)^T. One threadgroup per row, threads partition columns. +// Efficient for small N (e.g. decoder output N=40) where 64x64 tile waste dominates. +static void small_gemm_nt_dispatch(const float *A, const float *B, float *C, + int M, int N, int K, + cudaStream_t stream) { + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + id pso = mtl_pipeline("small_gemm_nt_f32"); + mtl_set_pso(ms, pso); + + NSUInteger off_a, off_b, off_c; + id buf_a = buffer_for_ptr(A, &off_a); + id buf_b = buffer_for_ptr(B, &off_b); + id buf_c = buffer_for_ptr(C, &off_c); + + bind_buf(ms, buf_a, off_a, 0); + bind_buf(ms, buf_b, off_b, 1); + bind_buf(ms, buf_c, off_c, 2); + + struct { uint32_t M, N, K; } params = {(uint32_t)M, (uint32_t)N, (uint32_t)K}; + mtl_set_params(ms, params, 3); + + // threadgroup size: round N up to next multiple of 32 for simdgroup alignment + int tg_size = ((N + 31) / 32) * 32; + tg_size = MIN(tg_size, (int)pso.maxTotalThreadsPerThreadgroup); + [ms->enc dispatchThreadgroups:MTLSizeMake(M, 1, 1) + threadsPerThreadgroup:MTLSizeMake(tg_size, 1, 1)]; + + ms->pending_work = true; + g_gemm_dispatch_count++; +} + +// out(...,N) = a(...,K) @ b(N,K)^T — leading dims folded into M +void puf_mm(PufTensor &a, PufTensor &b, PufTensor &out, + cudaStream_t stream) { + int na = a.ndim(), nb = b.ndim(); + int M = (int)(a.batch_size() * a.shape[na - 2]); + int K = (int)a.shape[na - 1]; + int N = (int)b.shape[nb - 2]; + bool aligned = (N % 32 == 0) && (M % 64 == 0); + float *a_f32 = (float *)a.bytes; + float *b_f32 = (float *)b.bytes; + + if (a.dtype_size == 2) { + if (aligned && + tensor_ops_gemm_nt_f16(a.bytes, b.bytes, out.bytes, M, N, K, stream)) { + // fp16 tensor_ops (aligned) + } else { + // fp16 unaligned NT: steel_gemm_f16 (stays on compute encoder) + compute_gemm_f16(a.bytes, b.bytes, out.bytes, M, N, K, + /*trans_a=*/false, /*trans_b=*/true, + K, K, N, 1.0f, 0.0f, stream); + } + } else if (aligned && + tensor_ops_gemm_nt(a_f32, b_f32, + (float *)out.bytes, M, N, K, stream)) { + // fp32 tensor_ops (aligned) + } else if (a.dtype_size == 4 && N < 128) { + // Small unaligned N: 1-row-per-threadgroup kernel (avoids 64x64 tile waste) + small_gemm_nt_dispatch(a_f32, b_f32, + (float *)out.bytes, M, N, K, stream); + } else { + // f32 unaligned: steel_gemm NT (stays on compute encoder) + compute_gemm(a_f32, b_f32, + (float *)out.bytes, M, N, K, + /*trans_a=*/false, /*trans_b=*/true, + K, K, N, 1.0f, 0.0f, stream); + } +} + +// out(M,N) = a(...,M)^T @ b(...,N) — leading dims folded into K +void puf_mm_tn(PufTensor &a, PufTensor &b, PufTensor &out, + cudaStream_t stream) { + int na = a.ndim(), nb = b.ndim(); + int K = (int)(a.batch_size() * a.shape[na - 2]); + int M = (int)a.shape[na - 1]; + int N = (int)b.shape[nb - 1]; + + bool aligned = (M % 64 == 0) && (N % 32 == 0); + float *a_f32 = (float *)a.bytes; + float *b_f32 = (float *)b.bytes; + + if (a.dtype_size == 2) { + if (aligned && + tensor_ops_gemm_tn_f16(a.bytes, b.bytes, out.bytes, M, N, K, stream)) { + // fp16 tensor_ops TN on compute encoder + } else { + // fp16 unaligned TN: steel_gemm_f16 (stays on compute encoder) + compute_gemm_f16(a.bytes, b.bytes, out.bytes, M, N, K, + /*trans_a=*/true, /*trans_b=*/false, + M, N, N, 1.0f, 0.0f, stream); + } + } else if (a.dtype_size == 4 && K > 4096 && + ((M + 31) / 32) * ((N + 31) / 32) < 32) { + // K-split TN for small output (few tiles) + large K reduction. + compute_gemm_ksplit_tn(a_f32, b_f32, + (float *)out.bytes, M, N, K, + M, N, N, stream); + } else if (aligned && + tensor_ops_gemm_tn(a_f32, b_f32, + (float *)out.bytes, M, N, K, stream)) { + // fp32 tensor_ops TN on compute encoder + } else { + // fp32 steel_gemm TN fallback (stays on compute encoder) + compute_gemm(a_f32, b_f32, + (float *)out.bytes, M, N, K, + /*trans_a=*/true, /*trans_b=*/false, + M, N, N, 1.0f, 0.0f, stream); + } +} + +// out(...,N) = a(...,K) @ b(K,N) — leading dims folded into M +void puf_mm_nn(PufTensor &a, PufTensor &b, PufTensor &out, + cudaStream_t stream) { + int na = a.ndim(), nb = b.ndim(); + int M = (int)(a.batch_size() * a.shape[na - 2]); + int K = (int)a.shape[na - 1]; + int N = (int)b.shape[nb - 1]; + + bool aligned = (M % 64 == 0) && (N % 32 == 0); + float *a_f32 = (float *)a.bytes; + float *b_f32 = (float *)b.bytes; + + if (a.dtype_size == 2) { + if (aligned && + tensor_ops_gemm_nn_f16(a.bytes, b.bytes, out.bytes, M, N, K, stream)) { + // fp16 tensor_ops NN on compute encoder + } else { + // fp16 unaligned NN: steel_gemm_f16 (stays on compute encoder) + compute_gemm_f16(a.bytes, b.bytes, out.bytes, M, N, K, + /*trans_a=*/false, /*trans_b=*/false, + K, N, N, 1.0f, 0.0f, stream); + } + } else if (aligned && + tensor_ops_gemm_nn(a_f32, b_f32, + (float *)out.bytes, M, N, K, stream)) { + // fp32 tensor_ops NN (aligned) + } else { + // f32 unaligned: steel_gemm NN (stays on compute encoder) + compute_gemm(a_f32, b_f32, + (float *)out.bytes, M, N, K, + /*trans_a=*/false, /*trans_b=*/false, + K, N, N, 1.0f, 0.0f, stream); + } +} + +// ============================================================================ +// addmm temp buffer — lazily allocated for tensor_ops addmm decomposition. +// Only used for aligned muon NS GEMMs (512×512). Page-aligned for MTLBuffer. +// ============================================================================ + +static float *addmm_temp_buf(int count) { + int64_t needed = (int64_t)count * sizeof(float); + int64_t page = 16384; + int64_t size = (needed + page - 1) & ~(page - 1); + if (size <= g_addmm_temp_size) return (float *)g_addmm_temp_base; + + if (g_addmm_temp_base) { + auto &bufs = g_ctx.buffers; + bufs.erase(std::remove_if(bufs.begin(), bufs.end(), + [](const WrappedBuffer &wb) { return wb.base == g_addmm_temp_base; }), + bufs.end()); + free(g_addmm_temp_base); + } + g_addmm_temp_base = (char *)mtl_alloc_scratch(size); + g_addmm_temp_size = size; + return (float *)g_addmm_temp_base; +} + +// out(...,N) = beta*out + alpha * a(...,K) @ b(K,N) +// Only called from Muon Newton-Schulz (small fp32 GEMMs: 512×512). +// When tensor_ops NN is available and dimensions align, decomposes into: +// temp = A @ B (tensor_ops, compute encoder) +// out *= beta (scale, compute encoder) +// out += alpha * temp (axpy, compute encoder) +// All three ops stay on the compute encoder. +void puf_addmm_nn(PufTensor &a, PufTensor &b, PufTensor &out, float alpha, + float beta, cudaStream_t stream) { + int na = a.ndim(), nb = b.ndim(); + int M = (int)(a.batch_size() * a.shape[na - 2]); + int K = (int)a.shape[na - 1]; + int N = (int)b.shape[nb - 1]; + + const float *a_f32 = (const float *)a.bytes; + const float *b_f32 = (const float *)b.bytes; + + if ((M % 64 == 0) && (N % 32 == 0) && g_ctx.tensor_ops_gemm_nn_f32) { + // Decompose: out = beta*out + alpha*(a@b) + // Step 1: temp = a @ b via tensor_ops NN (compute encoder) + float *temp = addmm_temp_buf(M * N); + tensor_ops_gemm_nn(a_f32, b_f32, temp, M, N, K, stream); + // Metal 4: force visibility of temp writes before scale/axpy reads. + mtl_barrier(mtl_resolve_stream(stream)); + + // Step 2: out *= beta (compute encoder) + MetalStream *ms = mtl_resolve_stream(stream); + ms->compute_encoder(); + int count = M * N; + + if (beta != 1.0f) { + auto pso = mtl_pipeline("scale_f32"); + mtl_set_pso(ms, pso); + NSUInteger off_out; + bind_buf(ms, buffer_for_ptr(out.bytes, &off_out), off_out, 0); + struct { float alpha; int n; } sp = {beta, count}; + mtl_set_params(ms, sp, 1); + mtl_dispatch_1d(ms, pso, count); + // Ensure scaled out is visible before axpy accumulation. + mtl_barrier(ms); + ms->compute_encoder(); + } + + // Step 3: out += alpha * temp (compute encoder) + { + auto pso = mtl_pipeline("axpy_f32"); + mtl_set_pso(ms, pso); + NSUInteger off_out, off_temp; + bind_buf(ms, buffer_for_ptr(out.bytes, &off_out), off_out, 0); + bind_buf(ms, buffer_for_ptr(temp, &off_temp), off_temp, 1); + struct { float alpha; int n; } ap = {alpha, count}; + mtl_set_params(ms, ap, 2); + mtl_dispatch_1d(ms, pso, count); + } + ms->pending_work = true; + } else { + // Unaligned addmm: steel_gemm with alpha/beta (stays on compute encoder) + compute_gemm(a_f32, b_f32, (float *)out.bytes, M, N, K, + /*trans_a=*/false, /*trans_b=*/false, + K, N, N, alpha, beta, stream); + } +} + +// ============================================================================ +// CUDA compatibility stubs for vecenv.h +// +// vecenv.h declares extern CUDA functions (cudaMemcpy, cudaMalloc, etc.) that +// the env's binding.c calls. On CUDA, libcudart provides these. On Metal, +// we provide trivial implementations: +// - cudaHostAlloc/cudaMalloc → calloc (unified memory, no separate GPU alloc) +// - cudaMemcpy/cudaMemcpyAsync → memcpy (same physical pages) +// - cudaMemset → memset +// - cudaStream* → no-op (single stream managed by MetalStream) +// - cudaDevice* → no-op +// ============================================================================ + +extern "C" { + +int cudaHostAlloc(void **ptr, size_t size, unsigned int /*flags*/) { + *ptr = calloc(1, size); + return 0; +} + +int cudaMalloc(void **ptr, size_t size) { + *ptr = calloc(1, size); + return 0; +} + +int cudaMemcpy(void *dst, const void *src, size_t size, int /*kind*/) { + memcpy(dst, src, size); + return 0; +} + +int cudaMemcpyAsync(void *dst, const void *src, size_t size, int /*kind*/, + void * /*stream*/) { + memcpy(dst, src, size); + return 0; +} + +int cudaMemset(void *ptr, int value, size_t size) { + memset(ptr, value, size); + return 0; +} + +int cudaFree(void *ptr) { + free(ptr); + return 0; +} + +int cudaFreeHost(void *ptr) { + free(ptr); + return 0; +} + +int cudaSetDevice(int /*device*/) { return 0; } + +int cudaDeviceSynchronize(void) { + if (g_ctx.stream.enc_active) + g_ctx.stream.sync(); + return 0; +} + +int cudaStreamSynchronize(void * /*stream*/) { + // No-op on Metal. GPU work is already synced inside net_callback_wrapper + // (ensure_gpu_synced under mutex). The vecenv memcpys are also no-ops + // (unified memory). Calling sync() here would race with other buffer + // threads that hold the GPU mutex and have an active encoder. + return 0; +} + +int cudaStreamCreateWithFlags(void **stream, unsigned int /*flags*/) { + assert(stream && "cudaStreamCreateWithFlags expects a non-null stream pointer"); + *stream = mtl_create_stream(); + return 0; +} + +int cudaStreamQuery(void * /*stream*/) { return 0; } + +const char *cudaGetErrorString(int /*error*/) { return "metal-compat-stub"; } + +} // extern "C" diff --git a/src/metal_pufferlib.mm b/src/metal_pufferlib.mm new file mode 100644 index 0000000000..726796c315 --- /dev/null +++ b/src/metal_pufferlib.mm @@ -0,0 +1,1265 @@ +#import "metal_platform.h" +#include "metal_kernels.mm" +#include "cpu_inference.h" +#include "vecenv.h" + + +#include +#include +#include +#include +#include + +static thread_local cudaStream_t tl_rollout_stream = 0; +static std::mutex g_rollout_profile_mutex; + +static inline float prof_ms(uint64_t t0, uint64_t t1) { + static mach_timebase_info_data_t tb; + static std::once_flag tb_once; + std::call_once(tb_once, []() { mach_timebase_info(&tb); }); + return (float)((double)(t1 - t0) * tb.numer / tb.denom / 1e6); +} + +int obs_dtype_size(int dtype) { + switch (dtype) { + case FLOAT: + return sizeof(float); + case INT: + return sizeof(int32_t); + case DOUBLE: + return sizeof(double); + case UNSIGNED_CHAR: + case CHAR: + return sizeof(char); + default: + assert(false && "Unsupported observation dtype"); + return 0; + } +} + +template +static inline void cpu_cast_to_f32(float* dst, const T* src, int count) { + for (int i = 0; i < count; i++) { + dst[i] = (float)src[i]; + } +} + +// ============================================================================ +// Environment creation — unified memory, no GPU copy needed +// ============================================================================ + +StaticVec* create_environments(int num_buffers, int total_agents, + const std::string& env_name, Dict* vec_kwargs, Dict* env_kwargs, EnvBuf& env) { + StaticVec* vec = create_static_vec(total_agents, num_buffers, /*gpu=*/0, vec_kwargs, env_kwargs); + + int obs_size = get_obs_size(); + int num_atns = get_num_atns(); + int obs_type = get_obs_type(); + + // Unified memory: env obs/actions/rewards/terminals point directly at vecenv buffers + env.obs = {.bytes = (char*)vec->gpu_observations, .shape = {total_agents, obs_size}, .dtype_size = obs_dtype_size(obs_type)}; + env.obs_raw_dtype = obs_type; + env.actions = {.bytes = (char*)vec->gpu_actions, .shape = {total_agents, num_atns}, .dtype_size = (int)sizeof(float)}; + env.rewards = {.data = vec->gpu_rewards, .shape = {total_agents}}; + env.terminals = {.data = vec->gpu_terminals, .shape = {total_agents}}; + + return vec; +} + +// ============================================================================ +// Hyperparameters — single GPU only, no NCCL +// ============================================================================ + +typedef struct { + // Layout + int horizon; + int total_agents; + int num_buffers; + // Model architecture + int num_atns; + int hidden_size; + int num_layers; + // Learning rate + float lr; + float min_lr_ratio; + bool anneal_lr; + // Optimizer (Muon only — Adam removed) + float beta1; + float weight_decay; + // Training + int minibatch_size; + float replay_ratio; + long total_timesteps; + float max_grad_norm; + // PPO + float clip_coef; + float vf_clip_coef; + float vf_coef; + float ent_coef; + // GAE + float gamma; + float gae_lambda; + // VTrace + float vtrace_rho_clip; + float vtrace_c_clip; + // Priority + float prio_alpha; + float prio_beta0; + // Flags + bool reset_state; + bool profile; + bool overlap; // async training overlap: train on separate GPU queue + bool cpu_inference; // CPU forward pass during rollout (no GPU sync) + bool train_fp16; // fp16 activations/grads during training (rollout stays fp32) + int ns_iters; // Newton-Schulz iterations in muon optimizer (1-5, default 5) + // Single GPU (Metal has no multi-GPU, but kept for upstream compat) + int gpu_id; + // Threading + int num_threads; + // RNG seed + uint64_t seed; +} HypersT; + +// ============================================================================ +// Profiling — CPU-based timing via mach_absolute_time +// ============================================================================ + +enum ProfileIdx { + PROF_ROLLOUT = 0, + PROF_EVAL_GPU, + PROF_EVAL_ENV, + // Fine-grained rollout sub-phases + PROF_ROLLOUT_OBS_COPY, + PROF_ROLLOUT_FWD, + PROF_ROLLOUT_ACT_COPY, + // Fine-grained training sub-phases + PROF_TRAIN_PRELOOP, + PROF_TRAIN_PRIO, + PROF_TRAIN_SELECT, + PROF_TRAIN_FWD, + PROF_TRAIN_PPO, + PROF_TRAIN_BACKWARD, + PROF_TRAIN_GRAD_COPY, + PROF_TRAIN_GRAD_CLIP, + PROF_TRAIN_MUON, + PROF_TRAIN_SYNC, + NUM_PROF, +}; + +static const char* PROF_NAMES[NUM_PROF] = { + "rollout", + "eval_gpu", + "eval_env", + "rollout_obs_copy", + "rollout_fwd", + "rollout_act_copy", + "train_preloop", + "train_prio", + "train_select", + "train_fwd", + "train_ppo", + "train_backward", + "train_grad_copy", + "train_grad_clip", + "train_muon", + "train_sync", +}; + +typedef struct { + float accum[NUM_PROF]; +} ProfileT; + +// ============================================================================ +// PuffeRL state — Metal version (no CUDA graphs, NCCL, nvml, multi-stream) +// ============================================================================ + +struct PuffeRL { + Policy* policy; + PolicyWeights weights_fp32; + PolicyWeights weights_fp16; // fp16 training weights + // Double-buffered inference weights for rollout/training overlap. + // Rollout reads weights_infer (GPU compute), training writes weights_fp32 (GPU). + // After each training sync, weights_fp32 is memcpy'd to weights_infer. + PolicyWeights weights_infer; + Allocator infer_params_alloc; + bool overlap_enabled = false; + bool train_pending = false; // GPU training dispatched but not yet synced + PolicyActivations train_activations; + AllocSet alloc_fp32; + AllocSet alloc_fp16; // fp16 training: weights, activations, gradients + Allocator pufferl_alloc; + StaticVec* vec; + Muon* muon; + HypersT hypers; + bool is_continuous; + std::vector rollout_streams; + std::vector buffer_states; + std::vector sample_act_f32_buffers; + std::vector buffer_activations; + std::vector buffer_allocs; + RolloutBuf rollouts; + RolloutBuf train_rollouts; + EnvBuf env; + TrainGraph train_buf; + FloatTensor old_values_puf; + FloatTensor advantages_puf; + IntTensor act_sizes_puf; + FloatTensor losses_puf; + PPOBuffersPuf ppo_bufs_puf; + PrioBuffers prio_bufs; + FloatTensor param_fp32_puf; + PufTensor param_fp16_puf; // fp16 weight buffer (flat view) + PufTensor grad_fp16_puf; // gradient buffer (fp16 when train_fp16, else fp32) + FloatTensor grad_norm_puf; + LongTensor rng_offset_puf; + // fp16 boundary buffers: obs cast (fp32->fp16), dec_out cast (fp16->fp32), + // state (zeroed fp16 for scan initial state) + PufTensor fp16_obs_buf; + PufTensor fp32_dec_out_buf; + PufTensor fp16_state_buf; + Allocator fp16_boundary_alloc; + ProfileT profile; + int rollout_sync_count; + double rollout_sync_ms; + int train_sync_count; + double train_sync_ms; + int epoch; + long global_step; + double start_time; + double last_log_time; + long last_log_step; + uint64_t rng_seed; + // Action mask: true if obs embeds a mask in the last act_n columns. + // When false, a static all-ones buffer is used instead. + bool has_mask = false; + int env_obs_width = 0; // raw obs width from env (e.g. 1096 = features + mask) + int mask_width = 0; // total action mask width = sum of all action head sizes (e.g. 79) + FloatTensor ones_mask; // (act_n) all 1.0f, fallback mask when !has_mask + // External mask path: when has_mask, masks are split from obs at rollout time. + FloatTensor rollout_masks; + FloatTensor train_masks; + FloatTensor mb_masks; + bool cpu_inference = false; // CPU forward pass for rollout (no GPU sync) + bool train_fp16 = false; // fp16 training activations/grads + // Decoder logits + f32 actions for GPU logprob recompute (cpu_inference only). + FloatTensor rollout_logits; // (horizon, total_agents, fused_cols) + FloatTensor train_logits; // (total_agents, horizon, fused_cols) + FloatTensor rollout_actions_f32; // (horizon, total_agents, num_atns) + FloatTensor train_actions_f32; // (total_agents, horizon, num_atns) +}; + +// ============================================================================ +// Logging +// ============================================================================ + +Dict* log_environments_impl(PuffeRL& pufferl) { + Dict* out = create_dict(128); + static_vec_log(pufferl.vec, out); + return out; +} + +// ============================================================================ +// Per-buffer thread init — called once per buffer thread at creation +// ============================================================================ + +extern "C" void thread_init_metal(void* ctx, int buf) { + PuffeRL* pufferl = (PuffeRL*)ctx; + assert(buf >= 0 && buf < (int)pufferl->rollout_streams.size()); + tl_rollout_stream = pufferl->rollout_streams[buf]; + assert(tl_rollout_stream && "thread_init_metal requires per-buffer stream"); +} + +// ============================================================================ +// Rollout callback — called per buffer per horizon step +// ============================================================================ + +extern "C" void net_callback_wrapper(void* ctx, int buf, int t) { + @autoreleasepool { + PuffeRL* pufferl = (PuffeRL*)ctx; + HypersT& hypers = pufferl->hypers; + + RolloutBuf& rollouts = pufferl->rollouts; + EnvBuf& env = pufferl->env; + int block_size = pufferl->vec->total_agents / hypers.num_buffers; + int start = buf * block_size; + cudaStream_t stream = tl_rollout_stream; + assert(stream && "rollout callback requires thread-local stream"); + + uint64_t tp0 = mach_absolute_time(); + + // Copy env obs to rollout buffer, splitting features and mask. + // env writes [features | mask] at env_obs_width stride. + // rollout obs stores only features at input_size stride. + // rollout masks stores only mask at act_n stride. + PufTensor& obs_env = env.obs; + int env_obs_width = pufferl->env_obs_width; + int input_size = (int)rollouts.observations.shape[2]; + int mask_w = pufferl->mask_width; + + FloatTensor obs_dst = puf_slice(rollouts.observations, t, start, block_size); + + if (pufferl->has_mask) { + // split copy: features prefix + mask suffix, row by row + FloatTensor mask_dst = puf_slice(pufferl->rollout_masks, t, start, block_size); + assert(pufferl->env.obs_raw_dtype == FLOAT && "mask split only supports float32 obs"); + const float* src_base = (const float*)(obs_env.bytes + (int64_t)start * env_obs_width * sizeof(float)); + float* feat_base = obs_dst.data; + float* mask_base = mask_dst.data; + for (int b = 0; b < block_size; b++) { + memcpy(feat_base + b * input_size, src_base + b * env_obs_width, + input_size * sizeof(float)); + memcpy(mask_base + b * mask_w, src_base + b * env_obs_width + input_size, + mask_w * sizeof(float)); + } + } else { + // no mask: copy full obs (env_obs_width == input_size) + PufTensor obs_src = { + .bytes = obs_env.bytes + (int64_t)start * env_obs_width * obs_env.dtype_size, + .shape = {block_size, env_obs_width}, + .dtype_size = obs_env.dtype_size + }; + int count = (int)obs_src.numel(); + switch (pufferl->env.obs_raw_dtype) { + case UNSIGNED_CHAR: + cpu_cast_u8_to_f32(obs_dst.data, (const uint8_t*)obs_src.bytes, + count); + break; + case CHAR: + cpu_cast_to_f32(obs_dst.data, (const int8_t*)obs_src.bytes, count); + break; + case FLOAT: + memcpy(obs_dst.data, obs_src.bytes, obs_src.numel() * obs_src.dtype_size); + break; + case INT: + cpu_cast_to_f32(obs_dst.data, (const int32_t*)obs_src.bytes, count); + break; + case DOUBLE: + cpu_cast_to_f32(obs_dst.data, (const double*)obs_src.bytes, count); + break; + default: + assert(false && "Unsupported observation dtype"); + } + } + + // Rewards + terminals -- direct memcpy, no sync check needed + FloatTensor rew_dst = puf_slice(rollouts.rewards, t, start, block_size); + memcpy(rew_dst.data, env.rewards.data + start, block_size * sizeof(float)); + + FloatTensor term_dst = puf_slice(rollouts.terminals, t, start, block_size); + memcpy(term_dst.data, env.terminals.data + start, block_size * sizeof(float)); + + uint64_t tp1 = mach_absolute_time(); + + // Forward pass + sampling + FloatTensor act_slice = puf_slice(rollouts.actions, t, start, block_size); + FloatTensor lp_slice = puf_slice(rollouts.logprobs, t, start, block_size); + FloatTensor val_slice = puf_slice(rollouts.values, t, start, block_size); + int num_atns = (int)puf_numel(pufferl->act_sizes_puf.shape); + uint32_t* buf_rng_offset = (uint32_t*)(pufferl->rng_offset_puf.data + buf); + uint64_t buf_rng_seed = pufferl->rng_seed + buf; + + PufTensor state_puf = pufferl->buffer_states[buf]; + PolicyWeights& infer_weights = pufferl->overlap_enabled + ? pufferl->weights_infer : pufferl->weights_fp32; + Policy* p = pufferl->policy; + PolicyActivations& acts = pufferl->buffer_activations[buf]; + FloatTensor& act_f32_buf = pufferl->sample_act_f32_buffers[buf]; + // Mask pointer setup for sampling + int fused_cols = ((DecoderWeights *)infer_weights.decoder)->output_dim + 1; + const float* mask_ptr; + int mask_stride; + if (pufferl->has_mask) { + FloatTensor mask_slice = puf_slice(pufferl->rollout_masks, t, start, block_size); + mask_ptr = mask_slice.data; + mask_stride = pufferl->mask_width; + } else { + mask_ptr = pufferl->ones_mask.data; + mask_stride = 0; + } + + if (pufferl->cpu_inference) { + // CPU path: cblas_sgemm + scalar gate + CPU sampling. No GPU, no sync. + // cpu_forward_and_sample still takes PufTensor for obs_dst -- wrap FloatTensor + PufTensor obs_puf = {.bytes = (char*)obs_dst.data, .shape = {obs_dst.shape[0], obs_dst.shape[1]}, .dtype_size = (int)sizeof(float)}; + cpu_forward_and_sample( + obs_puf, state_puf, infer_weights, hypers.hidden_size, acts, + pufferl->act_sizes_puf, act_f32_buf, + lp_slice.data, val_slice.data, + mask_ptr, mask_stride, + buf_rng_seed, buf_rng_offset); + + // Store decoder logits + f32 actions for GPU logprob recompute at + // training start. CPU sampling uses IEEE expf, PPO uses GPU fast::exp. + DecoderActivations *da = (DecoderActivations *)acts.decoder; + FloatTensor logits_dst = puf_slice(pufferl->rollout_logits, t, start, block_size); + memcpy(logits_dst.data, da->out.data, block_size * fused_cols * sizeof(float)); + FloatTensor acts_f32_dst = puf_slice(pufferl->rollout_actions_f32, t, start, block_size); + memcpy(acts_f32_dst.data, act_f32_buf.data, block_size * num_atns * sizeof(float)); + + memcpy(act_slice.data, act_f32_buf.data, block_size * num_atns * sizeof(float)); + } else { + // GPU path: Metal dispatch + sync (original behavior) + PrecisionTensor obs_pt = { + .data = obs_dst.data, + .shape = {obs_dst.shape[0], obs_dst.shape[1]}, + .dtype_size = (int)sizeof(float), + }; + PrecisionTensor state_pt = { + .data = (float*)state_puf.bytes, + .shape = {state_puf.shape[0], state_puf.shape[1], state_puf.shape[2]}, + .dtype_size = state_puf.dtype_size, + }; + PrecisionTensor mingru_input = p->encoder.forward(infer_weights.encoder, acts.encoder, obs_pt, stream); + PrecisionTensor h = p->network.forward(infer_weights.network, mingru_input, state_pt, acts.network, stream); + PrecisionTensor dec_pt = p->decoder.forward(infer_weights.decoder, acts.decoder, h, stream); + + mtl_sample_logits_dispatch_to( + dec_pt, pufferl->act_sizes_puf, + act_f32_buf.data, lp_slice.data, val_slice.data, + mask_ptr, mask_stride, + buf_rng_seed, buf_rng_offset, stream); + + mtl_ensure_stream_synced(stream); + + // Stash logits + f32 actions for logprob recompute when train_fp16=1. + // Rollout uses fp32 weights; training uses fp16 → precision mismatch in + // PPO ratio unless we recompute old_logprobs in fp16 at training start. + if (pufferl->train_fp16 && pufferl->rollout_logits.data) { + DecoderActivations *da = (DecoderActivations *)acts.decoder; + FloatTensor logits_dst = puf_slice(pufferl->rollout_logits, t, start, block_size); + memcpy(logits_dst.data, da->out.data, block_size * fused_cols * sizeof(float)); + FloatTensor acts_f32_dst = puf_slice(pufferl->rollout_actions_f32, t, start, block_size); + memcpy(acts_f32_dst.data, act_f32_buf.data, block_size * num_atns * sizeof(float)); + } + + memcpy(act_slice.data, act_f32_buf.data, block_size * num_atns * sizeof(float)); + } + + // Match upstream: do not zero RNN state on terminal. + + uint64_t tp2 = mach_absolute_time(); + + /* copy float32 actions to env buffer (actions are float after upstream 4.0 migration). + use act_f32_buf (already float) instead of act_slice (which stores doubles in rollout buf). */ + int64_t act_cols = env.actions.shape[1]; + memcpy( + env.actions.bytes + start * act_cols * sizeof(float), + act_f32_buf.data, + block_size * act_cols * sizeof(float)); + + uint64_t tp3 = mach_absolute_time(); + + // Accumulate fine-grained rollout timing (callbacks run concurrently). + { + std::lock_guard lk(g_rollout_profile_mutex); + pufferl->profile.accum[PROF_ROLLOUT_OBS_COPY] += prof_ms(tp0, tp1); + pufferl->profile.accum[PROF_ROLLOUT_FWD] += prof_ms(tp1, tp2); + pufferl->profile.accum[PROF_ROLLOUT_ACT_COPY] += prof_ms(tp2, tp3); + } + } // @autoreleasepool +} + +// ============================================================================ +// Weight copy: weights_fp32 → weights_infer (for rollout/training overlap) +// ============================================================================ + +static void copy_weights_to_infer(PuffeRL& pufferl) { + int64_t nbytes = pufferl.alloc_fp32.params.total_elems * sizeof(float); + memcpy(pufferl.infer_params_alloc.mem, pufferl.alloc_fp32.params.mem, nbytes); +} + +// ============================================================================ +// Forward declaration: waits for async GPU training to complete. +static void sync_pending_train(PuffeRL& pufferl); + +// ============================================================================ +// Training loop +// ============================================================================ + +void train_impl(PuffeRL& pufferl) { + HypersT& hypers = pufferl.hypers; + uint64_t tp_preloop0 = mach_absolute_time(); + + cudaStream_t train_stream = pufferl.overlap_enabled + ? (cudaStream_t)mtl_train_stream() + : (cudaStream_t)mtl_stream(); + + // GPU training: keep all ops on the Metal encoder (GEMM, copy, zero, add). + puf_set_gpu_training(true); + + // Transpose rollouts from (horizon, segments, ...) to (segments, horizon, ...) + RolloutBuf& src = pufferl.rollouts; + RolloutBuf& rollouts = pufferl.train_rollouts; + + puf_transpose_01(rollouts.observations, src.observations, train_stream); + puf_transpose_01(rollouts.actions, src.actions, train_stream); + puf_transpose_01(rollouts.logprobs, src.logprobs, train_stream); + puf_transpose_01(rollouts.rewards, src.rewards, train_stream); + puf_transpose_01(rollouts.terminals, src.terminals, train_stream); + puf_transpose_01(rollouts.ratio, src.ratio, train_stream); + puf_transpose_01(rollouts.values, src.values, train_stream); + if (pufferl.has_mask) + puf_transpose_01(pufferl.train_masks, pufferl.rollout_masks, train_stream); + + // Metal 4: ensure all rollout transposes are visible before consumers read them. + mtl_barrier((MetalStream*)train_stream); + + // Recompute old logprobs when rollout and training use different math. + if (pufferl.cpu_inference || pufferl.train_fp16) { + puf_transpose_01(pufferl.train_logits, pufferl.rollout_logits, train_stream); + puf_transpose_01(pufferl.train_actions_f32, pufferl.rollout_actions_f32, train_stream); + mtl_barrier((MetalStream*)train_stream); + + int total_samples = hypers.total_agents * hypers.horizon; + int fused_cols = (int)pufferl.train_logits.shape[2]; + int num_atns = (int)pufferl.train_actions_f32.shape[2]; + + // Mask: embedded in obs or all-ones fallback + const float *mask_ptr; + int mask_stride; + if (pufferl.has_mask) { + mask_ptr = pufferl.train_masks.data; + mask_stride = pufferl.mask_width; + } else { + mask_ptr = pufferl.ones_mask.data; + mask_stride = 0; + } + + mtl_recompute_logprobs( + rollouts.logprobs.data, + pufferl.train_logits.data, + pufferl.train_actions_f32.data, + pufferl.act_sizes_puf.data, + mask_ptr, mask_stride, + total_samples, num_atns, fused_cols, train_stream); + } + + // Clamp rewards and fill ratio (f32 path only, no bf16) + mtl_clamp_f32(rollouts.rewards.data, -1.0f, 1.0f, + (int)puf_numel(rollouts.rewards.shape), train_stream); + mtl_fill_f32(rollouts.ratio.data, 1.0f, + (int)puf_numel(rollouts.ratio.shape), train_stream); + + // old_values = values.clone() + puf_copy(pufferl.old_values_puf, rollouts.values, train_stream); + // Metal 4 visibility boundary before minibatch loop consumes transposed rollouts. + mtl_barrier((MetalStream*)train_stream); + + int batch_size = hypers.total_agents * hypers.horizon; + int minibatch_segments = hypers.minibatch_size / hypers.horizon; + float prio_alpha = hypers.prio_alpha; + int current_epoch = pufferl.epoch; + int total_epochs = hypers.total_timesteps / batch_size; + int total_minibatches = hypers.replay_ratio * batch_size / hypers.minibatch_size; + + if (hypers.anneal_lr) { + float lr_min = hypers.min_lr_ratio * hypers.lr; + float lr = cosine_annealing(hypers.lr, lr_min, current_epoch, total_epochs); + float* lr_ptr = pufferl.muon->lr_ptr; + *lr_ptr = lr; + } + + float anneal_beta = hypers.prio_beta0 + (1.0f - hypers.prio_beta0) + * prio_alpha * (float)current_epoch / (float)total_epochs; + + uint64_t tp_preloop1 = mach_absolute_time(); + pufferl.profile.accum[PROF_TRAIN_PRELOOP] += prof_ms(tp_preloop0, tp_preloop1); + + // Single minibatch step shared by overlap and non-overlap. + auto run_minibatch = [&](cudaStream_t s, uint32_t* rng_offset, bool gpu_profile) { + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp0 = mach_absolute_time(); + + puf_zero(&pufferl.advantages_puf, s); + puff_advantage(rollouts.values, rollouts.rewards, rollouts.terminals, + rollouts.ratio, pufferl.advantages_puf, hypers.gamma, hypers.gae_lambda, + hypers.vtrace_rho_clip, hypers.vtrace_c_clip, s); + + prio_precompute(pufferl.advantages_puf, prio_alpha, pufferl.prio_bufs, s); + prio_sample(minibatch_segments, hypers.total_agents, anneal_beta, + pufferl.prio_bufs, pufferl.rng_seed, rng_offset, s); + mtl_barrier((MetalStream*)s); + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp2 = mach_absolute_time(); + + if (hypers.reset_state) puf_zero(&pufferl.train_buf.mb_state, s); + { + RolloutBuf sel_src = rollouts; + sel_src.values = pufferl.old_values_puf; + mtl_select_copy(sel_src, pufferl.train_buf, + (const int64_t*)pufferl.prio_bufs.idx.data, + pufferl.advantages_puf.data, + pufferl.prio_bufs.mb_prio.data, + minibatch_segments, + pufferl.fp16_obs_buf.bytes, s); + // gather masks from train_masks into mb_masks using same priority indices. + // reuses index_copy_kernel as a gather: dst[i] = src[idx[i]]. + if (pufferl.has_mask) { + MetalStream *ms2 = mtl_resolve_stream(s); + ms2->compute_encoder(); + auto pso = mtl_pipeline("index_gather_kernel"); + mtl_set_pso(ms2, pso); + int mw = pufferl.mask_width; + int mask_seg_bytes = hypers.horizon * mw * (int)sizeof(float); + mtl_set_ptr(ms2, pufferl.mb_masks.data, 0); + mtl_set_ptr(ms2, (void*)pufferl.prio_bufs.idx.data, 1); + mtl_set_ptr(ms2, pufferl.train_masks.data, 2); + struct { int num_idx; int row_bytes; } mp = {minibatch_segments, mask_seg_bytes}; + mtl_set_params(ms2, mp, 3); + mtl_dispatch_groups(ms2, pso, (minibatch_segments + 255) / 256, 256); + } + } + mtl_barrier((MetalStream*)s); + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp3 = mach_absolute_time(); + + PolicyWeights& train_weights = pufferl.train_fp16 ? pufferl.weights_fp16 : pufferl.weights_fp32; + PrecisionTensor obs_pt; + PrecisionTensor state_pt; + if (pufferl.train_fp16) { + obs_pt = { + .data = (float*)pufferl.fp16_obs_buf.bytes, + .shape = {pufferl.fp16_obs_buf.shape[0], pufferl.fp16_obs_buf.shape[1], pufferl.fp16_obs_buf.shape[2]}, + .dtype_size = pufferl.fp16_obs_buf.dtype_size, + }; + state_pt = { + .data = (float*)pufferl.fp16_state_buf.bytes, + .shape = {pufferl.fp16_state_buf.shape[0], pufferl.fp16_state_buf.shape[1], pufferl.fp16_state_buf.shape[2], pufferl.fp16_state_buf.shape[3]}, + .dtype_size = pufferl.fp16_state_buf.dtype_size, + }; + } else { + FloatTensor &mo = pufferl.train_buf.mb_obs; + obs_pt = {.data = mo.data, .shape = {mo.shape[0], mo.shape[1], mo.shape[2]}, .dtype_size = (int)sizeof(float)}; + FloatTensor &ms = pufferl.train_buf.mb_state; + state_pt = {.data = ms.data, .shape = {ms.shape[0], ms.shape[1], ms.shape[2], ms.shape[3]}, .dtype_size = (int)sizeof(float)}; + } + if (pufferl.train_fp16 && hypers.reset_state) puf_zero(&pufferl.fp16_state_buf, s); + + PrecisionTensor dec_pt = policy_forward_train(pufferl.policy, train_weights, + pufferl.train_activations, obs_pt, state_pt, s); + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp4 = mach_absolute_time(); + + PufTensor dec_puf = to_puf(dec_pt); + if (dec_puf.dtype_size == 2) { + mtl_cast_f16_to_f32((float*)pufferl.fp32_dec_out_buf.bytes, + dec_puf.bytes, + (int)dec_puf.numel(), s); + mtl_barrier((MetalStream*)s); + dec_puf = pufferl.fp32_dec_out_buf; + } + + PrecisionTensor p_logstd = {.dtype_size = (int)sizeof(float)}; + if (pufferl.is_continuous) { + p_logstd = ((DecoderWeights*)pufferl.weights_fp32.decoder)->logstd; + } + + { + const float* ppo_mask_ptr; + int ppo_mask_stride; + if (pufferl.has_mask) { + ppo_mask_ptr = pufferl.mb_masks.data; + ppo_mask_stride = pufferl.mask_width; + } else { + ppo_mask_ptr = pufferl.ones_mask.data; + ppo_mask_stride = 0; + } + // When PER is active, pass full-batch advantages for unbiased var/mean. + const FloatTensor *full_adv = (prio_alpha > 0.0f) ? &pufferl.advantages_puf : nullptr; + PufTensor logstd_puf = to_puf(p_logstd); + ppo_loss_fwd_bwd(dec_puf, logstd_puf, pufferl.train_buf, + pufferl.act_sizes_puf, pufferl.losses_puf, + hypers.clip_coef, hypers.vf_clip_coef, hypers.vf_coef, hypers.ent_coef, + pufferl.ppo_bufs_puf, pufferl.is_continuous, + ppo_mask_ptr, ppo_mask_stride, full_adv, s); + } + mtl_barrier((MetalStream*)s); + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp5 = mach_absolute_time(); + + // policy_backward now takes FloatTensor grads directly (matching upstream) + FloatTensor grad_logstd_ft = pufferl.is_continuous ? pufferl.ppo_bufs_puf.grad_logstd : FloatTensor(); + policy_backward(pufferl.policy, train_weights, pufferl.train_activations, + pufferl.ppo_bufs_puf.grad_logits, grad_logstd_ft, + pufferl.ppo_bufs_puf.grad_values, s); + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp6 = mach_absolute_time(); + + FloatTensor& gc = pufferl.muon->gc_puf; + mtl_barrier((MetalStream*)s); // policy_backward writes grads, copy/cast reads them + if (pufferl.grad_fp16_puf.dtype_size == 2) { + mtl_cast_f16_to_f32(gc.data, + pufferl.grad_fp16_puf.bytes, + (int)pufferl.grad_fp16_puf.numel(), s); + } else { + // grad_fp16_puf is fp32 when !train_fp16, copy to gc + mtl_copy_f32(gc.data, (const float*)pufferl.grad_fp16_puf.bytes, + (int)pufferl.grad_fp16_puf.numel(), s); + } + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp7 = mach_absolute_time(); + + mtl_barrier((MetalStream*)s); + { + float* scratch = pufferl.grad_norm_puf.data; + clip_grad_norm_f32(gc, scratch, hypers.max_grad_norm, 1e-6f, s); + } + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp8 = mach_absolute_time(); + + mtl_barrier((MetalStream*)s); + muon_step(pufferl.muon, s); + + + if (pufferl.train_fp16) { + mtl_cast_f32_to_f16(pufferl.param_fp16_puf.bytes, + (const float*)pufferl.alloc_fp32.params.mem, + (int)pufferl.alloc_fp32.params.total_elems, s); + } + + mtl_barrier((MetalStream*)s); + + // Scatter mb_ratio and mb_newvalue back into rollout buffers so + // subsequent minibatches see updated importance weights and values. + // Matches CUDA upstream (pufferlib.cu:1416-1428). + mtl_scatter_ppo_outputs(pufferl.train_buf, rollouts, + (const int64_t*)pufferl.prio_bufs.idx.data, s); + mtl_barrier((MetalStream*)s); + + if (gpu_profile) mtl_ensure_stream_synced(s); + uint64_t tp9 = mach_absolute_time(); + + pufferl.profile.accum[PROF_TRAIN_PRIO] += prof_ms(tp0, tp2); + pufferl.profile.accum[PROF_TRAIN_SELECT] += prof_ms(tp2, tp3); + pufferl.profile.accum[PROF_TRAIN_FWD] += prof_ms(tp3, tp4); + pufferl.profile.accum[PROF_TRAIN_PPO] += prof_ms(tp4, tp5); + pufferl.profile.accum[PROF_TRAIN_BACKWARD] += prof_ms(tp5, tp6); + pufferl.profile.accum[PROF_TRAIN_GRAD_COPY] += prof_ms(tp6, tp7); + pufferl.profile.accum[PROF_TRAIN_GRAD_CLIP] += prof_ms(tp7, tp8); + pufferl.profile.accum[PROF_TRAIN_MUON] += prof_ms(tp8, tp9); + }; + + uint32_t* train_rng_offset = (uint32_t*)(pufferl.rng_offset_puf.data + hypers.num_buffers); + + puf_set_gpu_training(false); + + if (pufferl.overlap_enabled) { + // Overlap: dispatch all minibatches on train_stream (separate Metal command queue). + // GPU executes async during next rollout. 1-iteration policy lag — V-trace compensates. + cudaStream_t ts = train_stream; + + if (pufferl.train_pending) { + sync_pending_train(pufferl); + } + + // Copy trained weights to inference buffer so the NEXT rollout sees them. + { + int64_t total_elems = pufferl.alloc_fp32.params.total_elems; + PufTensor fp32_all = {.bytes = (char*)pufferl.alloc_fp32.params.mem, + .shape = {total_elems}, .dtype_size = sizeof(float)}; + PufTensor infer_all = {.bytes = (char*)pufferl.infer_params_alloc.mem, + .shape = {total_elems}, .dtype_size = sizeof(float)}; + puf_copy(infer_all, fp32_all, ts); + mtl_barrier((MetalStream*)ts); + } + + puf_set_gpu_training(true); + MetalStream* mts = (MetalStream*)ts; + for (int mb = 0; mb < total_minibatches; ++mb) { + run_minibatch(ts, train_rng_offset, false); + // Commit current command buffer when ring is >75% full to prevent + // overflow on high replay_ratio configs. Metal queue serial execution + // guarantees the GPU finishes reading ring data before we overwrite it. + if (mb + 1 < total_minibatches && + mts->const_ring_offset > MTL_CONST_RING_SIZE * 3 / 4) { + mts->commit_chunk(); + } + } + puf_set_gpu_training(false); + + ((MetalStream*)train_stream)->flush(); + pufferl.train_pending = true; + pufferl.epoch += 1; + return; + } + + // Non-overlap: run minibatch loop synchronously. + bool gpu_profile = hypers.profile; + puf_set_gpu_training(true); + MetalStream* mts_sync = (MetalStream*)train_stream; + for (int mb = 0; mb < total_minibatches; ++mb) { + run_minibatch(train_stream, train_rng_offset, gpu_profile); + if (mb + 1 < total_minibatches && + mts_sync->const_ring_offset > MTL_CONST_RING_SIZE * 3 / 4) { + mts_sync->commit_chunk(); + } + } + + pufferl.epoch += 1; + + uint64_t tp_sync0 = mach_absolute_time(); + puf_set_gpu_training(false); + mtl_ensure_stream_synced(train_stream); + uint64_t tp_sync1 = mach_absolute_time(); + pufferl.profile.accum[PROF_TRAIN_SYNC] += prof_ms(tp_sync0, tp_sync1); + +} + +// Wait for async GPU training to complete, then snapshot weights for inference. +// Called at the end of rollouts() before the next iteration needs updated weights. +static void sync_pending_train(PuffeRL& pufferl) { + if (!pufferl.train_pending) return; + // Wait for async training on train_stream (separate queue) to complete. + // After this, weights_infer and fused weight are up-to-date. + MetalStream* ts = (MetalStream*)mtl_train_stream(); + ts->wait_completed(); + pufferl.train_pending = false; +} + +// ============================================================================ +// Initialization +// ============================================================================ + +std::unique_ptr create_pufferl_impl(HypersT& hypers, + const std::string& env_name, Dict* vec_kwargs, Dict* env_kwargs) { + auto pufferl = std::make_unique(); + pufferl->hypers = hypers; + + mtl_init(); + + pufferl->rng_seed = hypers.seed; + + StaticVec* vec = create_environments(hypers.num_buffers, hypers.total_agents, + env_name, vec_kwargs, env_kwargs, pufferl->env); + pufferl->vec = vec; + + int num_action_heads = pufferl->env.actions.shape[1]; + int* raw_act_sizes = get_act_sizes(); + int act_n = 0; + for (int i = 0; i < num_action_heads; i++) { + act_n += raw_act_sizes[i]; + } + + // CPU-based profiling (no CUDA events) + memset(pufferl->profile.accum, 0, sizeof(pufferl->profile.accum)); + + // Determine action space type + int num_continuous = 0; + int num_discrete = 0; + for (int i = 0; i < num_action_heads; i++) { + if (raw_act_sizes[i] == 1) { + num_continuous++; + } else { + num_discrete++; + } + } + if (num_continuous > 0 && num_discrete > 0) { + assert(false && "Mixed continuous/discrete action spaces not supported"); + } + pufferl->is_continuous = (num_continuous > 0); + assert(!(hypers.train_fp16 && pufferl->is_continuous) && + "train_fp16 currently supports discrete action spaces only"); + + int env_obs_width = pufferl->env.obs.shape[1]; + int hidden_size = hypers.hidden_size; + int num_layers = hypers.num_layers; + + // Action mask: env_config "mask_in_obs" > 0 means mask is embedded in obs. + { + DictItem* mask_entry = dict_get_unsafe(env_kwargs, "mask_in_obs"); + pufferl->has_mask = (mask_entry && mask_entry->value > 0.0f); + } + pufferl->env_obs_width = env_obs_width; + pufferl->mask_width = act_n; // total mask width = sum of action head sizes + // when mask is embedded, split it: encoder sees only the feature prefix + int input_size = pufferl->has_mask ? (env_obs_width - act_n) : env_obs_width; + + bool is_continuous = pufferl->is_continuous; + int decoder_output_size = is_continuous ? num_action_heads : act_n; + + int minibatch_segments = hypers.minibatch_size / hypers.horizon; + int inf_batch = vec->total_agents / hypers.num_buffers; + + // ======================================================================== + // fp32 master weights (for optimizer) + // ======================================================================== + + int esz_fp32 = sizeof(float); + pufferl->alloc_fp32.esz = esz_fp32; + Allocator& fp32_params = pufferl->alloc_fp32.params; + + Encoder encoder = { + .forward = encoder_forward, + .backward = encoder_backward, + .init_weights = encoder_init_weights, + .reg_params = encoder_reg_params, + .reg_train = encoder_reg_train, + .reg_rollout = encoder_reg_rollout, + }; + Decoder decoder = { + .forward = decoder_forward, + .backward = decoder_backward, + .init_weights = decoder_init_weights, + .reg_params = decoder_reg_params, + .reg_train = decoder_reg_train, + .reg_rollout = decoder_reg_rollout, + }; + Network network = { + .forward = mingru_forward, + .forward_train = mingru_forward_train, + .backward = mingru_backward, + .init_weights = mingru_init_weights, + .reg_params = mingru_reg_params, + .reg_train = mingru_reg_train, + .reg_rollout = mingru_reg_rollout, + }; + + Policy* policy = new Policy{ + .encoder = encoder, .decoder = decoder, .network = network, + .input_dim = input_size, .hidden_dim = hidden_size, .output_dim = decoder_output_size, + .num_atns = act_n, + }; + pufferl->policy = policy; + + // fp32 master weights + auto new_weights = [&]() -> PolicyWeights { + PolicyWeights w; + w.encoder = new EncoderWeights{.in_dim = input_size, .out_dim = hidden_size}; + w.decoder = new DecoderWeights{.hidden_dim = hidden_size, .output_dim = decoder_output_size, .continuous = is_continuous}; + w.network = new MinGRUWeights{.hidden = hidden_size, .num_layers = num_layers, .horizon = hypers.horizon}; + ((MinGRUWeights*)w.network)->weights.resize(num_layers); + return w; + }; + + pufferl->weights_fp32 = new_weights(); + PolicyWeights& wfp32 = pufferl->weights_fp32; + encoder.reg_params(wfp32.encoder, &fp32_params, esz_fp32); + decoder.reg_params(wfp32.decoder, &fp32_params, esz_fp32); + network.reg_params(wfp32.network, &fp32_params, esz_fp32); + + pufferl->alloc_fp32.create(); + + // Wrap fp32 params allocator for Metal GPU access + mtl_wrap_allocator(&fp32_params); + + pufferl->param_fp32_puf = {.data = (float*)fp32_params.mem, .shape = {fp32_params.total_elems}}; + + // Init weights on fp32 master + { + cudaStream_t default_stream = (cudaStream_t)mtl_stream(); + uint64_t init_seed = hypers.seed; + encoder.init_weights(wfp32.encoder, &init_seed, default_stream); + decoder.init_weights(wfp32.decoder, &init_seed, default_stream); + network.init_weights(wfp32.network, &init_seed, default_stream); + mtl_ensure_stream_synced(default_stream); + } + + // Fused encoder+layer0 is disabled. + // The null guard in mingru_forward falls back to encoder.forward(). + + // ======================================================================== + // Double-buffered inference weights (for rollout/training overlap) + // Same shapes as weights_fp32, separate allocator for isolation. + // ======================================================================== + + { + Allocator& infer_alloc = pufferl->infer_params_alloc; + pufferl->weights_infer = new_weights(); + PolicyWeights& wi = pufferl->weights_infer; + encoder.reg_params(wi.encoder, &infer_alloc, esz_fp32); + decoder.reg_params(wi.decoder, &infer_alloc, esz_fp32); + network.reg_params(wi.network, &infer_alloc, esz_fp32); + infer_alloc.create(); + mtl_wrap_allocator(&infer_alloc); + // Initial copy: weights_fp32 → weights_infer + copy_weights_to_infer(*pufferl); + // Fused encoder+layer0 disabled (nonlinear encoder) + } + pufferl->overlap_enabled = hypers.overlap; + pufferl->cpu_inference = hypers.cpu_inference; + pufferl->train_fp16 = hypers.train_fp16; + + // fp16 training weights, activations, and gradients. + + int B_TT = minibatch_segments * hypers.horizon; + int esz_fp16 = 2; + + // fp16 training weights (separate allocation from fp32 master) + pufferl->alloc_fp16.esz = esz_fp16; + Allocator& fp16_params = pufferl->alloc_fp16.params; + Allocator& acts = pufferl->alloc_fp16.acts; + Allocator& grads = pufferl->alloc_fp16.grads; + + pufferl->weights_fp16 = new_weights(); + PolicyWeights& wfp16 = pufferl->weights_fp16; + + encoder.reg_params(wfp16.encoder, &fp16_params, esz_fp16); + decoder.reg_params(wfp16.decoder, &fp16_params, esz_fp16); + network.reg_params(wfp16.network, &fp16_params, esz_fp16); + + // Register train activations/grads. + // train_fp16: activations/grads use fp16 (esz_fp16=2), enabling fp16 GEMM paths. + // Otherwise: activations/grads stay fp32 (PRECISION_SIZE=4) even in fp16 allocator. + int train_precision = pufferl->train_fp16 ? esz_fp16 : PRECISION_SIZE; + PolicyActivations& tb = pufferl->train_activations; + tb.encoder = new EncoderActivations{}; + tb.decoder = new DecoderActivations{}; + tb.network = new MinGRUActivations{}; + encoder.reg_train(wfp16.encoder, tb.encoder, &acts, &grads, B_TT, train_precision); + decoder.reg_train(wfp16.decoder, tb.decoder, &acts, &grads, B_TT, train_precision); + network.reg_train(wfp16.network, tb.network, &acts, &grads, B_TT, train_precision); + + pufferl->alloc_fp16.create(); + + // Wrap fp16 allocators for Metal GPU access + mtl_wrap_allocator(&fp16_params); + mtl_wrap_allocator(&acts); + mtl_wrap_allocator(&grads); + + pufferl->param_fp16_puf = {.bytes = (char*)fp16_params.mem, .shape = {fp16_params.total_elems}, .dtype_size = esz_fp16}; + // When train_fp16: grads are fp16, need cast to fp32 before muon. + // When fp32: grads are fp32, just copy (no cast needed). + int grad_dtype = pufferl->train_fp16 ? esz_fp16 : esz_fp32; + pufferl->grad_fp16_puf = {.bytes = (char*)grads.mem, .shape = {grads.total_elems}, .dtype_size = grad_dtype}; + + // Cast fp32 master weights → fp16 training weights + { + cudaStream_t s = (cudaStream_t)mtl_stream(); + mtl_cast_f32_to_f16(pufferl->param_fp16_puf.bytes, + (const float*)fp32_params.mem, + (int)fp32_params.total_elems, s); + mtl_ensure_stream_synced(s); + } + + // Boundary buffers: fp16 obs (encoder input), fp32 dec_out (PPO input), + // fp16 state (scan initial state — mb_state is fp32 but scan reads half*) + { + int dec_fused = decoder_output_size + 1; + pufferl->fp16_obs_buf = {.shape = {minibatch_segments, hypers.horizon, input_size}, .dtype_size = esz_fp16}; + pufferl->fp32_dec_out_buf = {.shape = {minibatch_segments, hypers.horizon, dec_fused}, .dtype_size = esz_fp32}; + pufferl->fp16_state_buf = {.shape = {num_layers, minibatch_segments, 1, hidden_size}, .dtype_size = esz_fp16}; + alloc_register(&pufferl->fp16_boundary_alloc, &pufferl->fp16_obs_buf); + alloc_register(&pufferl->fp16_boundary_alloc, &pufferl->fp32_dec_out_buf); + alloc_register(&pufferl->fp16_boundary_alloc, &pufferl->fp16_state_buf); + pufferl->fp16_boundary_alloc.create(); + mtl_wrap_allocator(&pufferl->fp16_boundary_alloc); + } + + // ======================================================================== + // Optimizer (Muon) — operates on fp32 master weights + // ======================================================================== + + float lr = hypers.lr; + float beta1 = hypers.beta1; + pufferl->muon = new Muon{}; + int horizon = hypers.horizon; + int total_agents = vec->total_agents; + int batch = total_agents / hypers.num_buffers; + int num_buffers = hypers.num_buffers; + + // ======================================================================== + // Register all init-to-close buffers into pufferl_alloc, then create once + // ======================================================================== + Allocator& alloc = pufferl->pufferl_alloc; + + int p = PRECISION_SIZE; + pufferl->rng_offset_puf = {.shape = {num_buffers + 1}}; + pufferl->act_sizes_puf = {.shape = {num_action_heads}}; + pufferl->losses_puf = {.shape = {NUM_LOSSES}}; + pufferl->grad_norm_puf = {.shape = {1}}; + alloc_register(&alloc, &pufferl->rng_offset_puf); + alloc_register(&alloc, &pufferl->act_sizes_puf); + alloc_register(&alloc, &pufferl->losses_puf); + alloc_register(&alloc, &pufferl->grad_norm_puf); + + // Per-buffer RNN states + pufferl->buffer_states.resize(num_buffers); + pufferl->sample_act_f32_buffers.resize(num_buffers); + for (int i = 0; i < num_buffers; i++) { + pufferl->buffer_states[i] = {.shape = {num_layers, batch, hidden_size}, .dtype_size = p}; + alloc_register(&alloc, &pufferl->buffer_states[i]); + pufferl->sample_act_f32_buffers[i] = {.shape = {batch, num_action_heads}}; + alloc_register(&alloc, &pufferl->sample_act_f32_buffers[i]); + } + + // Rollout buffers (horizon, total_agents, ...) + register_rollout_buffers(pufferl->rollouts, alloc, horizon, total_agents, input_size, num_action_heads); + + // Train graph buffers + register_train_buffers(pufferl->train_buf, alloc, minibatch_segments, horizon, input_size, + hidden_size, num_action_heads, num_layers); + + // Pre-allocated transposed rollouts for train_impl (total_agents, horizon, ...) + register_rollout_buffers(pufferl->train_rollouts, alloc, total_agents, horizon, input_size, num_action_heads); + + // Pre-allocated train temporaries + pufferl->old_values_puf = {.shape = {total_agents, horizon}}; + pufferl->advantages_puf = {.shape = {total_agents, horizon}}; + alloc_register(&alloc, &pufferl->old_values_puf); + alloc_register(&alloc, &pufferl->advantages_puf); + + // PPO loss buffers + register_ppo_buffers(pufferl->ppo_bufs_puf, alloc, minibatch_segments, hypers.horizon, decoder_output_size, is_continuous); + + // Priority replay buffers + register_prio_buffers(pufferl->prio_bufs, alloc, hypers.total_agents, minibatch_segments); + + // Mask buffers: when has_mask, masks are split from obs at rollout time. + if (pufferl->has_mask) { + pufferl->rollout_masks = {.shape = {horizon, total_agents, act_n}}; + pufferl->train_masks = {.shape = {total_agents, horizon, act_n}}; + pufferl->mb_masks = {.shape = {minibatch_segments, hypers.horizon, act_n}}; + alloc_register(&alloc, &pufferl->rollout_masks); + alloc_register(&alloc, &pufferl->train_masks); + alloc_register(&alloc, &pufferl->mb_masks); + } else { + pufferl->ones_mask = {.shape = {act_n}}; + alloc_register(&alloc, &pufferl->ones_mask); + } + + // Decoder logits + f32 actions for logprob recompute at training start. + if (pufferl->cpu_inference || hypers.train_fp16) { + int fused = decoder_output_size + 1; + int na = num_action_heads; + pufferl->rollout_logits = {.shape = {horizon, total_agents, fused}}; + pufferl->train_logits = {.shape = {total_agents, horizon, fused}}; + pufferl->rollout_actions_f32 = {.shape = {horizon, total_agents, na}}; + pufferl->train_actions_f32 = {.shape = {total_agents, horizon, na}}; + alloc_register(&alloc, &pufferl->rollout_logits); + alloc_register(&alloc, &pufferl->train_logits); + alloc_register(&alloc, &pufferl->rollout_actions_f32); + alloc_register(&alloc, &pufferl->train_actions_f32); + } + + // Optimizer init (register buffers with shared allocator) + muon_init(pufferl->muon, &fp32_params, + pufferl->param_fp32_puf, lr, beta1, (double)hypers.weight_decay, + hypers.ns_iters, alloc); + // Single allocation for all registered buffers + alloc.create(); + + // Wrap pufferl_alloc for Metal GPU access + mtl_wrap_allocator(&alloc); + + // Post-create initialization: unified memory, write directly + memset(pufferl->rng_offset_puf.data, 0, (num_buffers + 1) * sizeof(long)); + memcpy(pufferl->act_sizes_puf.data, raw_act_sizes, num_action_heads * sizeof(int32_t)); + memset(pufferl->losses_puf.data, 0, NUM_LOSSES * sizeof(float)); + + // Fill all-ones mask (after alloc.create + mtl_wrap) + if (!pufferl->has_mask) { + float* ones = pufferl->ones_mask.data; + for (int i = 0; i < act_n; i++) ones[i] = 1.0f; + } + + + // muon_post_create: write lr and zero momentum (unified memory) + pufferl->muon->lr_ptr = pufferl->muon->lr_puf.data; + pufferl->muon->lr_derived_ptr = pufferl->muon->lr_derived_puf.data; + if (pufferl->muon->ns_norm_puf.data) + pufferl->muon->ns.norm_ptr = pufferl->muon->ns_norm_puf.data; + *pufferl->muon->lr_ptr = pufferl->muon->lr_val_init; + memset(pufferl->muon->lr_derived_ptr, 0, 2 * sizeof(float)); + memset(pufferl->muon->mb_puf.data, 0, puf_numel(pufferl->muon->mb_puf.shape) * sizeof(float)); + + // Per-buffer inference activations (separate allocators) + pufferl->buffer_activations.resize(num_buffers); + pufferl->buffer_allocs.resize(num_buffers); + for (int i = 0; i < num_buffers; i++) { + PolicyActivations& rbuf = pufferl->buffer_activations[i]; + Allocator& ralloc = pufferl->buffer_allocs[i]; + rbuf.encoder = new EncoderActivations{}; + rbuf.decoder = new DecoderActivations{}; + rbuf.network = new MinGRUActivations{}; + // Rollout uses fp32 — register with fp32 weights (dimensions only, dtype from PRECISION_SIZE) + encoder.reg_rollout(pufferl->weights_fp32.encoder, rbuf.encoder, &ralloc, inf_batch); + decoder.reg_rollout(pufferl->weights_fp32.decoder, rbuf.decoder, &ralloc, inf_batch); + network.reg_rollout(pufferl->weights_fp32.network, rbuf.network, &ralloc, inf_batch); + ralloc.create(); + // Wrap each per-buffer allocator for Metal GPU access + mtl_wrap_allocator(&ralloc); + } + + // No CUDA graph warmup on Metal (cudagraphs always -1) + + // Create per-buffer Metal streams for rollout callback workers. + pufferl->rollout_streams.resize(num_buffers); + for (int i = 0; i < num_buffers; i++) { + cudaStream_t s = (cudaStream_t)mtl_create_stream(); + pufferl->rollout_streams[i] = s; + vec->streams[i] = s; + } + + // Create threads for vecenv — thread_init binds per-thread rollout stream + create_static_threads(vec, hypers.num_threads, horizon, pufferl.get(), + net_callback_wrapper, thread_init_metal); + static_vec_reset(vec); + + pufferl->epoch = 0; + pufferl->global_step = 0; + struct timeval tv; + gettimeofday(&tv, NULL); + double now = tv.tv_sec + tv.tv_usec * 1e-6; + pufferl->start_time = now; + pufferl->last_log_time = now; + pufferl->last_log_step = 0; + + return pufferl; +} + +// ============================================================================ +// Cleanup +// ============================================================================ + +void close_impl(PuffeRL& pufferl) { + sync_pending_train(pufferl); + mtl_ensure_stream_synced((cudaStream_t)mtl_stream()); + + delete pufferl.muon; + + auto delete_weights = [](PolicyWeights& w) { + delete (EncoderWeights*)w.encoder; + delete (DecoderWeights*)w.decoder; + delete (MinGRUWeights*)w.network; + }; + delete (EncoderActivations*)pufferl.train_activations.encoder; + delete (DecoderActivations*)pufferl.train_activations.decoder; + delete (MinGRUActivations*)pufferl.train_activations.network; + delete_weights(pufferl.weights_fp32); + delete_weights(pufferl.weights_fp16); + delete_weights(pufferl.weights_infer); + for (auto& rbuf : pufferl.buffer_activations) { + delete (EncoderActivations*)rbuf.encoder; + delete (DecoderActivations*)rbuf.decoder; + delete (MinGRUActivations*)rbuf.network; + } + delete pufferl.policy; + + static_vec_close(pufferl.vec); + + for (cudaStream_t s : pufferl.rollout_streams) { + mtl_destroy_stream((void*)s); + } + pufferl.rollout_streams.clear(); + + // Release MTLBuffers BEFORE freeing the underlying memory they reference. + // MTLBuffers created with newBufferWithBytesNoCopy need their backing pages + // still mapped when ARC releases them (Metal unmaps the GPU address space). + mtl_destroy(); + + pufferl.alloc_fp32.destroy(); + pufferl.alloc_fp16.destroy(); + pufferl.fp16_boundary_alloc.destroy(); + pufferl.infer_params_alloc.destroy(); + pufferl.pufferl_alloc.destroy(); + for (auto& a : pufferl.buffer_allocs) { + a.destroy(); + } +} diff --git a/src/metal_shader_src.h b/src/metal_shader_src.h new file mode 100644 index 0000000000..561c60ba17 --- /dev/null +++ b/src/metal_shader_src.h @@ -0,0 +1,2523 @@ +#ifndef PUFFERLIB_METAL_SHADER_SRC_H +#define PUFFERLIB_METAL_SHADER_SRC_H + +static const char *get_all_metal_shader_source() { + return R"METAL( +#include +#include +#include +#include +#include +using namespace metal; + +inline float sigmoid_f(float x) { + float z = exp(-abs(x)); + return x >= 0.0f ? 1.0f / (1.0f + z) : z / (1.0f + z); +} + +inline float sigmoid_backward_f(float x, float grad_output) { + float sig = sigmoid_f(x); + return grad_output * sig * (1.0f - sig); +} + +inline float fast_tanh_f(float x) { + float v1 = clamp(x, -9.0f, 9.0f); + float v2 = v1 * v1; + // Horner polynomial (matches PyTorch implementation) + float p = v2 * (-2.76076847742355e-16f) + 2.00018790482477e-13f; + p = v2 * p + (-8.60467152213735e-11f); + p = v2 * p + 5.12229709037114e-08f; + p = v2 * p + 1.48572235717979e-05f; + p = v2 * p + 6.37261928875436e-04f; + p = v2 * p + 4.89352455891786e-03f; + p = v1 * p; + float q = v2 * 1.19825839466702e-06f + 1.18534705686654e-04f; + q = v2 * q + 2.26843463243900e-03f; + q = v2 * q + 4.89352518554385e-03f; + return p / q; +} + +inline float fast_sigmoid_f(float x) { + float y = fast_tanh_f(x * 0.5f); + return clamp((y + 1.0f) * 0.5f, 0.0f, 1.0f); +} + +inline float tilde_relu_fwd(float x) { + return x >= 0.0f ? x + 0.5f : fast_sigmoid_f(x); +} + +inline float tilde_relu_bwd(float x, float grad) { + if (x >= 0.0f) return grad; + float sig = fast_sigmoid_f(x); + return grad * sig * (1.0f - sig); +} + +inline float lerp_f(float a, float b, float w) { + float diff = b - a; + return abs(w) < 0.5f ? a + w * diff : b - diff * (1.0f - w); +} + +// MSL has no built-in log1p. Goldberg's trick: compensates for rounding in 1+x. +inline float log1p_f(float x) { + float u = 1.0f + x; + return (u == 1.0f) ? x : log(u) * x / (u - 1.0f); +} + +constant float SOFTPLUS_BETA = 1.0f; +constant float SOFTPLUS_THRESHOLD = 20.0f; + +inline float softplus_fwd(float x) { + float xs = x * SOFTPLUS_BETA; + return xs > SOFTPLUS_THRESHOLD ? x : log1p_f(exp(xs)) / SOFTPLUS_BETA; +} + +inline float softplus_bwd(float grad_output, float x) { + float beta_x = SOFTPLUS_BETA * x; + if (beta_x > SOFTPLUS_THRESHOLD) return grad_output; + float exp_beta_x = exp(beta_x); + return grad_output * (exp_beta_x / (1.0f + exp_beta_x)); +} + +inline void log_coeffs_and_values_fwd(float gate, float hidden, + thread float& log_coeff, thread float& log_value) { + float abs_gate = abs(gate); + float sp_neg = log1p_f(exp(-abs_gate)); + float softplus_gate, softplus_neg_gate; + if (gate >= 0.0f) { + softplus_gate = gate + sp_neg; + softplus_neg_gate = sp_neg; + } else { + softplus_gate = sp_neg; + softplus_neg_gate = -gate + sp_neg; + } + log_coeff = -softplus_gate; + float log_z = -softplus_neg_gate; + float log_tilde_h = hidden >= 0.0f ? log(hidden + 0.5f) : -softplus_fwd(-hidden); + log_value = log_z + log_tilde_h; +} + +inline void log_coeffs_and_values_bwd(float grad_lc, float grad_lv, + float gate, float hidden, + thread float& grad_gate, thread float& grad_hidden) { + float sig_gate = sigmoid_f(gate); + grad_gate = -grad_lc * sig_gate + grad_lv * (1.0f - sig_gate); + if (hidden >= 0.0f) { + grad_hidden = grad_lv / (hidden + 0.5f); + } else { + grad_hidden = grad_lv * sigmoid_f(-hidden); + } +} + +inline float relu_f(float x) { return max(0.0f, x); } +inline float relu_backward_f(float x, float grad_output) { return (x > 0.0f) ? grad_output : 0.0f; } + +struct Philox4x32 { + uint4 counter; + uint2 key; +}; + +inline uint4 philox4x32_round(uint4 ctr, uint2 key) { + constexpr uint PHILOX_M0 = 0xD2511F53u; + constexpr uint PHILOX_M1 = 0xCD9E8D57u; + uint hi0 = mulhi(PHILOX_M0, ctr.x); + uint lo0 = PHILOX_M0 * ctr.x; + uint hi1 = mulhi(PHILOX_M1, ctr.z); + uint lo1 = PHILOX_M1 * ctr.z; + return uint4(hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0); +} + +inline uint4 philox4x32_10(uint4 counter, uint2 key) { + constexpr uint PHILOX_W0 = 0x9E3779B9u; + constexpr uint PHILOX_W1 = 0xBB67AE85u; + for (int i = 0; i < 10; i++) { + counter = philox4x32_round(counter, key); + key.x += PHILOX_W0; + key.y += PHILOX_W1; + } + return counter; +} + +inline float philox_uniform(thread uint& state_idx, uint4 rng_out) { + uint val; + switch (state_idx & 3) { + case 0: val = rng_out.x; break; + case 1: val = rng_out.y; break; + case 2: val = rng_out.z; break; + default: val = rng_out.w; break; + } + state_idx++; + // Convert to (0, 1] uniform + return (float(val >> 8) + 0.5f) / 16777216.0f; +} + +// Box-Muller transform for Normal(0,1) +inline float philox_normal(float u1, float u2) { + return sqrt(-2.0f * log(u1)) * cos(2.0f * M_PI_F * u2); +} + +// mingru_gate_inference: fused chunk + tilde_relu + lerp + highway output gate +// combined is (B, 3*H) = [hidden, gate, proj], state is (B, H) +// x_in is (B, H) = input before projection +// out = sigmoid(proj) * mingru_out + (1 - sigmoid(proj)) * x_in +// next_state = mingru_out +struct MingruGateParams { + int H; + int B; +}; + +kernel void mingru_gate_inference( + device float* out [[buffer(0)]], + device float* next_state [[buffer(1)]], + const device float* combined [[buffer(2)]], + const device float* state_in [[buffer(3)]], + const device float* x_in [[buffer(4)]], + constant MingruGateParams& p [[buffer(5)]], + uint idx [[thread_position_in_grid]] +) { + int N = p.B * p.H; + if ((int)idx >= N) return; + + int b = (int)idx / p.H; + int h = (int)idx % p.H; + int base = b * 3 * p.H; + + float hidden = combined[base + h]; + float gate = combined[base + p.H + h]; + float proj = combined[base + 2 * p.H + h]; + float state = state_in[idx]; + float x = x_in[idx]; + + float gate_sig = sigmoid_f(gate); + float hidden_tilde = tilde_relu_fwd(hidden); + float mingru_out = lerp_f(state, hidden_tilde, gate_sig); + float proj_sig = sigmoid_f(proj); + + next_state[idx] = max(mingru_out, 1e-30f); + out[idx] = proj_sig * mingru_out + (1.0f - proj_sig) * x; +} + +constant int CHECKPOINT_INTERVAL = 4; + +struct ScanParams { + int T_seq; + int H; + int B; +}; + +kernel void mingru_scan_forward_checkpointed( + device float* out [[buffer(0)]], + device float* next_state [[buffer(1)]], + device float* a_star_buf [[buffer(2)]], + device float* s_buf [[buffer(3)]], + device float* log_values_buf [[buffer(4)]], + const device float* combined [[buffer(5)]], + const device float* state [[buffer(6)]], + const device float* input [[buffer(7)]], + constant ScanParams& p [[buffer(8)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.B * p.H) return; + + int b = (int)idx / p.H; + int h = (int)idx % p.H; + int bH = b * p.H; + int H3 = 3 * p.H; + int bHT = bH * p.T_seq; + int out_base = bHT + h; + int cbase = 3 * bHT; + + float a_star = 0.0f; + float log_value = 0.0f; + // fast:: matches CUDA __expf/__logf (kernels.cu:270-305) + float s = fast::log(state[bH + h]); + log_value = s; + + int T_out = p.T_seq + 1; + int buf_base = b * T_out * p.H + h; + int buf_curr = buf_base; + a_star_buf[buf_curr] = a_star; + s_buf[buf_curr] = s; + log_values_buf[buf_curr] = log_value; + + int out_curr = out_base; + int t_offset = 0; + + for (int t = 1; t < p.T_seq + 1; t++) { + float hidden_val = combined[cbase + h + t_offset]; + float gate_val = combined[cbase + p.H + h + t_offset]; + float proj_val = combined[cbase + 2 * p.H + h + t_offset]; + float x_val = input[out_base + (t - 1) * p.H]; + + float log_coeff_val; + log_coeffs_and_values_fwd(gate_val, hidden_val, log_coeff_val, log_value); + + a_star += log_coeff_val; + + float z = log_value - a_star; + float max_val = fmax(s, z); + s = max_val + log1p_f(fast::exp(-abs(s - z))); + + float scan_result = fast::exp(a_star + s); + float proj_sigmoid = sigmoid_f(proj_val); + + out[out_curr] = proj_sigmoid * scan_result + (1.0f - proj_sigmoid) * x_val; + + buf_curr += p.H; + out_curr += p.H; + t_offset += H3; + + if (t % CHECKPOINT_INTERVAL == 0) { + a_star_buf[buf_curr] = a_star; + s_buf[buf_curr] = s; + log_values_buf[buf_curr] = log_value; + } + } + + // Floor at 1e-30 to prevent log(0)=-inf on the next forward pass. + // exp(a_star+s) underflows to exactly 0.0f in fp32 when a_star+s < -87.3. + // A zero state causes log(0)=-inf → permanent -inf propagation through + // all subsequent scan steps. 1e-30 is well above fp32 denormal range + // and below any meaningful state value. + next_state[bH + h] = max(fast::exp(a_star + s), 1e-30f); +} + +kernel void mingru_scan_backward_checkpointed( + device float* grad_combined [[buffer(0)]], + device float* grad_state [[buffer(1)]], + device float* grad_input [[buffer(2)]], + const device float* grad_out [[buffer(3)]], + const device float* grad_next_state [[buffer(4)]], + const device float* combined [[buffer(5)]], + const device float* state [[buffer(6)]], + const device float* input [[buffer(7)]], + const device float* a_star_buf [[buffer(8)]], + const device float* s_buf [[buffer(9)]], + const device float* log_values_buf [[buffer(10)]], + constant ScanParams& p [[buffer(11)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.B * p.H) return; + + int b = (int)idx / p.H; + int h = (int)idx % p.H; + int bHT = b * p.H * p.T_seq; + int cbase = 3 * bHT; + int H3 = 3 * p.H; + int state_idx = b * p.H + h; + int out_base = bHT + h; + + int T_out = p.T_seq + 1; + int buf_base = b * T_out * p.H + h; + + float acc = 0.0f; + float s_val_next = 0.0f; + float carry_grad_a = 0.0f; + + for (int chunk_end = p.T_seq; chunk_end > 0; chunk_end -= CHECKPOINT_INTERVAL) { + int chunk_start = (chunk_end > CHECKPOINT_INTERVAL) ? (chunk_end - CHECKPOINT_INTERVAL) : 0; + int chunk_len = chunk_end - chunk_start; + + // Chunk storage in thread-local arrays + float chunk_a_star[CHECKPOINT_INTERVAL]; + float chunk_s[CHECKPOINT_INTERVAL]; + float chunk_log_values[CHECKPOINT_INTERVAL]; + float chunk_hidden[CHECKPOINT_INTERVAL]; + float chunk_gate[CHECKPOINT_INTERVAL]; + + // Load checkpoint + int ckpt_buf_idx = buf_base + chunk_start * p.H; + float recomp_a_star = a_star_buf[ckpt_buf_idx]; + float recomp_s = s_buf[ckpt_buf_idx]; + float recomp_log_value = log_values_buf[ckpt_buf_idx]; + + // Forward recompute within chunk + for (int i = 0; i < chunk_len; i++) { + int t = chunk_start + 1 + i; + int t_offset = (t - 1) * H3; + float hv = combined[cbase + h + t_offset]; + float gv = combined[cbase + p.H + h + t_offset]; + + float lc; + log_coeffs_and_values_fwd(gv, hv, lc, recomp_log_value); + recomp_a_star += lc; + + float z = recomp_log_value - recomp_a_star; + float mv = fmax(recomp_s, z); + recomp_s = mv + log1p_f(fast::exp(-abs(recomp_s - z))); + + chunk_a_star[i] = recomp_a_star; + chunk_s[i] = recomp_s; + chunk_log_values[i] = recomp_log_value; + chunk_hidden[i] = hv; + chunk_gate[i] = gv; + } + + // Backward through chunk + for (int i = chunk_len - 1; i >= 0; i--) { + int t = chunk_start + 1 + i; + int t_offset = (t - 1) * H3; + + float a_star_t = chunk_a_star[i]; + float s_t = chunk_s[i]; + float log_value_t = chunk_log_values[i]; + float hidden_val = chunk_hidden[i]; + float gate_val = chunk_gate[i]; + float proj_val = combined[cbase + 2 * p.H + h + t_offset]; + int input_idx = out_base + (t - 1) * p.H; + float x_val = input[input_idx]; + + float scan_result = fast::exp(a_star_t + s_t); + float z = log_value_t - a_star_t; + + float grad_out_val = grad_out[input_idx]; + float grad_scan_from_next = (t == p.T_seq) ? grad_next_state[state_idx] : 0.0f; + + float proj_sigmoid = sigmoid_f(proj_val); + float grad_scan_result = grad_scan_from_next + grad_out_val * proj_sigmoid; + float grad_proj = grad_out_val * (scan_result - x_val) * proj_sigmoid * (1.0f - proj_sigmoid); + grad_input[input_idx] = grad_out_val * (1.0f - proj_sigmoid); + + float grad_log_h = grad_scan_result * scan_result; + float grad_s = grad_log_h; + + if (t == p.T_seq) { + acc = grad_s; + } else { + acc = grad_s + acc * fast::exp(s_t - s_val_next); + } + float grad_z = acc * fast::exp(z - s_t); + s_val_next = s_t; + + float grad_a = grad_log_h + carry_grad_a - grad_z; + carry_grad_a = grad_a; + + float grad_g, grad_h; + log_coeffs_and_values_bwd(grad_a, grad_z, gate_val, hidden_val, grad_g, grad_h); + + grad_combined[cbase + h + t_offset] = grad_h; + grad_combined[cbase + p.H + h + t_offset] = grad_g; + grad_combined[cbase + 2 * p.H + h + t_offset] = grad_proj; + } + } + + // Gradient for initial state (t=0) + int ckpt_0_idx = buf_base; + float a_star_0 = a_star_buf[ckpt_0_idx]; + float s_0 = s_buf[ckpt_0_idx]; + float log_value_0 = log_values_buf[ckpt_0_idx]; + + acc = acc * fast::exp(s_0 - s_val_next); + float grad_z_0 = acc * fast::exp((log_value_0 - a_star_0) - s_0); + + grad_state[state_idx] = (state[state_idx] > 0.0f) ? (grad_z_0 / state[state_idx]) : 0.0f; +} + +struct SampleParams { + uint64_t seed; + uint offset; + int num_atns; + int num_atns_total; // sum of act_sizes + int B; + int logits_stride; + int logstd_stride; + int value_stride; + int is_continuous; // 1 for continuous, 0 for discrete + int mask_stride; // stride between rows in mask buffer (may differ from num_atns_total) +}; + +// Apply action mask to a logit: invalid actions get -1e9. +inline float masked_logit(float l, float m) { + if (m < 0.5f) l = -1e9f; + return l; +} + +kernel void sample_logits_kernel( + device float* actions [[buffer(0)]], + device float* logprobs [[buffer(1)]], + device float* value_out [[buffer(2)]], + const device float* logits [[buffer(3)]], + const device float* logstd [[buffer(4)]], + const device float* value [[buffer(5)]], + const device int* act_sizes [[buffer(6)]], + constant SampleParams& sp [[buffer(7)]], + const device float* action_mask [[buffer(8)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= sp.B) return; + + uint offset = sp.offset; + + // Generate Philox RNG state + uint4 counter = uint4((uint)idx, offset, 0u, 0u); + uint2 key = uint2((uint)(sp.seed & 0xFFFFFFFF), (uint)(sp.seed >> 32)); + uint4 rng_out = philox4x32_10(counter, key); + uint rng_idx = 0; + + int logits_base = (int)idx * sp.logits_stride; + float total_log_prob = 0.0f; + + if (sp.is_continuous) { + constexpr float LOG_2PI = 1.8378770664093453f; + int logstd_base = (int)idx * sp.logstd_stride; + + for (int h = 0; h < sp.num_atns; h++) { + float mean = logits[logits_base + h]; + float log_std = logstd[logstd_base + h]; + float std = exp(log_std); + + // Need 2 uniforms for Box-Muller + float u1 = philox_uniform(rng_idx, rng_out); + float u2 = philox_uniform(rng_idx, rng_out); + if (rng_idx >= 4) { + counter.z++; + rng_out = philox4x32_10(counter, key); + rng_idx = 0; + } + float noise = philox_normal(u1, u2); + float action = mean + std * noise; + + float normalized = (action - mean) / std; + float log_prob = -0.5f * normalized * normalized - 0.5f * LOG_2PI - log_std; + + actions[(int)idx * sp.num_atns + h] = action; + total_log_prob += log_prob; + } + } else { + // Discrete action sampling (multinomial) + // CUDA joint-ratio: accumulate scalar total_log_prob across heads + int logits_offset = 0; + + for (int h = 0; h < sp.num_atns; h++) { + int A = act_sizes[h]; + + // Mask base index for this env (mask_stride allows non-contiguous layout) + int mask_base = (int)idx * sp.mask_stride; + + // Max + logsumexp (with mask) + float max_val = -INFINITY; + for (int a = 0; a < A; a++) { + float l = masked_logit(logits[logits_base + logits_offset + a], + action_mask[mask_base + logits_offset + a]); + max_val = fmax(max_val, l); + } + float sum_exp = 0.0f; + for (int a = 0; a < A; a++) { + float l = masked_logit(logits[logits_base + logits_offset + a], + action_mask[mask_base + logits_offset + a]); + sum_exp += exp(l - max_val); + } + float logsumexp_val = max_val + log(sum_exp); + + // Random sample + float rand_val = philox_uniform(rng_idx, rng_out); + if (rng_idx >= 4) { + counter.z++; + rng_out = philox4x32_10(counter, key); + rng_idx = 0; + } + + // Inverse CDF sampling (with mask) + float cumsum = 0.0f; + int sampled_action = A - 1; + for (int a = 0; a < A; a++) { + float l = masked_logit(logits[logits_base + logits_offset + a], + action_mask[mask_base + logits_offset + a]); + float prob = exp(l - logsumexp_val); + cumsum += prob; + if (rand_val < cumsum) { + sampled_action = a; + break; + } + } + + float log_prob = masked_logit( + logits[logits_base + logits_offset + sampled_action], + action_mask[mask_base + logits_offset + sampled_action]) - logsumexp_val; + + actions[(int)idx * sp.num_atns + h] = float(sampled_action); + total_log_prob += log_prob; + + logits_offset += A; + } + } + // Scalar joint logprob (matches CUDA kernels.cu:995) + logprobs[(int)idx] = total_log_prob; + value_out[idx] = value[(int)idx * sp.value_stride]; +} + +// +// Used when CPU inference produces actions but logprobs need GPU-precision +// exp/log to match PPO training kernels. Without this, the importance ratio +// exp(new_logp - old_logp) sees a systematic IEEE vs fast::exp bias that +// causes NaN with overlap (stale weights amplify the mismatch). + +struct RecomputeLogprobsParams { + int B; + int num_atns; + int logits_stride; // fused_cols per row + int mask_stride; // 0 = broadcast (all-ones) +}; + +kernel void recompute_logprobs_kernel( + device float* logprobs [[buffer(0)]], + const device float* logits [[buffer(1)]], + const device float* actions_f32 [[buffer(2)]], + const device int* act_sizes [[buffer(3)]], + const device float* action_mask [[buffer(4)]], + constant RecomputeLogprobsParams& rp [[buffer(5)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= rp.B) return; + + int logits_base = (int)idx * rp.logits_stride; + int mask_base = (rp.mask_stride == 0) ? 0 : (int)idx * rp.mask_stride; + + float total_log_prob = 0.0f; + int logits_offset = 0; + + for (int h = 0; h < rp.num_atns; h++) { + int A = act_sizes[h]; + int act = int(actions_f32[(int)idx * rp.num_atns + h]); + + // Max + logsumexp (with mask) + float max_val = -INFINITY; + for (int a = 0; a < A; a++) { + max_val = fmax(max_val, masked_logit( + logits[logits_base + logits_offset + a], + action_mask[mask_base + logits_offset + a])); + } + float sum_exp = 0.0f; + for (int a = 0; a < A; a++) { + sum_exp += exp(masked_logit( + logits[logits_base + logits_offset + a], + action_mask[mask_base + logits_offset + a]) - max_val); + } + float logsumexp_val = max_val + log(sum_exp); + + float head_lp = masked_logit( + logits[logits_base + logits_offset + act], + action_mask[mask_base + logits_offset + act]) - logsumexp_val; + total_log_prob += head_lp; + + logits_offset += A; + } + // Scalar joint logprob (matches CUDA kernels.cu:995) + logprobs[(int)idx] = total_log_prob; +} + +constant int PPO_THREADS = 256; +constant int LOSS_PG = 0; +constant int LOSS_VF = 1; +constant int LOSS_ENT = 2; +constant int LOSS_TOTAL = 3; +constant int LOSS_OLD_APPROX_KL = 4; +constant int LOSS_APPROX_KL = 5; +constant int LOSS_CLIPFRAC = 6; +constant int LOSS_N = 7; +constant int MAX_ATN_HEADS = 16; + +// Float atomic add via CAS loop (MSL has no native atomic) +inline void atomic_add_float(device atomic_uint* addr, float val) { + uint expected = atomic_load_explicit(addr, memory_order_relaxed); + while (true) { + float current = as_type(expected); + float desired = current + val; + uint desired_bits = as_type(desired); + if (atomic_compare_exchange_weak_explicit(addr, &expected, desired_bits, + memory_order_relaxed, memory_order_relaxed)) { + return; + } + } +} + +// PPO helper: compute logsumexp, entropy, log_prob for a single discrete head with masks. +// mask pointer + mask_offset index into the action mask for this head. +// Invalid actions (mask < 0.5) get logit = -1e9, matching rollout sampling. +inline void ppo_discrete_head( + const device float* logits, + int logits_base, int logits_stride_a, int logits_offset, + int A, int act, + const device float* mask, int mask_offset, + thread float& out_logsumexp, thread float& out_entropy, thread float& out_logp +) { + float max_logit = -INFINITY; + float sum = 0.0f; + float act_logit = 0.0f; + + for (int a = 0; a < A; a++) { + float l = logits[logits_base + (logits_offset + a) * logits_stride_a]; + if (mask[mask_offset + a] < 0.5f) l = -1e9f; + if (a == act) act_logit = l; + if (l > max_logit) { + sum *= exp(max_logit - l); + max_logit = l; + } + sum += exp(l - max_logit); + } + // Degenerate input (all masked or non-finite model output): propagate NaN + // so the corruption surfaces immediately in the PPO loss rather than + // silently producing logp=0 (ratio=1) which poisons gradients. + if (!isfinite(max_logit) || !isfinite(sum) || sum <= 0.0f) { + out_logsumexp = NAN; + out_entropy = NAN; + out_logp = NAN; + return; + } + float lse = max_logit + log(sum); + + float ent = 0.0f; + for (int a = 0; a < A; a++) { + float l = logits[logits_base + (logits_offset + a) * logits_stride_a]; + if (mask[mask_offset + a] < 0.5f) l = -1e9f; + float logp = l - lse; + float p = exp(clamp(logp, -80.0f, 80.0f)); + ent -= p * logp; + } + + out_logsumexp = lse; + out_entropy = ent; + out_logp = act_logit - lse; +} + +// PPO helper: compute log_prob and entropy for a single continuous head +inline void ppo_continuous_head( + float mean, float log_std, float action, + thread float& out_logp, thread float& out_entropy +) { + constexpr float HALF_LOG_2PI = 0.9189385332046727f; + constexpr float HALF_1_PLUS_LOG_2PI = 1.4189385332046727f; + float std = exp(log_std); + float normalized = (action - mean) / std; + out_logp = -0.5f * normalized * normalized - HALF_LOG_2PI - log_std; + out_entropy = HALF_1_PLUS_LOG_2PI + log_std; +} + +struct PPOFusedParams { + int num_atns; + float clip_coef; + float vf_clip_coef; + float vf_coef; + float ent_coef; + int T_seq; + int A_total; + int N; + int logits_stride_n; + int logits_stride_t; + int logits_stride_a; + int values_stride_n; + int values_stride_t; + int is_continuous; + int num_atns_total; // sum of act_sizes, for mask buffer indexing + int mask_stride; // stride in floats between consecutive mask rows in obs +}; + +// Fused PPO forward + backward: computes loss partials AND gradients in one pass. +// Per-block partial sums written to ppo_partials, reduced by ppo_loss_reduce_kernel. +kernel void ppo_loss_fwd_bwd_kernel( + device float* ppo_partials [[buffer(0)]], + device float* grad_logits [[buffer(1)]], + device float* grad_logstd [[buffer(2)]], + device float* grad_values_pred [[buffer(3)]], + const device float* logits [[buffer(4)]], + const device float* logstd [[buffer(5)]], + const device float* values_pred [[buffer(6)]], + const device float* actions [[buffer(7)]], + const device float* old_logprobs [[buffer(8)]], + const device float* advantages [[buffer(9)]], + const device float* prio [[buffer(10)]], + const device float* values [[buffer(11)]], + const device float* returns_buf [[buffer(12)]], + const device float* adv_mean [[buffer(13)]], + const device float* adv_var [[buffer(14)]], + const device int* act_sizes [[buffer(15)]], + constant PPOFusedParams& pp [[buffer(16)]], + const device float* action_mask [[buffer(17)]], + device float* out_ratio [[buffer(18)]], + device float* out_newvalue [[buffer(19)]], + uint idx [[thread_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint block_id [[threadgroup_position_in_grid]] +) { + int total_elements = pp.N * pp.T_seq; + float inv_NT = 1.0f / float(total_elements); + + threadgroup float block_losses[7][PPO_THREADS]; + for (int c = 0; c < LOSS_N; c++) { + block_losses[c][tid] = 0.0f; + } + + // MSL has no goto — use if/else instead of the CUDA goto pattern + if ((int)idx < total_elements) { + int n = (int)idx / pp.T_seq; + int t = (int)idx % pp.T_seq; + int nt = n * pp.T_seq + t; + + int logits_base = n * pp.logits_stride_n + t * pp.logits_stride_t; + int values_idx = n * pp.values_stride_n + t * pp.values_stride_t; + int grad_logits_base = nt * pp.A_total; + + // Shared computation (used by both forward and backward) + float adv = advantages[nt]; + float w = prio[n]; + float val = values[nt]; + float ret = returns_buf[nt]; + float val_pred = values_pred[values_idx]; + out_newvalue[nt] = val_pred; + + float adv_std = sqrt(adv_var[0]); + float adv_normalized = (adv - adv_mean[0]) / (adv_std + 1e-8f); + + // grad_loss is always 1.0 + float dL = inv_NT; + float d_pg_loss = dL; + float d_entropy_term = dL * (-pp.ent_coef); + + // Value loss (forward) + value gradient (backward) + float v_error = val_pred - val; + float v_clipped = val + clamp(v_error, -pp.vf_clip_coef, pp.vf_clip_coef); + float v_loss_unclipped = (val_pred - ret) * (val_pred - ret); + float v_loss_clipped = (v_clipped - ret) * (v_clipped - ret); + float v_loss = 0.5f * fmax(v_loss_unclipped, v_loss_clipped); + + float d_val_pred = 0.0f; + if (v_loss_clipped > v_loss_unclipped) { + if (v_error >= -pp.vf_clip_coef && v_error <= pp.vf_clip_coef) { + d_val_pred = v_clipped - ret; + } + } else { + d_val_pred = val_pred - ret; + } + grad_values_pred[nt] = dL * w * pp.vf_coef * d_val_pred; + + // Policy loss + gradients. Both branches produce pg_loss, total_entropy, + // logratio, ratio — the block-loss accumulation is shared after the if/else. + float pg_loss, total_entropy, logratio, ratio; + + if (pp.is_continuous) { + // Continuous: old_logprobs is scalar (matches CUDA) + float old_logp = old_logprobs[nt]; + float total_log_prob = 0.0f; + total_entropy = 0.0f; + for (int h = 0; h < pp.num_atns; h++) { + float mean = logits[logits_base + h * pp.logits_stride_a]; + float log_std = logstd[logits_base + h * pp.logits_stride_a]; + float action = actions[nt * pp.num_atns + h]; + float lp, ent; + ppo_continuous_head(mean, log_std, action, lp, ent); + total_log_prob += lp; + total_entropy += ent; + } + + logratio = total_log_prob - old_logp; + ratio = exp(logratio); + out_ratio[nt] = ratio; + float ratio_clipped = clamp(ratio, 1.0f - pp.clip_coef, 1.0f + pp.clip_coef); + float wa = -w * adv_normalized; + pg_loss = fmax(wa * ratio, wa * ratio_clipped); + + // Backward: policy gradient + float d_ratio = wa * d_pg_loss; + if (wa * ratio_clipped > wa * ratio) { + if (ratio <= (1.0f - pp.clip_coef) || ratio >= (1.0f + pp.clip_coef)) + d_ratio = 0.0f; + } + float d_new_logp = d_ratio * ratio; + + for (int h = 0; h < pp.num_atns; h++) { + float mean = logits[logits_base + h * pp.logits_stride_a]; + float log_std = logstd[logits_base + h * pp.logits_stride_a]; + float std = exp(log_std); + float var = std * std; + float action = actions[nt * pp.num_atns + h]; + float diff = action - mean; + + grad_logits[grad_logits_base + h] = d_new_logp * diff / var; + grad_logstd[grad_logits_base + h] = d_new_logp * (diff * diff / var - 1.0f) + d_entropy_term; + } + } else { + // Discrete joint-ratio clipping (matches CUDA kernels.cu:738-807). + // Sum per-head logprobs into scalar, compute single joint ratio, clip once. + float head_logsumexp[MAX_ATN_HEADS]; + float head_entropy[MAX_ATN_HEADS]; + int head_act[MAX_ATN_HEADS]; + int mask_base = (int)idx * pp.mask_stride; + + int logits_offset = 0; + float total_log_prob = 0.0f; + total_entropy = 0.0f; + + for (int h = 0; h < pp.num_atns; h++) { + int A = act_sizes[h]; + int act = int(actions[nt * pp.num_atns + h]); + head_act[h] = act; + float lse, ent, lp; + ppo_discrete_head(logits, logits_base, pp.logits_stride_a, logits_offset, + A, act, action_mask, mask_base + logits_offset, lse, ent, lp); + head_logsumexp[h] = lse; + head_entropy[h] = ent; + total_log_prob += lp; + total_entropy += ent; + logits_offset += A; + } + + float old_logp = old_logprobs[nt]; + logratio = total_log_prob - old_logp; + ratio = exp(logratio); + out_ratio[nt] = ratio; + float ratio_clipped = clamp(ratio, 1.0f - pp.clip_coef, 1.0f + pp.clip_coef); + float wa = -w * adv_normalized; + pg_loss = fmax(wa * ratio, wa * ratio_clipped); + + // Backward: single joint d_new_logp shared across all heads + float d_ratio = wa * d_pg_loss; + if (wa * ratio_clipped > wa * ratio) { + if (ratio <= (1.0f - pp.clip_coef) || ratio >= (1.0f + pp.clip_coef)) + d_ratio = 0.0f; + } + float d_new_logp = d_ratio * ratio; + + // Gradient pass over logits (reuses head_logsumexp, head_entropy) + logits_offset = 0; + for (int h = 0; h < pp.num_atns; h++) { + int A = act_sizes[h]; + int act = head_act[h]; + float lse = head_logsumexp[h]; + float ent = head_entropy[h]; + + for (int a = 0; a < A; a++) { + float raw_l = logits[logits_base + (logits_offset + a) * pp.logits_stride_a]; + float m = action_mask[mask_base + logits_offset + a]; + float l = (m < 0.5f) ? -1e9f : raw_l; + float logp = l - lse; + float p = exp(logp); + float d_logit = (a == act) ? d_new_logp : 0.0f; + d_logit -= p * d_new_logp; + d_logit += d_entropy_term * p * (-ent - logp); + if (m < 0.5f) d_logit = 0.0f; + grad_logits[grad_logits_base + logits_offset + a] = d_logit; + } + logits_offset += A; + } + } + + // Shared loss accumulation (both branches produce pg_loss, total_entropy, logratio, ratio) + block_losses[LOSS_PG][tid] = pg_loss * inv_NT; + block_losses[LOSS_VF][tid] = v_loss * inv_NT; + block_losses[LOSS_ENT][tid] = total_entropy * inv_NT; + block_losses[LOSS_TOTAL][tid] = (pg_loss + pp.vf_coef * v_loss - pp.ent_coef * total_entropy) * inv_NT; + block_losses[LOSS_OLD_APPROX_KL][tid] = (-logratio) * inv_NT; + block_losses[LOSS_APPROX_KL][tid] = ((ratio - 1.0f) - logratio) * inv_NT; + block_losses[LOSS_CLIPFRAC][tid] = (abs(ratio - 1.0f) > pp.clip_coef ? 1.0f : 0.0f) * inv_NT; + } // end if (idx < total_elements) + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Block reduction (tree reduction) + for (int stride = PPO_THREADS / 2; stride > 0; stride >>= 1) { + if ((int)tid < stride) { + for (int c = 0; c < LOSS_N; c++) { + block_losses[c][tid] += block_losses[c][tid + stride]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tid == 0) { + int base_out = (int)block_id * (LOSS_N + 1); + ppo_partials[base_out] = block_losses[LOSS_TOTAL][0]; + for (int c = 0; c < LOSS_N; c++) { + ppo_partials[base_out + 1 + c] = block_losses[c][0]; + } + } +} + +// Deterministic reduction of per-block PPO loss partials + count increment +struct PPOReduceParams { + int num_blocks; +}; + +kernel void ppo_loss_reduce_kernel( + device float* loss [[buffer(0)]], + device float* losses_acc [[buffer(1)]], + const device float* partials [[buffer(2)]], + constant PPOReduceParams& pp [[buffer(3)]], + uint tid [[thread_index_in_threadgroup]] +) { + if ((int)tid > LOSS_N) return; + + float sum = 0.0f; + for (int b = 0; b < pp.num_blocks; b++) { + sum += partials[b * (LOSS_N + 1) + (int)tid]; + } + + if (tid == 0) { + *loss += sum; + } else { + losses_acc[(int)tid - 1] += sum; + } + + // Fold add_scalar: increment epoch count + if (tid == 0) { + losses_acc[LOSS_N] += 1.0f; + } +} + +struct AdvantageParams { + float gamma; + float lambda; + float rho_clip; + float c_clip; + int num_steps; + int horizon; +}; + +kernel void puff_advantage_kernel( + const device float* values [[buffer(0)]], + const device float* rewards [[buffer(1)]], + const device float* dones [[buffer(2)]], + const device float* importance [[buffer(3)]], + device float* advantages [[buffer(4)]], + constant AdvantageParams& p [[buffer(5)]], + uint row [[thread_position_in_grid]] +) { + if ((int)row >= p.num_steps) return; + + int offset = (int)row * p.horizon; + const device float* v = values + offset; + const device float* r = rewards + offset; + const device float* d = dones + offset; + const device float* imp = importance + offset; + device float* adv = advantages + offset; + + float lastpufferlam = 0.0f; + for (int t = p.horizon - 2; t >= 0; t--) { + int t_next = t + 1; + float nextnonterminal = 1.0f - d[t_next]; + float rho_t = min(imp[t], p.rho_clip); + float c_t = min(imp[t], p.c_clip); + float delta = rho_t * (r[t_next] + p.gamma * v[t_next] * nextnonterminal - v[t]); + lastpufferlam = delta + p.gamma * p.lambda * c_t * lastpufferlam * nextnonterminal; + adv[t] = lastpufferlam; + } +} + +struct PrioParams { + float prio_alpha; + int stride; +}; + +kernel void prio_adv_reduction_kernel( + const device float* advantages [[buffer(0)]], + device float* prio_weights [[buffer(1)]], + constant PrioParams& pp [[buffer(2)]], + uint row [[threadgroup_position_in_grid]], + uint tx [[thread_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]] +) { + int offset = (int)row * pp.stride; + + float local_sum = 0.0f; + for (int t = (int)tx; t < pp.stride; t += 32) { + local_sum += abs(advantages[offset + t]); + } + + // Simdgroup reduction + local_sum = simd_sum(local_sum); + + if (simd_lane == 0 && tx < 32) { + // epsilon floor prevents pow(0, alpha>0) = 0 which would permanently + // exclude zero-advantage segments from sampling in sparse-reward envs + float pw = pow(local_sum + 1e-6f, pp.prio_alpha); + if (isnan(pw) || isinf(pw)) pw = 1e-6f; + prio_weights[row] = pw; + } +} + +struct PrioNormParams { + int length; +}; + +kernel void prio_normalize_kernel( + device float* prio_weights [[buffer(0)]], + constant PrioNormParams& pp [[buffer(1)]], + uint tx [[thread_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]] +) { + constexpr float eps = 1e-6f; + constexpr int NUM_WARPS = 8; // 256 / 32 + + threadgroup float shmem[NUM_WARPS]; + threadgroup float block_sum; + + float local_sum = 0.0f; + for (int t = (int)tx; t < pp.length; t += 256) { + local_sum += prio_weights[t]; + } + + // Simdgroup reduction + local_sum = simd_sum(local_sum); + + if (simd_lane == 0) shmem[simd_id] = local_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_id == 0) { + float val = (simd_lane < NUM_WARPS) ? shmem[simd_lane] : 0.0f; + val = simd_sum(val); + // add eps * length so numerator and denominator are consistent: + // sum((w_i + eps)) = sum(w_i) + eps*length + if (tx == 0) block_sum = val + eps * float(pp.length); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = (int)tx; t < pp.length; t += 256) { + prio_weights[t] = (prio_weights[t] + eps) / block_sum; + } +} + +struct PrioImpParams { + int total_agents; + float anneal_beta; + int minibatch_segments; +}; + +struct PrioSampleParams { + uint64_t seed; + uint base_offset; + int total_segments; + int minibatch_segments; +}; + +kernel void prio_sample_kernel( + device int64_t* indices [[buffer(0)]], + const device float* prio_probs [[buffer(1)]], + constant PrioSampleParams& pp [[buffer(2)]], + uint tx [[thread_position_in_grid]] +) { + if ((int)tx >= pp.minibatch_segments) return; + + uint4 counter = uint4(pp.base_offset + tx, 0u, 0u, 0u); + uint2 key = uint2((uint)(pp.seed & 0xFFFFFFFF), (uint)(pp.seed >> 32)); + uint4 rng_out = philox4x32_10(counter, key); + uint rng_idx = 0; + float u = philox_uniform(rng_idx, rng_out); + + float cumsum = 0.0f; + int sampled = pp.total_segments - 1; + for (int i = 0; i < pp.total_segments; i++) { + cumsum += prio_probs[i]; + if (u <= cumsum) { + sampled = i; + break; + } + } + indices[tx] = sampled; +} + +kernel void prio_imp_weights_kernel( + const device int64_t* indices [[buffer(0)]], + const device float* prio_probs [[buffer(1)]], + device float* mb_prio [[buffer(2)]], + constant PrioImpParams& pp [[buffer(3)]], + uint tx [[thread_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]] +) { + // Compute raw importance weights + float w = 1.0f; + if ((int)tx < pp.minibatch_segments) { + float value = prio_probs[indices[tx]] * float(pp.total_agents); + w = pow(value, -pp.anneal_beta); + } + + // Max-normalize: w_i /= max(w_j) so all weights are in [0, 1]. + // Prevents gradient amplification from undersampled segments. + constexpr int NUM_WARPS = 8; + threadgroup float shmem[NUM_WARPS]; + float local_max = ((int)tx < pp.minibatch_segments) ? w : 0.0f; + local_max = simd_max(local_max); + if (simd_lane == 0) shmem[simd_id] = local_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_id == 0) { + float val = (simd_lane < NUM_WARPS) ? shmem[simd_lane] : 0.0f; + val = simd_max(val); + if (simd_lane == 0) shmem[0] = val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float w_max = max(shmem[0], 1e-8f); + + if ((int)tx < pp.minibatch_segments) { + mb_prio[tx] = w / w_max; + } +} + +struct FillParams { + float val; + int n; +}; + +kernel void fill_f32( + device float* dst [[buffer(0)]], + constant FillParams& p [[buffer(1)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] = p.val; +} + +kernel void copy_f32( + device float* dst [[buffer(0)]], + device const float* src [[buffer(1)]], + constant int& n [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < n) dst[idx] = src[idx]; +} + +struct ClampParams { + float lo; + float hi; + int n; +}; + +kernel void clamp_f32( + device float* dst [[buffer(0)]], + constant ClampParams& p [[buffer(1)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] = clamp(dst[idx], p.lo, p.hi); +} + +struct ScaleParams { + float alpha; + int n; +}; + +kernel void scale_f32( + device float* dst [[buffer(0)]], + constant ScaleParams& p [[buffer(1)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] *= p.alpha; +} + +struct AxpyParams { + float alpha; + int n; +}; + +kernel void axpy_f32( + device float* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + constant AxpyParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] += p.alpha * src[idx]; +} + +struct AddParams { + int n; +}; + +kernel void add_f32( + device float* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + constant AddParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] += src[idx]; +} + +kernel void add_f16( + device half* dst [[buffer(0)]], + const device half* src [[buffer(1)]], + constant AddParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] = half(float(dst[idx]) + float(src[idx])); +} + +struct NesterovParams { + float mu; + int n; +}; + +// Fused Nesterov momentum: mb = mu*mb + gc; gc = gc + mu*mb (note: gc uses updated mb) +kernel void nesterov_f32( + device float* mb [[buffer(0)]], + device float* gc [[buffer(1)]], + constant NesterovParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) { + float m = p.mu * mb[idx] + gc[idx]; + mb[idx] = m; + gc[idx] += p.mu * m; + } +} + +struct ScaleDevParams { + int n; +}; + +// Scale by device-side scalar: dst[i] *= *alpha_ptr +kernel void scale_f32_dev( + device float* dst [[buffer(0)]], + const device float* alpha_ptr [[buffer(1)]], + constant ScaleDevParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + float alpha = *alpha_ptr; + if ((int)idx < p.n) dst[idx] *= alpha; +} + +struct AxpyDevParams { + int n; +}; + +// dst += (*alpha) * src +kernel void axpy_f32_dev( + device float* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + const device float* alpha_ptr [[buffer(2)]], + constant AxpyDevParams& p [[buffer(3)]], + uint idx [[thread_position_in_grid]] +) { + float alpha = *alpha_ptr; + if ((int)idx < p.n) dst[idx] += alpha * src[idx]; +} + +struct AddScalarParams { + float val; +}; + +// *ptr += val (single element) +kernel void add_scalar( + device float* ptr [[buffer(0)]], + constant AddScalarParams& p [[buffer(1)]], + uint idx [[thread_position_in_grid]] +) { + if (idx == 0) *ptr += p.val; +} + +// Reads LR from device, computes neg_lr = -lr +kernel void compute_lr_scalars_kernel( + const device float* lr [[buffer(0)]], + device float* neg_lr [[buffer(1)]], + uint idx [[thread_position_in_grid]] +) { + if (idx == 0) { + *neg_lr = -(*lr); + } +} + +struct MuonParams { + int n; + float weight_decay; + float scale; +}; + +// Weight update: wb = wb * (1 - lr * wd) - lr * scale * up (matches CUDA muon.cu) +kernel void muon_weight_update_kernel( + device float* wb [[buffer(0)]], + const device float* up [[buffer(1)]], + const device float* lr_ptr [[buffer(2)]], + constant MuonParams& p [[buffer(3)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.n) return; + float lr = *lr_ptr; + float wd_scale = 1.0f - lr * p.weight_decay; + wb[idx] = wb[idx] * wd_scale - lr * p.scale * up[idx]; +} + +struct TransposeParams { + int R; + int C; +}; + +// Transpose R x C matrix: dst[c*R + r] = src[r*C + c] +kernel void transpose_f32( + device float* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + constant TransposeParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.R * p.C) return; + dst[(idx % p.C) * p.R + idx / p.C] = src[idx]; +} + +struct Transpose01Params { + int A; + int B; + int C; +}; + +// Transpose dims 0 and 1 of (A, B, C) tensor: dst[b*A*C + a*C + c] = src[a*B*C + b*C + c] +kernel void transpose_01( + device float* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + constant Transpose01Params& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + int total = p.A * p.B * p.C; + if ((int)idx >= total) return; + int a = (int)idx / (p.B * p.C); + int rem = (int)idx % (p.B * p.C); + int b = rem / p.C; + int c = rem % p.C; + dst[b * p.A * p.C + a * p.C + c] = src[idx]; +} + +// Transpose dims 0 and 1 for 8-byte elements (f64/u64), using uint2 pairs. +// Metal has no native double support, so we treat each 8-byte element as uint2. +// Same index math as transpose_01 — just operates on uint2 instead of float. +kernel void transpose_01_u64( + device uint2* dst [[buffer(0)]], + const device uint2* src [[buffer(1)]], + constant Transpose01Params& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + int total = p.A * p.B * p.C; + if ((int)idx >= total) return; + int a = (int)idx / (p.B * p.C); + int rem = (int)idx % (p.B * p.C); + int b = rem / p.C; + int c = rem % p.C; + dst[b * p.A * p.C + a * p.C + c] = src[idx]; +} + +struct NormParams { + int n; +}; + +// Per-block sum of squares (partial reduction) +kernel void norm_f32_kernel( + device float* partials [[buffer(0)]], + const device float* src [[buffer(1)]], + constant NormParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint block_id [[threadgroup_position_in_grid]], + uint grid_size [[threads_per_grid]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]] +) { + constexpr int NUM_WARPS = 8; // 256 / 32 + threadgroup float sdata[NUM_WARPS]; + float sum = 0.0f; + for (int i = (int)idx; i < p.n; i += (int)grid_size) { + sum += src[i] * src[i]; + } + sum = simd_sum(sum); + if (simd_lane == 0) sdata[simd_id] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_id == 0) { + sum = (simd_lane < NUM_WARPS) ? sdata[simd_lane] : 0.0f; + sum = simd_sum(sum); + if (simd_lane == 0) partials[block_id] = sum; + } +} + +struct NormReduceParams { + int num_blocks; +}; + +// Reduce per-block partials to a single sum-of-squares value +kernel void norm_reduce_kernel( + device float* out [[buffer(0)]], + const device float* partials [[buffer(1)]], + constant NormReduceParams& p [[buffer(2)]], + uint tid [[thread_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]] +) { + constexpr int NUM_WARPS = 8; + threadgroup float sdata[NUM_WARPS]; + float val = ((int)tid < p.num_blocks) ? partials[tid] : 0.0f; + val = simd_sum(val); + if (simd_lane == 0) sdata[simd_id] = val; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_id == 0) { + val = (simd_lane < NUM_WARPS) ? sdata[simd_lane] : 0.0f; + val = simd_sum(val); + if (simd_lane == 0) *out = val; + } +} + +struct ClipByNormParams { + float max_norm; + float eps; + int n; +}; + +// Clip gradient by global norm: dst[i] *= min(max_norm / (sqrt(sum_sq) + eps), 1.0) +// When clip_coef is 0 (norm overflow), zero directly to avoid inf*0=NaN. +kernel void clip_by_norm_f32( + device float* dst [[buffer(0)]], + const device float* sum_sq_ptr [[buffer(1)]], + constant ClipByNormParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + float clip_coef = min(p.max_norm / (sqrt(*sum_sq_ptr) + p.eps), 1.0f); + if ((int)idx < p.n) { + dst[idx] = (clip_coef > 0.0f) ? dst[idx] * clip_coef : 0.0f; + } +} + +struct NormalizeParams { + float eps; + int n; +}; + +// dst[i] /= max(sqrt(*norm), eps) — matches CUDA (no cap) +kernel void normalize_f32( + device float* dst [[buffer(0)]], + const device float* norm_ptr [[buffer(1)]], + constant NormalizeParams& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + float inv_norm = 1.0f / max(sqrt(*norm_ptr), p.eps); + if ((int)idx < p.n) dst[idx] = dst[idx] * inv_norm; +} + +struct VarMeanParams { + int n; +}; + +// Compute variance and mean of a float array (single threadgroup) +kernel void var_mean_kernel( + const device float* src [[buffer(0)]], + device float* var_out [[buffer(1)]], + device float* mean_out [[buffer(2)]], + constant VarMeanParams& p [[buffer(3)]], + uint tid [[thread_index_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]] +) { + constexpr int NUM_WARPS = 8; + threadgroup float sdata[NUM_WARPS]; + + // Pass 1: compute mean + float sum = 0.0f; + for (int i = (int)tid; i < p.n; i += 256) sum += src[i]; + sum = simd_sum(sum); + if (simd_lane == 0) sdata[simd_id] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_id == 0) { + sum = (simd_lane < NUM_WARPS) ? sdata[simd_lane] : 0.0f; + sum = simd_sum(sum); + if (simd_lane == 0) sdata[0] = sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float mean = sdata[0] / float(p.n); + if (tid == 0) *mean_out = mean; + + // Pass 2: compute variance + float ss = 0.0f; + for (int i = (int)tid; i < p.n; i += 256) { + float d = src[i] - mean; + ss += d * d; + } + ss = simd_sum(ss); + if (simd_lane == 0) sdata[simd_id] = ss; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_id == 0) { + ss = (simd_lane < NUM_WARPS) ? sdata[simd_lane] : 0.0f; + ss = simd_sum(ss); + if (simd_lane == 0) sdata[0] = ss; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tid == 0) *var_out = sdata[0] / float(p.n - 1); +} + +struct SumRowsParams { + int R; + int C; +}; + +// dst[c] = sum over rows of src[:, c] +kernel void sum_rows_to_f32_kernel( + device float* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + constant SumRowsParams& p [[buffer(2)]], + uint col [[thread_position_in_grid]] +) { + if ((int)col >= p.C) return; + float sum = 0.0f; + for (int r = 0; r < p.R; r++) sum += src[r * p.C + (int)col]; + dst[col] = sum; +} + +// --- Sum rows fp16 (for bias/LN param grads) --- + +kernel void sum_rows_f16_kernel( + device half* dst [[buffer(0)]], + const device half* src [[buffer(1)]], + constant SumRowsParams& p [[buffer(2)]], + uint col [[thread_position_in_grid]] +) { + if ((int)col >= p.C) return; + float sum = 0.0f; + for (int r = 0; r < p.R; r++) sum += float(src[r * p.C + (int)col]); + dst[col] = half(sum); +} + +struct AssembleDecoderGradParams { + int B_TT; + int od; + int od_plus_1; +}; + +// Assemble gradient: dst = [grad_logits | grad_value] in fused layout +kernel void assemble_decoder_grad_f32( + device float* dst [[buffer(0)]], + const device float* grad_logits [[buffer(1)]], + const device float* grad_value [[buffer(2)]], + constant AssembleDecoderGradParams& p [[buffer(3)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.B_TT * p.od_plus_1) return; + int row = (int)idx / p.od_plus_1; + int col = (int)idx % p.od_plus_1; + dst[idx] = (col < p.od) ? grad_logits[row * p.od + col] : grad_value[row]; +} + +struct SelectCopyParams { + int obs_row_bytes; + int act_row_bytes; + int lp_row_bytes; + int horizon; +}; + +// Minibatch assembly: copy observations, actions, logprobs, values+advantages+returns, prio +// Channel 0 fuses obs gather + f32→f16 cast: reads f32 src, writes f16 directly to fp16_obs_out. +// Dispatched as (minibatch_size, 5) threadgroups, each handles one channel for one row. +kernel void select_copy_kernel( + device char* mb_obs [[buffer(0)]], + device char* mb_actions [[buffer(1)]], + device char* mb_logprobs [[buffer(2)]], + device float* mb_values [[buffer(3)]], + device float* mb_advantages [[buffer(4)]], + device float* mb_returns [[buffer(5)]], + device float* mb_prio_out [[buffer(6)]], + const device char* src_obs [[buffer(7)]], + const device char* src_actions [[buffer(8)]], + const device char* src_logprobs [[buffer(9)]], + const device float* src_values [[buffer(10)]], + const device float* advantages [[buffer(11)]], + const device int64_t* idx [[buffer(12)]], + const device float* mb_prio_in [[buffer(13)]], + constant SelectCopyParams& p [[buffer(14)]], + device half* fp16_obs_out [[buffer(15)]], + uint2 group_id [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]] +) { + int mb = (int)group_id.x; + int ch = (int)group_id.y; + int src_row = (int)idx[mb]; + + if (ch == 0) { + // Fused obs gather + f32→f16 cast: copy f32 to mb_obs AND write f16 directly. + // mb_obs f32 copy is needed because PPO reads embedded action masks from it. + const device float* sptr = (const device float*)(src_obs + (int64_t)src_row * p.obs_row_bytes); + int count = p.obs_row_bytes / 4; // number of floats + device float* f32ptr = (device float*)(mb_obs + (int64_t)mb * p.obs_row_bytes); + device half* f16ptr = fp16_obs_out + (int64_t)mb * count; + for (int i = (int)tid; i < count; i += 256) { + float val = sptr[i]; + f32ptr[i] = val; + f16ptr[i] = half(val); + } + } else if (ch == 1) { + // Copy actions + const device int* sptr = (const device int*)(src_actions + (int64_t)src_row * p.act_row_bytes); + device int* dptr = (device int*)(mb_actions + (int64_t)mb * p.act_row_bytes); + for (int i = (int)tid; i < p.act_row_bytes / 4; i += 256) dptr[i] = sptr[i]; + } else if (ch == 2) { + // Copy logprobs + const device int* sptr = (const device int*)(src_logprobs + (int64_t)src_row * p.lp_row_bytes); + device int* dptr = (device int*)(mb_logprobs + (int64_t)mb * p.lp_row_bytes); + for (int i = (int)tid; i < p.lp_row_bytes / 4; i += 256) dptr[i] = sptr[i]; + } else if (ch == 3) { + // Copy values + advantages, compute returns = values + advantages + int srh = src_row * p.horizon; + int drh = mb * p.horizon; + for (int i = (int)tid; i < p.horizon; i += 256) { + float val = src_values[srh + i]; + float adv = advantages[srh + i]; + mb_values[drh + i] = val; + mb_advantages[drh + i] = adv; + mb_returns[drh + i] = val + adv; + } + } else if (ch == 4) { + // Copy prio weight + if (tid == 0) { + mb_prio_out[mb] = mb_prio_in[mb]; + } + } +} + +struct IndexCopyParams { + int num_idx; + int row_bytes; +}; + +// Indexed copy: for each i, copy src row i to dst row idx[i] +kernel void index_copy_kernel( + device char* dst [[buffer(0)]], + const device int64_t* idx [[buffer(1)]], + const device char* src [[buffer(2)]], + constant IndexCopyParams& p [[buffer(3)]], + uint i [[thread_position_in_grid]] +) { + if ((int)i >= p.num_idx) return; + int64_t dst_row = idx[i]; + const device char* s = src + (int64_t)i * p.row_bytes; + device char* d = dst + dst_row * p.row_bytes; + // Copy as 4-byte words, then handle remainder bytes + int words = p.row_bytes / 4; + const device uint* s4 = (const device uint*)s; + device uint* d4 = (device uint*)d; + for (int b = 0; b < words; b++) d4[b] = s4[b]; + for (int b = words * 4; b < p.row_bytes; b++) d[b] = s[b]; +} + +// index_gather_kernel: dst[i] = src[idx[i]] (gather, inverse of index_copy) +kernel void index_gather_kernel( + device char* dst [[buffer(0)]], + const device int64_t* idx [[buffer(1)]], + const device char* src [[buffer(2)]], + constant IndexCopyParams& p [[buffer(3)]], + uint i [[thread_position_in_grid]] +) { + if ((int)i >= p.num_idx) return; + int64_t src_row = idx[i]; + const device char* s = src + src_row * p.row_bytes; + device char* d = dst + (int64_t)i * p.row_bytes; + int words = p.row_bytes / 4; + const device uint* s4 = (const device uint*)s; + device uint* d4 = (device uint*)d; + for (int b = 0; b < words; b++) d4[b] = s4[b]; + for (int b = words * 4; b < p.row_bytes; b++) d[b] = s[b]; +} + +struct CastU8Params { + int n; +}; + +kernel void cast_u8_to_f32( + device float* dst [[buffer(0)]], + const device uchar* src [[buffer(1)]], + constant CastU8Params& p [[buffer(2)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx < p.n) dst[idx] = float(src[idx]); +} + +// IEEE 754 f64→f32 bit cast. Metal has no native double, so we read each +// 8-byte double as uint2 and extract +// sign, exponent, mantissa via bit manipulation. Subnormals flush to zero. +kernel void cast_f64_to_f32( + device const uint2* src [[buffer(0)]], + device float* dst [[buffer(1)]], + constant int& count [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if ((int)gid >= count) return; + uint2 bits = src[gid]; + uint hi = bits.y, lo = bits.x; + uint sign = hi >> 31; + int biased_exp = int((hi >> 20) & 0x7FFu); + int exp_f32 = biased_exp - 1023 + 127; + // Top 23 bits of the 52-bit mantissa: 20 from hi + 3 from lo + uint mantissa = ((hi & 0xFFFFFu) << 3) | (lo >> 29); + uint result; + if (biased_exp == 0) result = sign << 31; // zero / subnormal → ±0 + else if (biased_exp == 0x7FF) result = (sign << 31) | 0x7F800000u; // inf/nan → ±inf + else if (exp_f32 <= 0) result = sign << 31; // underflow → ±0 + else if (exp_f32 >= 255) result = (sign << 31) | 0x7F800000u; // overflow → ±inf + else result = (sign << 31) | (uint(exp_f32) << 23) | mantissa; + dst[gid] = as_type(result); +} + +// ============================================================================ +// Section 20: Tiled GEMM — C = alpha * op(A) @ op(B) + beta * C +// +// 64x64 tiled simdgroup_matrix GEMM for f32. Supports NT, NN, TN layouts +// via trans_a/trans_b parameters. Runs on the compute encoder. +// ============================================================================ + +struct GemmParams { + int M; // result rows + int N; // result columns + int K; // inner dimension + int lda; // leading dimension of A (physical columns) + int ldb; // leading dimension of B (physical columns) + int ldc; // leading dimension of C (= N) + float alpha; + float beta; + int trans_a; // 0 = no transpose, 1 = transpose + int trans_b; +}; + +constant int BM = 32; // tile rows (shared by ksplit and fp16 register GEMMs) +constant int BN = 32; // tile cols +constant int BK = 16; // tile K dimension +constant int TM = 4; // per-thread tile rows +constant int TN = 4; // per-thread tile cols + +// ============================================================================ +// Section 22: K-split GEMM — for tall-K backward weight-gradient GEMMs +// +// When M and N are small but K is large (e.g. M=128, N=128, K=32768), +// regular sgemm_reg has too few threadgroups (M/32 * N/32 = 16). K-split +// partitions K across multiple TGs in the Z axis, writing partial sums to +// a scratch buffer. A second kernel reduces the partials into C. +// ============================================================================ + +kernel void sgemm_ksplit( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* partials [[buffer(2)]], + constant GemmParams& p [[buffer(3)]], + constant int& k_per_split [[buffer(4)]], + uint3 group_id [[threadgroup_position_in_grid]], + uint3 local_id [[thread_position_in_threadgroup]] +) { + threadgroup float sA[BM][BK]; + threadgroup float sB[BK][BN]; + + int trow = (int)local_id.y; + int tcol = (int)local_id.x; + int tid = trow * (BN / TN) + tcol; + + float acc[TM][TN]; + for (int m = 0; m < TM; m++) + for (int n = 0; n < TN; n++) + acc[m][n] = 0.0f; + + int k_start = (int)group_id.z * k_per_split; + int k_end = min(k_start + k_per_split, p.K); + int num_k_tiles = (k_end - k_start + BK - 1) / BK; + + for (int kt = 0; kt < num_k_tiles; kt++) { + int k_base = k_start + kt * BK; + + for (int i = 0; i < (BM * BK) / 64; i++) { + int idx = tid + i * 64; + int r = idx / BK; + int c = idx % BK; + int gr = (int)group_id.y * BM + r; + int gc = k_base + c; + sA[r][c] = (gr < p.M && gc < k_end) + ? (p.trans_a ? A[gc * p.lda + gr] : A[gr * p.lda + gc]) + : 0.0f; + } + + for (int i = 0; i < (BK * BN) / 64; i++) { + int idx = tid + i * 64; + int r = idx / BN; + int c = idx % BN; + int gr = k_base + r; + int gc = (int)group_id.x * BN + c; + sB[r][c] = (gr < k_end && gc < p.N) + ? (p.trans_b ? B[gc * p.ldb + gr] : B[gr * p.ldb + gc]) + : 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 0; k < BK; k++) { + float a_reg[TM]; + float b_reg[TN]; + for (int m = 0; m < TM; m++) a_reg[m] = sA[trow * TM + m][k]; + for (int n = 0; n < TN; n++) b_reg[n] = sB[k][tcol * TN + n]; + for (int m = 0; m < TM; m++) + for (int n = 0; n < TN; n++) + acc[m][n] += a_reg[m] * b_reg[n]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write partial results: partials[split_idx * M * N + row * N + col] + int split_offset = (int)group_id.z * p.M * p.N; + int row_base = (int)group_id.y * BM + trow * TM; + int col_base = (int)group_id.x * BN + tcol * TN; + for (int m = 0; m < TM; m++) { + for (int n = 0; n < TN; n++) { + int r = row_base + m; + int c = col_base + n; + if (r < p.M && c < p.N) { + partials[split_offset + r * p.N + c] = acc[m][n]; + } + } + } +} + +struct ReduceKsplitParams { + int MN; + int num_splits; + float alpha; + float beta; +}; + +kernel void reduce_ksplit( + device const float* partials [[buffer(0)]], + device float* C [[buffer(1)]], + constant ReduceKsplitParams& p [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if ((int)gid >= p.MN) return; + float sum = 0.0f; + for (int s = 0; s < p.num_splits; s++) { + sum += partials[s * p.MN + (int)gid]; + } + C[gid] = p.alpha * sum + p.beta * C[gid]; +} + +// Reads fp32 grad_logits and grad_value (from PPO kernel), writes fp16 grad_out. +// grad_out[row * od1 + col] = (col < od) ? grad_logits[row * od + col] : grad_value[row] +kernel void assemble_decoder_grad_f32_to_f16( + device half* grad_out [[buffer(0)]], + const device float* grad_logits [[buffer(1)]], + const device float* grad_value [[buffer(2)]], + constant AssembleDecoderGradParams& p [[buffer(3)]], + uint gid [[thread_position_in_grid]] +) { + if ((int)gid >= p.B_TT * p.od_plus_1) return; + int row = (int)gid / p.od_plus_1; + int col = (int)gid % p.od_plus_1; + float val = (col < p.od) ? grad_logits[row * p.od + col] : grad_value[row]; + // Clamp to fp16 range to prevent inf (Metal fp16 max ~65504, unlike CUDA bf16) + grad_out[gid] = half(clamp(val, -65000.0f, 65000.0f)); +} + +kernel void cast_f32_to_f16( + device half* dst [[buffer(0)]], + const device float* src [[buffer(1)]], + constant int& count [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if ((int)gid >= count) return; + dst[gid] = half(src[gid]); +} + +kernel void cast_f16_to_f32( + device float* dst [[buffer(0)]], + const device half* src [[buffer(1)]], + constant int& count [[buffer(2)]], + uint gid [[thread_position_in_grid]] +) { + if ((int)gid >= count) return; + dst[gid] = float(src[gid]); +} + +kernel void mingru_scan_forward_checkpointed_fp16( + device half* out [[buffer(0)]], + device half* next_state [[buffer(1)]], + device float* a_star_buf [[buffer(2)]], + device float* s_buf [[buffer(3)]], + device float* log_values_buf [[buffer(4)]], + const device half* combined [[buffer(5)]], + const device half* state [[buffer(6)]], + const device half* input [[buffer(7)]], + constant ScanParams& p [[buffer(8)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.B * p.H) return; + + int b = (int)idx / p.H; + int h = (int)idx % p.H; + int bH = b * p.H; + int H3 = 3 * p.H; + int bHT = bH * p.T_seq; + int out_base = bHT + h; + int cbase = 3 * bHT; + + float a_star = 0.0f; + float log_value = 0.0f; + float s = log(float(state[bH + h])); + log_value = s; + + int T_out = p.T_seq + 1; + int buf_base = b * T_out * p.H + h; + int buf_curr = buf_base; + a_star_buf[buf_curr] = a_star; + s_buf[buf_curr] = s; + log_values_buf[buf_curr] = log_value; + + int out_curr = out_base; + int t_offset = 0; + + for (int t = 1; t < p.T_seq + 1; t++) { + float hidden_val = float(combined[cbase + h + t_offset]); + float gate_val = float(combined[cbase + p.H + h + t_offset]); + float proj_val = float(combined[cbase + 2 * p.H + h + t_offset]); + float x_val = float(input[out_base + (t - 1) * p.H]); + + float log_coeff_val; + log_coeffs_and_values_fwd(gate_val, hidden_val, log_coeff_val, log_value); + + a_star += log_coeff_val; + + float z = log_value - a_star; + float max_val = fmax(s, z); + s = max_val + log1p_f(exp(-abs(s - z))); + + float scan_result = exp(a_star + s); + float proj_sigmoid = sigmoid_f(proj_val); + float out_val = proj_sigmoid * scan_result + (1.0f - proj_sigmoid) * x_val; + + out[out_curr] = half(clamp(out_val, -65000.0f, 65000.0f)); + + buf_curr += p.H; + out_curr += p.H; + t_offset += H3; + + if (t % CHECKPOINT_INTERVAL == 0) { + a_star_buf[buf_curr] = a_star; + s_buf[buf_curr] = s; + log_values_buf[buf_curr] = log_value; + } + } + + float next_state_val = max(exp(a_star + s), 1e-30f); + next_state[bH + h] = half(min(next_state_val, 65000.0f)); +} + +kernel void mingru_scan_backward_checkpointed_fp16( + device half* grad_combined [[buffer(0)]], + device half* grad_state [[buffer(1)]], + device half* grad_input [[buffer(2)]], + const device half* grad_out [[buffer(3)]], + const device half* grad_next_state [[buffer(4)]], + const device half* combined [[buffer(5)]], + const device half* state [[buffer(6)]], + const device half* input [[buffer(7)]], + const device float* a_star_buf [[buffer(8)]], + const device float* s_buf [[buffer(9)]], + const device float* log_values_buf [[buffer(10)]], + constant ScanParams& p [[buffer(11)]], + uint idx [[thread_position_in_grid]] +) { + if ((int)idx >= p.B * p.H) return; + + int b = (int)idx / p.H; + int h = (int)idx % p.H; + int bHT = b * p.H * p.T_seq; + int cbase = 3 * bHT; + int H3 = 3 * p.H; + int state_idx = b * p.H + h; + int out_base = bHT + h; + + int T_out = p.T_seq + 1; + int buf_base = b * T_out * p.H + h; + + float acc = 0.0f; + float s_val_next = 0.0f; + float carry_grad_a = 0.0f; + + for (int chunk_end = p.T_seq; chunk_end > 0; chunk_end -= CHECKPOINT_INTERVAL) { + int chunk_start = (chunk_end > CHECKPOINT_INTERVAL) ? (chunk_end - CHECKPOINT_INTERVAL) : 0; + int chunk_len = chunk_end - chunk_start; + + float chunk_a_star[CHECKPOINT_INTERVAL]; + float chunk_s[CHECKPOINT_INTERVAL]; + float chunk_log_values[CHECKPOINT_INTERVAL]; + float chunk_hidden[CHECKPOINT_INTERVAL]; + float chunk_gate[CHECKPOINT_INTERVAL]; + + int ckpt_buf_idx = buf_base + chunk_start * p.H; + float recomp_a_star = a_star_buf[ckpt_buf_idx]; + float recomp_s = s_buf[ckpt_buf_idx]; + float recomp_log_value = log_values_buf[ckpt_buf_idx]; + + for (int i = 0; i < chunk_len; i++) { + int t = chunk_start + 1 + i; + int t_offset = (t - 1) * H3; + float hv = float(combined[cbase + h + t_offset]); + float gv = float(combined[cbase + p.H + h + t_offset]); + + float lc; + log_coeffs_and_values_fwd(gv, hv, lc, recomp_log_value); + recomp_a_star += lc; + + float z = recomp_log_value - recomp_a_star; + float mv = fmax(recomp_s, z); + recomp_s = mv + log1p_f(exp(-abs(recomp_s - z))); + + chunk_a_star[i] = recomp_a_star; + chunk_s[i] = recomp_s; + chunk_log_values[i] = recomp_log_value; + chunk_hidden[i] = hv; + chunk_gate[i] = gv; + } + + for (int i = chunk_len - 1; i >= 0; i--) { + int t = chunk_start + 1 + i; + int t_offset = (t - 1) * H3; + + float a_star_t = chunk_a_star[i]; + float s_t = chunk_s[i]; + float log_value_t = chunk_log_values[i]; + float hidden_val = chunk_hidden[i]; + float gate_val = chunk_gate[i]; + float proj_val = float(combined[cbase + 2 * p.H + h + t_offset]); + int input_idx = out_base + (t - 1) * p.H; + float x_val = float(input[input_idx]); + + float scan_result = exp(a_star_t + s_t); + float z = log_value_t - a_star_t; + + float grad_out_val = float(grad_out[input_idx]); + float grad_scan_from_next = (t == p.T_seq) ? float(grad_next_state[state_idx]) : 0.0f; + + float proj_sigmoid = sigmoid_f(proj_val); + float grad_scan_result = grad_scan_from_next + grad_out_val * proj_sigmoid; + float grad_proj = grad_out_val * (scan_result - x_val) * proj_sigmoid * (1.0f - proj_sigmoid); + float grad_input_val = grad_out_val * (1.0f - proj_sigmoid); + grad_input[input_idx] = half(clamp(grad_input_val, -65000.0f, 65000.0f)); + + float grad_log_h = grad_scan_result * scan_result; + float grad_s = grad_log_h; + + if (t == p.T_seq) { + acc = grad_s; + } else { + acc = grad_s + acc * exp(s_t - s_val_next); + } + float grad_z = acc * exp(z - s_t); + s_val_next = s_t; + + float grad_a = grad_log_h + carry_grad_a - grad_z; + carry_grad_a = grad_a; + + float grad_g, grad_h; + log_coeffs_and_values_bwd(grad_a, grad_z, gate_val, hidden_val, grad_g, grad_h); + + // Clamp to fp16 range to prevent inf (Metal fp16 max ~65504) + grad_combined[cbase + h + t_offset] = half(clamp(grad_h, -65000.0f, 65000.0f)); + grad_combined[cbase + p.H + h + t_offset] = half(clamp(grad_g, -65000.0f, 65000.0f)); + grad_combined[cbase + 2 * p.H + h + t_offset] = half(clamp(grad_proj, -65000.0f, 65000.0f)); + } + } + + int ckpt_0_idx = buf_base; + float a_star_0 = a_star_buf[ckpt_0_idx]; + float s_0 = s_buf[ckpt_0_idx]; + float log_value_0 = log_values_buf[ckpt_0_idx]; + + acc = acc * exp(s_0 - s_val_next); + float grad_z_0 = acc * exp((log_value_0 - a_star_0) - s_0); + + float state_val = float(state[state_idx]); + float grad_state_val = (state_val > 0.0f) ? (grad_z_0 / state_val) : 0.0f; + grad_state[state_idx] = half(clamp(grad_state_val, -65000.0f, 65000.0f)); +} + +// ============================================================================ +// Section 26: Steel GEMM — C = alpha * op(A) @ op(B) + beta * C +// +// MLX-inspired 64x64 tiled GEMM using simdgroup_matrix (Apple Silicon M3+). +// 4 simdgroups in 2x2 layout (128 threads), each computing 32x32 output +// via a 4x4 grid of 8x8 simdgroup_matrix multiply-accumulate operations. +// +// Hot loop uses direct device memory loads (ThunderMittens-validated: +// Apple Silicon's L2 cache provides effective data reuse without explicit +// threadgroup staging, eliminating barrier overhead). K-remainder and edge +// tile stores use threadgroup fallback. +// +// ============================================================================ + +kernel void steel_gemm( + device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant GemmParams& p [[buffer(3)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint sgid [[simdgroup_index_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]] +) { + // 64x64 output tile, 4 simdgroups (2x2), each 32x32 = 4x4 grid of 8x8 + constexpr int BM = 64, BN = 64; + constexpr int WM = 2, WN = 2; + constexpr int TM = BM / (8 * WM); // 4 + constexpr int TN = BN / (8 * WN); // 4 + + const int bm = (int)tgid.y * BM; + const int bn = (int)tgid.x * BN; + const int wm = (int)(sgid / WN); + const int wn = (int)(sgid % WN); + const int tid = (int)(sgid * 32 + lane); + + // Per-simdgroup 32x32 output starts at (sm, sn) + const int sm = bm + wm * 32; + const int sn = bn + wn * 32; + + // 4x4 accumulator grid of 8x8 simdgroup matrices (16 per simd group) + simdgroup_float8x8 acc[TM][TN]; + for (int i = 0; i < TM; i++) + for (int j = 0; j < TN; j++) + acc[i][j] = simdgroup_float8x8(0); + + const int K_aligned = (p.K / 8) * 8; + + // ---- Main loop: direct device loads, no threadgroup, no barriers ---- + // Each simd group loads its own A and B fragments independently. + // L2 cache provides cross-simdgroup data reuse (validated by + // ThunderMittens: 9% faster than MLX's threadgroup approach on M2 Pro). + + if (!p.trans_a && p.trans_b) { + // NT: C = A(M,K) @ B(N,K)^T — forward pass, Muon + for (int k = 0; k < K_aligned; k += 8) { + simdgroup_float8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], A + (long)(sm + i*8) * p.lda + k, p.lda); + for (int j = 0; j < TN; j++) { + simdgroup_float8x8 b_frag; + simdgroup_load(b_frag, B + (long)(sn + j*8) * p.ldb + k, p.ldb, + ulong2(0,0), true); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + } else if (!p.trans_a && !p.trans_b) { + // NN: C = A(M,K) @ B(K,N) — backward input grad, Muon addmm + for (int k = 0; k < K_aligned; k += 8) { + simdgroup_float8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], A + (long)(sm + i*8) * p.lda + k, p.lda); + for (int j = 0; j < TN; j++) { + simdgroup_float8x8 b_frag; + simdgroup_load(b_frag, B + (long)k * p.ldb + sn + j*8, p.ldb); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + } else if (p.trans_a && !p.trans_b) { + // TN: C = A(K,M)^T @ B(K,N) — backward weight grad + for (int k = 0; k < K_aligned; k += 8) { + simdgroup_float8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], A + (long)k * p.lda + sm + i*8, p.lda, + ulong2(0,0), true); + for (int j = 0; j < TN; j++) { + simdgroup_float8x8 b_frag; + simdgroup_load(b_frag, B + (long)k * p.ldb + sn + j*8, p.ldb); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + } + + // Single threadgroup allocation shared between K-remainder and store phases. + // These phases are sequential (barriers between), so memory is safely reused. + // The compiler doubles explicit threadgroup memory (simdgroup register spill), + // so separate arrays for sA+sB+sC exceed 32KB. Aliasing them into one buffer + // keeps total at BM*BN*4*2 = 32768 = limit. + constexpr int SMEM_STRIDE = BN; + threadgroup float _smem[BM * SMEM_STRIDE]; + + // ---- K-remainder: threadgroup fallback for last partial chunk ---- + // Only triggered when K % 8 != 0 (e.g., K=373). Loads remaining + // elements into zero-padded 8-wide threadgroup tiles. + // Reinterprets _smem as sA (BM×9 float) and sB (8×(BN+1) float). + if (K_aligned < p.K) { + threadgroup float* sA = _smem; // BM*9 = 576 floats + threadgroup float* sB = _smem + BM * 9; // 8*(BN+1) = 520 floats + constexpr int sB_stride = BN + 1; + + int k = K_aligned; + int rem = p.K - k; + + // Cooperative load A: BM × 8 (zero-padded beyond rem) + int total_a = BM * 8; + for (int idx = tid; idx < total_a; idx += 128) { + int r = idx / 8; + int c = idx % 8; + int gr = bm + r; + int gc = k + c; + sA[r * 9 + c] = (gr < p.M && c < rem) + ? (p.trans_a ? A[(long)gc * p.lda + gr] : A[(long)gr * p.lda + gc]) + : 0.0f; + } + + // Cooperative load B: 8 × BN (zero-padded beyond rem) + int total_b = 8 * BN; + for (int idx = tid; idx < total_b; idx += 128) { + int r = idx / BN; + int c = idx % BN; + int gr = k + r; + int gc = bn + c; + sB[r * sB_stride + c] = (r < rem && gc < p.N) + ? (p.trans_b ? B[(long)gc * p.ldb + gr] : B[(long)gr * p.ldb + gc]) + : 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + simdgroup_float8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], sA + (wm * 32 + i * 8) * 9, 9); + for (int j = 0; j < TN; j++) { + simdgroup_float8x8 b_frag; + simdgroup_load(b_frag, sB + (wn * 32 + j * 8), sB_stride); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + + // ---- Store results ---- + // Fast path: direct simdgroup_store for interior tiles with alpha=1, beta=0. + // Slow path: threadgroup staging via _smem for edge tiles and non-trivial alpha/beta. + // CRITICAL: condition must be uniform across the threadgroup — all simdgroups + // must take the same branch, otherwise the threadgroup_barrier in the slow + // path causes undefined behavior (some threads skip it). + bool fast_store = (p.alpha == 1.0f && p.beta == 0.0f + && bm + BM <= p.M && bn + BN <= p.N); + if (fast_store) { + for (int i = 0; i < TM; i++) + for (int j = 0; j < TN; j++) + simdgroup_store(acc[i][j], + C + (long)(sm + i*8) * p.ldc + sn + j*8, p.ldc); + } else { + for (int i = 0; i < TM; i++) + for (int j = 0; j < TN; j++) + simdgroup_store(acc[i][j], + _smem + (wm*32 + i*8) * SMEM_STRIDE + (wn*32 + j*8), + SMEM_STRIDE); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Cooperative store to device: each thread handles BM*BN/128 = 32 elements + for (int idx = tid; idx < BM * BN; idx += 128) { + int r = idx / BN; + int c = idx % BN; + int gr = bm + r; + int gc = bn + c; + if (gr < p.M && gc < p.N) { + long out_idx = (long)gr * p.ldc + gc; + float val = _smem[r * SMEM_STRIDE + c]; + if (p.beta == 0.0f) + C[out_idx] = p.alpha * val; + else + C[out_idx] = p.alpha * val + p.beta * C[out_idx]; + } + } + } +} + +// Used when N doesn't meet tensor_ops alignment (N%32!=0). +// One threadgroup per output row, threads partition N columns. +// C(M,N) = A(M,K) @ B(N,K)^T, all row-major. + +struct SmallGemmParams { + uint M; + uint N; + uint K; +}; + +kernel void small_gemm_nt_f32( + const device float* A [[buffer(0)]], + const device float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant SmallGemmParams& p [[buffer(3)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]]) +{ + // tgid = output row, tid = output column + uint m = tgid; + if (tid >= p.N) return; + + const device float* a_row = A + m * p.K; + const device float* b_row = B + tid * p.K; + + float sum = 0.0f; + // float4 vectorized accumulation (K must be multiple of 4) + uint K4 = p.K & ~3u; + for (uint k = 0; k < K4; k += 4) { + float4 a4 = *reinterpret_cast(a_row + k); + float4 b4 = *reinterpret_cast(b_row + k); + sum += dot(a4, b4); + } + // handle remainder (K not multiple of 4) + for (uint k = K4; k < p.K; k++) + sum += a_row[k] * b_row[k]; + + C[m * p.N + tid] = sum; +} + +// ============================================================================ +// Section 27: Steel GEMM fp16 — half I/O, float accumulation +// +// Same 64x64 tiled GEMM as steel_gemm but with half-precision inputs/outputs. +// Uses simdgroup_half8x8 for loads, simdgroup_float8x8 for accumulation +// (mixed-precision multiply_accumulate). Stores back as half. +// ============================================================================ + +kernel void steel_gemm_f16( + device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant GemmParams& p [[buffer(3)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint sgid [[simdgroup_index_in_threadgroup]], + uint lane [[thread_index_in_simdgroup]] +) { + constexpr int BM = 64, BN = 64; + constexpr int WM = 2, WN = 2; + constexpr int TM = BM / (8 * WM); // 4 + constexpr int TN = BN / (8 * WN); // 4 + + const int bm = (int)tgid.y * BM; + const int bn = (int)tgid.x * BN; + const int wm = (int)(sgid / WN); + const int wn = (int)(sgid % WN); + const int tid = (int)(sgid * 32 + lane); + + const int sm = bm + wm * 32; + const int sn = bn + wn * 32; + + // f32 accumulators for numerical stability + simdgroup_float8x8 acc[TM][TN]; + for (int i = 0; i < TM; i++) + for (int j = 0; j < TN; j++) + acc[i][j] = simdgroup_float8x8(0); + + const int K_aligned = (p.K / 8) * 8; + + if (!p.trans_a && p.trans_b) { + // NT: C = A(M,K) @ B(N,K)^T + for (int k = 0; k < K_aligned; k += 8) { + simdgroup_half8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], A + (long)(sm + i*8) * p.lda + k, p.lda); + for (int j = 0; j < TN; j++) { + simdgroup_half8x8 b_frag; + simdgroup_load(b_frag, B + (long)(sn + j*8) * p.ldb + k, p.ldb, + ulong2(0,0), true); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + } else if (!p.trans_a && !p.trans_b) { + // NN: C = A(M,K) @ B(K,N) + for (int k = 0; k < K_aligned; k += 8) { + simdgroup_half8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], A + (long)(sm + i*8) * p.lda + k, p.lda); + for (int j = 0; j < TN; j++) { + simdgroup_half8x8 b_frag; + simdgroup_load(b_frag, B + (long)k * p.ldb + sn + j*8, p.ldb); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + } else if (p.trans_a && !p.trans_b) { + // TN: C = A(K,M)^T @ B(K,N) + for (int k = 0; k < K_aligned; k += 8) { + simdgroup_half8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], A + (long)k * p.lda + sm + i*8, p.lda, + ulong2(0,0), true); + for (int j = 0; j < TN; j++) { + simdgroup_half8x8 b_frag; + simdgroup_load(b_frag, B + (long)k * p.ldb + sn + j*8, p.ldb); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + } + + // Single threadgroup allocation shared between K-remainder and store phases. + // These phases are sequential (barriers between), so memory is safely reused. + // Store needs BM*BN floats = 16384 bytes (largest consumer). + // K-remainder needs BM*9 + 8*(BN+1) halves = 2192 bytes (fits inside). + // Without aliasing, the compiler allocates all three arrays statically and + // exceeds the 32KB threadgroup memory limit. The compiler doubles explicit + // threadgroup memory (simdgroup register spill), so BM*BN*4*2 = 32768 = limit. + // No +1 stride padding: minor bank conflict on simdgroup_store vs. correctness. + constexpr int SMEM_STRIDE = BN; + threadgroup float _smem[BM * SMEM_STRIDE]; + + // K-remainder: reinterpret _smem as half arrays for partial-K accumulation + if (K_aligned < p.K) { + threadgroup half* sA = (threadgroup half*)_smem; + threadgroup half* sB = sA + BM * 9; + constexpr int sB_stride = BN + 1; + + int k = K_aligned; + int rem = p.K - k; + + int total_a = BM * 8; + for (int idx = tid; idx < total_a; idx += 128) { + int r = idx / 8; + int c = idx % 8; + int gr = bm + r; + int gc = k + c; + sA[r * 9 + c] = (gr < p.M && c < rem) + ? (p.trans_a ? A[(long)gc * p.lda + gr] : A[(long)gr * p.lda + gc]) + : half(0); + } + + int total_b = 8 * BN; + for (int idx = tid; idx < total_b; idx += 128) { + int r = idx / BN; + int c = idx % BN; + int gr = k + r; + int gc = bn + c; + sB[r * sB_stride + c] = (r < rem && gc < p.N) + ? (p.trans_b ? B[(long)gc * p.ldb + gr] : B[(long)gr * p.ldb + gc]) + : half(0); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + simdgroup_half8x8 a_frag[TM]; + for (int i = 0; i < TM; i++) + simdgroup_load(a_frag[i], sA + (wm * 32 + i * 8) * 9, 9); + for (int j = 0; j < TN; j++) { + simdgroup_half8x8 b_frag; + simdgroup_load(b_frag, sB + (wn * 32 + j * 8), sB_stride); + for (int i = 0; i < TM; i++) + simdgroup_multiply_accumulate(acc[i][j], a_frag[i], b_frag, acc[i][j]); + } + } + + // Barrier: K-remainder reads _smem as half, store writes it as float. + // Without this, a fast simdgroup's float writes corrupt a slow one's half reads. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Store: f32 acc → half output via _smem staging (reuses same memory) + // simdgroup_store of float8x8 requires float* destination, so we stage + // through _smem then convert element-by-element to half for output. + { + for (int i = 0; i < TM; i++) + for (int j = 0; j < TN; j++) + simdgroup_store(acc[i][j], + _smem + (wm*32 + i*8) * SMEM_STRIDE + (wn*32 + j*8), + SMEM_STRIDE); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int idx = tid; idx < BM * BN; idx += 128) { + int r = idx / BN; + int c = idx % BN; + int gr = bm + r; + int gc = bn + c; + if (gr < p.M && gc < p.N) { + long out_idx = (long)gr * p.ldc + gc; + float val = _smem[r * SMEM_STRIDE + c]; + if (p.beta == 0.0f) + C[out_idx] = half(p.alpha * val); + else + C[out_idx] = half(p.alpha * val + p.beta * float(C[out_idx])); + } + } + } +} + +)METAL"; +} + +#endif // PUFFERLIB_METAL_SHADER_SRC_H diff --git a/src/puf_types.h b/src/puf_types.h new file mode 100644 index 0000000000..92aa6d81a0 --- /dev/null +++ b/src/puf_types.h @@ -0,0 +1,596 @@ +#ifndef PUFFERLIB_PUF_TYPES_H +#define PUFFERLIB_PUF_TYPES_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensor.h" + +inline int puf_ndim(const int64_t* shape) { + int n = 0; + while (n < PUF_MAX_DIMS && shape[n] != 0) n++; + return n; +} + +inline int64_t puf_numel(const int64_t* shape) { + int64_t n = 1; + for (int i = 0; i < PUF_MAX_DIMS && shape[i] != 0; i++) n *= shape[i]; + return n; +} + +inline int64_t puf_batch_size(const int64_t* shape) { + int n = puf_ndim(shape); + int64_t b = 1; + for (int i = 0; i < n - 2; i++) b *= shape[i]; + return b; +} + +using std::vector; + +#define PUF_HD +typedef void *cudaStream_t; +#define CUDA_STREAM_T_DEFINED + +constexpr bool USE_BF16 = false; + +struct PufTensor { + char *bytes = nullptr; + int64_t shape[PUF_MAX_DIMS] = {}; + int dtype_size = 0; + + PUF_HD int ndim() const { + return puf_ndim(shape); + } + + PUF_HD int64_t numel() const { + return puf_numel(shape); + } + + PufTensor squeeze(int dim) { + int n = ndim(); + shape[dim + 1] *= shape[dim]; + for (int i = dim; i < n - 1; i++) + shape[i] = shape[i + 1]; + shape[n - 1] = 0; + return *this; + } + + PufTensor unsqueeze(int dim, int64_t d0, int64_t d1) { + assert(d0 * d1 == shape[dim] && "unsqueeze: d0 * d1 must equal shape[dim]"); + int n = ndim(); + for (int i = n; i > dim; i--) + shape[i] = shape[i - 1]; + shape[dim] = d0; + shape[dim + 1] = d1; + return *this; + } + + int64_t batch_size() const { + return puf_batch_size(shape); + } + + const char *dtype_name() const { + switch (dtype_size) { + case 1: + return "i8"; + case 2: + return "f16"; + case 4: + return "f32"; + case 8: + return "f64"; + default: + return "?"; + } + } + + const char *repr() const { + static char buf[256]; + if (!bytes) { + snprintf(buf, sizeof(buf), "PufTensor(empty)"); + return buf; + } + int pos = snprintf(buf, sizeof(buf), "PufTensor(%s, [", dtype_name()); + for (int i = 0; i < ndim() && pos < (int)sizeof(buf) - 32; i++) { + pos += snprintf(buf + pos, sizeof(buf) - pos, "%s%lld", i ? ", " : "", + (long long)shape[i]); + } + snprintf(buf + pos, sizeof(buf) - pos, "], %lld elems)", + (long long)puf_numel(shape)); + return buf; + } +}; + +enum LossIdx { + LOSS_PG = 0, + LOSS_VF = 1, + LOSS_ENT = 2, + LOSS_TOTAL = 3, + LOSS_OLD_APPROX_KL = 4, + LOSS_APPROX_KL = 5, + LOSS_CLIPFRAC = 6, + LOSS_N = 7, + NUM_LOSSES = 8, +}; + +struct PrefixScan { + void *combined_ptr = nullptr; + void *state_ptr = nullptr; + void *input_ptr = nullptr; + int B = 0, T = 0, H = 0; + FloatTensor a_star, s_vals, log_values_buf; + PrecisionTensor out, next_state; + PrecisionTensor grad_combined, grad_state, grad_input; +}; + +struct AllocEntry { + void **data_ptr; + int64_t *shape; + int elem_size; +}; + +struct Allocator { + std::vector regs; + void *mem = nullptr; + int64_t total_elems = 0; + + void create() { + int64_t total_bytes = 0; + total_elems = 0; + + for (auto &e : regs) { + total_bytes = (total_bytes + 15) & ~15; + int64_t n = puf_numel(e.shape); + total_bytes += n * e.elem_size; + total_elems += n; + } + + if (total_bytes > 0) { +#ifdef __CUDACC__ + cudaMalloc(&mem, total_bytes); + cudaMemset(mem, 0, total_bytes); +#elif defined(WITH_METAL) + posix_memalign(&mem, 16384, total_bytes); + memset(mem, 0, total_bytes); +#else + mem = calloc(1, total_bytes); +#endif + int64_t offset = 0; + for (auto &e : regs) { + offset = (offset + 15) & ~15; + *e.data_ptr = (char *)mem + offset; + offset += puf_numel(e.shape) * e.elem_size; + } + } + } + + void destroy() { +#ifdef __CUDACC__ + if (mem) { + cudaFree(mem); + mem = nullptr; + } +#else + if (mem) { + free(mem); + mem = nullptr; + } +#endif + } +}; + +inline void alloc_register(Allocator *a, FloatTensor *t) { + a->regs.push_back({(void **)&t->data, t->shape, (int)sizeof(float)}); +} +inline void alloc_register(Allocator *a, PrecisionTensor *t) { + assert((t->dtype_size == 2 || t->dtype_size == 4) && + "alloc_register: unsupported precision tensor dtype"); + a->regs.push_back({(void **)&t->data, t->shape, t->dtype_size}); +} +inline void alloc_register(Allocator *a, IntTensor *t) { + a->regs.push_back({(void **)&t->data, t->shape, (int)sizeof(int)}); +} +inline void alloc_register(Allocator *a, LongTensor *t) { + a->regs.push_back({(void **)&t->data, t->shape, (int)sizeof(long)}); +} +inline void alloc_register(Allocator *a, PufTensor *t) { + a->regs.push_back({(void **)&t->bytes, t->shape, t->dtype_size}); +} + +struct AllocSet { + Allocator params, grads, acts; + int esz = 0; + void create() { + params.create(); + grads.create(); + acts.create(); + } + void destroy() { + params.destroy(); + grads.destroy(); + acts.destroy(); + } +}; + +struct PrioBuffers { + FloatTensor prio_probs, cdf, mb_prio; + LongTensor idx; +}; + +inline void register_prio_buffers(PrioBuffers &bufs, Allocator &alloc, int S, + int minibatch_segments) { + bufs = (PrioBuffers){ + .prio_probs = {.shape = {S}}, + .cdf = {.shape = {S}}, + .mb_prio = {.shape = {minibatch_segments, 1}}, + .idx = {.shape = {minibatch_segments}}, + }; + alloc_register(&alloc, &bufs.prio_probs); + alloc_register(&alloc, &bufs.cdf); + alloc_register(&alloc, &bufs.mb_prio); + alloc_register(&alloc, &bufs.idx); +} + +struct PPOBuffersPuf { + FloatTensor loss_output; + FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch; +}; + +inline void register_ppo_buffers(PPOBuffersPuf &bufs, Allocator &alloc, int N, + int T, int A_total, bool is_continuous) { + bufs = (PPOBuffersPuf){ + .loss_output = {.shape = {1}}, + .grad_logits = {.shape = {N, T, A_total}}, + .grad_values = {.shape = {N, T, 1}}, + .grad_logstd = {.shape = {N, T, A_total}}, + .adv_scratch = {.shape = {2}}, + }; + alloc_register(&alloc, &bufs.loss_output); + alloc_register(&alloc, &bufs.grad_logits); + alloc_register(&alloc, &bufs.grad_values); + if (is_continuous) + alloc_register(&alloc, &bufs.grad_logstd); + alloc_register(&alloc, &bufs.adv_scratch); +} + +struct RolloutBuf { + FloatTensor observations; + FloatTensor actions; + FloatTensor values; + FloatTensor logprobs; + FloatTensor rewards; + FloatTensor terminals; + FloatTensor ratio; + FloatTensor importance; +}; + +inline void register_rollout_buffers(RolloutBuf &bufs, Allocator &alloc, int H, + int S, int input_size, int num_atns) { + bufs.observations = {.shape = {H, S, input_size}}; + bufs.actions = {.shape = {H, S, num_atns}}; + bufs.values = {.shape = {H, S}}; + bufs.logprobs = {.shape = {H, S}}; + bufs.rewards = {.shape = {H, S}}; + bufs.terminals = {.shape = {H, S}}; + bufs.ratio = {.shape = {H, S}}; + bufs.importance = {.shape = {H, S}}; + alloc_register(&alloc, &bufs.observations); + alloc_register(&alloc, &bufs.actions); + alloc_register(&alloc, &bufs.values); + alloc_register(&alloc, &bufs.logprobs); + alloc_register(&alloc, &bufs.rewards); + alloc_register(&alloc, &bufs.terminals); + alloc_register(&alloc, &bufs.ratio); + alloc_register(&alloc, &bufs.importance); +} + +struct TrainGraph { + FloatTensor mb_obs; + FloatTensor mb_state; + FloatTensor mb_actions; + FloatTensor mb_logprobs; + FloatTensor mb_advantages; + FloatTensor mb_prio; + FloatTensor mb_values; + FloatTensor mb_returns; + FloatTensor mb_ratio; + FloatTensor mb_newvalue; +}; + +inline void register_train_buffers(TrainGraph &bufs, Allocator &alloc, int S, + int H, int input_size, int hidden_size, + int num_atns, int num_layers) { + bufs.mb_obs = {.shape = {S, H, input_size}}; + bufs.mb_state = {.shape = {num_layers, S, 1, hidden_size}}; + bufs.mb_actions = {.shape = {S, H, num_atns}}; + bufs.mb_logprobs = {.shape = {S, H}}; + bufs.mb_advantages = {.shape = {S, H}}; + bufs.mb_prio = {.shape = {S, 1}}; + bufs.mb_values = {.shape = {S, H}}; + bufs.mb_returns = {.shape = {S, H}}; + bufs.mb_ratio = {.shape = {S, H}}; + bufs.mb_newvalue = {.shape = {S, H, 1}}; + alloc_register(&alloc, &bufs.mb_obs); + alloc_register(&alloc, &bufs.mb_state); + alloc_register(&alloc, &bufs.mb_actions); + alloc_register(&alloc, &bufs.mb_logprobs); + alloc_register(&alloc, &bufs.mb_advantages); + alloc_register(&alloc, &bufs.mb_prio); + alloc_register(&alloc, &bufs.mb_values); + alloc_register(&alloc, &bufs.mb_returns); + alloc_register(&alloc, &bufs.mb_ratio); + alloc_register(&alloc, &bufs.mb_newvalue); +} + +typedef void (*init_weights_fn)(void *weights, uint64_t *seed, + cudaStream_t stream); +typedef void (*reg_params_fn)(void *weights, Allocator *alloc, int esz); +typedef void (*reg_train_fn)(void *weights, void *buf, Allocator *acts, + Allocator *grads, int B_TT, int precision); +typedef void (*reg_rollout_fn)(void *weights, void *buf, Allocator *alloc, + int B); +typedef PrecisionTensor (*forward_fn)(void *weights, void *activations, + PrecisionTensor input, cudaStream_t stream); +typedef void (*encoder_backward_fn)(void *weights, void *activations, + PrecisionTensor grad, cudaStream_t stream); +typedef PrecisionTensor (*decoder_backward_fn)(void *weights, void *activations, + FloatTensor grad_logits, + FloatTensor grad_logstd, + FloatTensor grad_value, + cudaStream_t stream); +typedef PrecisionTensor (*network_forward_fn)(void *weights, PrecisionTensor x, + PrecisionTensor state, void *activations, + cudaStream_t stream); +typedef PrecisionTensor (*network_forward_train_fn)(void *weights, PrecisionTensor x, + PrecisionTensor state, + void *activations, + cudaStream_t stream); +typedef PrecisionTensor (*network_backward_fn)(void *weights, PrecisionTensor grad, + void *activations, + cudaStream_t stream); + +struct Encoder { + forward_fn forward; + encoder_backward_fn backward; + init_weights_fn init_weights; + reg_params_fn reg_params; + reg_train_fn reg_train; + reg_rollout_fn reg_rollout; +}; + +struct Decoder { + forward_fn forward; + decoder_backward_fn backward; + init_weights_fn init_weights; + reg_params_fn reg_params; + reg_train_fn reg_train; + reg_rollout_fn reg_rollout; +}; + +struct Network { + network_forward_fn forward; + network_forward_train_fn forward_train; + network_backward_fn backward; + init_weights_fn init_weights; + reg_params_fn reg_params; + reg_train_fn reg_train; + reg_rollout_fn reg_rollout; +}; + +struct EncoderWeights { + PrecisionTensor weight; + int in_dim, out_dim; +}; +struct EncoderActivations { + PrecisionTensor out; + PrecisionTensor saved_input; + PrecisionTensor wgrad; +}; + +struct DecoderWeights { + PrecisionTensor weight; + PrecisionTensor logstd; + int hidden_dim, output_dim; + bool continuous; +}; +struct DecoderActivations { + PrecisionTensor out; + PrecisionTensor grad_out; + PrecisionTensor saved_input; + PrecisionTensor grad_input; + PrecisionTensor wgrad; + PrecisionTensor logstd_scratch; +}; + +struct MinGRUActivations { + int num_layers; + vector combined; + PrecisionTensor out; + PrecisionTensor next_state; + vector saved_inputs; + vector scan_bufs; + vector combined_bufs; + vector wgrad_scratch; + PrecisionTensor grad_input_buf; + PrecisionTensor grad_next_state; +}; + +struct MinGRUWeights { + int hidden, num_layers, horizon; + vector weights; +}; + +struct Policy { + Encoder encoder; + Decoder decoder; + Network network; + int input_dim, hidden_dim, output_dim; + int num_atns; +}; + +struct PolicyActivations { + void *encoder; + void *decoder; + void *network; +}; +struct PolicyWeights { + void *encoder; + void *decoder; + void *network; +}; + +inline PrecisionTensor *puf_squeeze(PrecisionTensor *t, int dim) { + int n = puf_ndim(t->shape); + t->shape[dim + 1] *= t->shape[dim]; + for (int i = dim; i < n - 1; i++) t->shape[i] = t->shape[i + 1]; + t->shape[n - 1] = 0; + return t; +} + +inline FloatTensor *puf_squeeze(FloatTensor *t, int dim) { + int n = puf_ndim(t->shape); + t->shape[dim + 1] *= t->shape[dim]; + for (int i = dim; i < n - 1; i++) t->shape[i] = t->shape[i + 1]; + t->shape[n - 1] = 0; + return t; +} + +inline PrecisionTensor *puf_unsqueeze(PrecisionTensor *t, int dim, int64_t d0, int64_t d1) { + int n = puf_ndim(t->shape); + for (int i = n; i > dim; i--) t->shape[i] = t->shape[i - 1]; + t->shape[dim] = d0; + t->shape[dim + 1] = d1; + return t; +} + +inline PrecisionTensor policy_forward(Policy *p, PolicyWeights &w, + PolicyActivations &activations, + PrecisionTensor obs, + PrecisionTensor state, + cudaStream_t stream) { + PrecisionTensor enc_out = + p->encoder.forward(w.encoder, activations.encoder, obs, stream); + PrecisionTensor h = p->network.forward(w.network, enc_out, state, + activations.network, stream); + return p->decoder.forward(w.decoder, activations.decoder, h, stream); +} + +inline PrecisionTensor policy_forward_train(Policy *p, PolicyWeights &w, + PolicyActivations &activations, + PrecisionTensor x, + PrecisionTensor state, + cudaStream_t stream) { + int B = x.shape[0], TT = x.shape[1]; + PrecisionTensor h = + p->encoder.forward(w.encoder, activations.encoder, *puf_squeeze(&x, 0), stream); + h = p->network.forward_train(w.network, *puf_unsqueeze(&h, 0, B, TT), state, + activations.network, stream); + PrecisionTensor dec_out = + p->decoder.forward(w.decoder, activations.decoder, *puf_squeeze(&h, 0), stream); + return *puf_unsqueeze(&dec_out, 0, B, TT); +} + +inline void policy_backward(Policy *p, PolicyWeights &w, + PolicyActivations &activations, + FloatTensor grad_logits, FloatTensor grad_logstd, + FloatTensor grad_value, cudaStream_t stream) { + int B = grad_logits.shape[0], TT = grad_logits.shape[1]; + PrecisionTensor grad_h = p->decoder.backward(w.decoder, activations.decoder, + *puf_squeeze(&grad_logits, 0), grad_logstd, + *puf_squeeze(&grad_value, 0), stream); + grad_h = p->network.backward(w.network, *puf_unsqueeze(&grad_h, 0, B, TT), + activations.network, stream); + p->encoder.backward(w.encoder, activations.encoder, grad_h, stream); +} + +inline float cosine_annealing(float lr_base, float lr_min, int t, int T) { + if (T == 0) + return lr_base; + float ratio = (float)t / (float)T; + ratio = std::max(0.0f, std::min(1.0f, ratio)); + return lr_min + 0.5f * (lr_base - lr_min) * (1.0f + std::cos(M_PI * ratio)); +} + +static constexpr double ns_coeffs[5][3] = { + {4.0848, -6.8946, 2.9270}, {3.9505, -6.3029, 2.6377}, + {3.7418, -5.5913, 2.3037}, {2.8769, -3.1427, 1.2046}, + {2.8366, -3.0525, 1.2012}, +}; + +struct NSScratch { + PufTensor x, A, gram, tmp; + PufTensor result_f32; + float *norm_ptr = nullptr; + int64_t max_M = 0; + int64_t max_N = 0; +}; + +inline PufTensor ns_slice(PufTensor &buf, int64_t rows, int64_t cols) { + return { + .bytes = buf.bytes, .shape = {rows, cols}, .dtype_size = buf.dtype_size}; +} + +struct Muon { + double momentum; + double weight_decay; + float lr_val_init; + int ns_iters; + float *lr_ptr; + float *lr_derived_ptr; + FloatTensor lr_puf, lr_derived_puf; + FloatTensor ns_norm_puf; + FloatTensor wb_puf, mb_puf, gc_puf, up_puf; + NSScratch ns; + Allocator *param_alloc; +}; + +inline PufTensor puf_slice(PufTensor &p, int t, int start, int count) { + if (p.ndim() == 3) { + int64_t S = p.shape[1], F = p.shape[2]; + return {.bytes = p.bytes + (t * S + start) * F * p.dtype_size, + .shape = {count, F}, + .dtype_size = p.dtype_size}; + } else { + int64_t S = p.shape[1]; + return {.bytes = p.bytes + (t * S + start) * p.dtype_size, + .shape = {count}, + .dtype_size = p.dtype_size}; + } +} + +inline FloatTensor puf_slice(FloatTensor &p, int t, int start, int count) { + if (puf_ndim(p.shape) == 3) { + int64_t S = p.shape[1], F = p.shape[2]; + return {.data = p.data + (t * S + start) * F, .shape = {count, F}}; + } else { + int64_t S = p.shape[1]; + return {.data = p.data + (t * S + start), .shape = {count}}; + } +} + +inline PrecisionTensor mingru_state_layer(PrecisionTensor &state, int i) { + int64_t B = state.shape[1], H = state.shape[2]; + PrecisionTensor layer = {}; + layer.data = (decltype(state.data))((char *)state.data + + i * B * H * state.dtype_size); + layer.shape[0] = B; + layer.shape[1] = H; + layer.dtype_size = state.dtype_size; + return layer; +} + +struct EnvBuf { + PufTensor obs; + int obs_raw_dtype; + PufTensor actions; + FloatTensor rewards; + FloatTensor terminals; +}; + +#endif // PUFFERLIB_PUF_TYPES_H diff --git a/src/tensor.h b/src/tensor.h index 8bf00d1c5f..2f01429e6a 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -29,7 +29,15 @@ typedef struct { typedef struct { precision_t* data; int64_t shape[PUF_MAX_DIMS]; + int dtype_size; } PrecisionTensor; +#else +typedef struct { + float* data; + int64_t shape[PUF_MAX_DIMS]; + int dtype_size; +} PrecisionTensor; +#define PRECISION_SIZE ((int)sizeof(float)) #endif #endif // PUFFERLIB_TENSOR_H diff --git a/src/vecenv.h b/src/vecenv.h index 6464de89f8..71998f6e0c 100644 --- a/src/vecenv.h +++ b/src/vecenv.h @@ -1,13 +1,13 @@ -// vecenv.h - Static env binding: types + implementation -// Types/declarations always available (for pufferlib.cu). -// Implementations compiled only when OBS_SIZE is defined (by binding.c). - #pragma once #include #include #include #include +#ifdef __APPLE__ +#include +#include +#endif #ifdef __cplusplus extern "C" { @@ -15,6 +15,12 @@ extern "C" { #include "tensor.h" +#define FLOAT 1 +#define INT 2 +#define UNSIGNED_CHAR 3 +#define DOUBLE 4 +#define CHAR 5 + // Dict types typedef struct { const char* key; @@ -46,7 +52,6 @@ static inline DictItem* dict_get_unsafe(Dict* dict, const char* key) { static inline DictItem* dict_get(Dict* dict, const char* key) { DictItem* item = dict_get_unsafe(dict, key); - if (item == NULL) printf("dict_get failed to find key: %s\n", key); assert(item != NULL); return item; } @@ -63,8 +68,9 @@ static inline void dict_set(Dict* dict, const char* key, double value) { dict->size++; } -// Forward declare CUDA stream type +#ifndef CUDA_STREAM_T_DEFINED typedef struct CUstream_st* cudaStream_t; +#endif // Threading state typedef struct StaticThreading StaticThreading; @@ -119,6 +125,7 @@ void static_vec_read_profile(StaticVec* vec, float out[NUM_EVAL_PROF]); // Env info int get_obs_size(void); +int get_obs_type(void); // legacy compat (Metal path) int get_num_atns(void); int* get_act_sizes(void); int get_num_act_sizes(void); @@ -146,12 +153,23 @@ int my_put(void* env, Dict* kwargs); #ifdef OBS_SIZE +#ifndef OBS_TENSOR_T + #if OBS_TYPE == FLOAT + #define OBS_TENSOR_T FloatTensor + #elif OBS_TYPE == UNSIGNED_CHAR + #define OBS_TENSOR_T ByteTensor + #elif OBS_TYPE == INT + #define OBS_TENSOR_T IntTensor + #else + #define OBS_TENSOR_T FloatTensor + #endif +#endif + static inline size_t obs_element_size(void) { OBS_TENSOR_T t; return sizeof(*t.data); } -// Usually near the top, after any #includes #define _STRINGIFY(x) #x #define STRINGIFY(x) _STRINGIFY(x) const char dtype_symbol[] = STRINGIFY(OBS_TENSOR_T); @@ -194,12 +212,17 @@ void my_log(Log* log, Dict* out); struct StaticThreading { - atomic_int* buffer_states; atomic_int shutdown; int num_threads; int num_buffers; pthread_t* threads; float* accum; // [num_buffers * NUM_EVAL_PROF] per-buffer timing in ms +#ifdef __APPLE__ + dispatch_semaphore_t* buf_ready; // main signals -> worker wakes + dispatch_semaphore_t* buf_done; // worker signals -> main wakes +#else + atomic_int* buffer_states; // spin-wait fallback for Linux/CUDA +#endif }; typedef struct StaticOMPArg { @@ -226,23 +249,32 @@ static void* static_omp_threadmanager(void* arg) { thread_init(ctx, buf); } +#ifdef __APPLE__ + // pin rollout threads to P-cores for deterministic scheduling + pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); +#endif + int agents_per_buffer = vec->agents_per_buffer; int agent_start = buf * agents_per_buffer; int env_start = vec->buffer_env_starts[buf]; int env_count = vec->buffer_env_counts[buf]; - atomic_int* buffer_states = threading->buffer_states; int num_workers = threading->num_threads / vec->buffers; if (num_workers < 1) num_workers = 1; Env* envs = (Env*)vec->envs; - printf("Num workers: %d\n", num_workers); while (true) { +#ifdef __APPLE__ + dispatch_semaphore_wait(threading->buf_ready[buf], DISPATCH_TIME_FOREVER); + if (atomic_load(&threading->shutdown)) return NULL; +#else + atomic_int* buffer_states = threading->buffer_states; while (atomic_load(&buffer_states[buf]) != OMP_RUNNING) { if (atomic_load(&threading->shutdown)) { return NULL; } } +#endif cudaStream_t stream = vec->streams[buf]; float* my_accum = &threading->accum[buf * NUM_EVAL_PROF]; @@ -288,26 +320,42 @@ static void* static_omp_threadmanager(void* arg) { cudaMemcpyHostToDevice, stream); } cudaStreamSynchronize(stream); +#ifdef __APPLE__ + dispatch_semaphore_signal(threading->buf_done[buf]); +#else atomic_store(&buffer_states[buf], OMP_WAITING); +#endif } } void static_vec_omp_step(StaticVec* vec) { StaticThreading* threading = vec->threading; - for (int buf = 0; buf < vec->buffers; buf++) { +#ifdef __APPLE__ + for (int buf = 0; buf < vec->buffers; buf++) + dispatch_semaphore_signal(threading->buf_ready[buf]); + for (int buf = 0; buf < vec->buffers; buf++) + dispatch_semaphore_wait(threading->buf_done[buf], DISPATCH_TIME_FOREVER); +#else + for (int buf = 0; buf < vec->buffers; buf++) atomic_store(&threading->buffer_states[buf], OMP_RUNNING); - } - for (int buf = 0; buf < vec->buffers; buf++) { + for (int buf = 0; buf < vec->buffers; buf++) while (atomic_load(&threading->buffer_states[buf]) != OMP_WAITING) {} - } +#endif } void static_vec_seq_step(StaticVec* vec) { StaticThreading* threading = vec->threading; +#ifdef __APPLE__ + for (int buf = 0; buf < vec->buffers; buf++) { + dispatch_semaphore_signal(threading->buf_ready[buf]); + dispatch_semaphore_wait(threading->buf_done[buf], DISPATCH_TIME_FOREVER); + } +#else for (int buf = 0; buf < vec->buffers; buf++) { atomic_store(&threading->buffer_states[buf], OMP_RUNNING); while (atomic_load(&threading->buffer_states[buf]) != OMP_WAITING) {} } +#endif } // Optional: Initialize all envs at once (for shared state, variable agents per env, etc.) @@ -461,9 +509,18 @@ void create_static_threads(StaticVec* vec, int num_threads, int horizon, vec->threading = (StaticThreading*)calloc(1, sizeof(StaticThreading)); vec->threading->num_threads = num_threads; vec->threading->num_buffers = vec->buffers; - vec->threading->buffer_states = (atomic_int*)calloc(vec->buffers, sizeof(atomic_int)); vec->threading->threads = (pthread_t*)calloc(vec->buffers, sizeof(pthread_t)); vec->threading->accum = (float*)calloc(vec->buffers * NUM_EVAL_PROF, sizeof(float)); +#ifdef __APPLE__ + vec->threading->buf_ready = (dispatch_semaphore_t*)calloc(vec->buffers, sizeof(dispatch_semaphore_t)); + vec->threading->buf_done = (dispatch_semaphore_t*)calloc(vec->buffers, sizeof(dispatch_semaphore_t)); + for (int i = 0; i < vec->buffers; i++) { + vec->threading->buf_ready[i] = dispatch_semaphore_create(0); + vec->threading->buf_done[i] = dispatch_semaphore_create(0); + } +#else + vec->threading->buffer_states = (atomic_int*)calloc(vec->buffers, sizeof(atomic_int)); +#endif // Streams are now created by pufferlib.cu (PyTorch-managed streams) // Do NOT create streams here - they've already been set up @@ -485,6 +542,11 @@ void static_vec_close(StaticVec* vec) { if (vec->threading != NULL) { atomic_store(&vec->threading->shutdown, 1); +#ifdef __APPLE__ + // Wake all waiting workers so they can check shutdown flag and exit. + for (int i = 0; i < vec->buffers; i++) + dispatch_semaphore_signal(vec->threading->buf_ready[i]); +#endif for (int i = 0; i < vec->buffers; i++) { pthread_join(vec->threading->threads[i], NULL); } @@ -498,7 +560,16 @@ void static_vec_close(StaticVec* vec) { my_vec_close(envs); free(vec->envs); if (vec->threading != NULL) { +#ifdef __APPLE__ + for (int i = 0; i < vec->buffers; i++) { + dispatch_release(vec->threading->buf_ready[i]); + dispatch_release(vec->threading->buf_done[i]); + } + free(vec->threading->buf_ready); + free(vec->threading->buf_done); +#else free(vec->threading->buffer_states); +#endif free(vec->threading->threads); free(vec->threading->accum); free(vec->threading); @@ -596,6 +667,18 @@ void static_vec_render(StaticVec* vec, int env_id) { } int get_obs_size(void) { return OBS_SIZE; } +#ifdef OBS_TYPE +int get_obs_type(void) { return OBS_TYPE; } +#else +int get_obs_type(void) { + if (strcmp(dtype_symbol, "FloatTensor") == 0) return FLOAT; + if (strcmp(dtype_symbol, "ByteTensor") == 0) return UNSIGNED_CHAR; + if (strcmp(dtype_symbol, "IntTensor") == 0) return INT; + if (strcmp(dtype_symbol, "LongTensor") == 0) return INT; + assert(false && "Unsupported observation tensor type"); + return FLOAT; +} +#endif int get_num_atns(void) { return NUM_ATNS; } static int _act_sizes[] = ACT_SIZES; int* get_act_sizes(void) { return _act_sizes; } From 8949be71a7d0bcb66be056d75501281bebb90e8e Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Sun, 19 Apr 2026 14:10:03 +0300 Subject: [PATCH 2/6] drop dead metal muon knobs --- config/default.ini | 2 -- src/metal_bindings.mm | 5 ----- src/metal_kernels.mm | 26 ++++++++++---------------- src/metal_pufferlib.mm | 5 +---- src/metal_shader_src.h | 4 +--- src/puf_types.h | 2 -- 6 files changed, 12 insertions(+), 32 deletions(-) diff --git a/config/default.ini b/config/default.ini index 8fc61c21dc..ed8c237c21 100644 --- a/config/default.ini +++ b/config/default.ini @@ -54,13 +54,11 @@ vf_clip_coef = 0.2 max_grad_norm = 1.5 ent_coef = 0.001 beta1 = 0.95 -weight_decay = 0.0 beta2 = 0.999 eps = 1e-12 overlap = 0 cpu_inference = 0 train_fp16 = 0 -ns_iters = 5 minibatch_size = 8192 horizon = 64 vtrace_rho_clip = 1.0 diff --git a/src/metal_bindings.mm b/src/metal_bindings.mm index 94f867b8c6..1cb52f74f5 100644 --- a/src/metal_bindings.mm +++ b/src/metal_bindings.mm @@ -406,7 +406,6 @@ static void vec_close(VecEnv& ve) { hypers.min_lr_ratio = get_config(train_kwargs, "min_lr_ratio"); hypers.anneal_lr = get_config(train_kwargs, "anneal_lr"); hypers.beta1 = get_config(train_kwargs, "beta1"); - hypers.weight_decay = get_config(train_kwargs, "weight_decay"); hypers.minibatch_size = get_config(train_kwargs, "minibatch_size"); hypers.replay_ratio = get_config(train_kwargs, "replay_ratio"); hypers.total_timesteps = get_config(train_kwargs, "total_timesteps"); @@ -435,8 +434,6 @@ static void vec_close(VecEnv& ve) { hypers.train_fp16 = (train_kwargs.contains("train_fp16") && get_config(train_kwargs, "train_fp16") > 0) || (args.contains("train_fp16") && get_config(args, "train_fp16") > 0); - hypers.ns_iters = train_kwargs.contains("ns_iters") ? (int)get_config(train_kwargs, "ns_iters") - : args.contains("ns_iters") ? (int)get_config(args, "ns_iters") : 5; hypers.gpu_id = args.contains("gpu_id") ? (int)get_config(args, "gpu_id") : 0; mtl_enable_gpu_timing(hypers.profile); @@ -497,7 +494,6 @@ static void vec_close(VecEnv& ve) { .def_readwrite("min_lr_ratio", &HypersT::min_lr_ratio) .def_readwrite("anneal_lr", &HypersT::anneal_lr) .def_readwrite("beta1", &HypersT::beta1) - .def_readwrite("weight_decay", &HypersT::weight_decay) .def_readwrite("total_timesteps", &HypersT::total_timesteps) .def_readwrite("max_grad_norm", &HypersT::max_grad_norm) .def_readwrite("clip_coef", &HypersT::clip_coef) @@ -515,7 +511,6 @@ static void vec_close(VecEnv& ve) { .def_readwrite("overlap", &HypersT::overlap) .def_readwrite("cpu_inference", &HypersT::cpu_inference) .def_readwrite("train_fp16", &HypersT::train_fp16) - .def_readwrite("ns_iters", &HypersT::ns_iters) .def_readwrite("gpu_id", &HypersT::gpu_id); py::class_(m, "FloatTensor") diff --git a/src/metal_kernels.mm b/src/metal_kernels.mm index 47ae1d6598..37a6d2c2bb 100644 --- a/src/metal_kernels.mm +++ b/src/metal_kernels.mm @@ -976,9 +976,8 @@ void mtl_select_copy(RolloutBuf &rollouts, TrainGraph &graph, } void mtl_muon_weight_update(float *weights, const float *updates, - const float *lr_ptr, float weight_decay, - float scale, int count, - cudaStream_t stream) { + const float *lr_ptr, float scale, int count, + cudaStream_t stream) { MetalStream *ms = mtl_resolve_stream(stream); ms->compute_encoder(); auto pso = mtl_pipeline("muon_weight_update_kernel"); @@ -988,13 +987,14 @@ void mtl_muon_weight_update(float *weights, const float *updates, mtl_set_ptr(ms, lr_ptr, 2); struct { int count; - float weight_decay; float scale; - } params = {count, weight_decay, scale}; + } params = {count, scale}; mtl_set_params(ms, params, 3); mtl_dispatch_1d(ms, pso, count); } +static constexpr int kMuonNsIters = 5; + // ============================================================================ // Kaiming uniform init (CPU-side, matches CUDA puf_kaiming_init) // @@ -1022,11 +1022,8 @@ void puf_kaiming_init(PufTensor &dst, float gain, uint64_t seed, } void muon_init(Muon *m, Allocator *param_alloc, FloatTensor weight_buffer, - double lr_val, double momentum, double weight_decay, - int ns_iters, Allocator &alloc) { + double lr_val, double momentum, Allocator &alloc) { m->momentum = momentum; - m->weight_decay = weight_decay; - m->ns_iters = (ns_iters > 0 && ns_iters <= 5) ? ns_iters : 5; m->lr_val_init = (float)lr_val; m->lr_ptr = nullptr; m->lr_derived_ptr = nullptr; @@ -1133,8 +1130,8 @@ void muon_step(Muon *m, cudaStream_t stream) { mtl_barrier(ms); // Newton-Schulz iterations - for (int i = 0; i < m->ns_iters; ++i) { - int ci = i * 4 / (m->ns_iters - 1 + (m->ns_iters == 1)); + for (int i = 0; i < kMuonNsIters; ++i) { + int ci = i * 4 / (kMuonNsIters - 1 + (kMuonNsIters == 1)); float a = (float)ns_coeffs[ci][0], b = (float)ns_coeffs[ci][1], c = (float)ns_coeffs[ci][2]; PufTensor &src = (i % 2 == 0) ? x : tmp; @@ -1151,7 +1148,7 @@ void muon_step(Muon *m, cudaStream_t stream) { mtl_barrier(ms); } - PufTensor &result_precision = (m->ns_iters % 2 == 0) ? x : tmp; + PufTensor &result_precision = (kMuonNsIters % 2 == 0) ? x : tmp; // Scale matches CUDA models.cu:1233: sqrt(max(1.0, R/C)). // For tall matrices (R>C), scale up by sqrt(R/C) to compensate for @@ -1189,10 +1186,7 @@ void muon_step(Muon *m, cudaStream_t stream) { offset += puf_numel(e.shape); } - // Apply weight update: w = w * (1 - lr*wd) - lr * up - // Scale is already baked into up_puf during NS loop, so pass scale=1.0 here. - mtl_muon_weight_update(m->wb_puf.data, m->up_puf.data, - m->lr_ptr, (float)m->weight_decay, 1.0f, + mtl_muon_weight_update(m->wb_puf.data, m->up_puf.data, m->lr_ptr, 1.0f, (int)puf_numel(m->wb_puf.shape), stream); mtl_barrier(ms); } diff --git a/src/metal_pufferlib.mm b/src/metal_pufferlib.mm index 726796c315..7f3f32215a 100644 --- a/src/metal_pufferlib.mm +++ b/src/metal_pufferlib.mm @@ -85,7 +85,6 @@ static inline void cpu_cast_to_f32(float* dst, const T* src, int count) { bool anneal_lr; // Optimizer (Muon only — Adam removed) float beta1; - float weight_decay; // Training int minibatch_size; float replay_ratio; @@ -111,7 +110,6 @@ static inline void cpu_cast_to_f32(float* dst, const T* src, int count) { bool overlap; // async training overlap: train on separate GPU queue bool cpu_inference; // CPU forward pass during rollout (no GPU sync) bool train_fp16; // fp16 activations/grads during training (rollout stays fp32) - int ns_iters; // Newton-Schulz iterations in muon optimizer (1-5, default 5) // Single GPU (Metal has no multi-GPU, but kept for upstream compat) int gpu_id; // Threading @@ -1140,8 +1138,7 @@ static void sync_pending_train(PuffeRL& pufferl) { // Optimizer init (register buffers with shared allocator) muon_init(pufferl->muon, &fp32_params, - pufferl->param_fp32_puf, lr, beta1, (double)hypers.weight_decay, - hypers.ns_iters, alloc); + pufferl->param_fp32_puf, lr, beta1, alloc); // Single allocation for all registered buffers alloc.create(); diff --git a/src/metal_shader_src.h b/src/metal_shader_src.h index 561c60ba17..f9931eb313 100644 --- a/src/metal_shader_src.h +++ b/src/metal_shader_src.h @@ -1313,7 +1313,6 @@ kernel void compute_lr_scalars_kernel( struct MuonParams { int n; - float weight_decay; float scale; }; @@ -1327,8 +1326,7 @@ kernel void muon_weight_update_kernel( ) { if ((int)idx >= p.n) return; float lr = *lr_ptr; - float wd_scale = 1.0f - lr * p.weight_decay; - wb[idx] = wb[idx] * wd_scale - lr * p.scale * up[idx]; + wb[idx] = wb[idx] - lr * p.scale * up[idx]; } struct TransposeParams { diff --git a/src/puf_types.h b/src/puf_types.h index 92aa6d81a0..de807a536f 100644 --- a/src/puf_types.h +++ b/src/puf_types.h @@ -538,9 +538,7 @@ inline PufTensor ns_slice(PufTensor &buf, int64_t rows, int64_t cols) { struct Muon { double momentum; - double weight_decay; float lr_val_init; - int ns_iters; float *lr_ptr; float *lr_derived_ptr; FloatTensor lr_puf, lr_derived_puf; From 57f6a60307a0143da73741a29da57734acbfb1eb Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Fri, 1 May 2026 18:08:21 +0300 Subject: [PATCH 3/6] pufferl: cap minibatch_size to total_agents during eval eval forces horizon=1 so the per-eval batch is total_agents*1. if the trained model used minibatch > total_agents (e.g. trained at agents=512 horizon=64 minibatch=8192, evaluated at agents=512 horizon=1 = 512), the backend's divisibility check trips. cap minibatch to the eval batch so puffer eval works for any sane training config. --- pufferlib/pufferl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 8fc0c03a89..b567c3decd 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -403,6 +403,11 @@ def eval(env_name, args=None, load_path=None): args = args or load_config(env_name) args['reset_state'] = False args['train']['horizon'] = 1 + # Eval batches are total_agents*1, so cap minibatch to that to satisfy + # the divisibility check. Training-time minibatch may be larger. + eval_batch = args['vec']['total_agents'] + if args['train']['minibatch_size'] > eval_batch: + args['train']['minibatch_size'] = eval_batch backend = _resolve_backend(args) pufferl = backend.create_pufferl(args) From a728372426161fcb4da9734d9b68840da01db65c Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Fri, 1 May 2026 18:08:21 +0300 Subject: [PATCH 4/6] pufferl: write trial config stub at trial start for post-mortem trial logs were only written at the end of training. when a trial hung or crashed, the config that triggered it was lost. observed during a wordle sweep that stalled for ~4h with zero trace of which protein suggestion deadlocked the worker. now write {config, metrics={}, status='pending'} at trial start, then overwrite with {config, metrics, status='completed'} at end. hung trials leave the pending stub behind so we can grep logs// for status: pending and see exactly what config killed it. uses default=str on the json dump to handle bytes (nccl_id is b'' before being popped a few lines later) and any numpy scalars that protein may leak into args. without it, the stub-write itself crashes. --- pufferlib/pufferl.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index b567c3decd..9c5e9a3702 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -209,6 +209,14 @@ def _train(env_name, args, sweep_obj=None, result_queue=None, verbose=False): log_dir = os.path.join(args['log_dir'], args['env_name']) os.makedirs(log_dir, exist_ok=True) + # Write config-only stub at trial start so hung/crashed trials leave a + # post-mortem trace. Overwritten with full {config, metrics} at trial end. + # default=str handles bytes (nccl_id is b'' before being popped below) and + # any numpy scalars that protein may have leaked into args. + log_path = os.path.join(log_dir, run_id + '.json') + with open(log_path, 'w') as f: + json.dump({**args, 'metrics': {}, 'status': 'pending'}, f, default=str) + try: pufferl = backend.create_pufferl(args) except RuntimeError as e: @@ -295,11 +303,9 @@ def _train(env_name, args, sweep_obj=None, result_queue=None, verbose=False): for k in metrics: metrics[k][-1] = all_logs[-1][k] - # Save own log: config + downsampled results - log_dir = os.path.join(args['log_dir'], args['env_name']) - os.makedirs(log_dir, exist_ok=True) - with open(os.path.join(log_dir, run_id + '.json'), 'w') as f: - json.dump({**args, 'metrics': metrics}, f) + # Save own log: config + downsampled results (overwrites pending stub) + with open(log_path, 'w') as f: + json.dump({**args, 'metrics': metrics, 'status': 'completed'}, f, default=str) if args['wandb']: if sweep_obj is None and model_path: # Don't spam uploads during sweeps From 67da3acceccd718ff8306597ee5e1ccf8fbff33b Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Fri, 1 May 2026 18:25:55 +0300 Subject: [PATCH 5/6] sync metal backend with latest fixes from inferno branch Upstream-4.0 sync, Metal perf wins (encoder alias, fp16 select_copy gate, action buffer f32, col2im atomic-free), Muon registration fixes, const_ring overflow handling, Metal sync timeout bump, unified bindings API, and a runtime ENV_NAME mismatch guard so the static lib and bindings can never disagree about which environment they were compiled for. New test tests/test_metal_const_ring.mm locks in the const_ring reservation helper, the int-config rounding behavior, and the divisibility validator. --- src/bindings.cu | 22 +++++- src/bindings_cpu.cpp | 17 ++++- src/metal_bindings.mm | 130 +++++++++++++++++++++++++++------ src/metal_kernels.mm | 16 ++-- src/metal_platform.h | 83 ++++++++++++++++++++- src/metal_platform.mm | 36 ++++++--- src/metal_pufferlib.mm | 5 +- src/metal_shader_src.h | 70 +++++++++++------- src/pufferlib.cu | 2 +- src/vecenv.h | 8 ++ tests/test_metal_const_ring.mm | 34 +++++++++ 11 files changed, 344 insertions(+), 79 deletions(-) create mode 100644 tests/test_metal_const_ring.mm diff --git a/src/bindings.cu b/src/bindings.cu index 4469cb512c..9f463ed7f2 100644 --- a/src/bindings.cu +++ b/src/bindings.cu @@ -2,6 +2,9 @@ #include #include +#include +#include +#include #include "pufferlib.cu" #define _PUFFER_STRINGIFY(x) #x @@ -9,6 +12,16 @@ namespace py = pybind11; +static void assert_static_env_name_matches(void) { + const char* binding_env_name = PUFFER_STRINGIFY(ENV_NAME); + const char* static_env_name = get_static_env_name(); + if (strcmp(binding_env_name, static_env_name) != 0) { + throw std::runtime_error( + std::string("compiled _C env mismatch: binding env_name=") + + binding_env_name + ", static_env_name=" + static_env_name); + } +} + // Wrapper functions for Python bindings pybind11::dict puf_log(pybind11::object pufferl_obj) { auto& pufferl = pufferl_obj.cast(); @@ -106,7 +119,7 @@ pybind11::dict puf_eval_log(pybind11::object pufferl_obj) { pufferl.last_log_step = pufferl.global_step; pybind11::dict env_dict; - Dict* env_out = create_dict(32); + Dict* env_out = create_dict(64); static_vec_eval_log(pufferl.vec, env_out); for (int i = 0; i < env_out->size; i++) { env_dict[env_out->items[i].key] = env_out->items[i].value; @@ -248,7 +261,7 @@ Dict* py_dict_to_c_dict(py::dict py_dict) { } // ============================================================================ -// Python-facing VecEnv: wraps StaticVec for use from python_pufferl.py. +// Python-facing VecEnv wrapper. // After vec_step(), GPU buffers are current — Python wraps them zero-copy // with torch.from_blob(ptr, shape, dtype, device='cuda'). // ============================================================================ @@ -318,7 +331,7 @@ void cpu_vec_step_py(VecEnv& ve, long long actions_ptr) { } py::dict vec_log(VecEnv& ve) { - Dict* out = create_dict(32); + Dict* out = create_dict(64); static_vec_log(ve.vec, out); py::dict result; for (int i = 0; i < out->size; i++) { @@ -410,6 +423,8 @@ std::unique_ptr create_pufferl(py::dict args) { } PYBIND11_MODULE(_C, m) { + assert_static_env_name_matches(); + // Multi-GPU: generate NCCL unique ID (call on rank 0, pass bytes to all ranks) m.def("get_nccl_id", []() { ncclUniqueId id; @@ -454,6 +469,7 @@ PYBIND11_MODULE(_C, m) { m.attr("precision_bytes") = (int)sizeof(precision_t); m.attr("env_name") = PUFFER_STRINGIFY(ENV_NAME); + m.attr("static_env_name") = get_static_env_name(); m.attr("gpu") = 1; // Core functions diff --git a/src/bindings_cpu.cpp b/src/bindings_cpu.cpp index 5ba4dc81e5..3fbb69c2d4 100644 --- a/src/bindings_cpu.cpp +++ b/src/bindings_cpu.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include #define _PUFFER_STRINGIFY(x) #x #define PUFFER_STRINGIFY(x) _PUFFER_STRINGIFY(x) @@ -12,6 +14,16 @@ namespace py = pybind11; +static void assert_static_env_name_matches(void) { + const char* binding_env_name = PUFFER_STRINGIFY(ENV_NAME); + const char* static_env_name = get_static_env_name(); + if (strcmp(binding_env_name, static_env_name) != 0) { + throw std::runtime_error( + std::string("compiled _C env mismatch: binding env_name=") + + binding_env_name + ", static_env_name=" + static_env_name); + } +} + // Stub out CUDA functions that the static lib references (dead code when gpu=0) extern "C" { typedef int cudaError_t; @@ -141,7 +153,7 @@ static void cpu_vec_step_py(VecEnv& ve, long long actions_ptr) { } static py::dict vec_log(VecEnv& ve) { - Dict* out = create_dict(32); + Dict* out = create_dict(64); static_vec_log(ve.vec, out); py::dict result; for (int i = 0; i < out->size; i++) @@ -161,8 +173,11 @@ static void vec_close(VecEnv& ve) { // ============================================================================ PYBIND11_MODULE(_C, m) { + assert_static_env_name_matches(); + m.attr("precision_bytes") = 4; m.attr("env_name") = PUFFER_STRINGIFY(ENV_NAME); + m.attr("static_env_name") = get_static_env_name(); m.attr("gpu") = 0; m.def("puff_advantage_cpu", &py_puff_advantage_cpu); diff --git a/src/metal_bindings.mm b/src/metal_bindings.mm index 1cb52f74f5..893ca3b4a2 100644 --- a/src/metal_bindings.mm +++ b/src/metal_bindings.mm @@ -1,6 +1,7 @@ #include "metal_pufferlib.mm" #include +#include #include #include #include @@ -38,6 +39,16 @@ static double wall_clock() { return std::string(buffer); } +static void assert_static_env_name_matches(void) { + const char* binding_env_name = PUFFER_STRINGIFY(ENV_NAME); + const char* static_env_name = get_static_env_name(); + if (strcmp(binding_env_name, static_env_name) != 0) { + throw std::runtime_error( + std::string("compiled _C env mismatch: binding env_name=") + + binding_env_name + ", static_env_name=" + static_env_name); + } +} + static py::dict get_utilization(int gpu_id) { (void)gpu_id; py::dict result; @@ -223,8 +234,17 @@ static void save_weights(py::object pufferl_obj, const std::string& path) { if (!f) { throw std::runtime_error("Failed to open " + path + " for writing"); } - fwrite(pufferl.alloc_fp32.params.mem, 1, nbytes, f); - fclose(f); + size_t expected = (size_t)nbytes; + size_t written = fwrite(pufferl.alloc_fp32.params.mem, 1, expected, f); + int close_result = fclose(f); + if (written != expected) { + throw std::runtime_error( + "Failed to write " + path + ": expected " + std::to_string(expected) + + " bytes, wrote " + std::to_string(written)); + } + if (close_result != 0) { + throw std::runtime_error("Failed to close " + path + " after writing"); + } } static void load_weights(py::object pufferl_obj, const std::string& path) { @@ -299,17 +319,67 @@ static void py_puff_advantage_cpu( } static double get_config(py::dict& kwargs, const char* key) { - assert(kwargs.contains(key) && "Missing config key"); - return kwargs[key].cast(); + if (!kwargs.contains(key)) { + throw std::invalid_argument(std::string("Missing config key: ") + key); + } + try { + return kwargs[key].cast(); + } catch (const py::cast_error&) { + throw std::invalid_argument(std::string(key) + " must be numeric"); + } +} + +static int get_config_int(py::dict& kwargs, const char* key) { + return mtl_parse_int_config_value(key, get_config(kwargs, key)); +} + +static int get_config_positive_int(py::dict& kwargs, const char* key) { + return mtl_validate_positive_config_value(key, get_config_int(kwargs, key)); } -static Dict* py_dict_to_c_dict(py::dict py_dict) { +static long get_config_positive_long(py::dict& kwargs, const char* key) { + double value = get_config(kwargs, key); + if (!std::isfinite(value)) { + throw std::invalid_argument(std::string(key) + " must be a finite positive integer"); + } + double rounded = std::round(value); + if (rounded <= 0.0) { + throw std::invalid_argument(std::string(key) + " must be a positive integer"); + } + if (rounded > (double)std::numeric_limits::max()) { + throw std::invalid_argument(std::string(key) + " is outside long range"); + } + return (long)rounded; +} + +static uint64_t get_config_uint64(py::dict& kwargs, const char* key) { + double value = get_config(kwargs, key); + if (!std::isfinite(value) || std::trunc(value) != value || value < 0.0) { + throw std::invalid_argument(std::string(key) + " must be a non-negative integer"); + } + if (value > (double)std::numeric_limits::max()) { + throw std::invalid_argument(std::string(key) + " is outside uint64 range"); + } + return (uint64_t)value; +} + +static bool is_python_side_channel_env_key(const char* key) { + return std::strcmp(key, "record_best_replay_path") == 0 || + std::strcmp(key, "play_replay_path") == 0; +} + +static Dict* py_dict_to_c_dict(py::dict py_dict, bool is_env_dict) { Dict* c_dict = create_dict(py_dict.size()); for (auto item : py_dict) { const char* key = PyUnicode_AsUTF8(item.first.ptr()); + if (!key) { + throw std::invalid_argument("Config dict keys must be strings"); + } try { dict_set(c_dict, key, item.second.cast()); } catch (const py::cast_error&) { + if (is_env_dict && is_python_side_channel_env_key(key)) continue; + throw std::invalid_argument(std::string(key) + " must be numeric"); } } return c_dict; @@ -330,10 +400,12 @@ static double get_config(py::dict& kwargs, const char* key) { py::dict vec_kwargs = args["vec"].cast(); py::dict env_kwargs = args["env"].cast(); - int total_agents = (int)get_config(vec_kwargs, "total_agents"); - int num_buffers = (int)get_config(vec_kwargs, "num_buffers"); - Dict* vec_dict = py_dict_to_c_dict(vec_kwargs); - Dict* env_dict = py_dict_to_c_dict(env_kwargs); + int total_agents = get_config_positive_int(vec_kwargs, "total_agents"); + int num_buffers = get_config_positive_int(vec_kwargs, "num_buffers"); + mtl_validate_divisible_config_values( + "total_agents", total_agents, "num_buffers", num_buffers); + Dict* vec_dict = py_dict_to_c_dict(vec_kwargs, false); + Dict* env_dict = py_dict_to_c_dict(env_kwargs, true); auto ve = std::make_unique(); { @@ -394,21 +466,21 @@ static void vec_close(VecEnv& ve) { py::dict policy_kwargs = args["policy"].cast(); HypersT hypers; - hypers.total_agents = get_config(vec_kwargs, "total_agents"); - hypers.num_buffers = get_config(vec_kwargs, "num_buffers"); - hypers.num_threads = get_config(vec_kwargs, "num_threads"); - hypers.horizon = get_config(train_kwargs, "horizon"); - hypers.hidden_size = get_config(policy_kwargs, "hidden_size"); - hypers.num_layers = get_config(policy_kwargs, "num_layers"); - hypers.seed = args.contains("seed") ? (uint64_t)get_config(args, "seed") - : train_kwargs.contains("seed") ? (uint64_t)get_config(train_kwargs, "seed") : 42; + hypers.total_agents = get_config_positive_int(vec_kwargs, "total_agents"); + hypers.num_buffers = get_config_positive_int(vec_kwargs, "num_buffers"); + hypers.num_threads = get_config_positive_int(vec_kwargs, "num_threads"); + hypers.horizon = get_config_positive_int(train_kwargs, "horizon"); + hypers.hidden_size = get_config_positive_int(policy_kwargs, "hidden_size"); + hypers.num_layers = get_config_positive_int(policy_kwargs, "num_layers"); + hypers.seed = args.contains("seed") ? get_config_uint64(args, "seed") + : train_kwargs.contains("seed") ? get_config_uint64(train_kwargs, "seed") : 42; hypers.lr = get_config(train_kwargs, "learning_rate"); hypers.min_lr_ratio = get_config(train_kwargs, "min_lr_ratio"); hypers.anneal_lr = get_config(train_kwargs, "anneal_lr"); hypers.beta1 = get_config(train_kwargs, "beta1"); - hypers.minibatch_size = get_config(train_kwargs, "minibatch_size"); + hypers.minibatch_size = get_config_positive_int(train_kwargs, "minibatch_size"); hypers.replay_ratio = get_config(train_kwargs, "replay_ratio"); - hypers.total_timesteps = get_config(train_kwargs, "total_timesteps"); + hypers.total_timesteps = get_config_positive_long(train_kwargs, "total_timesteps"); hypers.max_grad_norm = get_config(train_kwargs, "max_grad_norm"); hypers.clip_coef = get_config(train_kwargs, "clip_coef"); hypers.vf_clip_coef = get_config(train_kwargs, "vf_clip_coef"); @@ -434,13 +506,24 @@ static void vec_close(VecEnv& ve) { hypers.train_fp16 = (train_kwargs.contains("train_fp16") && get_config(train_kwargs, "train_fp16") > 0) || (args.contains("train_fp16") && get_config(args, "train_fp16") > 0); - hypers.gpu_id = args.contains("gpu_id") ? (int)get_config(args, "gpu_id") : 0; + hypers.gpu_id = args.contains("gpu_id") ? get_config_int(args, "gpu_id") : 0; + mtl_validate_divisible_config_values( + "total_agents", hypers.total_agents, "num_buffers", hypers.num_buffers); + mtl_validate_divisible_config_values( + "minibatch_size", hypers.minibatch_size, "horizon", hypers.horizon); + long long batch_size_long = (long long)hypers.total_agents * (long long)hypers.horizon; + if (batch_size_long > (long long)std::numeric_limits::max()) { + throw std::invalid_argument("total_agents * horizon is outside int range"); + } + int batch_size = (int)batch_size_long; + mtl_validate_divisible_config_values( + "total_agents * horizon", batch_size, "minibatch_size", hypers.minibatch_size); mtl_enable_gpu_timing(hypers.profile); std::string env_name = args["env_name"].cast(); - Dict* vec_dict = py_dict_to_c_dict(vec_kwargs); - Dict* env_dict = py_dict_to_c_dict(env_kwargs); + Dict* vec_dict = py_dict_to_c_dict(vec_kwargs, false); + Dict* env_dict = py_dict_to_c_dict(env_kwargs, true); std::unique_ptr pufferl; { @@ -452,6 +535,8 @@ static void vec_close(VecEnv& ve) { } PYBIND11_MODULE(_C, m) { + assert_static_env_name_matches(); + m.def("get_nccl_id", []() -> py::bytes { throw std::runtime_error("Metal backend does not support multi-GPU"); }); @@ -459,6 +544,7 @@ static void vec_close(VecEnv& ve) { m.attr("precision_bytes") = 4; m.attr("env_name") = PUFFER_STRINGIFY(ENV_NAME); + m.attr("static_env_name") = get_static_env_name(); m.attr("gpu") = 0; m.def("log", &puf_log); diff --git a/src/metal_kernels.mm b/src/metal_kernels.mm index 37a6d2c2bb..d91fbd59e9 100644 --- a/src/metal_kernels.mm +++ b/src/metal_kernels.mm @@ -931,7 +931,7 @@ void prio_sample(int minibatch_segments, int total_agents, void mtl_select_copy(RolloutBuf &rollouts, TrainGraph &graph, const int64_t *idx, const float *advantages, const float *mb_prio, int mb_segs, - void *fp16_obs_out, cudaStream_t stream) { + void *fp16_obs_out, bool train_fp16, cudaStream_t stream) { int obs_row_bytes = (int)(puf_numel(rollouts.observations.shape) / rollouts.observations.shape[0]) * (int)sizeof(float); @@ -964,8 +964,8 @@ void mtl_select_copy(RolloutBuf &rollouts, TrainGraph &graph, mtl_set_ptr(ms, mb_prio, 13); struct { - int obs_row_bytes, act_row_bytes, lp_row_bytes, horizon; - } params = {obs_row_bytes, act_row_bytes, lp_row_bytes, horizon}; + int obs_row_bytes, act_row_bytes, lp_row_bytes, horizon, train_fp16; + } params = {obs_row_bytes, act_row_bytes, lp_row_bytes, horizon, train_fp16 ? 1 : 0}; mtl_set_params(ms, params, 14); mtl_set_ptr(ms, fp16_obs_out, 15); @@ -1196,10 +1196,10 @@ static PrecisionTensor encoder_forward(void *w, void *activations, EncoderWeights *ew = (EncoderWeights *)w; EncoderActivations *a = (EncoderActivations *)activations; MetalStream *ms = mtl_resolve_stream(stream); - if (a->saved_input.data) { - PufTensor dst = to_puf(a->saved_input), src = to_puf(input); - puf_copy(dst, src, stream); - } + /* Alias input into saved_input for the backward weight-grad GEMM. + * mb_obs persists untouched between forward and backward of the same + * minibatch, so a pointer alias is equivalent to a copy. */ + a->saved_input.data = input.data; PufTensor inp = to_puf(input), wt = to_puf(ew->weight), out = to_puf(a->out); puf_mm(inp, wt, out, stream); @@ -1241,7 +1241,7 @@ static void encoder_reg_train(void *w, void *activations, .wgrad = {.shape = {ew->out_dim, ew->in_dim}, .dtype_size = precision}, }; alloc_register(acts, &a->out); - alloc_register(acts, &a->saved_input); + /* saved_input is aliased to the encoder input in encoder_forward; no allocation. */ alloc_register(grads, &a->wgrad); } diff --git a/src/metal_platform.h b/src/metal_platform.h index 09894b60d7..5e36d67368 100644 --- a/src/metal_platform.h +++ b/src/metal_platform.h @@ -10,6 +10,13 @@ #include "puf_types.h" #include #include +#include +#include +#include +#include +#include +#include +#include #include struct MetalStream { @@ -75,7 +82,11 @@ static inline MetalStream *mtl_resolve_stream(cudaStream_t s) { static inline void mtl_ensure_stream_synced(cudaStream_t s) { MetalStream *ms = mtl_resolve_stream(s); - if (ms->enc_active || ms->pending_work) ms->sync(); + if (ms->flushed) { + ms->wait_completed(); + } else if (ms->enc_active || ms->pending_work) { + ms->sync(); + } } void *mtl_create_stream(); @@ -123,6 +134,62 @@ inline void mtl_set_tensor(MetalStream *ms, const PufTensor &t, id mtl_buffer_for_ptr(const void *ptr, NSUInteger *out_offset); +inline bool mtl_const_ring_reserve_range(NSUInteger current_offset, + NSUInteger raw_size, + NSUInteger *next_offset) { + if (raw_size > MTL_CONST_RING_SIZE || + current_offset > MTL_CONST_RING_SIZE) { + return false; + } + + NSUInteger aligned = (raw_size + 15) & ~(NSUInteger)15; + if (aligned > MTL_CONST_RING_SIZE - current_offset) { + return false; + } + + *next_offset = current_offset + aligned; + return true; +} + +inline int mtl_parse_int_config_value(const char *key, double value) { + if (!std::isfinite(value)) { + throw std::invalid_argument(std::string(key) + " must be a finite integer"); + } + double rounded = std::round(value); + if (rounded < (double)std::numeric_limits::min() || + rounded > (double)std::numeric_limits::max()) { + throw std::invalid_argument(std::string(key) + " is outside int range"); + } + return (int)rounded; +} + +inline int mtl_validate_nonzero_config_value(const char *key, int value) { + if (value == 0) { + throw std::invalid_argument(std::string(key) + " must be nonzero"); + } + return value; +} + +inline int mtl_validate_positive_config_value(const char *key, int value) { + if (value <= 0) { + throw std::invalid_argument(std::string(key) + " must be positive"); + } + return value; +} + +inline void mtl_validate_divisible_config_values(const char *numerator_key, + int numerator, + const char *denominator_key, + int denominator) { + mtl_validate_nonzero_config_value(denominator_key, denominator); + if (numerator % denominator != 0) { + throw std::invalid_argument(std::string(numerator_key) + " must be divisible by " + + denominator_key + ": " + + std::to_string(numerator) + " % " + + std::to_string(denominator) + " != 0"); + } +} + inline void mtl_set_tensor(MetalStream *ms, const FloatTensor &t, uint32_t index) { NSUInteger offset; @@ -135,14 +202,22 @@ inline void mtl_set_tensor(MetalStream *ms, const FloatTensor &t, } template inline void mtl_set_params(MetalStream *ms, const T ¶ms, uint32_t index) { - NSUInteger aligned = (sizeof(T) + 15) & ~15; - assert(ms->const_ring_offset + aligned <= MTL_CONST_RING_SIZE); + NSUInteger next_offset = 0; + if (!mtl_const_ring_reserve_range(ms->const_ring_offset, sizeof(T), + &next_offset)) { + std::fprintf(stderr, + "mtl_set_params: constant ring overflow: offset=%llu size=%llu capacity=%llu\n", + (unsigned long long)ms->const_ring_offset, + (unsigned long long)sizeof(T), + (unsigned long long)MTL_CONST_RING_SIZE); + std::abort(); + } memcpy((char *)[ms->const_ring contents] + ms->const_ring_offset, ¶ms, sizeof(T)); uint64_t addr = ms->const_ring.gpuAddress + ms->const_ring_offset; [ms->arg_table setAddress:addr atIndex:index]; ms->bound_addresses[index] = addr; - ms->const_ring_offset += aligned; + ms->const_ring_offset = next_offset; } inline void mtl_dispatch_1d(MetalStream *ms, id pso, diff --git a/src/metal_platform.mm b/src/metal_platform.mm index 9b0ac710d6..5125a1e7f2 100644 --- a/src/metal_platform.mm +++ b/src/metal_platform.mm @@ -279,7 +279,12 @@ kernel void tensor_ops_gemm_tn_f16( static bool g_gpu_timing_enabled = false; static double g_gpu_exec_ns = 0.0; static double g_sched_wait_ns = 0.0; -static constexpr NSUInteger kMetalSyncTimeoutMs = 300000; // 5 min — worst-case sweep configs push 3K+ minibatches per epoch +static constexpr NSUInteger kMetalSyncTimeoutMs = 300000; + +static void mtl_abort_sync_timeout(const char *where) { + std::fprintf(stderr, "Metal sync timeout in %s\n", where); + std::abort(); +} static double mach_to_ns(uint64_t ticks) { if (g_timebase.denom == 0) mach_timebase_info(&g_timebase); @@ -311,7 +316,7 @@ static double mach_to_ns(uint64_t ticks) { [q signalEvent:sync_event value:val]; BOOL signaled = [sync_event waitUntilSignaledValue:val timeoutMS:kMetalSyncTimeoutMs]; if (!signaled) { - assert(false && "Metal sync timeout in MetalStream::sync"); + mtl_abort_sync_timeout("MetalStream::sync"); } if (gpu_start > 0 && gpu_end > 0) { g_gpu_exec_ns += (gpu_end - gpu_start) * 1e9; @@ -322,7 +327,7 @@ static double mach_to_ns(uint64_t ticks) { [q signalEvent:sync_event value:val]; BOOL signaled = [sync_event waitUntilSignaledValue:val timeoutMS:kMetalSyncTimeoutMs]; if (!signaled) { - assert(false && "Metal sync timeout in MetalStream::sync"); + mtl_abort_sync_timeout("MetalStream::sync"); } } uint64_t t1 = mach_absolute_time(); @@ -357,7 +362,19 @@ static double mach_to_ns(uint64_t ticks) { id bufs[] = { cmd }; id q = (this == &ctx->train_stream) ? ctx->train_queue : ctx->queue; + uint64_t t0 = mach_absolute_time(); + uint64_t val = ++sync_event_value; [q commit:bufs count:1]; + [q signalEvent:sync_event value:val]; + BOOL signaled = [sync_event waitUntilSignaledValue:val timeoutMS:kMetalSyncTimeoutMs]; + if (!signaled) { + mtl_abort_sync_timeout("MetalStream::commit_chunk"); + } + uint64_t t1 = mach_absolute_time(); + g_sync_count++; + g_sync_total_ns += mach_to_ns(t1 - t0); + pending_work = false; + flushed = false; cmd = [ctx->device newCommandBuffer]; assert(cmd && "Failed to allocate Metal command buffer for chunked training"); @@ -369,7 +386,7 @@ static double mach_to_ns(uint64_t ticks) { uint64_t t0 = mach_absolute_time(); BOOL signaled = [sync_event waitUntilSignaledValue:flush_event_val timeoutMS:kMetalSyncTimeoutMs]; if (!signaled) { - assert(false && "Metal sync timeout in MetalStream::wait_completed"); + mtl_abort_sync_timeout("MetalStream::wait_completed"); } uint64_t t1 = mach_absolute_time(); g_sync_count++; @@ -1246,16 +1263,13 @@ int cudaFreeHost(void *ptr) { int cudaSetDevice(int /*device*/) { return 0; } int cudaDeviceSynchronize(void) { - if (g_ctx.stream.enc_active) - g_ctx.stream.sync(); + mtl_ensure_stream_synced((cudaStream_t)&g_ctx.stream); + mtl_ensure_stream_synced((cudaStream_t)&g_ctx.train_stream); return 0; } -int cudaStreamSynchronize(void * /*stream*/) { - // No-op on Metal. GPU work is already synced inside net_callback_wrapper - // (ensure_gpu_synced under mutex). The vecenv memcpys are also no-ops - // (unified memory). Calling sync() here would race with other buffer - // threads that hold the GPU mutex and have an active encoder. +int cudaStreamSynchronize(void *stream) { + mtl_ensure_stream_synced((cudaStream_t)stream); return 0; } diff --git a/src/metal_pufferlib.mm b/src/metal_pufferlib.mm index 7f3f32215a..5fa3277f78 100644 --- a/src/metal_pufferlib.mm +++ b/src/metal_pufferlib.mm @@ -590,7 +590,7 @@ void train_impl(PuffeRL& pufferl) { pufferl.advantages_puf.data, pufferl.prio_bufs.mb_prio.data, minibatch_segments, - pufferl.fp16_obs_buf.bytes, s); + pufferl.fp16_obs_buf.bytes, pufferl.train_fp16, s); // gather masks from train_masks into mb_masks using same priority indices. // reuses index_copy_kernel as a gather: dst[i] = src[idx[i]]. if (pufferl.has_mask) { @@ -772,9 +772,6 @@ void train_impl(PuffeRL& pufferl) { MetalStream* mts = (MetalStream*)ts; for (int mb = 0; mb < total_minibatches; ++mb) { run_minibatch(ts, train_rng_offset, false); - // Commit current command buffer when ring is >75% full to prevent - // overflow on high replay_ratio configs. Metal queue serial execution - // guarantees the GPU finishes reading ring data before we overwrite it. if (mb + 1 < total_minibatches && mts->const_ring_offset > MTL_CONST_RING_SIZE * 3 / 4) { mts->commit_chunk(); diff --git a/src/metal_shader_src.h b/src/metal_shader_src.h index f9931eb313..1b29eb2860 100644 --- a/src/metal_shader_src.h +++ b/src/metal_shader_src.h @@ -425,9 +425,8 @@ struct SampleParams { int mask_stride; // stride between rows in mask buffer (may differ from num_atns_total) }; -// Apply action mask to a logit: invalid actions get -1e9. inline float masked_logit(float l, float m) { - if (m < 0.5f) l = -1e9f; + if (m < 0.5f) l = -INFINITY; return l; } @@ -495,11 +494,19 @@ kernel void sample_logits_kernel( // Max + logsumexp (with mask) float max_val = -INFINITY; + bool has_valid_action = false; for (int a = 0; a < A; a++) { + has_valid_action = has_valid_action || action_mask[mask_base + logits_offset + a] >= 0.5f; float l = masked_logit(logits[logits_base + logits_offset + a], action_mask[mask_base + logits_offset + a]); max_val = fmax(max_val, l); } + if (!has_valid_action) { + actions[(int)idx * sp.num_atns + h] = NAN; + total_log_prob = NAN; + logits_offset += A; + continue; + } float sum_exp = 0.0f; for (int a = 0; a < A; a++) { float l = masked_logit(logits[logits_base + logits_offset + a], @@ -581,11 +588,18 @@ kernel void recompute_logprobs_kernel( // Max + logsumexp (with mask) float max_val = -INFINITY; + bool has_valid_action = false; for (int a = 0; a < A; a++) { + has_valid_action = has_valid_action || action_mask[mask_base + logits_offset + a] >= 0.5f; max_val = fmax(max_val, masked_logit( logits[logits_base + logits_offset + a], action_mask[mask_base + logits_offset + a])); } + if (!has_valid_action) { + total_log_prob = NAN; + logits_offset += A; + continue; + } float sum_exp = 0.0f; for (int a = 0; a < A; a++) { sum_exp += exp(masked_logit( @@ -630,9 +644,6 @@ inline void atomic_add_float(device atomic_uint* addr, float val) { } } -// PPO helper: compute logsumexp, entropy, log_prob for a single discrete head with masks. -// mask pointer + mask_offset index into the action mask for this head. -// Invalid actions (mask < 0.5) get logit = -1e9, matching rollout sampling. inline void ppo_discrete_head( const device float* logits, int logits_base, int logits_stride_a, int logits_offset, @@ -643,21 +654,21 @@ inline void ppo_discrete_head( float max_logit = -INFINITY; float sum = 0.0f; float act_logit = 0.0f; + bool has_valid_action = false; for (int a = 0; a < A; a++) { - float l = logits[logits_base + (logits_offset + a) * logits_stride_a]; - if (mask[mask_offset + a] < 0.5f) l = -1e9f; + float m = mask[mask_offset + a]; + float l = masked_logit(logits[logits_base + (logits_offset + a) * logits_stride_a], m); if (a == act) act_logit = l; + if (m < 0.5f) continue; + has_valid_action = true; if (l > max_logit) { sum *= exp(max_logit - l); max_logit = l; } sum += exp(l - max_logit); } - // Degenerate input (all masked or non-finite model output): propagate NaN - // so the corruption surfaces immediately in the PPO loss rather than - // silently producing logp=0 (ratio=1) which poisons gradients. - if (!isfinite(max_logit) || !isfinite(sum) || sum <= 0.0f) { + if (!has_valid_action || !isfinite(max_logit) || !isfinite(sum) || sum <= 0.0f) { out_logsumexp = NAN; out_entropy = NAN; out_logp = NAN; @@ -667,8 +678,9 @@ inline void ppo_discrete_head( float ent = 0.0f; for (int a = 0; a < A; a++) { - float l = logits[logits_base + (logits_offset + a) * logits_stride_a]; - if (mask[mask_offset + a] < 0.5f) l = -1e9f; + if (mask[mask_offset + a] < 0.5f) continue; + float l = masked_logit(logits[logits_base + (logits_offset + a) * logits_stride_a], + mask[mask_offset + a]); float logp = l - lse; float p = exp(clamp(logp, -80.0f, 80.0f)); ent -= p * logp; @@ -887,13 +899,16 @@ kernel void ppo_loss_fwd_bwd_kernel( for (int a = 0; a < A; a++) { float raw_l = logits[logits_base + (logits_offset + a) * pp.logits_stride_a]; float m = action_mask[mask_base + logits_offset + a]; - float l = (m < 0.5f) ? -1e9f : raw_l; + if (m < 0.5f) { + grad_logits[grad_logits_base + logits_offset + a] = 0.0f; + continue; + } + float l = masked_logit(raw_l, m); float logp = l - lse; float p = exp(logp); float d_logit = (a == act) ? d_new_logp : 0.0f; d_logit -= p * d_new_logp; d_logit += d_entropy_term * p * (-ent - logp); - if (m < 0.5f) d_logit = 0.0f; grad_logits[grad_logits_base + logits_offset + a] = d_logit; } logits_offset += A; @@ -1586,10 +1601,11 @@ struct SelectCopyParams { int act_row_bytes; int lp_row_bytes; int horizon; + int train_fp16; // 1: encoder reads fp16_obs_out (skip mb_obs f32 write); 0: encoder reads mb_obs (skip f16 write) }; // Minibatch assembly: copy observations, actions, logprobs, values+advantages+returns, prio -// Channel 0 fuses obs gather + f32→f16 cast: reads f32 src, writes f16 directly to fp16_obs_out. +// Channel 0 fuses obs gather + f32→f16 cast. Only the variant the encoder will read gets written. // Dispatched as (minibatch_size, 5) threadgroups, each handles one channel for one row. kernel void select_copy_kernel( device char* mb_obs [[buffer(0)]], @@ -1616,16 +1632,20 @@ kernel void select_copy_kernel( int src_row = (int)idx[mb]; if (ch == 0) { - // Fused obs gather + f32→f16 cast: copy f32 to mb_obs AND write f16 directly. - // mb_obs f32 copy is needed because PPO reads embedded action masks from it. + // Fused obs gather + (optional) f32→f16 cast. Encoder reads exactly one variant + // depending on train_fp16; the other is dead-on-arrival, so skip writing it. const device float* sptr = (const device float*)(src_obs + (int64_t)src_row * p.obs_row_bytes); - int count = p.obs_row_bytes / 4; // number of floats - device float* f32ptr = (device float*)(mb_obs + (int64_t)mb * p.obs_row_bytes); - device half* f16ptr = fp16_obs_out + (int64_t)mb * count; - for (int i = (int)tid; i < count; i += 256) { - float val = sptr[i]; - f32ptr[i] = val; - f16ptr[i] = half(val); + int count = p.obs_row_bytes / 4; + if (p.train_fp16) { + device half* f16ptr = fp16_obs_out + (int64_t)mb * count; + for (int i = (int)tid; i < count; i += 256) { + f16ptr[i] = half(sptr[i]); + } + } else { + device float* f32ptr = (device float*)(mb_obs + (int64_t)mb * p.obs_row_bytes); + for (int i = (int)tid; i < count; i += 256) { + f32ptr[i] = sptr[i]; + } } } else if (ch == 1) { // Copy actions diff --git a/src/pufferlib.cu b/src/pufferlib.cu index 6c513c97b7..1d7c5966cb 100644 --- a/src/pufferlib.cu +++ b/src/pufferlib.cu @@ -330,7 +330,7 @@ typedef struct { } PuffeRL; Dict* log_environments_impl(PuffeRL& pufferl) { - Dict* out = create_dict(32); + Dict* out = create_dict(64); static_vec_log(pufferl.vec, out); return out; } diff --git a/src/vecenv.h b/src/vecenv.h index 71998f6e0c..74350f2463 100644 --- a/src/vecenv.h +++ b/src/vecenv.h @@ -122,6 +122,7 @@ void static_vec_omp_step(StaticVec* vec); void static_vec_seq_step(StaticVec* vec); void static_vec_render(StaticVec* vec, int env_id); void static_vec_read_profile(StaticVec* vec, float out[NUM_EVAL_PROF]); +const char* get_static_env_name(void); // Env info int get_obs_size(void); @@ -172,8 +173,15 @@ static inline size_t obs_element_size(void) { #define _STRINGIFY(x) #x #define STRINGIFY(x) _STRINGIFY(x) +#ifndef ENV_NAME +#define ENV_NAME unknown +#endif const char dtype_symbol[] = STRINGIFY(OBS_TENSOR_T); +const char* get_static_env_name(void) { + return STRINGIFY(ENV_NAME); +} + #include #include #include diff --git a/tests/test_metal_const_ring.mm b/tests/test_metal_const_ring.mm new file mode 100644 index 0000000000..52214808e7 --- /dev/null +++ b/tests/test_metal_const_ring.mm @@ -0,0 +1,34 @@ +/** + * @file test_metal_const_ring.mm + * @brief Tests constant-ring reservation bounds. + */ + +#include "src/metal_platform.h" + +int main(void) { + NSUInteger next_offset = 0; + + if (!mtl_const_ring_reserve_range(0, 1, &next_offset)) return 1; + if (next_offset != 16) return 2; + + if (!mtl_const_ring_reserve_range(MTL_CONST_RING_SIZE - 16, 16, &next_offset)) return 3; + if (next_offset != MTL_CONST_RING_SIZE) return 4; + + if (mtl_const_ring_reserve_range(MTL_CONST_RING_SIZE - 16, 17, &next_offset)) return 5; + + if (mtl_parse_int_config_value("total_agents", 0.0) != 0) return 6; + if (mtl_parse_int_config_value("total_agents", 128.0) != 128) return 7; + + if (mtl_parse_int_config_value("horizon", 4.5) != 5) return 8; + if (mtl_parse_int_config_value("horizon", 4.4) != 4) return 81; + + bool rejected_nondivisible = false; + try { + mtl_validate_divisible_config_values("total_agents", 10, "num_buffers", 3); + } catch (const std::invalid_argument&) { + rejected_nondivisible = true; + } + if (!rejected_nondivisible) return 9; + + return 0; +} From c6cdb07b3961e8f3cd624176743ef319dde95e36 Mon Sep 17 00:00:00 2001 From: Valtteri Valo Date: Fri, 1 May 2026 18:26:01 +0300 Subject: [PATCH 6/6] build.sh: pass ENV_NAME when compiling static lib The static lib defines get_static_env_name() which returns STRINGIFY(ENV_NAME). Without -DENV_NAME the macro defaults to "unknown", which now trips the runtime guard in PYBIND11_MODULE that compares static_env_name against the binding's ENV_NAME. The static obj already needs to know its env for the binding source it includes, so this just propagates the same define. --- build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/build.sh b/build.sh index 68740090ed..abd3751297 100755 --- a/build.sh +++ b/build.sh @@ -241,6 +241,7 @@ STATIC_CFLAGS=( -I. -Isrc -I"$SRC_DIR" -Ivendor -I./"$RAYLIB_NAME"/include -DPLATFORM_DESKTOP + -DENV_NAME="$ENV" -fvisibility=hidden -fPIC )