diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 708479f79..97184fc43 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -97,6 +97,56 @@ jobs: name: katago-macos-opencl path: cpp/katago + build-macos-metal: + runs-on: macos-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + brew install ninja zlib libzip + brew tap chinchangyang/katagocoreml-cpp + brew install katagocoreml + + - name: Cache CMake build + uses: actions/cache@v4 + with: + path: | + cpp/CMakeCache.txt + cpp/CMakeFiles + cpp/build.ninja + cpp/.ninja_deps + cpp/.ninja_log + key: ${{ runner.os }}-cmake-metal-${{ hashFiles('**/CMakeLists.txt') }} + restore-keys: | + ${{ runner.os }}-cmake-metal- + + - name: Configure CMake + working-directory: cpp + run: | + cmake . -G Ninja -DUSE_BACKEND=METAL -DCMAKE_BUILD_TYPE=Release + + - name: Build + working-directory: cpp + run: | + ninja + + - name: Run tests + working-directory: cpp + run: | + ./katago runtests + + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-macos-metal + path: cpp/katago + build-windows: runs-on: windows-latest permissions: diff --git a/.gitignore b/.gitignore index 2e933d553..b509ec1ff 100644 --- a/.gitignore +++ b/.gitignore @@ -90,3 +90,7 @@ cpp/.ninja_log cpp/build.ninja cpp/KataGoSwift.* cpp/include/KataGoSwift/KataGoSwift-swift.h + +# For CoreML Backend +cpp/KataGoCoreML.* +cpp/include/KataGoCoreML/KataGoCoreML-swift.h diff --git a/Compiling.md b/Compiling.md index 648fea548..acc1b3cbd 100644 --- a/Compiling.md +++ b/Compiling.md @@ -118,8 +118,12 @@ As also mentioned in the instructions below but repeated here for visibility, if * If using OpenCL, you will want to verify that KataGo is picking up the correct device (e.g. some systems may have both an Intel CPU OpenCL and GPU OpenCL, if KataGo appears to pick the wrong one, you can correct this by specifying `openclGpuToUse` in `configs/gtp_example.cfg`). ## MacOS - * TLDR: + * TLDR (Metal backend - recommended for most users, hybrid CPU+GPU+Neural Engine for maximum throughput): ``` + # First, install the katagocoreml library via Homebrew + brew tap chinchangyang/katagocoreml-cpp + brew install katagocoreml + git clone https://github.com/lightvector/KataGo.git cd KataGo/cpp # If you get missing library errors, install the appropriate packages using your system package manager and try again. @@ -132,6 +136,7 @@ As also mentioned in the instructions below but repeated here for visibility, if * CMake with a minimum version of 3.18.2: `brew install cmake`. * AppleClang and Swift compilers: `xcode-select --install`. * If using the Metal backend, [Ninja](https://ninja-build.org): `brew install ninja` + * If using the Metal backend, katagocoreml library: `brew tap chinchangyang/katagocoreml-cpp && brew install katagocoreml` * libzip: `brew install libzip`. * If you want to do self-play training and research, probably Google perftools `brew install gperftools` for TCMalloc or some other better malloc implementation. For unknown reasons, the allocation pattern in self-play with large numbers of threads and parallel games causes a lot of memory fragmentation under glibc malloc that will eventually run your machine out of memory, but better mallocs handle it fine. * If compiling to contribute to public distributed training runs, OpenSSL is required (`brew install openssl`). diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 254d23233..4578732ae 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.18.2) -if(USE_BACKEND STREQUAL "METAL") +if(USE_BACKEND STREQUAL "METAL" OR USE_BACKEND STREQUAL "COREML") project(katago LANGUAGES CXX Swift) else() project(katago) @@ -32,7 +32,7 @@ endif() set(BUILD_DISTRIBUTED 0 CACHE BOOL "Build with http support for contributing to distributed training") set(USE_BACKEND CACHE STRING "Neural net backend") string(TOUPPER "${USE_BACKEND}" USE_BACKEND) -set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN) +set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN METAL) set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc") set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe") @@ -97,7 +97,7 @@ elseif(USE_BACKEND STREQUAL "TENSORRT") message(FATAL_ERROR "Combining USE_CACHE_TENSORRT_PLAN with BUILD_DISTRIBUTED is not supported - it would consume excessive disk space and might worsen performance every time models are updated. Use only one at a time in a given build of KataGo.") endif() elseif(USE_BACKEND STREQUAL "METAL") - message(STATUS "-DUSE_BACKEND=METAL, using Metal backend.") + message(STATUS "-DUSE_BACKEND=METAL, using Metal backend with hybrid MPSGraph + CoreML execution.") if(NOT "${CMAKE_GENERATOR}" STREQUAL "Ninja") message(FATAL_ERROR "Bidirectional C++ Interop requires Ninja generator. Have ${CMAKE_GENERATOR}") endif() @@ -107,6 +107,8 @@ elseif(USE_BACKEND STREQUAL "METAL") if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") message(FATAL_ERROR "Project requires building with AppleClang. Have ${CMAKE_CXX_COMPILER_ID}") endif() + find_package(PkgConfig REQUIRED) + pkg_check_modules(KATAGOCOREML REQUIRED katagocoreml) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/external/macos/cmake/modules") include(InitializeSwift) include(AddSwift) @@ -115,11 +117,11 @@ elseif(USE_BACKEND STREQUAL "METAL") neuralnet/metalbackend.cpp ) add_library(KataGoSwift STATIC - neuralnet/metalbackend.swift) + neuralnet/metalbackend.swift + neuralnet/metallayers.swift) _swift_generate_cxx_header( KataGoSwift - "${CMAKE_CURRENT_BINARY_DIR}/include/KataGoSwift/KataGoSwift-swift.h" - SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/neuralnet/metalbackend.swift") + "${CMAKE_CURRENT_BINARY_DIR}/include/KataGoSwift/KataGoSwift-swift.h") target_include_directories(KataGoSwift PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/include") set_target_properties(KataGoSwift PROPERTIES Swift_MODULE_NAME "KataGoSwift") target_compile_options(KataGoSwift PUBLIC @@ -399,9 +401,14 @@ elseif(USE_BACKEND STREQUAL "TENSORRT") target_link_libraries(katago CUDA::cudart_static ${TENSORRT_LIBRARY}) elseif(USE_BACKEND STREQUAL "METAL") target_compile_definitions(katago PRIVATE USE_METAL_BACKEND) - target_link_libraries(katago KataGoSwift) + target_include_directories(katago PRIVATE ${KATAGOCOREML_INCLUDE_DIRS}) + find_library(KATAGOCOREML_LIB katagocoreml HINTS /usr/local/lib REQUIRED) + target_link_directories(katago PRIVATE ${KATAGOCOREML_LIBRARY_DIRS}) + target_link_libraries(katago KataGoSwift ${KATAGOCOREML_LIB} ${KATAGOCOREML_LDFLAGS} + "-framework MetalPerformanceShaders" + "-framework MetalPerformanceShadersGraph") if("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64") - message(WARNING "You are currently running cmake on an Intel-based processor. It is known that running KataGo in this configuration may encounter performance issues. It is recommended to switch to a cmake version designed for ARM64 architecture for optimal performance.") + message(WARNING "Metal backend may not work optimally on Intel. ARM64 architecture is recommended.") endif() elseif(USE_BACKEND STREQUAL "OPENCL") target_compile_definitions(katago PRIVATE USE_OPENCL_BACKEND) diff --git a/cpp/command/benchmark.cpp b/cpp/command/benchmark.cpp index 3100fb1b1..b29b9325b 100644 --- a/cpp/command/benchmark.cpp +++ b/cpp/command/benchmark.cpp @@ -260,6 +260,9 @@ int MainCmds::benchmark(const vector& args) { #ifdef USE_METAL_BACKEND cout << "You are currently using the Metal version of KataGo." << endl; #endif +#ifdef USE_COREML_BACKEND + cout << "You are currently using the Core ML version of KataGo." << endl; +#endif #ifdef USE_OPENCL_BACKEND cout << "You are currently using the OpenCL version of KataGo." << endl; cout << "If you have a strong GPU capable of FP16 tensor cores (e.g. RTX2080), " diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 58ff2c4a3..76e236537 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -5,62 +5,156 @@ #include "../neuralnet/nninputs.h" #include "../neuralnet/nninterface.h" #include "../neuralnet/metalbackend.h" -#include "../core/test.h" -/// Converts a ConvLayerDesc instance from C++ to Swift by creating a new SWConvLayerDesc instance with the same properties. -/// - Parameter desc: The ConvLayerDesc instance to convert. -/// - Returns: A SWConvLayerDesc instance with the same properties as the input ConvLayerDesc. -SWConvLayerDesc MetalProcess::convLayerDescToSwift(const ConvLayerDesc * desc) { +#include +#include +#include +#include +#include +#include // For getpid() - SWConvLayerDesc swDesc = createSWConvLayerDesc(desc->convYSize, - desc->convXSize, - desc->inChannels, - desc->outChannels, - desc->dilationY, - desc->dilationX, - (float*)desc->weights.data()); +using namespace std; - return swDesc; +//------------------------------------------------------------------------------ +// CoreML Model Conversion - Native C++ using katagocoreml library +//------------------------------------------------------------------------------ + +namespace gfs = ghc::filesystem; + +// Minimum batch sizes for hybrid execution mode. +// Hybrid splits batches between CoreML (CPU+ANE) and MPSGraph (GPU). +// When batch is too small to split, prefer MPSGraph-only for stability: +// MPSGraph has more predictable latency and avoids CoreML dispatch overhead. +static constexpr int MIN_COREML_BATCH = 1; +static constexpr int MIN_MPSGRAPH_BATCH = 1; + +namespace CoreMLConversion { + +// Get temp directory for model conversion +static string getTempDirectory() { + gfs::path tempDir = gfs::temp_directory_path() / "katago_coreml"; + std::error_code ec; + gfs::create_directories(tempDir, ec); + if(ec) { + throw runtime_error("Failed to create temp directory: " + ec.message()); + } + return tempDir.string(); } -/// Converts a BatchNormLayerDesc instance from C++ to Swift by creating a new SWBatchNormLayerDesc instance with the same properties. -/// - Parameter desc: The BatchNormLayerDesc instance to convert. -/// - Returns: A SWBatchNormLayerDesc instance with the same properties as the input BatchNormLayerDesc. -SWBatchNormLayerDesc MetalProcess::batchNormLayerDescToSwift(const BatchNormLayerDesc * desc) { +// Generate unique temporary path for model conversion +static string generateTempPath(int serverThreadIdx) { + auto now = chrono::steady_clock::now().time_since_epoch().count(); + return getTempDirectory() + "/model_" + to_string(getpid()) + "_" + + to_string(serverThreadIdx) + "_" + to_string(now) + ".mlpackage"; +} - SWBatchNormLayerDesc swDesc = - createSWBatchNormLayerDesc(desc->numChannels, - (float*)desc->mergedScale.data(), - (float*)desc->mergedBias.data()); +// CoreML model metadata constants +static const string COREML_MODEL_AUTHOR = "KataGo"; +static const string COREML_MODEL_LICENSE = "See original model file for license terms"; - return swDesc; +// Convert KataGo model to CoreML in temp directory, returns path to .mlpackage +// The caller (Swift side) is responsible for deleting the temp file after loading +static string convertModelToTemp( + const string& modelPath, + int boardX, + int boardY, + bool useFP16, + bool optimizeMask, + int maxBatchSize, + int serverThreadIdx +) { + // maxBatchSize is validated upstream: cfg.getInt("nnMaxBatchSize", 1, 65536) in setup.cpp + // and NNEvaluator constructor throws if maxBatchSize <= 0. Assert for defensive documentation. + assert(maxBatchSize >= 1); + + string tempPath = generateTempPath(serverThreadIdx); + cerr << "Metal backend " << serverThreadIdx << ": Converting model to " << tempPath << endl; + + katagocoreml::ConversionOptions opts; + opts.board_x_size = boardX; + opts.board_y_size = boardY; + opts.compute_precision = useFP16 ? "FLOAT16" : "FLOAT32"; + opts.optimize_identity_mask = optimizeMask; + opts.min_batch_size = 1; + opts.max_batch_size = maxBatchSize; + opts.author = COREML_MODEL_AUTHOR; + opts.license = COREML_MODEL_LICENSE; + + try { + katagocoreml::KataGoConverter::convert(modelPath, tempPath, opts); + } catch(const exception& e) { + // Clean up partial conversion on failure + std::error_code ec; + gfs::remove_all(tempPath, ec); + if(ec) { + cerr << "Metal backend " << serverThreadIdx << ": Warning: Failed to clean up partial conversion at " << tempPath << ": " << ec.message() << endl; + } + throw runtime_error(string("Metal backend ") + to_string(serverThreadIdx) + ": Core ML model conversion failed: " + e.what()); + } + + cerr << "Metal backend " << serverThreadIdx << ": Conversion completed" << endl; + return tempPath; } -/// Convert an activation layer description from C++ to Swift -/// - Parameter desc: An activation layer description -ActivationKind MetalProcess::activationLayerDescToSwift(const ActivationLayerDesc * desc) { +} // namespace CoreMLConversion + +//------------------------------------------------------------------------------ +// Model Descriptor Conversion - C++ to Swift types for MPSGraph +//------------------------------------------------------------------------------ + +namespace MetalProcess { + +/// Converts a ConvLayerDesc instance from C++ to Swift +SWConvLayerDesc convLayerDescToSwift(const ConvLayerDesc* desc) { + return createSWConvLayerDesc( + desc->convYSize, + desc->convXSize, + desc->inChannels, + desc->outChannels, + desc->dilationY, + desc->dilationX, + (float*)desc->weights.data()); +} - switch (desc->activation) { +/// Converts a BatchNormLayerDesc instance from C++ to Swift +SWBatchNormLayerDesc batchNormLayerDescToSwift(const BatchNormLayerDesc* desc) { + return createSWBatchNormLayerDesc( + desc->numChannels, + (float*)desc->mergedScale.data(), + (float*)desc->mergedBias.data()); +} + +/// Convert an activation layer description from C++ to Swift +ActivationKind activationLayerDescToSwift(const ActivationLayerDesc* desc) { + switch(desc->activation) { case ACTIVATION_RELU: return ActivationKind::relu(); case ACTIVATION_MISH: return ActivationKind::mish(); case ACTIVATION_MISH_SCALE8: - testAssert(false); // Metal does not use scaled mish activations due to no fp16 - return ActivationKind::identity(); // Placeholder for compilation + return ActivationKind::identity(); // Metal/CoreML does not use scaled mish case ACTIVATION_IDENTITY: return ActivationKind::identity(); default: - testAssert(false); - return ActivationKind::identity(); // Placeholder for compilation + return ActivationKind::identity(); } } -/// Convert a residual block description from C++ to Swift -/// - Parameter desc: A residual block description -/// - Returns: The residual block description converted to SWResidualBlockDesc -SWResidualBlockDesc MetalProcess::residualBlockDescToSwift(const ResidualBlockDesc * desc) { +/// Convert a matrix multiplication layer description from C++ to Swift +SWMatMulLayerDesc matMulLayerDescToSwift(const MatMulLayerDesc* desc) { + return createSWMatMulLayerDesc( + desc->inChannels, + desc->outChannels, + (float*)desc->weights.data()); +} +/// Convert a matrix bias layer description from C++ to Swift +SWMatBiasLayerDesc matBiasLayerDescToSwift(const MatBiasLayerDesc* desc) { + return createSWMatBiasLayerDesc(desc->numChannels, (float*)desc->weights.data()); +} + +/// Convert a residual block description from C++ to Swift +SWResidualBlockDesc residualBlockDescToSwift(const ResidualBlockDesc* desc) { SWBatchNormLayerDesc preBN = batchNormLayerDescToSwift(&desc->preBN); ActivationKind preActivationKind = activationLayerDescToSwift(&desc->preActivation); SWConvLayerDesc regularConv = convLayerDescToSwift(&desc->regularConv); @@ -68,34 +162,17 @@ SWResidualBlockDesc MetalProcess::residualBlockDescToSwift(const ResidualBlockDe ActivationKind midActivationKind = activationLayerDescToSwift(&desc->midActivation); SWConvLayerDesc finalConv = convLayerDescToSwift(&desc->finalConv); - SWResidualBlockDesc swDesc = - createSWResidualBlockDesc(preBN, - preActivationKind, - regularConv, - midBN, - midActivationKind, - finalConv); - - return swDesc; -} - -/// Convert a matrix multiplication layer description from C++ to Swift -/// - Parameter desc: A matrix multiplication layer description -/// - Returns: The matrix multiplication layer description converted to SWMatMulLayerDesc -SWMatMulLayerDesc MetalProcess::matMulLayerDescToSwift(const MatMulLayerDesc * desc) { - - SWMatMulLayerDesc swDesc = createSWMatMulLayerDesc(desc->inChannels, - desc->outChannels, - (float*)desc->weights.data()); - - return swDesc; + return createSWResidualBlockDesc( + preBN, + preActivationKind, + regularConv, + midBN, + midActivationKind, + finalConv); } /// Convert a global pooling residual block description from C++ to Swift -/// - Parameter desc: A global pooling residual block description -/// - Returns: The global pooling residual block description converted to SWGlobalPoolingResidualBlockDesc -SWGlobalPoolingResidualBlockDesc MetalProcess::globalPoolingResidualBlockDescToSwift(const GlobalPoolingResidualBlockDesc* desc) { - +SWGlobalPoolingResidualBlockDesc globalPoolingResidualBlockDescToSwift(const GlobalPoolingResidualBlockDesc* desc) { SWBatchNormLayerDesc preBN = batchNormLayerDescToSwift(&desc->preBN); ActivationKind preActivationKind = activationLayerDescToSwift(&desc->preActivation); SWConvLayerDesc regularConv = convLayerDescToSwift(&desc->regularConv); @@ -107,37 +184,53 @@ SWGlobalPoolingResidualBlockDesc MetalProcess::globalPoolingResidualBlockDescToS ActivationKind midActivationKind = activationLayerDescToSwift(&desc->midActivation); SWConvLayerDesc finalConv = convLayerDescToSwift(&desc->finalConv); - SWGlobalPoolingResidualBlockDesc swDesc = - createSWGlobalPoolingResidualBlockDesc(preBN, - preActivationKind, - regularConv, - gpoolConv, - gpoolBN, - gpoolActivationKind, - gpoolToBiasMul, - midBN, - midActivationKind, - finalConv); + return createSWGlobalPoolingResidualBlockDesc( + preBN, + preActivationKind, + regularConv, + gpoolConv, + gpoolBN, + gpoolActivationKind, + gpoolToBiasMul, + midBN, + midActivationKind, + finalConv); +} + +// Forward declaration for mutual recursion +swift::Array residualBlocksToSwift(const vector>& blocks); + +/// Convert a nested bottleneck residual block description from C++ to Swift +SWNestedBottleneckResidualBlockDesc nestedBottleneckResidualBlockDescToSwift(const NestedBottleneckResidualBlockDesc* desc) { + SWBatchNormLayerDesc preBN = batchNormLayerDescToSwift(&desc->preBN); + ActivationKind preActivationKind = activationLayerDescToSwift(&desc->preActivation); + SWConvLayerDesc preConv = convLayerDescToSwift(&desc->preConv); + auto swBlocks = residualBlocksToSwift(desc->blocks); + SWBatchNormLayerDesc postBN = batchNormLayerDescToSwift(&desc->postBN); + ActivationKind postActivationKind = activationLayerDescToSwift(&desc->postActivation); + SWConvLayerDesc postConv = convLayerDescToSwift(&desc->postConv); - return swDesc; + return createSWNestedBottleneckResidualBlockDesc( + preBN, + preActivationKind, + preConv, + swBlocks, + postBN, + postActivationKind, + postConv); } /// Convert residual blocks from C++ to Swift -/// - Parameters: -/// - blocks: Residual blocks -/// - swBlocks: A pointer to an array of BlockDescriptor -swift::Array MetalProcess::residualBlocksToSwift(const vector>& blocks) { - +swift::Array residualBlocksToSwift(const vector>& blocks) { auto builder = createBlockDescriptorBuilder(); - for (int i = 0; i < blocks.size(); i++) { + for(size_t i = 0; i < blocks.size(); i++) { + void* blockDesc = blocks[i].second.get(); - void * blockDesc = blocks[i].second.get(); - - if (blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { BlockDescriptor descriptor = globalPoolingResidualBlockDescToSwift((GlobalPoolingResidualBlockDesc*)blockDesc); builder.enque(descriptor); - } else if (blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + } else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { BlockDescriptor descriptor = nestedBottleneckResidualBlockDescToSwift((NestedBottleneckResidualBlockDesc*)blockDesc); builder.enque(descriptor); } else { @@ -149,35 +242,8 @@ swift::Array MetalProcess::residualBlocksToSwift(const vector

preBN); - ActivationKind preActivationKind = activationLayerDescToSwift(&desc->preActivation); - SWConvLayerDesc preConv = convLayerDescToSwift(&desc->preConv); - auto swBlocks = residualBlocksToSwift(desc->blocks); - SWBatchNormLayerDesc postBN = batchNormLayerDescToSwift(&desc->postBN); - ActivationKind postActivationKind = activationLayerDescToSwift(&desc->postActivation); - SWConvLayerDesc postConv = convLayerDescToSwift(&desc->postConv); - - SWNestedBottleneckResidualBlockDesc swDesc = - createSWNestedBottleneckResidualBlockDesc(preBN, - preActivationKind, - preConv, - swBlocks, - postBN, - postActivationKind, - postConv); - - return swDesc; -} - /// Convert a SGF metadata encoder description from C++ to Swift -/// - Parameter desc: A SGF metadata encoder description -/// - Returns: The SGF metadata encoder description converted to SWSGFMetadataEncoderDesc -swift::Optional MetalProcess::sGFMetadataEncoderDescToSwift(const SGFMetadataEncoderDesc * desc) { - +swift::Optional sGFMetadataEncoderDescToSwift(const SGFMetadataEncoderDesc* desc) { SWMatMulLayerDesc mul1 = matMulLayerDescToSwift(&desc->mul1); SWMatBiasLayerDesc bias1 = matBiasLayerDescToSwift(&desc->bias1); ActivationKind act1 = activationLayerDescToSwift(&desc->act1); @@ -186,24 +252,20 @@ swift::Optional MetalProcess::sGFMetadataEncoderDescTo ActivationKind act2 = activationLayerDescToSwift(&desc->act2); SWMatMulLayerDesc mul3 = matMulLayerDescToSwift(&desc->mul3); - auto swSGFMetadataEncoderDesc = createSWSGFMetadataEncoderDesc(desc->metaEncoderVersion, - desc->numInputMetaChannels, - mul1, - bias1, - act1, - mul2, - bias2, - act2, - mul3); - - return swSGFMetadataEncoderDesc; + return createSWSGFMetadataEncoderDesc( + desc->metaEncoderVersion, + desc->numInputMetaChannels, + mul1, + bias1, + act1, + mul2, + bias2, + act2, + mul3); } /// Convert a trunk description from C++ to Swift -/// - Parameter trunk: A trunk description -/// - Returns: The trunk description converted to SWTrunkDesc -SWTrunkDesc MetalProcess::trunkDescToSwift(const TrunkDesc * trunk) { - +SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { SWConvLayerDesc initialConv = convLayerDescToSwift(&trunk->initialConv); SWMatMulLayerDesc initialMatMul = matMulLayerDescToSwift(&trunk->initialMatMul); auto sgfMetadataEncoder = sGFMetadataEncoderDescToSwift(&trunk->sgfMetadataEncoder); @@ -211,26 +273,22 @@ SWTrunkDesc MetalProcess::trunkDescToSwift(const TrunkDesc * trunk) { SWBatchNormLayerDesc trunkTipBN = batchNormLayerDescToSwift(&trunk->trunkTipBN); ActivationKind trunkTipActivation = activationLayerDescToSwift(&trunk->trunkTipActivation); - SWTrunkDesc swTrunkDesc = createSWTrunkDesc(trunk->modelVersion, - trunk->trunkNumChannels, - trunk->midNumChannels, - trunk->regularNumChannels, - trunk->gpoolNumChannels, - initialConv, - initialMatMul, - sgfMetadataEncoder, - swBlocks, - trunkTipBN, - trunkTipActivation); - - return swTrunkDesc; + return createSWTrunkDesc( + trunk->modelVersion, + trunk->trunkNumChannels, + trunk->midNumChannels, + trunk->regularNumChannels, + trunk->gpoolNumChannels, + initialConv, + initialMatMul, + sgfMetadataEncoder, + swBlocks, + trunkTipBN, + trunkTipActivation); } /// Convert a policy head description from C++ to Swift -/// - Parameter policyHead: A policy head description -/// - Returns: The policy head description converted to SWPolicyHeadDesc -SWPolicyHeadDesc MetalProcess::policyHeadDescToSwift(const PolicyHeadDesc * policyHead) { - +SWPolicyHeadDesc policyHeadDescToSwift(const PolicyHeadDesc* policyHead) { SWConvLayerDesc p1Conv = convLayerDescToSwift(&policyHead->p1Conv); SWConvLayerDesc g1Conv = convLayerDescToSwift(&policyHead->g1Conv); SWBatchNormLayerDesc g1BN = batchNormLayerDescToSwift(&policyHead->g1BN); @@ -244,38 +302,24 @@ SWPolicyHeadDesc MetalProcess::policyHeadDescToSwift(const PolicyHeadDesc * poli ActivationKind passActivation = activationLayerDescToSwift(&policyHead->passActivation); SWMatMulLayerDesc gpoolToPassMul2 = matMulLayerDescToSwift(&policyHead->gpoolToPassMul2); - SWPolicyHeadDesc swPolicyHead = createSWPolicyHeadDesc(policyHead->modelVersion, - p1Conv, - g1Conv, - g1BN, - g1Activation, - gpoolToBiasMul, - p1BN, - p1Activation, - p2Conv, - gpoolToPassMul, - gpoolToPassBias, - passActivation, - gpoolToPassMul2); - - return swPolicyHead; -} - -/// Convert a matrix bias layer description from C++ to Swift -/// - Parameter desc: A matrix bias layer description -/// - Returns: The matrix bias layer description converted to SWMatBiasLayerDesc -SWMatBiasLayerDesc MetalProcess::matBiasLayerDescToSwift(const MatBiasLayerDesc * desc) { - - SWMatBiasLayerDesc swDesc = createSWMatBiasLayerDesc(desc->numChannels, (float*)desc->weights.data()); - - return swDesc; + return createSWPolicyHeadDesc( + policyHead->modelVersion, + p1Conv, + g1Conv, + g1BN, + g1Activation, + gpoolToBiasMul, + p1BN, + p1Activation, + p2Conv, + gpoolToPassMul, + gpoolToPassBias, + passActivation, + gpoolToPassMul2); } /// Convert a value head description from C++ to Swift -/// - Parameter valueHead: A value head description -/// - Returns: The value head description converted to SWValueHeadDesc -SWValueHeadDesc MetalProcess::valueHeadDescToSwift(const ValueHeadDesc * valueHead) { - +SWValueHeadDesc valueHeadDescToSwift(const ValueHeadDesc* valueHead) { SWConvLayerDesc v1Conv = convLayerDescToSwift(&valueHead->v1Conv); SWBatchNormLayerDesc v1BN = batchNormLayerDescToSwift(&valueHead->v1BN); ActivationKind v1Activation = activationLayerDescToSwift(&valueHead->v1Activation); @@ -288,136 +332,90 @@ SWValueHeadDesc MetalProcess::valueHeadDescToSwift(const ValueHeadDesc * valueHe SWMatBiasLayerDesc sv3Bias = matBiasLayerDescToSwift(&valueHead->sv3Bias); SWConvLayerDesc vOwnershipConv = convLayerDescToSwift(&valueHead->vOwnershipConv); - SWValueHeadDesc swDesc = createSWValueHeadDesc(valueHead->modelVersion, - v1Conv, - v1BN, - v1Activation, - v2Mul, - v2Bias, - v2Activation, - v3Mul, - v3Bias, - sv3Mul, - sv3Bias, - vOwnershipConv); - - return swDesc; -} - -SWModelDesc MetalProcess::modelDescToSwift(const ModelDesc* modelDesc) { - return createSWModelDesc(modelDesc->modelVersion, - swift::String(modelDesc->name), - modelDesc->numInputChannels, - modelDesc->numInputGlobalChannels, - modelDesc->numInputMetaChannels, - modelDesc->numValueChannels, - modelDesc->numScoreValueChannels, - modelDesc->numOwnershipChannels, - trunkDescToSwift(&modelDesc->trunk), - policyHeadDescToSwift(&modelDesc->policyHead), - valueHeadDescToSwift(&modelDesc->valueHead)); -} - -//--------------------------------------------------------------------------------------------------------- - -/** - * @brief This function initializes the global state of the NeuralNet class upon program startup. - * This function should be called only once upon program startup. It ensures that the global state - * of the NeuralNet class is properly initialized, enabling it to function correctly throughout - * the lifetime of the program. - * Note that this function does not take any input parameters or return any values. - */ + return createSWValueHeadDesc( + valueHead->modelVersion, + v1Conv, + v1BN, + v1Activation, + v2Mul, + v2Bias, + v2Activation, + v3Mul, + v3Bias, + sv3Mul, + sv3Bias, + vOwnershipConv); +} + +/// Convert a model description from C++ to Swift +SWModelDesc modelDescToSwift(const ModelDesc* modelDesc) { + return createSWModelDesc( + modelDesc->modelVersion, + swift::String(modelDesc->name), + modelDesc->numInputChannels, + modelDesc->numInputGlobalChannels, + modelDesc->numInputMetaChannels, + modelDesc->numValueChannels, + modelDesc->numScoreValueChannels, + modelDesc->numOwnershipChannels, + modelDesc->numPolicyChannels, + trunkDescToSwift(&modelDesc->trunk), + policyHeadDescToSwift(&modelDesc->policyHead), + valueHeadDescToSwift(&modelDesc->valueHead)); +} + +} // namespace MetalProcess + +//------------------------------------------------------------------------------ +// LoadedModel implementation +//------------------------------------------------------------------------------ + +LoadedModel::LoadedModel(const string& fileName, const string& expectedSha256) { + modelPath = fileName; + ModelDesc::loadFromFileMaybeGZipped(fileName, modelDesc, expectedSha256); +} + +//------------------------------------------------------------------------------ +// NeuralNet namespace - Global functions +//------------------------------------------------------------------------------ + void NeuralNet::globalInitialize() { - // Do nothing. + // No global initialization needed for Metal backend } -/** - * @brief This function cleans up the global state of the NeuralNet class at program termination. - * This function should be called once at program termination. It ensures that the global state of - * the NeuralNet class is properly cleaned up, freeing any resources that were allocated during the - * lifetime of the program. - * Note that this function does not take any input parameters or return any values. - */ void NeuralNet::globalCleanup() { - // Do nothing. -} - -/** - * @brief Loads a neural network model from a file. - * This function creates a LoadedModel object by loading a neural network model from a file specified by - * the `file` parameter and expected SHA-256 hash specified by the `expectedSha256` parameter. The LoadedModel - * object is returned as a pointer. - * @param file The name of the file containing the neural network model. - * @param expectedSha256 The expected SHA-256 hash of the model file. - * @return A pointer to the LoadedModel object created by loading the model file. - */ + // No cleanup needed - temp files are deleted immediately after loading +} + LoadedModel* NeuralNet::loadModelFile(const string& file, const string& expectedSha256) { LoadedModel* loadedModel = new LoadedModel(file, expectedSha256); return loadedModel; } -/** - * @brief Frees memory used by a LoadedModel object. - * This function deallocates memory used by a LoadedModel object specified by the `loadedModel` parameter. - * @param loadedModel A pointer to the LoadedModel object to deallocate memory for. - */ void NeuralNet::freeLoadedModel(LoadedModel* loadedModel) { delete loadedModel; } -/** - * @brief Retrieves the model description associated with the loaded model. - * - * This function accesses the model description from a given LoadedModel instance. - * It returns a constant reference to the ModelDesc, which contains details - * about the structure and parameters of the neural network model. - * - * @param loadedModel Pointer to the LoadedModel instance from which to retrieve - * the model description. This should not be null. - * @return const ModelDesc& A constant reference to the model description of - * the loaded model. - */ const ModelDesc& NeuralNet::getModelDesc(const LoadedModel* loadedModel) { return loadedModel->modelDesc; } +//------------------------------------------------------------------------------ +// ComputeContext implementation //------------------------------------------------------------------------------ ComputeContext::ComputeContext(int nnX, int nnY, enabled_t useFP16Mode, enabled_t useNHWCMode): -metalComputeContext(createMetalComputeContext(nnX, nnY)) { +metalContext(createMetalComputeContext(nnX, nnY, useFP16Mode != enabled_t::False)) { this->useFP16Mode = useFP16Mode; - - SWEnable swUseFP16Mode = - (useFP16Mode == enabled_t::False) ? SWEnable::False() : - (useFP16Mode == enabled_t::True) ? SWEnable::True() : - SWEnable::Auto(); - - SWEnable swUseNHWCMode = - (useNHWCMode == enabled_t::False) ? SWEnable::False() : - (useNHWCMode == enabled_t::True) ? SWEnable::True() : - SWEnable::Auto(); + this->nnXLen = nnX; + this->nnYLen = nnY; + // Metal backend only supports NCHW layout (MPSGraph native format) + (void)useNHWCMode; } ComputeContext::~ComputeContext() { } -/** - * @brief Creates a ComputeContext object for computing neural network operations. - * This function creates a ComputeContext object by setting configuration settings for neural network computations, - * such as whether to use half-precision floating-point (FP16) mode and whether to use the NHWC format for input - * tensors. The ComputeContext object is returned as a pointer. - * @param gpuIdxs (Unused) A vector of GPU indices to use for computations. - * @param logger (Unused) A pointer to a Logger object to use for logging messages. - * @param nnXLen The width of the input tensor. - * @param nnYLen The height of the input tensor. - * @param openCLTunerFile (Unused) The name of a file containing OpenCL tuning parameters. - * @param homeDataDirOverride (Unused) A directory to use for storing data. - * @param openCLReTunePerBoardSize (Unused) Whether to re-tune OpenCL parameters for different board sizes. - * @param useFP16Mode Whether to use half-precision floating-point (FP16) mode for computations. - * @param useNHWCMode Whether to use the NHWC format for input tensors. - * @param loadedModel (Unused) A pointer to a LoadedModel object containing a loaded neural network model. - * @return A pointer to the ComputeContext object created. - */ ComputeContext* NeuralNet::createComputeContext( const vector& gpuIdxs, Logger* logger, @@ -440,29 +438,148 @@ ComputeContext* NeuralNet::createComputeContext( return new ComputeContext(nnXLen, nnYLen, useFP16Mode, useNHWCMode); } -/** - * @brief Frees memory used by a ComputeContext object. - * This function deallocates memory used by a ComputeContext object specified by the `computeContext` parameter. - * @param computeContext A pointer to the ComputeContext object to deallocate memory for. - */ void NeuralNet::freeComputeContext(ComputeContext* computeContext) { delete computeContext; } -//-------------------------------------------------------------- +//------------------------------------------------------------------------------ +// ComputeHandle implementation +//------------------------------------------------------------------------------ + +static mutex computeHandleMutex; + +// Helper function to convert model and create hybrid compute handle +// This is needed because Swift Optional doesn't support assignment in C++ +static swift::Optional convertAndCreateHybridHandle( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int serverThreadIdx +) { + auto metalContext = context->metalContext; + int nnXLen = metalContext.getNnXLen(); + int nnYLen = metalContext.getNnYLen(); + bool useFP16 = (context->useFP16Mode != enabled_t::False); + bool optimizeMask = requireExactNNLen; + + // Convert model to CoreML format in temp directory + // The Swift side will delete the temp file after loading + string coremlModelPath = CoreMLConversion::convertModelToTemp( + loadedModel->modelPath, + nnXLen, + nnYLen, + useFP16, + optimizeMask, + maxBatchSize, + serverThreadIdx + ); + + // Convert model descriptor to Swift format for MPSGraph path + SWModelDesc swModelDesc = MetalProcess::modelDescToSwift(&loadedModel->modelDesc); + + // Create hybrid compute handle (CoreML on CPU+ANE, MPSGraph on GPU) + return createHybridComputeHandle( + swift::String(coremlModelPath), + swModelDesc, + serverThreadIdx, + requireExactNNLen, + loadedModel->modelDesc.numInputChannels, + loadedModel->modelDesc.numInputGlobalChannels, + loadedModel->modelDesc.numInputMetaChannels, + loadedModel->modelDesc.numPolicyChannels, + loadedModel->modelDesc.numValueChannels, + loadedModel->modelDesc.numScoreValueChannels, + loadedModel->modelDesc.numOwnershipChannels, + metalContext + ); +} + +// Helper function to create hybrid handle if FP16 mode with sufficient batch size, otherwise returns none +static swift::Optional createHybridHandleIfNeeded( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int serverThreadIdx +) { + if(context->useFP16Mode == enabled_t::False) { + // FP32 mode - don't create hybrid handle + return swift::Optional::none(); + } + + // Hybrid mode splits batches: CoreML takes max(1, ...), MPSGraph takes remainder + // Minimum samples for meaningful split = 1 (CoreML) + 1 (MPSGraph) = 2 + // If batch can't be split, prefer MPSGraph-only for stability + if(maxBatchSize < MIN_COREML_BATCH + MIN_MPSGRAPH_BATCH) { + return swift::Optional::none(); + } -ComputeHandle::ComputeHandle(ComputeContext* context, - const LoadedModel* loadedModel, - bool inputsUseNHWC, - int gpuIdx, - int serverThreadIdx): -metalhandle(maybeCreateMetalComputeHandle((gpuIdx < 100), - serverThreadIdx, - MetalProcess::modelDescToSwift(&loadedModel->modelDesc), - context->metalComputeContext)) { + // FP16 mode with sufficient batch size: Use hybrid execution (CoreML on CPU+ANE, MPSGraph on GPU) + return convertAndCreateHybridHandle(context, loadedModel, requireExactNNLen, maxBatchSize, serverThreadIdx); +} + +// Helper function to create MPSGraph-only handle when needed +// Used when: (1) useFP16=false to avoid slow FP32 CoreML, or (2) batch too small for hybrid split +static swift::Optional createMPSGraphHandleIfNeeded( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int serverThreadIdx +) { + // Use MPSGraph-only when: + // 1. FP32 mode (CoreML FP32 on CPU+ANE is slow), OR + // 2. Batch too small to split (hybrid requires minCoreML + minMPSGraph samples) + bool batchTooSmallForHybrid = maxBatchSize < MIN_COREML_BATCH + MIN_MPSGRAPH_BATCH; + + if(context->useFP16Mode != enabled_t::False && !batchTooSmallForHybrid) { + // FP16 mode with sufficient batch - hybrid handle will be created instead + return swift::Optional::none(); + } + + // Log reason for MPSGraph-only mode + if(batchTooSmallForHybrid) { + cerr << "Metal backend " << serverThreadIdx << ": Batch size " << maxBatchSize + << " too small for hybrid split - using MPSGraph GPU-only" << endl; + } else { + cerr << "Metal backend " << serverThreadIdx << ": FP32 mode - using MPSGraph GPU-only (skipping CoreML converter)" << endl; + } + + // Convert model descriptor to Swift format for MPSGraph path + // Note: No CoreML conversion needed - MPSGraph reads weights directly + SWModelDesc swModelDesc = MetalProcess::modelDescToSwift(&loadedModel->modelDesc); + + // Create MPSGraph-only handle (GPU only) + return createMPSGraphOnlyHandle( + swModelDesc, + serverThreadIdx, + requireExactNNLen, + context->metalContext + ); +} + +ComputeHandle::ComputeHandle( + ComputeContext* context, + const LoadedModel* loadedModel, + bool inputsUseNHWC, + int gpuIdx, + int serverThreadIdx, + bool requireExactNNLen, + int maxBatchSize): +hybridHandle(createHybridHandleIfNeeded(context, loadedModel, requireExactNNLen, maxBatchSize, serverThreadIdx)), +mpsGraphOnlyHandle(createMPSGraphHandleIfNeeded(context, loadedModel, requireExactNNLen, maxBatchSize, serverThreadIdx)) { + bool hasHybrid = static_cast(hybridHandle); + bool hasMPSGraph = static_cast(mpsGraphOnlyHandle); + if(hasHybrid && hasMPSGraph) { + throw runtime_error("Metal backend: Logic error - both hybridHandle and mpsGraphOnlyHandle are valid"); + } + if(!hasHybrid && !hasMPSGraph) { + throw runtime_error("Metal backend: Failed to create compute handle - both CoreML and MPSGraph initialization failed (check logs above for details)"); + } const ModelDesc* modelDesc = &loadedModel->modelDesc; - auto metalContext = context->metalComputeContext; + auto metalContext = context->metalContext; nnXLen = metalContext.getNnXLen(); nnYLen = metalContext.getNnYLen(); @@ -470,34 +587,13 @@ metalhandle(maybeCreateMetalComputeHandle((gpuIdx < 100), version = modelDesc->modelVersion; metaEncoderVersion = modelDesc->metaEncoderVersion; this->inputsUseNHWC = inputsUseNHWC; - - /* Use FP16 mode if the model supports it and the user has not explicitly - * disabled it. */ + this->requireExactNNLen = requireExactNNLen; useFP16 = (context->useFP16Mode != enabled_t::False); - - (void)serverThreadIdx; } ComputeHandle::~ComputeHandle() { } -static mutex computeHandleMutex; - -/** - * @brief Create a new ComputeHandle object for performing neural network computations. - * This function creates a new ComputeHandle object for performing neural network computations, - * using the specified parameters and settings. The object is allocated on the heap using the - * 'new' operator and returned as a pointer. - * @param context A pointer to the ComputeContext object to use for computation. - * @param loadedModel A pointer to the LoadedModel object containing the neural network model to use. - * @param logger A pointer to the Logger object to use for logging messages. - * @param maxBatchSize The maximum batch size to use for computation. - * @param requireExactNNLen Whether the neural network length must match the input data length exactly. - * @param inputsUseNHWC Whether the input data uses NHWC format. - * @param gpuIdxForThisThread The index of the GPU to use for computation. - * @param serverThreadIdx The index of the server thread to use for computation. - * @return A pointer to the newly-created ComputeHandle object. - */ ComputeHandle* NeuralNet::createComputeHandle( ComputeContext* context, const LoadedModel* loadedModel, @@ -509,63 +605,38 @@ ComputeHandle* NeuralNet::createComputeHandle( int serverThreadIdx) { (void)logger; - (void)maxBatchSize; - // Current implementation always tolerates excess nn len - (void)requireExactNNLen; - // Transfer the default GPU index into physical GPU index 0 int gpuIdx = (gpuIdxForThisThread == -1) ? 0 : gpuIdxForThisThread; ComputeHandle* handle = nullptr; { lock_guard lock(computeHandleMutex); - handle = new ComputeHandle(context, loadedModel, inputsUseNHWC, gpuIdx, serverThreadIdx); + handle = new ComputeHandle(context, loadedModel, inputsUseNHWC, gpuIdx, serverThreadIdx, requireExactNNLen, maxBatchSize); } return handle; } -/** - * @brief Free the memory used by a ComputeHandle object. - * This function frees the memory used by the specified ComputeHandle object, which was - * previously allocated on the heap using the 'new' operator. - * @param handle A pointer to the ComputeHandle object to free. - */ void NeuralNet::freeComputeHandle(ComputeHandle* handle) { delete handle; } -/** - * @brief Check whether a ComputeHandle object is using 16-bit floating-point precision. - * This function checks whether the specified ComputeHandle object is using 16-bit floating-point - * precision for computation, and returns a boolean value indicating the result. - * @param handle A pointer to the ComputeHandle object to check. - * @return True if the ComputeHandle object is using 16-bit floating-point precision, false otherwise. - */ bool NeuralNet::isUsingFP16(const ComputeHandle* handle) { return handle->useFP16; } +//------------------------------------------------------------------------------ +// Device information //------------------------------------------------------------------------------ -/** - * @brief Print information about the available devices. - */ void NeuralNet::printDevices() { printMetalDevices(); } -//-------------------------------------------------------------- +//------------------------------------------------------------------------------ +// InputBuffers implementation +//------------------------------------------------------------------------------ -/** - * @brief Construct a new InputBuffers object for storing input data for neural network computation. - * This constructor initializes a new InputBuffers object for storing input data for neural network - * computation, based on the specified parameters and settings. - * @param loadedModel A pointer to the LoadedModel object containing the neural network model to use. - * @param maxBatchSz The maximum batch size to use for computation. - * @param nnXLen The x length of the neural network computation context. - * @param nnYLen The y length of the neural network computation context. - */ InputBuffers::InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int nnXLen, int nnYLen) { const ModelDesc& m = loadedModel->modelDesc; @@ -587,6 +658,7 @@ InputBuffers::InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int n singleOwnershipResultElts = (size_t)m.numOwnershipChannels * nnXLen * nnYLen; singleOwnerMapElts = (size_t)m.numOwnershipChannels * nnXLen * nnYLen; singleScoreValuesResultElts = (size_t)m.numScoreValueChannels; + singleMaskElts = (size_t)nnXLen * nnYLen; assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); @@ -603,10 +675,10 @@ InputBuffers::InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int n ownershipResultBufferElts = (size_t)maxBatchSize * singleOwnershipResultElts; ownerMapBufferElts = (size_t)maxBatchSz * singleOwnerMapElts; scoreValuesResultBufferElts = (size_t)maxBatchSize * singleScoreValuesResultElts; + userInputMaskBufferElts = (size_t)maxBatchSize * singleMaskElts; rowSpatialBuffer = new float[rowSpatialBufferElts]; userInputBuffer = new float[userInputBufferElts]; - // Zero out the input buffer for arbitrary board sizes memset(&userInputBuffer[0], 0, userInputBufferElts * sizeof(userInputBuffer[0])); userInputGlobalBuffer = new float[userInputGlobalBufferElts]; @@ -618,13 +690,10 @@ InputBuffers::InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int n ownershipResults = new float[ownershipResultBufferElts]; ownerMapBuffer = new float[ownerMapBufferElts]; scoreValuesResults = new float[scoreValuesResultBufferElts]; + userInputMaskBuffer = new float[userInputMaskBufferElts]; + memset(&userInputMaskBuffer[0], 0, userInputMaskBufferElts * sizeof(userInputMaskBuffer[0])); } -/** - * @brief Destroy the InputBuffers object and free all associated memory. - * This destructor destroys the InputBuffers object and frees all memory associated with it, - * including all input and output buffers used for neural network computation. - */ InputBuffers::~InputBuffers() { delete[] rowSpatialBuffer; delete[] userInputBuffer; @@ -637,48 +706,25 @@ InputBuffers::~InputBuffers() { delete[] ownershipResults; delete[] ownerMapBuffer; delete[] scoreValuesResults; + delete[] userInputMaskBuffer; } -/** - * @brief Create a new InputBuffers object for storing input data for neural network computation. - * This function creates a new InputBuffers object for storing input data for neural network computation, - * using the specified parameters and settings. The object is allocated on the heap using the 'new' operator - * and returned as a pointer. - * @param loadedModel A pointer to the LoadedModel object containing the neural network model to use. - * @param maxBatchSize The maximum batch size to use for computation. - * @param nnXLen The x length of the neural network computation context. - * @param nnYLen The y length of the neural network computation context. - * @return A pointer to the newly-created InputBuffers object. - */ InputBuffers* NeuralNet::createInputBuffers(const LoadedModel* loadedModel, int maxBatchSize, int nnXLen, int nnYLen) { return new InputBuffers(loadedModel, maxBatchSize, nnXLen, nnYLen); } -/** - * @brief Free the memory used by an InputBuffers object. - * This function frees the memory used by the specified InputBuffers object, which was - * previously allocated on the heap using the 'new' operator. - * @param inputBuffers A pointer to the InputBuffers object to free. - */ void NeuralNet::freeInputBuffers(InputBuffers* inputBuffers) { delete inputBuffers; } -//-------------------------------------------------------------- +//------------------------------------------------------------------------------ +// MetalProcess namespace - Helper functions +//------------------------------------------------------------------------------ void MetalProcess::copyRowData(float* dest, const float* src, size_t numElements) { copy(src, src + numElements, dest); } -/** - * @brief Convert input data from NHWC format to NCHW format in-place if necessary. - * - * @param rowSpatialInput Pointer to the input data (single batch element assumed). - * @param C Number of channels. - * @param H Height. - * @param W Width. - * @param inputsUseNHWC Flag indicating if the input data is currently in NHWC format. - */ void MetalProcess::convertNCHW( float* rowSpatialInput, const int C, @@ -766,6 +812,11 @@ void MetalProcess::processRowData(size_t row, ComputeHandle* gpuHandle, InputBuf nnYLen, nnXLen, gpuHandle->inputsUseNHWC); + + // Copy first channel of spatial input (mask) to dedicated mask buffer + // After NCHW conversion, the first nnXLen*nnYLen elements are the mask channel + float* rowMaskInput = &inputBuffers->userInputMaskBuffer[inputBuffers->singleMaskElts * row]; + copy(rowSpatialInput, rowSpatialInput + inputBuffers->singleMaskElts, rowMaskInput); } float MetalProcess::policyOptimismCalc(const double policyOptimism, const float p, const float pOpt) { @@ -782,7 +833,7 @@ void MetalProcess::processOptimism( float* targetBuffer = &buffers.policyProbsBuffer[row * singlePolicyResultElts]; float* policyOutputBuf = &buffers.policyResults[row * singlePolicyResultElts * buffers.policyResultChannels]; - for(auto i = 0; i < singlePolicyResultElts; ++i) { + for(size_t i = 0; i < singlePolicyResultElts; ++i) { const float p = policyOutputBuf[i]; const float pOpt = policyOutputBuf[i + singlePolicyResultElts]; targetBuffer[i] = MetalProcess::policyOptimismCalc(policyOptimism, p, pOpt); @@ -801,7 +852,6 @@ void MetalProcess::processPolicy( size_t row) { auto& buffers = *inputBuffers; float* targetBuffer = &buffers.policyResults[row * buffers.singlePolicyResultElts * buffers.policyResultChannels]; - const auto symmetry = inputBuf->symmetry; const auto policyOptimism = inputBuf->policyOptimism; if(buffers.policyResultChannels == 1) { @@ -813,7 +863,7 @@ void MetalProcess::processPolicy( } SymmetryHelpers::copyOutputsWithSymmetry( - targetBuffer, currentOutput->policyProbs, 1, gpuHandle->nnYLen, gpuHandle->nnXLen, symmetry); + targetBuffer, currentOutput->policyProbs, 1, gpuHandle->nnYLen, gpuHandle->nnXLen, inputBuf->symmetry); } void MetalProcess::processValue( @@ -839,7 +889,6 @@ void MetalProcess::processOwnership( const size_t singleOwnershipResultElts = inputBuffers->singleOwnershipResultElts; const size_t ownershipOutputBufOffset = row * singleOwnershipResultElts; - // Copy ownership results with symmetry if available if(currentOutput->whiteOwnerMap != nullptr) { const float* ownershipOutputBuf = &inputBuffers->ownershipResults[ownershipOutputBufOffset]; SymmetryHelpers::copyOutputsWithSymmetry( @@ -890,7 +939,6 @@ void MetalProcess::processScoreValues( size_t numScoreValueChannels = inputBuffers->singleScoreValuesResultElts; assert(numScoreValueChannels == 1); currentOutput->whiteScoreMean = currentScoreValueData[0]; - //Version 3 neural nets don't have any second moment currentOutput, implicitly already folding it in, so we just use the mean squared currentOutput->whiteScoreMeanSq = currentOutput->whiteScoreMean * currentOutput->whiteScoreMean; currentOutput->whiteLead = currentOutput->whiteScoreMean; currentOutput->varTimeLeft = 0; @@ -914,16 +962,6 @@ void MetalProcess::processRow( MetalProcess::processScoreValues(inputBuffers, currentOutput, gpuHandle->version, row); } -/** - * @brief Compute the neural network output using Metal API and the specified input data and GPU handle. - * This function computes the neural network output using the Metal API and the specified input data and ComputeHandle - * object for GPU acceleration. The computed output is stored in the specified vector of NNOutput pointers. - * @param gpuHandle A pointer to the ComputeHandle object to use for GPU computation. - * @param inputBuffers A pointer to the InputBuffers object containing the input data for computation. - * @param numBatchEltsFilled The number of batch elements filled in the input buffer. - * @param inputBufs An array of pointers to NNResultBuf objects containing the neural network input data. - * @param outputs A vector of NNOutput pointers to store the computed output. - */ void MetalProcess::getMetalOutput( ComputeHandle* gpuHandle, InputBuffers* inputBuffers, @@ -935,47 +973,57 @@ void MetalProcess::getMetalOutput( int batchSize = numBatchEltsFilled; assert(batchSize <= inputBuffers->maxBatchSize); - assert((NNModelVersion::getNumSpatialFeatures(gpuHandle->version) * gpuHandle->nnXLen * gpuHandle->nnYLen) <= inputBuffers->singleInputElts); - assert(NNModelVersion::getNumGlobalFeatures(gpuHandle->version) == inputBuffers->singleInputGlobalElts); + assert((NNModelVersion::getNumSpatialFeatures(gpuHandle->version) * gpuHandle->nnXLen * gpuHandle->nnYLen) <= (int)inputBuffers->singleInputElts); + assert(NNModelVersion::getNumGlobalFeatures(gpuHandle->version) == (int)inputBuffers->singleInputGlobalElts); if(gpuHandle->metaEncoderVersion > 0) { - assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == inputBuffers->singleInputMetaElts); + assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == (int)inputBuffers->singleInputMetaElts); } assert(inputBuffers->singleValueResultElts == 3); - for(size_t row = 0; row < batchSize; row++) { + for(int row = 0; row < batchSize; row++) { MetalProcess::processRowData(row, gpuHandle, inputBuffers, inputBufs); } - auto metalHandle = gpuHandle->metalhandle; - assert(metalHandle); - - metalHandle.get().apply(inputBuffers->userInputBuffer, - inputBuffers->userInputGlobalBuffer, - inputBuffers->userInputMetaBuffer, - inputBuffers->policyResults, - inputBuffers->policyPassResults, - inputBuffers->valueResults, - inputBuffers->scoreValuesResults, - inputBuffers->ownershipResults, - batchSize); + // Dispatch to appropriate handle based on mode + if(gpuHandle->hybridHandle) { + // FP16 mode: Use hybrid execution (CoreML on CPU+ANE, MPSGraph on GPU) + // Mask buffer has correct stride (singleMaskElts = H*W per batch element) + // When requireExactNNLen is true, mask operations can be optimized (optimize_identity_mask) + gpuHandle->hybridHandle.get().apply( + inputBuffers->userInputBuffer, + inputBuffers->userInputGlobalBuffer, + inputBuffers->userInputMetaBuffer, + inputBuffers->userInputMaskBuffer, // Dedicated mask buffer with correct stride + inputBuffers->policyResults, + inputBuffers->policyPassResults, + inputBuffers->valueResults, + inputBuffers->scoreValuesResults, + inputBuffers->ownershipResults, + batchSize); + } else if(gpuHandle->mpsGraphOnlyHandle) { + // FP32 mode: Use MPSGraph only (GPU-only) + // Mask is extracted internally from channel 0 of spatial input via strided reads + gpuHandle->mpsGraphOnlyHandle.get().apply( + inputBuffers->userInputBuffer, + inputBuffers->userInputGlobalBuffer, + inputBuffers->userInputMetaBuffer, + inputBuffers->policyResults, + inputBuffers->policyPassResults, + inputBuffers->valueResults, + inputBuffers->scoreValuesResults, + inputBuffers->ownershipResults, + batchSize); + } else { + throw runtime_error("Metal backend: No valid compute handle available"); + } - for(size_t row = 0; row < batchSize; row++) { + for(int row = 0; row < batchSize; row++) { MetalProcess::processRow(row, gpuHandle, inputBuffers, inputBufs, outputs); } } -/** - * @brief Compute the neural network output using the specified input data and GPU handle. - * This function computes the neural network output using the specified input data and ComputeHandle object - * for GPU acceleration. The computed output is stored in the specified vector of NNOutput pointers. - * @param gpuHandle A pointer to the ComputeHandle object to use for GPU computation. - * @param inputBuffers A pointer to the InputBuffers object containing the input data for computation. - * @param numBatchEltsFilled The number of batch elements filled in the input buffer. - * @param inputBufs An array of pointers to NNResultBuf objects containing the neural network input data. - * @param outputs A vector of NNOutput pointers to store the computed output. - */ void NeuralNet::getOutput( ComputeHandle* gpuHandle, InputBuffers* inputBuffers, @@ -986,41 +1034,254 @@ void NeuralNet::getOutput( MetalProcess::getMetalOutput(gpuHandle, inputBuffers, numBatchEltsFilled, inputBufs, outputs); } -bool MetalProcess::testEvaluateConv(const ConvLayerDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - vector& outputBuffer) { +//------------------------------------------------------------------------------ +// Test functions - Metal backend uses NCHW layout (not NHWC) +//------------------------------------------------------------------------------ + +namespace MetalProcess { + +// Helper function to compute merged scale and bias from raw values +// This is needed because test descriptors are created manually without computing merged values +static void computeMergedBatchNormValues( + const BatchNormLayerDesc* desc, + vector& mergedScale, + vector& mergedBias) { + + int numChannels = desc->numChannels; + mergedScale.resize(numChannels); + mergedBias.resize(numChannels); + + // If merged values are already computed, use them + if(!desc->mergedScale.empty() && !desc->mergedBias.empty()) { + mergedScale = desc->mergedScale; + mergedBias = desc->mergedBias; + return; + } + + // Otherwise compute from raw values: mergedScale = scale / sqrt(variance + epsilon) + // mergedBias = bias - mergedScale * mean + // Note: Use scale/bias values from vectors if available, regardless of hasScale/hasBias flags + // This matches how desc.cpp computes merged values during model loading + for(int c = 0; c < numChannels; c++) { + float scale = c < (int)desc->scale.size() ? desc->scale[c] : 1.0f; + float bias = c < (int)desc->bias.size() ? desc->bias[c] : 0.0f; + float mean = c < (int)desc->mean.size() ? desc->mean[c] : 0.0f; + float variance = c < (int)desc->variance.size() ? desc->variance[c] : 1.0f; + float epsilon = desc->epsilon; + + mergedScale[c] = scale / sqrt(variance + epsilon); + mergedBias[c] = bias - mergedScale[c] * mean; + } +} + +// Helper to convert BatchNormLayerDesc to Swift with computed merged values +static SWBatchNormLayerDesc batchNormLayerDescToSwiftWithMerge( + const BatchNormLayerDesc* desc, + vector& mergedScaleStorage, + vector& mergedBiasStorage) { + + computeMergedBatchNormValues(desc, mergedScaleStorage, mergedBiasStorage); + + return createSWBatchNormLayerDesc( + desc->numChannels, + mergedScaleStorage.data(), + mergedBiasStorage.data()); +} + +// Helper to convert ResidualBlockDesc to Swift with computed merged values +static SWResidualBlockDesc residualBlockDescToSwiftWithMerge( + const ResidualBlockDesc* desc, + vector& mergedScalePreBN, + vector& mergedBiasPreBN, + vector& mergedScaleMidBN, + vector& mergedBiasMidBN) { + + computeMergedBatchNormValues(&desc->preBN, mergedScalePreBN, mergedBiasPreBN); + computeMergedBatchNormValues(&desc->midBN, mergedScaleMidBN, mergedBiasMidBN); + + SWBatchNormLayerDesc preBN = createSWBatchNormLayerDesc( + desc->preBN.numChannels, + mergedScalePreBN.data(), + mergedBiasPreBN.data()); + + ActivationKind preActivationKind = MetalProcess::activationLayerDescToSwift(&desc->preActivation); + SWConvLayerDesc regularConv = MetalProcess::convLayerDescToSwift(&desc->regularConv); + + SWBatchNormLayerDesc midBN = createSWBatchNormLayerDesc( + desc->midBN.numChannels, + mergedScaleMidBN.data(), + mergedBiasMidBN.data()); + + ActivationKind midActivationKind = MetalProcess::activationLayerDescToSwift(&desc->midActivation); + SWConvLayerDesc finalConv = MetalProcess::convLayerDescToSwift(&desc->finalConv); + + return createSWResidualBlockDesc( + preBN, + preActivationKind, + regularConv, + midBN, + midActivationKind, + finalConv); +} + +// Helper to convert GlobalPoolingResidualBlockDesc to Swift with computed merged values +static SWGlobalPoolingResidualBlockDesc globalPoolingResidualBlockDescToSwiftWithMerge( + const GlobalPoolingResidualBlockDesc* desc, + vector& mergedScalePreBN, + vector& mergedBiasPreBN, + vector& mergedScaleMidBN, + vector& mergedBiasMidBN, + vector& mergedScaleGpoolBN, + vector& mergedBiasGpoolBN) { + + computeMergedBatchNormValues(&desc->preBN, mergedScalePreBN, mergedBiasPreBN); + computeMergedBatchNormValues(&desc->gpoolBN, mergedScaleGpoolBN, mergedBiasGpoolBN); + computeMergedBatchNormValues(&desc->midBN, mergedScaleMidBN, mergedBiasMidBN); + + SWBatchNormLayerDesc preBN = createSWBatchNormLayerDesc( + desc->preBN.numChannels, + mergedScalePreBN.data(), + mergedBiasPreBN.data()); + + ActivationKind preActivationKind = MetalProcess::activationLayerDescToSwift(&desc->preActivation); + SWConvLayerDesc regularConv = MetalProcess::convLayerDescToSwift(&desc->regularConv); + SWConvLayerDesc gpoolConv = MetalProcess::convLayerDescToSwift(&desc->gpoolConv); + + SWBatchNormLayerDesc gpoolBN = createSWBatchNormLayerDesc( + desc->gpoolBN.numChannels, + mergedScaleGpoolBN.data(), + mergedBiasGpoolBN.data()); + + ActivationKind gpoolActivationKind = MetalProcess::activationLayerDescToSwift(&desc->gpoolActivation); + SWMatMulLayerDesc gpoolToBiasMul = MetalProcess::matMulLayerDescToSwift(&desc->gpoolToBiasMul); + + SWBatchNormLayerDesc midBN = createSWBatchNormLayerDesc( + desc->midBN.numChannels, + mergedScaleMidBN.data(), + mergedBiasMidBN.data()); + + ActivationKind midActivationKind = MetalProcess::activationLayerDescToSwift(&desc->midActivation); + SWConvLayerDesc finalConv = MetalProcess::convLayerDescToSwift(&desc->finalConv); + + return createSWGlobalPoolingResidualBlockDesc( + preBN, + preActivationKind, + regularConv, + gpoolConv, + gpoolBN, + gpoolActivationKind, + gpoolToBiasMul, + midBN, + midActivationKind, + finalConv); +} + +bool testEvaluateConv( + const ConvLayerDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + const vector& inputBuffer, + vector& outputBuffer) { + + SWConvLayerDesc swDesc = MetalProcess::convLayerDescToSwift(desc); size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->outChannels; outputBuffer.resize(numOutputFloats); - testConvLayer(convLayerDescToSwift(desc), - nnXLen, - nnYLen, - batchSize, - (float*)inputBuffer.data(), - (float*)outputBuffer.data()); - - return true; -} - -/** - * @brief Evaluate a convolutional layer using Metal API for testing purposes. - * This function evaluates a convolutional layer using the Metal API for testing purposes. - * The input buffer and output buffer are specified as vectors of floats, and the result of the computation - * is stored in the output buffer. The function returns true if the evaluation is implemented. - * @param desc A pointer to the ConvLayerDesc object describing the convolutional layer to evaluate. - * @param batchSize The batch size to use for computation. - * @param nnXLen The x length of the neural network computation context. - * @param nnYLen The y length of the neural network computation context. - * @param useFP16 A boolean indicating whether to use half-precision floating point format for computation. - * @param useNHWC A boolean indicating whether to use NHWC layout for input and output buffers. - * @param inputBuffer A vector of floats containing the input buffer data. - * @param outputBuffer A vector of floats to store the computed output. - * @return true if the convolutional layer evaluation is implemented, false otherwise. - */ + return testConvLayer( + swDesc, + batchSize, + nnXLen, + nnYLen, + (float*)inputBuffer.data(), + outputBuffer.data()); +} + +bool testEvaluateBatchNorm( + const BatchNormLayerDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer) { + + vector mergedScaleStorage; + vector mergedBiasStorage; + SWBatchNormLayerDesc swDesc = batchNormLayerDescToSwiftWithMerge(desc, mergedScaleStorage, mergedBiasStorage); + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->numChannels; + outputBuffer.resize(numOutputFloats); + + return testBatchNormLayer( + swDesc, + batchSize, + nnXLen, + nnYLen, + (float*)inputBuffer.data(), + (float*)maskBuffer.data(), + outputBuffer.data()); +} + +bool testEvaluateResidualBlock( + const ResidualBlockDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer) { + + vector mergedScalePreBN, mergedBiasPreBN; + vector mergedScaleMidBN, mergedBiasMidBN; + SWResidualBlockDesc swDesc = residualBlockDescToSwiftWithMerge( + desc, mergedScalePreBN, mergedBiasPreBN, mergedScaleMidBN, mergedBiasMidBN); + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->preBN.numChannels; + outputBuffer.resize(numOutputFloats); + + return testResidualBlock( + swDesc, + batchSize, + nnXLen, + nnYLen, + (float*)inputBuffer.data(), + (float*)maskBuffer.data(), + outputBuffer.data()); +} + +bool testEvaluateGlobalPoolingResidualBlock( + const GlobalPoolingResidualBlockDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer) { + + vector mergedScalePreBN, mergedBiasPreBN; + vector mergedScaleMidBN, mergedBiasMidBN; + vector mergedScaleGpoolBN, mergedBiasGpoolBN; + SWGlobalPoolingResidualBlockDesc swDesc = globalPoolingResidualBlockDescToSwiftWithMerge( + desc, mergedScalePreBN, mergedBiasPreBN, mergedScaleMidBN, mergedBiasMidBN, + mergedScaleGpoolBN, mergedBiasGpoolBN); + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->preBN.numChannels; + outputBuffer.resize(numOutputFloats); + + return testGlobalPoolingResidualBlock( + swDesc, + batchSize, + nnXLen, + nnYLen, + (float*)inputBuffer.data(), + (float*)maskBuffer.data(), + outputBuffer.data()); +} + +} // namespace MetalProcess + bool NeuralNet::testEvaluateConv( const ConvLayerDesc* desc, int batchSize, @@ -1031,49 +1292,16 @@ bool NeuralNet::testEvaluateConv( const vector& inputBuffer, vector& outputBuffer) { + // Metal backend only supports NCHW layout + if(useNHWC) + return false; + + // useFP16 is ignored - MPSGraph tests use FP32 (void)useFP16; - (void)useNHWC; + return MetalProcess::testEvaluateConv(desc, batchSize, nnXLen, nnYLen, inputBuffer, outputBuffer); } -bool MetalProcess::testEvaluateBatchNorm(const BatchNormLayerDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - const vector& maskBuffer, - vector& outputBuffer) { - - size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->numChannels; - outputBuffer.resize(numOutputFloats); - - testBatchNormLayer(batchNormLayerDescToSwift(desc), - nnXLen, - nnYLen, - batchSize, - (float*)inputBuffer.data(), - (float*)maskBuffer.data(), - (float*)outputBuffer.data()); - - return true; -} - -/** - * @brief Evaluate a batch normalization layer using Metal API for testing purposes. - * This function evaluates a batch normalization layer using the Metal API for testing purposes. - * The input buffer and output buffer are specified as vectors of floats, and the result of the computation - * is stored in the output buffer. The function returns true if the evaluation is implemented. - * @param desc A pointer to the BatchNormLayerDesc object describing the batch normalization layer to evaluate. - * @param batchSize The batch size to use for computation. - * @param nnXLen The x length of the neural network computation context. - * @param nnYLen The y length of the neural network computation context. - * @param useFP16 A boolean indicating whether to use half-precision floating point format for computation. - * @param useNHWC A boolean indicating whether to use NHWC layout for input and output buffers. - * @param inputBuffer A vector of floats containing the input buffer data. - * @param maskBuffer A vector of floats containing the mask buffer data. Mask should be in 'NHW' format (no "C" channel). - * @param outputBuffer A vector of floats to store the computed output. - * @return true if the batch normalization layer evaluation is implemented, false otherwise. - */ bool NeuralNet::testEvaluateBatchNorm( const BatchNormLayerDesc* desc, int batchSize, @@ -1085,49 +1313,16 @@ bool NeuralNet::testEvaluateBatchNorm( const vector& maskBuffer, vector& outputBuffer) { + // Metal backend only supports NCHW layout + if(useNHWC) + return false; + + // useFP16 is ignored - MPSGraph tests use FP32 (void)useFP16; - (void)useNHWC; + return MetalProcess::testEvaluateBatchNorm(desc, batchSize, nnXLen, nnYLen, inputBuffer, maskBuffer, outputBuffer); } -bool MetalProcess::testEvaluateResidualBlock(const ResidualBlockDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - const vector& maskBuffer, - vector& outputBuffer) { - - size_t numTrunkFloats = (size_t)batchSize * nnXLen * nnYLen * desc->preBN.numChannels; - outputBuffer.resize(numTrunkFloats); - - testResidualBlock(residualBlockDescToSwift(desc), - batchSize, - nnXLen, - nnYLen, - (float*)inputBuffer.data(), - (float*)maskBuffer.data(), - (float*)outputBuffer.data()); - - return true; -} - -/** - * @brief Evaluate a residual block using Metal API for testing purposes. - * This function evaluates a residual block using the Metal API for testing purposes. - * The input buffer and output buffer are specified as vectors of floats, and the result of the computation - * is stored in the output buffer. The function returns true if the evaluation is implemented. - * @param desc A pointer to the ResidualBlockDesc object describing the residual block to evaluate. - * @param batchSize The batch size to use for computation. - * @param nnXLen The x length of the neural network computation context. - * @param nnYLen The y length of the neural network computation context. - * @param useFP16 A boolean indicating whether to use half-precision floating point format for computation. - * @param useNHWC A boolean indicating whether to use NHWC layout for input and output buffers. - * @param inputBuffer A vector of floats containing the input buffer data. - * @param maskBuffer A vector of floats containing the mask buffer data. - * @param outputBuffer A vector of floats to store the computed output. - * @return true if the residual block evaluation is implemented, false otherwise. - */ bool NeuralNet::testEvaluateResidualBlock( const ResidualBlockDesc* desc, int batchSize, @@ -1139,50 +1334,16 @@ bool NeuralNet::testEvaluateResidualBlock( const vector& maskBuffer, vector& outputBuffer) { + // Metal backend only supports NCHW layout + if(useNHWC) + return false; + + // useFP16 is ignored - MPSGraph tests use FP32 (void)useFP16; - (void)useNHWC; + return MetalProcess::testEvaluateResidualBlock(desc, batchSize, nnXLen, nnYLen, inputBuffer, maskBuffer, outputBuffer); } -bool MetalProcess::testEvaluateGlobalPoolingResidualBlock(const GlobalPoolingResidualBlockDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - const vector& maskBuffer, - vector& outputBuffer) { - - size_t numTrunkFloats = (size_t)batchSize * nnXLen * nnYLen * desc->preBN.numChannels; - outputBuffer.resize(numTrunkFloats); - - testGlobalPoolingResidualBlock(globalPoolingResidualBlockDescToSwift(desc), - batchSize, - nnXLen, - nnYLen, - (float*)inputBuffer.data(), - (float*)maskBuffer.data(), - (float*)outputBuffer.data()); - - return true; -} - -/** - * @brief Evaluate a global pooling residual block using Metal API for testing purposes. - * This function evaluates a global pooling residual block using the Metal API for testing purposes. - * The input buffer and output buffer are specified as vectors of floats, and the result of the computation - * is stored in the output buffer. The function returns true if the evaluation is implemented. - * @param desc A pointer to the GlobalPoolingResidualBlockDesc object describing the global pooling residual block to - * evaluate. - * @param batchSize The batch size to use for computation. - * @param nnXLen The x length of the neural network computation context. - * @param nnYLen The y length of the neural network computation context. - * @param useFP16 A boolean indicating whether to use half-precision floating point format for computation. - * @param useNHWC A boolean indicating whether to use NHWC layout for input and output buffers. - * @param inputBuffer A vector of floats containing the input buffer data. - * @param maskBuffer A vector of floats containing the mask buffer data. - * @param outputBuffer A vector of floats to store the computed output. - * @return true if the global pooling residual block evaluation is implemented, false otherwise. - */ bool NeuralNet::testEvaluateGlobalPoolingResidualBlock( const GlobalPoolingResidualBlockDesc* desc, int batchSize, @@ -1194,9 +1355,14 @@ bool NeuralNet::testEvaluateGlobalPoolingResidualBlock( const vector& maskBuffer, vector& outputBuffer) { + // Metal backend only supports NCHW layout + if(useNHWC) + return false; + + // useFP16 is ignored - MPSGraph tests use FP32 (void)useFP16; - (void)useNHWC; + return MetalProcess::testEvaluateGlobalPoolingResidualBlock(desc, batchSize, nnXLen, nnYLen, inputBuffer, maskBuffer, outputBuffer); } -#endif // USE_METAL_BACKEND +#endif // USE_METAL_BACKEND diff --git a/cpp/neuralnet/metalbackend.h b/cpp/neuralnet/metalbackend.h index 34e44b8e7..12cc6b0c0 100644 --- a/cpp/neuralnet/metalbackend.h +++ b/cpp/neuralnet/metalbackend.h @@ -1,4 +1,7 @@ -#pragma once +#ifndef NEURALNET_METALBACKEND_H_ +#define NEURALNET_METALBACKEND_H_ + +#ifdef USE_METAL_BACKEND #include #include "desc.h" @@ -13,51 +16,6 @@ using namespace std; using namespace KataGoSwift; namespace MetalProcess { -SWConvLayerDesc convLayerDescToSwift(const ConvLayerDesc * desc); -SWBatchNormLayerDesc batchNormLayerDescToSwift(const BatchNormLayerDesc * desc); -ActivationKind activationLayerDescToSwift(const ActivationLayerDesc * desc); -SWResidualBlockDesc residualBlockDescToSwift(const ResidualBlockDesc * desc); -SWMatMulLayerDesc matMulLayerDescToSwift(const MatMulLayerDesc * desc); -SWGlobalPoolingResidualBlockDesc globalPoolingResidualBlockDescToSwift(const GlobalPoolingResidualBlockDesc* desc); -swift::Array residualBlocksToSwift(const vector>& blocks); -SWNestedBottleneckResidualBlockDesc nestedBottleneckResidualBlockDescToSwift(const NestedBottleneckResidualBlockDesc* desc); -swift::Optional sGFMetadataEncoderDescToSwift(const SGFMetadataEncoderDesc * desc); -SWTrunkDesc trunkDescToSwift(const TrunkDesc * trunk); -SWPolicyHeadDesc policyHeadDescToSwift(const PolicyHeadDesc * policyHead); -SWMatBiasLayerDesc matBiasLayerDescToSwift(const MatBiasLayerDesc * desc); -SWValueHeadDesc valueHeadDescToSwift(const ValueHeadDesc * valueHead); -SWModelDesc modelDescToSwift(const ModelDesc* modelDesc); - -bool testEvaluateConv(const ConvLayerDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - vector& outputBuffer); - -bool testEvaluateBatchNorm(const BatchNormLayerDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - const vector& maskBuffer, - vector& outputBuffer); - -bool testEvaluateResidualBlock(const ResidualBlockDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - const vector& maskBuffer, - vector& outputBuffer); - -bool testEvaluateGlobalPoolingResidualBlock(const GlobalPoolingResidualBlockDesc* desc, - int batchSize, - int nnXLen, - int nnYLen, - const vector& inputBuffer, - const vector& maskBuffer, - vector& outputBuffer); void copyRowData(float* dest, const float* src, size_t numElements); void convertNCHW(float* rowSpatialInput, int C, int H, int W, bool inputsUseNHWC); @@ -89,63 +47,44 @@ void processRow(size_t row, vector& outputs); void getMetalOutput(ComputeHandle* gpuHandle, - InputBuffers* inputBuffers, - int numBatchEltsFilled, - NNResultBuf** inputBufs, - vector& outputs); -}; + InputBuffers* inputBuffers, + int numBatchEltsFilled, + NNResultBuf** inputBufs, + vector& outputs); +} /** * @brief Represents a loaded neural network model. * A LoadedModel object contains a ModelDesc object that describes the characteristics of the loaded model. - * The default constructor, copy constructor, and assignment operator are deleted to prevent - * creation of an uninitialized LoadedModel object, copying of the loaded model, and potential memory leaks. + * For Metal backend, we also store the model path for on-demand conversion. */ struct LoadedModel { /** * @brief The description of the loaded model. - * The modelDesc field is a ModelDesc object that describes the characteristics of the loaded model. */ ModelDesc modelDesc; + /** + * @brief Path to the original .bin.gz model file for conversion. + */ + string modelPath; + /** * @brief Construct a new Loaded Model object - * This constructor loads a machine learning model from a file and sets the modelDesc field to the - * characteristics of the loaded model. + * This constructor loads a machine learning model from a file and sets the modelDesc field. * @param fileName The name of the file containing the machine learning model. * @param expectedSha256 The expected SHA-256 hash of the model file. */ - LoadedModel(const string& fileName, const string& expectedSha256) - { - ModelDesc::loadFromFileMaybeGZipped(fileName, modelDesc, expectedSha256); - } + LoadedModel(const string& fileName, const string& expectedSha256); - /** - * @brief Delete the default constructor - * The default constructor is deleted to prevent creation of an uninitialized LoadedModel object. - */ LoadedModel() = delete; - - /** - * @brief Delete the copy constructor - * The copy constructor is deleted to prevent copying of the loaded model. - */ LoadedModel(const LoadedModel&) = delete; - - /** - * @brief Delete the assignment operator - * The assignment operator is deleted to prevent copying of the loaded model. - */ LoadedModel& operator=(const LoadedModel&) = delete; }; /** - * @brief Context for computing neural network operations. - * A ComputeContext object contains configuration settings for neural network computations, such as - * whether to use half-precision floating-point (FP16) mode and whether to use the NHWC format for - * input tensors. The default constructor, copy constructor, and assignment operator are deleted - * to prevent creation of an uninitialized ComputeContext object, copying of the object, and potential - * memory leaks. + * @brief Context for computing neural network operations using Metal. + * Contains global configuration settings for neural network computations. */ struct ComputeContext { /** @@ -154,64 +93,47 @@ struct ComputeContext { enabled_t useFP16Mode; /** - * @brief ComputeContext ID + * @brief The width of the neural network input. */ - int identifier; + int nnXLen; /** - * @brief Metal compute context instance + * @brief The height of the neural network input. */ - MetalComputeContext metalComputeContext; + int nnYLen; + + /** + * @brief Metal compute context instance from Swift. + */ + MetalComputeContext metalContext; /** * @brief Constructs a ComputeContext object. - * This constructor creates a ComputeContext object and sets the configuration settings for neural network - * computations, including whether to use FP16 mode and whether to use the NHWC format for input tensors. * @param nnX The width of the input tensor. * @param nnY The height of the input tensor. - * @param useFP16Mode Whether to use half-precision floating-point (FP16) mode for computations. + * @param useFP16Mode Whether to use half-precision floating-point (FP16) mode. * @param useNHWCMode Whether to use the NHWC format for input tensors. */ ComputeContext(int nnX, int nnY, enabled_t useFP16Mode, enabled_t useNHWCMode); - /** - * @brief Destroys the ComputeContext object. - */ ~ComputeContext(); - - /** - * @brief Deletes the default constructor. - */ ComputeContext() = delete; - - /** - * @brief Deletes the copy constructor. - */ ComputeContext(const ComputeContext&) = delete; - - /** - * @brief Deletes the copy constructor. - * - * @return ComputeContext& - */ ComputeContext& operator=(const ComputeContext&) = delete; }; /** - * @brief A handle for performing neural network computations. - * This struct represents a handle for computing neural network operations. It contains various - * parameters and settings that determine how the computation is performed. + * @brief A handle for performing neural network computations using Metal. + * This struct represents a per-thread handle for computing neural network operations. */ struct ComputeHandle { - int identifier; - /** - * @brief The x length of the neural network computation context. + * @brief The x length of the neural network. */ int nnXLen; /** - * @brief The y length of the neural network computation context. + * @brief The y length of the neural network. */ int nnYLen; @@ -236,53 +158,55 @@ struct ComputeHandle { bool inputsUseNHWC; /** - * @brief Whether to use 16-bit floating-point precision for computation. + * @brief Whether to use 16-bit floating-point precision. */ bool useFP16; /** - * @brief The Metal handle instance. + * @brief Whether exact neural net length is required (enables mask optimization). + */ + bool requireExactNNLen; + + /** + * @brief The hybrid compute handle instance from Swift. + * This handle dispatches work to both CoreML (CPU+ANE) and MPSGraph (GPU). */ - swift::Optional metalhandle; + swift::Optional hybridHandle; + + /** + * @brief The MPSGraph-only handle instance from Swift (used for FP32 mode). + * This handle dispatches work only to GPU, avoiding slow FP32 CPU+ANE execution. + */ + swift::Optional mpsGraphOnlyHandle; /** * @brief Construct a new ComputeHandle object. - * This constructor initializes a new ComputeHandle object with the specified parameters and settings. * @param context The ComputeContext object to use for computation. - * @param loadedModel A pointer to the LoadedModel object containing the neural network model to use. + * @param loadedModel A pointer to the LoadedModel object. * @param inputsUseNHWC Whether the input data uses NHWC format. - * @param gpuIdx The index of the GPU to use for computation. - * @param serverThreadIdx The index of the server thread to use for computation. + * @param gpuIdx The index of the GPU to use. + * @param serverThreadIdx The index of the server thread. + * @param requireExactNNLen Whether exact NN length is required. + * @param maxBatchSize Maximum batch size for dynamic batch support. */ ComputeHandle( ComputeContext* context, const LoadedModel* loadedModel, bool inputsUseNHWC, int gpuIdx, - int serverThreadIdx); + int serverThreadIdx, + bool requireExactNNLen, + int maxBatchSize); - /** - * @brief Destroy the ComputeHandle object. - * This destructor frees any resources that were allocated for the ComputeHandle object. - */ ~ComputeHandle(); - - /** - * @brief Delete the default constructor. - */ ComputeHandle() = delete; - - /** - * @brief Delete the copy constructor. - */ ComputeHandle(const ComputeHandle&) = delete; - - /** - * @brief Delete the assignment operator. - */ ComputeHandle& operator=(const ComputeHandle&) = delete; }; +/** + * @brief Input and output buffers for neural network inference. + */ struct InputBuffers { int maxBatchSize; size_t policyResultChannels; @@ -298,6 +222,7 @@ struct InputBuffers { size_t singleOwnershipResultElts; size_t singleOwnerMapElts; size_t singleScoreValuesResultElts; + size_t singleMaskElts; size_t rowSpatialBufferElts; size_t userInputBufferElts; @@ -310,6 +235,7 @@ struct InputBuffers { size_t ownershipResultBufferElts; size_t ownerMapBufferElts; size_t scoreValuesResultBufferElts; + size_t userInputMaskBufferElts; float* rowSpatialBuffer; float* userInputBuffer; @@ -322,6 +248,7 @@ struct InputBuffers { float* ownershipResults; float* ownerMapBuffer; float* scoreValuesResults; + float* userInputMaskBuffer; InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int nnXLen, int nnYLen); ~InputBuffers(); @@ -329,3 +256,7 @@ struct InputBuffers { InputBuffers(const InputBuffers&) = delete; InputBuffers& operator=(const InputBuffers&) = delete; }; + +#endif // USE_METAL_BACKEND + +#endif // NEURALNET_METALBACKEND_H_ diff --git a/cpp/neuralnet/metalbackend.swift b/cpp/neuralnet/metalbackend.swift index 97c6e181d..8f209d054 100644 --- a/cpp/neuralnet/metalbackend.swift +++ b/cpp/neuralnet/metalbackend.swift @@ -1,3032 +1,681 @@ import Foundation +import CoreML import MetalPerformanceShaders import MetalPerformanceShadersGraph /// A class that handles output to standard error. class StandardError: TextOutputStream { - /// Outputs the specified string to the standard error stream. func write(_ string: String) { - /// Tries to write the UTF-8 encoded contents of the string to the standard error file handle. try? FileHandle.standardError.write(contentsOf: Data(string.utf8)) } } -/// A function to print error messages +/// Print to standard error func printError(_ item: Any) { - // Create an instance of StandardError to direct output to the standard error stream var instance = StandardError() - // Output the provided item to the standard error using the created instance print(item, to: &instance) } -/// An extension to the Data struct for handling float data with optional FP16 conversion. -extension Data { - /// Initializes a new Data instance using an UnsafeMutablePointer, with optional conversion to FP16 format. - /// - Parameters: - /// - floatsNoCopy: An UnsafeMutablePointer containing the float data. - /// - shape: An array of NSNumber objects representing the shape of the data. - init( - floatsNoCopy: UnsafeMutablePointer, - shape: [NSNumber] - ) { - self.init( - bytesNoCopy: floatsNoCopy, - count: shape.countBytesOfFloat32(), - deallocator: .none) - } -} - -/// Extension to MPSNDArray to convert from MPSGraphTensor, and to read/write bytes from/to UnsafeMutableRawPointer -extension MPSNDArray { - /// Read bytes from the buffer - /// - Parameter buffer: The buffer to read - func readBytes(_ buffer: UnsafeMutableRawPointer) { - self.readBytes(buffer, strideBytes: nil) - } - - /// Write bytes to the buffer - /// - Parameter buffer: The buffer to write - func writeBytes(_ buffer: UnsafeMutableRawPointer) { - self.writeBytes(buffer, strideBytes: nil) - } -} - -/// Extension to Array to count number of elements and bytes -extension Array where Element == NSNumber { - /// Count number of elements - /// - Returns: Number of elements - func countElements() -> Int { - return reduce(1, { $0 * $1.intValue }) - } - - /// Count number of bytes - /// - Parameter dataType: The data type - /// - Returns: Number of bytes - func countBytesOfFloat32() -> Int { - return countElements() * MemoryLayout.size - } -} - -/// Extension to MPSGraph to the mish activation function -extension MPSGraph { - /// This function applies the Mish activation function on the input tensor `x`. The Mish function is defined as - /// x * tanh(Softplus(x)), where Softplus(x) is defined as log(1 + exp(min(x, 10.39))) if x < 10.39 and x otherwise. - /// When FP16 is later used, the threshold of softplus will need to be modified to 10.39, which is different from - /// the original 20. This is because exp(10.39) = 32532.666936 < 32767.0 < 65504.0, so the result of exp(10.39) can - /// be represented by float16. If the threshold of softplus is 20, the result of exp(20) is 485165195.40979004, - /// which is out of range of float16. - /// - Parameter tensor: The input tensor of mish activation function - /// - Returns: The output tensor of mish activation function - func mish(tensor: MPSGraphTensor) -> MPSGraphTensor { - assert(tensor.dataType == .float32) - - let one = 1.0 - let threshold = 20.0 - let thresholdTensor = constant(threshold, dataType: tensor.dataType) - let minimumTensor = minimum(tensor, thresholdTensor, name: nil) - let expTensor = exponent(with: minimumTensor, name: nil) - let oneTensor = constant(one, dataType: tensor.dataType) - let addTensor = addition(expTensor, oneTensor, name: nil) - let logTensor = logarithm(with: addTensor, name: nil) - let lessTensor = lessThan(tensor, thresholdTensor, name: nil) - let selectTensor = select( - predicate: lessTensor, trueTensor: logTensor, falseTensor: tensor, name: nil) - let tanhTensor = tanh(with: selectTensor, name: nil) - let mulTensor = multiplication(tensor, tanhTensor, name: nil) - - return mulTensor - } -} - -/// A structure that represents the input shape -struct InputShape { - /// Create a shape for the input tensor - /// - Parameters: - /// - batchSize: Batch size - /// - numChannels: Number of channels - /// - nnYLen: Y length - /// - nnXLen: X length - /// - Returns: The shape - static func create( - batchSize: NSNumber, - numChannels: NSNumber, - nnYLen: NSNumber, - nnXLen: NSNumber - ) -> [NSNumber] { - let shape = [ - batchSize, - numChannels, - nnYLen, - nnXLen, - ] - return shape - } - - /// Get the channel axis - /// - Returns: The channel axis - static func getChannelAxis() -> Int { - return 1 - } - - /// Get the HW axes - /// - Returns: The HW axes - static func getHWAxes() -> [NSNumber] { - let hwAxes = [2, 3] as [NSNumber] - return hwAxes - } -} - -/// A structure that represents the input layer -struct InputLayer { - let tensor: MPSGraphTensor - let shape: [NSNumber] - - /// Initialize a InputLayer object - /// - Parameters: - /// - graph: The graph - /// - nnXLen: X length - /// - nnYLen: Y length - /// - numChannels: Number of channels - /// - dataType: Data type - init( - graph: MPSGraph, - nnXLen: NSNumber, - nnYLen: NSNumber, - numChannels: NSNumber, - dataType: MPSDataType = .float32 - ) { - shape = InputShape.create( - batchSize: -1, - numChannels: numChannels, - nnYLen: nnYLen, - nnXLen: nnXLen) - - self.tensor = graph.placeholder( - shape: shape, - dataType: dataType, - name: nil) - - assert(self.tensor.shape?.count == 4) - } -} - -/// A structure that represents an input global layer for a neural network model. -struct InputGlobalLayer { - let tensor: MPSGraphTensor - let shape: [NSNumber] - - /// Initializes an InputGlobalLayer object with a graph, batch size, number of global features, data type, and input shape. - /// - Parameters: - /// - graph: The graph. - /// - numGlobalFeatures: The number of global features. - /// - dataType: The data type. - init( - graph: MPSGraph, - numGlobalFeatures: NSNumber, - dataType: MPSDataType = .float32 - ) { - shape = InputShape.create( - batchSize: -1, - numChannels: numGlobalFeatures, - nnYLen: 1, - nnXLen: 1) - - self.tensor = graph.placeholder( - shape: shape, - dataType: dataType, - name: nil) - - assert(self.tensor.shape?.count == 4) - } -} - -/// A structure representing the input meta layer for a neural network graph. -struct InputMetaLayer { - /// A `MPSGraphTensor` representing the placeholder tensor in the graph. - let tensor: MPSGraphTensor - /// An array of `NSNumber` representing the shape of the tensor placeholder. - let shape: [NSNumber] - - /// Initializes a new `InputMetaLayer` instance with the given graph and number of meta features. - /// - /// - Parameters: - /// - graph: The `MPSGraph` instance where the placeholder tensor will be created. - /// - numMetaFeatures: The number of meta features (channels) for the input tensor. - /// - dataType: The data type - /// - /// This initializer sets the shape of the input tensor using a helper function `InputShape.create` with - /// a dynamic batch size (-1), the specified number of channels, and a spatial size of 1x1 (nnYLen and nnXLen). - /// It also creates a placeholder tensor in the MPS graph with the specified shape and data type `float32`. - init( - graph: MPSGraph, - numMetaFeatures: NSNumber, - dataType: MPSDataType = .float32 - ) { - // Define the shape of the input tensor with dynamic batch size, specified number of channels, and spatial dimensions 1x1. - shape = InputShape.create( - batchSize: -1, - numChannels: numMetaFeatures, - nnYLen: 1, - nnXLen: 1) - - // Create a placeholder tensor in the graph with the above-defined shape and data type float32. - self.tensor = graph.placeholder( - shape: shape, - dataType: dataType, - name: nil) - } -} - -/// A structure that represents a mask layer for a neural network model. -struct MaskLayer { - let tensor: MPSGraphTensor - let shape: [NSNumber] - - /// Initializes a MaskLayer object with a graph, batch size, x and y lengths, data type, and input shape. - /// - Parameters: - /// - graph: The graph. - /// - nnXLen: The length of the x-axis. - /// - nnYLen: The length of the y-axis. - /// - dataType: The data type. - init( - graph: MPSGraph, - nnXLen: NSNumber, - nnYLen: NSNumber, - dataType: MPSDataType = .float32 - ) { - shape = InputShape.create( - batchSize: -1, - numChannels: 1, - nnYLen: nnYLen, - nnXLen: nnXLen) - - self.tensor = graph.placeholder( - shape: shape, - dataType: dataType, - name: nil) - - assert(self.tensor.shape?.count == 4) - } -} - -/// A structure that represents a layer which performs the summation operation on a mask layer. -struct MaskSumLayer { - let tensor: MPSGraphTensor - - /// Initializes a MaskSumLayer object with a given tensor. - /// - Parameter tensor: The tensor to use for the layer. - init(tensor: MPSGraphTensor) { - self.tensor = tensor - assert(self.tensor.shape?.count == 4) - } - - /// Initializes a MaskSumLayer object with a graph, a mask layer, and a boolean flag indicating whether to use NHWC or NCHW format. - /// - Parameters: - /// - graph: The graph. - /// - maskTensor: The mask tensor. - init( - graph: MPSGraph, - maskTensor: MPSGraphTensor - ) { - let hwAxes = InputShape.getHWAxes() - - self.tensor = graph.reductionSum( - with: maskTensor, - axes: hwAxes, - name: nil) - - assert(self.tensor.shape?.count == 4) - } -} - -/// A structure that represents a layer which performs square root, subtraction, and multiplication operations on a MaskSumLayer object. -struct MaskSumSqrtS14M01Layer { - let tensor: MPSGraphTensor - - /// Initializes a MaskSumSqrtS14M01Layer object with a given tensor. - /// - Parameter tensor: The tensor to use for the layer. - init(tensor: MPSGraphTensor) { - self.tensor = tensor - assert(self.tensor.shape?.count == 4) - } - - /// Initializes a MaskSumSqrtS14M01Layer object with a graph, a MaskSumLayer object, and a boolean flag indicating whether to use 16-bit floating-point data type. - /// - Parameters: - /// - graph: The graph. - /// - maskSum: The MaskSumLayer object. - init( - graph: MPSGraph, - maskSum: MaskSumLayer - ) { - let sqrtMaskSum = graph.squareRoot(with: maskSum.tensor, name: nil) - - let fourTeen = graph.constant( - 14.0, - shape: [1], - dataType: maskSum.tensor.dataType) - - let subtracted = graph.subtraction(sqrtMaskSum, fourTeen, name: nil) - - let zeroPointone = graph.constant( - 0.1, - shape: [1], - dataType: maskSum.tensor.dataType) - - self.tensor = graph.multiplication( - subtracted, - zeroPointone, - name: nil) - - assert(self.tensor.shape?.count == 4) - } -} - -/// A structure that represents a layer which performs squaring and subtraction operations on a MaskSumSqrtS14M01Layer object. -struct MaskSumSqrtS14M01SquareS01Layer { - let tensor: MPSGraphTensor - - /// Initializes a MaskSumSqrtS14M01SquareS01Layer object with a given tensor. - /// - Parameter tensor: The tensor to use for the layer. - init(tensor: MPSGraphTensor) { - self.tensor = tensor - assert(self.tensor.shape?.count == 4) - } - - /// Initializes a MaskSumSqrtS14M01SquareS01Layer object with a graph, a MaskSumSqrtS14M01Layer object, and a boolean flag indicating whether to use 16-bit floating-point data type. - /// - Parameters: - /// - graph: The graph. - /// - maskSumSqrtS14M01: The MaskSumSqrtS14M01Layer object. - init( - graph: MPSGraph, - maskSumSqrtS14M01: MaskSumSqrtS14M01Layer - ) { - let squared = graph.square(with: maskSumSqrtS14M01.tensor, name: nil) - - let zeroPointone = graph.constant( - 0.1, - shape: [1], - dataType: maskSumSqrtS14M01.tensor.dataType) - - self.tensor = graph.subtraction( - squared, - zeroPointone, - name: nil) - - assert(self.tensor.shape?.count == 4) - } -} - -/// A Swift structure that represents a network tester, which tests various neural network configurations. -struct NetworkTester { - - /// A static function that tests a custom neural network configuration with the given parameters. - /// - Parameters: - /// - batchSize: The number of input batches. - /// - nnXLen: The width of the input tensor. - /// - nnYLen: The height of the input tensor. - /// - numChannels: The number of channels in the input tensor. - /// - input: A pointer to the input data. - /// - mask: A pointer to the mask data. - /// - output: A pointer to the output data. - /// - networkBuilder: A closure that takes an MPSGraph, InputLayer, and MaskLayer, and returns an MPSGraphTensor representing the custom network configuration. - static func test( - batchSize: NSNumber, - nnXLen: NSNumber, - nnYLen: NSNumber, - numChannels: NSNumber, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer, - networkBuilder: (MPSGraph, InputLayer, MaskLayer) -> MPSGraphTensor - ) { - - // Create a Metal device. - let device = MTLCreateSystemDefaultDevice()! - - // Create a MPSGraph. - let graph = MPSGraph() - - // Create the input and mask layers. - let inputLayer = InputLayer( - graph: graph, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: numChannels) - - let maskLayer = MaskLayer( - graph: graph, - nnXLen: nnXLen, - nnYLen: nnYLen) - - // Build the custom network configuration using the provided networkBuilder closure. - let resultTensor = networkBuilder(graph, inputLayer, maskLayer) - - // Create input shape - let inputShape = InputShape.create( - batchSize: batchSize, - numChannels: numChannels, - nnYLen: nnYLen, - nnXLen: nnXLen) - - // Create MPSNDArrayDescriptors from the input shape. - let sourceDescriptor = MPSNDArrayDescriptor( - dataType: inputLayer.tensor.dataType, - shape: inputShape) - - // Create MPSNDArray from the source descriptor. - let sourceArray = MPSNDArray( - device: device, - descriptor: sourceDescriptor) - - // Create a mask shape - let maskShape = InputShape.create( - batchSize: batchSize, - numChannels: 1, - nnYLen: nnYLen, - nnXLen: nnXLen) - - // Create MPSNDArrayDescriptors from the mask shape. - let maskDescriptor = MPSNDArrayDescriptor( - dataType: maskLayer.tensor.dataType, - shape: maskShape) - - // Create MPSNDArray from the mask descriptor. - let maskArray = MPSNDArray( - device: device, - descriptor: maskDescriptor) - - // Write input and mask data to their respective MPSNDArrays, converting to FP16 if necessary. - sourceArray.writeBytes(input) - maskArray.writeBytes(mask) - - // Create MPSGraphTensorData objects from the source and mask arrays. - let sourceTensorData = MPSGraphTensorData(sourceArray) - let maskTensorData = MPSGraphTensorData(maskArray) - - // Execute the graph and fetch the result. - let fetch = graph.run( - feeds: [ - inputLayer.tensor: sourceTensorData, - maskLayer.tensor: maskTensorData, - ], - targetTensors: [resultTensor], - targetOperations: nil) - - // Read the output data from the result tensor, converting from FP16 to FP32 if necessary. - fetch[resultTensor]?.mpsndarray().readBytes(output) - } -} - -/// A struct that represents a description of convolutional layer. -public struct SWConvLayerDesc { - let convYSize: NSNumber - let convXSize: NSNumber - let inChannels: NSNumber - let outChannels: NSNumber - let dilationY: Int - let dilationX: Int - let weights: UnsafeMutablePointer - - /// Initializes a SWConvLayerDesc object. - /// - Parameters: - /// - convYSize: The Y size of the convolution. - /// - convXSize: The X size of the convolution. - /// - inChannels: The number of input channels. - /// - outChannels: The number of output channels. - /// - dilationY: The dilation in the Y direction. - /// - dilationX: The dilation in the X direction. - /// - weights: A pointer to the weights. - init( - convYSize: NSNumber, - convXSize: NSNumber, - inChannels: NSNumber, - outChannels: NSNumber, - dilationY: Int, - dilationX: Int, - weights: UnsafeMutablePointer - ) { - self.convYSize = convYSize - self.convXSize = convXSize - self.inChannels = inChannels - self.outChannels = outChannels - self.dilationY = dilationY - self.dilationX = dilationX - self.weights = weights - } -} - -public func createSWConvLayerDesc( - convYSize: Int32, - convXSize: Int32, - inChannels: Int32, - outChannels: Int32, - dilationY: Int32, - dilationX: Int32, - weights: UnsafeMutablePointer -) -> SWConvLayerDesc { - return SWConvLayerDesc( - convYSize: convYSize as NSNumber, - convXSize: convXSize as NSNumber, - inChannels: inChannels as NSNumber, - outChannels: outChannels as NSNumber, - dilationY: Int(dilationY), - dilationX: Int(dilationX), - weights: weights) -} - -/// A class that represents a convolutional layer using MPSGraph -class ConvLayer { - /// The result tensor of the convolutional operation - let resultTensor: MPSGraphTensor - /// The convolution 2D operation descriptor - let convDescriptor = MPSGraphConvolution2DOpDescriptor( - strideInX: 1, - strideInY: 1, - dilationRateInX: 1, - dilationRateInY: 1, - groups: 1, - paddingStyle: .TF_SAME, - dataLayout: .NCHW, - weightsLayout: .OIHW)! - - /// Class method that tests the convolutional layer by running a forward pass - /// - Parameters: - /// - descriptor: A descriptor for the convolutional layer - /// - nnXLen: The width of the input tensor - /// - nnYLen: The height of the input tensor - /// - batchSize: The batch size of the input tensor - /// - input: A pointer to the input tensor data - /// - output: A pointer to the output tensor data - class func test( - descriptor: SWConvLayerDesc, - nnXLen: NSNumber, - nnYLen: NSNumber, - batchSize: NSNumber, - input: UnsafeMutablePointer, - output: UnsafeMutablePointer - ) { - let device = MTLCreateSystemDefaultDevice()! - let graph = MPSGraph() - - let source = InputLayer( - graph: graph, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.inChannels) - - let conv = ConvLayer( - graph: graph, - sourceTensor: source.tensor, - descriptor: descriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let inputShape = InputShape.create( - batchSize: batchSize, - numChannels: descriptor.inChannels, - nnYLen: nnYLen, - nnXLen: nnXLen) - - let sourceDescriptor = MPSNDArrayDescriptor( - dataType: source.tensor.dataType, - shape: inputShape) - - let sourceArray = MPSNDArray( - device: device, - descriptor: sourceDescriptor) - - sourceArray.writeBytes(input) - let sourceTensorData = MPSGraphTensorData(sourceArray) - - let fetch = graph.run( - feeds: [source.tensor: sourceTensorData], - targetTensors: [conv.resultTensor], - targetOperations: nil) - - fetch[conv.resultTensor]?.mpsndarray().readBytes(output) - } - - /// Initializes a ConvLayer object - /// - Parameters: - /// - graph: An MPSGraph object - /// - sourceTensor: The input tensor for the convolutional layer - /// - descriptor: A descriptor for the convolutional layer - /// - nnXLen: The width of the input tensor - /// - nnYLen: The height of the input tensor - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - descriptor: SWConvLayerDesc, - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - let weightsShape = [ - descriptor.outChannels, - descriptor.inChannels, - descriptor.convYSize, - descriptor.convXSize, - ] - - let weightsData = Data( - floatsNoCopy: descriptor.weights, - shape: weightsShape) - - let weightsTensor = graph.constant( - weightsData, - shape: weightsShape, - dataType: sourceTensor.dataType) - - resultTensor = graph.convolution2D( - sourceTensor, - weights: weightsTensor, - descriptor: convDescriptor, - name: nil) - - assert(resultTensor.shape?.count == 4) - } -} - -public func testConvLayer( - descriptor: SWConvLayerDesc, - nnXLen: Int32, - nnYLen: Int32, - batchSize: Int32, - input: UnsafeMutablePointer, - output: UnsafeMutablePointer -) { - ConvLayer.test( - descriptor: descriptor, - nnXLen: nnXLen as NSNumber, - nnYLen: nnYLen as NSNumber, - batchSize: batchSize as NSNumber, - input: input, - output: output) -} - -/// A struct that represents a description of a batch normalization layer. -public struct SWBatchNormLayerDesc { - let numChannels: NSNumber - let mergedScale: UnsafeMutablePointer - let mergedBias: UnsafeMutablePointer - - /// Initializes a SWBatchNormLayerDesc object. - /// - Parameters: - /// - numChannels: The number of channels in the input tensor. - /// - mergedScale: A pointer to the merged scale. - /// - mergedBias: A pointer to the merged bias. - init( - numChannels: NSNumber, - mergedScale: UnsafeMutablePointer, - mergedBias: UnsafeMutablePointer - ) { - self.numChannels = numChannels - self.mergedScale = mergedScale - self.mergedBias = mergedBias - } -} - -public func createSWBatchNormLayerDesc( - numChannels: Int32, - mergedScale: UnsafeMutablePointer, - mergedBias: UnsafeMutablePointer -) -> SWBatchNormLayerDesc { - return SWBatchNormLayerDesc( - numChannels: numChannels as NSNumber, - mergedScale: mergedScale, - mergedBias: mergedBias) -} - -/// A class that represents a batch normalization layer. -class BatchNormLayer { - let resultTensor: MPSGraphTensor - - /// Executes a test for the batch normalization layer. - /// - Parameters: - /// - descriptor: The description of the batch normalization layer. - /// - nnXLen: The width of the input tensor. - /// - nnYLen: The height of the input tensor. - /// - batchSize: The number of input batches. - /// - input: A pointer to the input data. - /// - mask: A pointer to the mask data. - /// - output: A pointer to the output data. - class func test( - descriptor: SWBatchNormLayerDesc, - nnXLen: NSNumber, - nnYLen: NSNumber, - batchSize: NSNumber, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer - ) { - - NetworkTester.test( - batchSize: batchSize, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.numChannels, - input: input, - mask: mask, - output: output - ) { graph, inputLayer, maskLayer in - - let batchNorm = BatchNormLayer( - graph: graph, - sourceTensor: inputLayer.tensor, - maskTensor: maskLayer.tensor, - descriptor: descriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - return batchNorm.resultTensor - } - } - - /// Initializes a BatchNormLayer object with the specified parameters, and computes the normalized and masked result tensor. - /// - Parameters: - /// - graph: The MPSGraph object used to build the BatchNormLayer. - /// - sourceTensor: The input tensor to the BatchNormLayer. - /// - maskTensor: The mask tensor to apply to the normalized tensor. - /// - descriptor: The BatchNormLayer descriptor containing parameters such as the number of channels, mean, variance, scale, and bias. - /// - nnXLen: The length of the input tensor in the X direction. - /// - nnYLen: The length of the input tensor in the Y direction. - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - descriptor: SWBatchNormLayerDesc, - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - let scaleBiasShape = InputShape.create( - batchSize: 1, - numChannels: descriptor.numChannels, - nnYLen: 1, - nnXLen: 1) - - let mergedScaleData = Data( - floatsNoCopy: descriptor.mergedScale, - shape: scaleBiasShape) - - let mergedBiasData = Data( - floatsNoCopy: descriptor.mergedBias, - shape: scaleBiasShape) - - let scaleTensor = graph.constant( - mergedScaleData, - shape: scaleBiasShape, - dataType: sourceTensor.dataType) - - let biasTensor = graph.constant( - mergedBiasData, - shape: scaleBiasShape, - dataType: sourceTensor.dataType) - - let scaled = graph.multiplication( - sourceTensor, - scaleTensor, - name: nil) - - let normalized = graph.addition( - scaled, - biasTensor, - name: nil) - - resultTensor = graph.multiplication( - normalized, - maskTensor, - name: nil) - - assert(resultTensor.shape?.count == 4) - } -} - -public func testBatchNormLayer( - descriptor: SWBatchNormLayerDesc, - nnXLen: Int32, - nnYLen: Int32, - batchSize: Int32, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer -) { - BatchNormLayer.test( - descriptor: descriptor, - nnXLen: nnXLen as NSNumber, - nnYLen: nnYLen as NSNumber, - batchSize: batchSize as NSNumber, - input: input, - mask: mask, - output: output) -} - -/// An enumeration of the different kinds of activation function. -public enum ActivationKind { - case identity - case relu - case mish -} - -/// A structure that represents an activation layer -struct ActivationLayer { - let resultTensor: MPSGraphTensor - - /// Initialize an ActivationLayer object - /// - Parameters: - /// - graph: The MPSGraph - /// - sourceTensor: The input tensor - /// - activationKind: The activation kind - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - activationKind: ActivationKind - ) { - - switch activationKind { - case .relu: - resultTensor = graph.reLU(with: sourceTensor, name: nil) - case .mish: - resultTensor = graph.mish(tensor: sourceTensor) - default: - resultTensor = sourceTensor - } - - assert(resultTensor.shape == sourceTensor.shape) - } -} - -/// A class that represents a residual block in a convolutional neural network. -public class SWResidualBlockDesc: BlockDescriptor { - /// A description of the batch normalization layer that is applied before the first convolutional layer. - let preBN: SWBatchNormLayerDesc - - /// The type of activation function that is applied before the first convolutional layer. - let preActivation: ActivationKind - - /// A description of the convolutional layer that is applied in the middle of the residual block. - let regularConv: SWConvLayerDesc - - /// A description of the batch normalization layer that is applied after the middle convolutional layer. - let midBN: SWBatchNormLayerDesc - - /// The type of activation function that is applied after the middle convolutional layer. - let midActivation: ActivationKind - - /// A description of the convolutional layer that is applied at the end of the residual block. - let finalConv: SWConvLayerDesc - - /// Initializes a `SWResidualBlockDesc` object. - /// - Parameters: - /// - preBN: A description of the batch normalization layer that is applied before the first convolutional layer. - /// - preActivation: The type of activation function that is applied before the first convolutional layer. - /// - regularConv: A description of the convolutional layer that is applied in the middle of the residual block. - /// - midBN: A description of the batch normalization layer that is applied after the middle convolutional layer. - /// - midActivation: The type of activation function that is applied after the middle convolutional layer. - /// - finalConv: A description of the convolutional layer that is applied at the end of the residual block. - init( - preBN: SWBatchNormLayerDesc, - preActivation: ActivationKind, - regularConv: SWConvLayerDesc, - midBN: SWBatchNormLayerDesc, - midActivation: ActivationKind, - finalConv: SWConvLayerDesc - ) { - self.preBN = preBN - self.preActivation = preActivation - self.regularConv = regularConv - self.midBN = midBN - self.midActivation = midActivation - self.finalConv = finalConv - } -} - -public func createSWResidualBlockDesc( - preBN: SWBatchNormLayerDesc, - preActivation: ActivationKind, - regularConv: SWConvLayerDesc, - midBN: SWBatchNormLayerDesc, - midActivation: ActivationKind, - finalConv: SWConvLayerDesc -) -> SWResidualBlockDesc { - return SWResidualBlockDesc( - preBN: preBN, - preActivation: preActivation, - regularConv: regularConv, - midBN: midBN, - midActivation: midActivation, - finalConv: finalConv) -} - -/// A class that represents a Residual Block layer -class ResidualBlock { - let resultTensor: MPSGraphTensor - - /// A function that runs tests on the Residual Block layer - /// - /// - Parameters: - /// - descriptor: The Residual Block descriptor - /// - batchSize: Batch size - /// - nnXLen: X length - /// - nnYLen: Y length - /// - input: The input float32 pointer - /// - mask: The mask float32 pointer - /// - output: The output float32 pointer - class func test( - descriptor: SWResidualBlockDesc, - batchSize: NSNumber, - nnXLen: NSNumber, - nnYLen: NSNumber, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer - ) { - - NetworkTester.test( - batchSize: batchSize, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.preBN.numChannels, - input: input, - mask: mask, - output: output - ) { graph, inputLayer, maskLayer in - - let block = ResidualBlock( - graph: graph, - sourceTensor: inputLayer.tensor, - maskTensor: maskLayer.tensor, - descriptor: descriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - return block.resultTensor - } - } - - /// Initialize a ResidualBlock object - /// - /// - Parameters: - /// - graph: The MPSGraph - /// - sourceTensor: The input tensor - /// - maskTensor: The mask tensor - /// - descriptor: The Residual Block descriptor - /// - nnXLen: X length - /// - nnYLen: Y length - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - descriptor: SWResidualBlockDesc, - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - let preBN = BatchNormLayer( - graph: graph, - sourceTensor: sourceTensor, - maskTensor: maskTensor, - descriptor: descriptor.preBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let preActivation = ActivationLayer( - graph: graph, - sourceTensor: preBN.resultTensor, - activationKind: descriptor.preActivation) - - let regularConv = ConvLayer( - graph: graph, - sourceTensor: preActivation.resultTensor, - descriptor: descriptor.regularConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let midBN = BatchNormLayer( - graph: graph, - sourceTensor: regularConv.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.midBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let midActivation = ActivationLayer( - graph: graph, - sourceTensor: midBN.resultTensor, - activationKind: descriptor.midActivation) - - let finalConv = ConvLayer( - graph: graph, - sourceTensor: midActivation.resultTensor, - descriptor: descriptor.finalConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - resultTensor = graph.addition( - sourceTensor, - finalConv.resultTensor, - name: nil) - - assert(resultTensor.shape?.count == 4) - } -} - -public func testResidualBlock( - descriptor: SWResidualBlockDesc, - batchSize: Int32, - nnXLen: Int32, - nnYLen: Int32, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer -) { - ResidualBlock.test( - descriptor: descriptor, - batchSize: batchSize as NSNumber, - nnXLen: nnXLen as NSNumber, - nnYLen: nnYLen as NSNumber, - input: input, - mask: mask, - output: output) -} - -/// A structure that represents a global pooling layer -struct GlobalPoolingLayer { - /// The resulting tensor after applying the global pooling operation - let resultTensor: MPSGraphTensor - - /// Initialize a GlobalPoolingLayer object - /// - Parameters: - /// - graph: The graph - /// - sourceTensor: The source tensor to be pooled - /// - maskTensor: The mask tensor - /// - maskSumTensor: The sum of the mask - /// - maskSumSqrtS14M01Tensor: The multiplication of subtraction of square root of the sum of the mask - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor - ) { - let hwAxes = InputShape.getHWAxes() - let channelAxis = InputShape.getChannelAxis() - - let sumTensor = graph.reductionSum( - with: sourceTensor, - axes: hwAxes, - name: nil) - - let meanTensor = graph.division(sumTensor, maskSumTensor, name: nil) - - let meanMaskTensor = graph.multiplication( - meanTensor, - maskSumSqrtS14M01Tensor, - name: nil) - - let oneTensor = graph.constant(1.0, dataType: sourceTensor.dataType) - let maskM1Tensor = graph.subtraction(maskTensor, oneTensor, name: nil) - let addition = graph.addition(sourceTensor, maskM1Tensor, name: nil) - - let maxTensor = graph.reductionMaximum( - with: addition, - axes: hwAxes, - name: nil) - - resultTensor = graph.concatTensors( - [ - meanTensor, - meanMaskTensor, - maxTensor, - ], - dimension: channelAxis, - name: nil) - - assert(resultTensor.shape?.count == 4) - assert(resultTensor.shape?[2] == 1) - assert(resultTensor.shape?[3] == 1) - } -} - -/// A structure that represents a layer that performs global pooling on the input tensor -struct GlobalPoolingValueLayer { - let resultTensor: MPSGraphTensor - - /// Initialize a GlobalPoolingValueLayer object - /// - Parameters: - /// - graph: The graph - /// - sourceTensor: The input tensor - /// - maskSumTensor: The sum of the mask - /// - maskSumSqrtS14M01Tensor: The multiplication of subtraction of square root of the sum of the mask - /// - maskSumSqrtS14M01SquareS01Tensor: The subtraction of square of multiplication of subtraction of square root of the sum of the mask - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor, - maskSumSqrtS14M01SquareS01Tensor: MPSGraphTensor - ) { - let hwAxes = InputShape.getHWAxes() - let channelAxis = InputShape.getChannelAxis() - - let sumTensor = graph.reductionSum( - with: sourceTensor, - axes: hwAxes, - name: nil) - - let meanTensor = graph.division(sumTensor, maskSumTensor, name: nil) - - let meanMaskTensor = graph.multiplication( - meanTensor, - maskSumSqrtS14M01Tensor, - name: nil) - - let meanMaskSquareTensor = graph.multiplication( - meanTensor, - maskSumSqrtS14M01SquareS01Tensor, - name: nil) - - resultTensor = graph.concatTensors( - [ - meanTensor, - meanMaskTensor, - meanMaskSquareTensor, - ], - dimension: channelAxis, - name: nil) - - assert(resultTensor.shape?.count == 4) - assert(resultTensor.shape?[2] == 1) - assert(resultTensor.shape?[3] == 1) - } -} - -/// A struct that represents a matrix multiplication layer descriptor -public struct SWMatMulLayerDesc { - /// The number of input channels - let inChannels: NSNumber - /// The number of output channels - let outChannels: NSNumber - /// The weights used for the matrix multiplication - let weights: UnsafeMutablePointer - - /// Initialize a SWMatMulLayerDesc object - /// - Parameters: - /// - inChannels: The number of input channels - /// - outChannels: The number of output channels - /// - weights: The weights used for the matrix multiplication - init( - inChannels: NSNumber, - outChannels: NSNumber, - weights: UnsafeMutablePointer - ) { - self.inChannels = inChannels - self.outChannels = outChannels - self.weights = weights - } -} - -public func createSWMatMulLayerDesc( - inChannels: Int32, - outChannels: Int32, - weights: UnsafeMutablePointer -) -> SWMatMulLayerDesc { - return SWMatMulLayerDesc( - inChannels: inChannels as NSNumber, - outChannels: outChannels as NSNumber, - weights: weights) -} - -/// A structure representing a matrix multiplication layer. -struct MatMulLayer { - /// The resulting tensor from the layer. - let resultTensor: MPSGraphTensor - - /// Initializes a MatMulLayer object. - /// - Parameters: - /// - graph: The graph. - /// - descriptor: The matrix multiplication layer descriptor. - /// - sourceTensor: The input tensor to the layer. - init( - graph: MPSGraph, - descriptor: SWMatMulLayerDesc, - sourceTensor: MPSGraphTensor - ) { - - assert( - (sourceTensor.shape?.count == 4) || (sourceTensor.shape?[1] == descriptor.inChannels)) - assert( - (sourceTensor.shape?.count == 2) || (sourceTensor.shape?[1] == descriptor.inChannels)) - - let weightsShape = [ - descriptor.inChannels, - descriptor.outChannels, - ] - - let weightsData = Data( - floatsNoCopy: descriptor.weights, - shape: weightsShape) - - let weightsTensor = graph.constant( - weightsData, - shape: weightsShape, - dataType: sourceTensor.dataType) - - let shape = [-1, descriptor.inChannels] - - let reshapedSource = graph.reshape( - sourceTensor, - shape: shape, - name: nil) - - resultTensor = graph.matrixMultiplication( - primary: reshapedSource, - secondary: weightsTensor, - name: nil) - - assert(resultTensor.shape?.count == 2) - } -} - -/// An Objective-C class that represents the bias layer description used in Swift. -public struct SWMatBiasLayerDesc { - /// The number of channels. - let numChannels: NSNumber - /// The pointer to the weights. - let weights: UnsafeMutablePointer - - /// Initialize an instance of SWMatBiasLayerDesc. - /// - Parameters: - /// - numChannels: The number of channels. - /// - weights: The pointer to the weights. - init( - numChannels: NSNumber, - weights: UnsafeMutablePointer - ) { - self.numChannels = numChannels - self.weights = weights - } -} - -public func createSWMatBiasLayerDesc( - numChannels: Int32, - weights: UnsafeMutablePointer -) -> SWMatBiasLayerDesc { - return SWMatBiasLayerDesc( - numChannels: numChannels as NSNumber, - weights: weights) -} - -/// A structure that performs matrix bias operations -struct MatBiasLayer { - /// The resulting tensor from the layer. - let resultTensor: MPSGraphTensor - - /// Initializes a MatBiasLayer object. - /// - Parameters: - /// - graph: The graph. - /// - descriptor: The descriptor that contains information about the layer - /// - sourceTensor: The input tensor to the layer. - init( - graph: MPSGraph, - descriptor: SWMatBiasLayerDesc, - sourceTensor: MPSGraphTensor - ) { - - assert( - (sourceTensor.shape?.count == 2) && (sourceTensor.shape?[1] == descriptor.numChannels)) - - let weightsShape = [1, descriptor.numChannels] - - let weightsData = Data( - floatsNoCopy: descriptor.weights, - shape: weightsShape) - - let weightsTensor = graph.constant( - weightsData, - shape: weightsShape, - dataType: sourceTensor.dataType) - - resultTensor = graph.addition( - sourceTensor, - weightsTensor, - name: nil) - } -} - -/// A structure that performs bias operations in NC coordinates. -struct AddNCBiasLayer { - /// The resulting tensor from the layer. - let resultTensor: MPSGraphTensor - - /// Initializes an AddNCBiasLayer object. - /// - Parameters: - /// - graph: The graph. - /// - sourceTensor: The input tensor to the layer. - /// - biasTensor: The bias tensor. - /// - nnXLen: The x length. - /// - nnYLen: The y length. - /// - numChannels: The number of channels. - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - biasTensor: MPSGraphTensor, - nnXLen: NSNumber, - nnYLen: NSNumber, - numChannels: NSNumber - ) { - let shape = InputShape.create( - batchSize: -1, - numChannels: numChannels, - nnYLen: 1, - nnXLen: 1) - - assert(biasTensor.shape?[1] == shape[1]) - - let reshaped = graph.reshape(biasTensor, shape: shape, name: nil) - resultTensor = graph.addition(sourceTensor, reshaped, name: nil) - - assert(resultTensor.shape?.count == 4) - assert(resultTensor.shape?[2] == nnYLen) - assert(resultTensor.shape?[3] == nnXLen) - } -} - -/// A class that represents a residual block with global pooling. -public class SWGlobalPoolingResidualBlockDesc: BlockDescriptor { - /// The batch normalization layer before the residual block. - let preBN: SWBatchNormLayerDesc - - /// The pre-activation function of the residual block. - let preActivation: ActivationKind - - /// The regular convolutional layer in the residual block. - let regularConv: SWConvLayerDesc - - /// The convolutional layer for global pooling. - let gpoolConv: SWConvLayerDesc - - /// The batch normalization layer after the global pooling convolutional layer. - let gpoolBN: SWBatchNormLayerDesc - - /// The activation function after the global pooling batch normalization layer. - let gpoolActivation: ActivationKind - - /// The matrix multiplication layer that multiplies the global pooled output with a bias. - let gpoolToBiasMul: SWMatMulLayerDesc - - /// The batch normalization layer after the matrix multiplication layer. - let midBN: SWBatchNormLayerDesc - - /// The activation function after the mid batch normalization layer. - let midActivation: ActivationKind - - /// The final convolutional layer in the residual block. - let finalConv: SWConvLayerDesc - - /// Initialize a SWGlobalPoolingResidualBlockDesc object. - /// - Parameters: - /// - preBN: The batch normalization layer before the residual block. - /// - preActivation: The pre-activation function of the residual block. - /// - regularConv: The regular convolutional layer in the residual block. - /// - gpoolConv: The convolutional layer for global pooling. - /// - gpoolBN: The batch normalization layer after the global pooling convolutional layer. - /// - gpoolActivation: The activation function after the global pooling batch normalization layer. - /// - gpoolToBiasMul: The matrix multiplication layer that multiplies the global pooled output with a bias. - /// - midBN: The batch normalization layer after the matrix multiplication layer. - /// - midActivation: The activation function after the mid batch normalization layer. - /// - finalConv: The final convolutional layer in the residual block. - init( - preBN: SWBatchNormLayerDesc, - preActivation: ActivationKind, - regularConv: SWConvLayerDesc, - gpoolConv: SWConvLayerDesc, - gpoolBN: SWBatchNormLayerDesc, - gpoolActivation: ActivationKind, - gpoolToBiasMul: SWMatMulLayerDesc, - midBN: SWBatchNormLayerDesc, - midActivation: ActivationKind, - finalConv: SWConvLayerDesc - ) { - self.preBN = preBN - self.preActivation = preActivation - self.regularConv = regularConv - self.gpoolConv = gpoolConv - self.gpoolBN = gpoolBN - self.gpoolActivation = gpoolActivation - self.gpoolToBiasMul = gpoolToBiasMul - self.midBN = midBN - self.midActivation = midActivation - self.finalConv = finalConv - } -} - -public func createSWGlobalPoolingResidualBlockDesc( - preBN: SWBatchNormLayerDesc, - preActivation: ActivationKind, - regularConv: SWConvLayerDesc, - gpoolConv: SWConvLayerDesc, - gpoolBN: SWBatchNormLayerDesc, - gpoolActivation: ActivationKind, - gpoolToBiasMul: SWMatMulLayerDesc, - midBN: SWBatchNormLayerDesc, - midActivation: ActivationKind, - finalConv: SWConvLayerDesc -) -> SWGlobalPoolingResidualBlockDesc { - - return SWGlobalPoolingResidualBlockDesc( - preBN: preBN, - preActivation: preActivation, - regularConv: regularConv, - gpoolConv: gpoolConv, - gpoolBN: gpoolBN, - gpoolActivation: gpoolActivation, - gpoolToBiasMul: gpoolToBiasMul, - midBN: midBN, - midActivation: midActivation, - finalConv: finalConv) -} - -/// A class representing a residual block with global pooling -class GlobalPoolingResidualBlock { - let resultTensor: MPSGraphTensor - - /// A method to test the global pooling residual block - /// - /// - Parameters: - /// - descriptor: The descriptor of the global pooling residual block - /// - batchSize: The batch size - /// - nnXLen: The X length - /// - nnYLen: The Y length - /// - input: The input pointer - /// - mask: The mask pointer - /// - output: The output pointer - class func test( - descriptor: SWGlobalPoolingResidualBlockDesc, - batchSize: NSNumber, - nnXLen: NSNumber, - nnYLen: NSNumber, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer - ) { - - NetworkTester.test( - batchSize: batchSize, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.preBN.numChannels, - input: input, - mask: mask, - output: output - ) { graph, inputLayer, maskLayer in - - let maskSum = MaskSumLayer( - graph: graph, - maskTensor: maskLayer.tensor) - - let maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer( - graph: graph, - maskSum: maskSum) - - let block = - GlobalPoolingResidualBlock( - graph: graph, - sourceTensor: inputLayer.tensor, - maskTensor: maskLayer.tensor, - maskSumTensor: maskSum.tensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, - descriptor: descriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - return block.resultTensor - } - } - - /// Initialize a GlobalPoolingResidualBlock object - /// - /// - Parameters: - /// - graph: The graph - /// - sourceTensor: The source tensor - /// - maskTensor: The mask tensor - /// - maskSumTensor: The mask sum tensor - /// - maskSumSqrtS14M01Tensor: The mask sum square tensor - /// - descriptor: The descriptor of the global pooling residual block - /// - nnXLen: The X length - /// - nnYLen: The Y length - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor, - descriptor: SWGlobalPoolingResidualBlockDesc, - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - let maskSum = MaskSumLayer(tensor: maskSumTensor) - let maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer(tensor: maskSumSqrtS14M01Tensor) - - let preBN = BatchNormLayer( - graph: graph, - sourceTensor: sourceTensor, - maskTensor: maskTensor, - descriptor: descriptor.preBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let preActivation = ActivationLayer( - graph: graph, - sourceTensor: preBN.resultTensor, - activationKind: descriptor.preActivation) - - let regularConv = ConvLayer( - graph: graph, - sourceTensor: preActivation.resultTensor, - descriptor: descriptor.regularConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let gpoolConv = ConvLayer( - graph: graph, - sourceTensor: preActivation.resultTensor, - descriptor: descriptor.gpoolConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let gpoolBN = BatchNormLayer( - graph: graph, - sourceTensor: gpoolConv.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.gpoolBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let gpoolActivation = ActivationLayer( - graph: graph, - sourceTensor: gpoolBN.resultTensor, - activationKind: descriptor.gpoolActivation) - - let gpoolConcat = GlobalPoolingLayer( - graph: graph, - sourceTensor: gpoolActivation.resultTensor, - maskTensor: maskTensor, - maskSumTensor: maskSum.tensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor) - - assert(gpoolConcat.resultTensor.shape?[1] == descriptor.gpoolToBiasMul.inChannels) - - let gpoolToBiasMul = MatMulLayer( - graph: graph, - descriptor: descriptor.gpoolToBiasMul, - sourceTensor: gpoolConcat.resultTensor) - - let added = AddNCBiasLayer( - graph: graph, - sourceTensor: regularConv.resultTensor, - biasTensor: gpoolToBiasMul.resultTensor, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.gpoolToBiasMul.outChannels) - - let midBN = BatchNormLayer( - graph: graph, - sourceTensor: added.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.midBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let midActivation = ActivationLayer( - graph: graph, - sourceTensor: midBN.resultTensor, - activationKind: descriptor.midActivation) - - let finalConv = ConvLayer( - graph: graph, - sourceTensor: midActivation.resultTensor, - descriptor: descriptor.finalConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - resultTensor = graph.addition( - sourceTensor, - finalConv.resultTensor, - name: nil) - - assert(resultTensor.shape?.count == 4) - } -} - -public func testGlobalPoolingResidualBlock( - descriptor: SWGlobalPoolingResidualBlockDesc, - batchSize: Int32, - nnXLen: Int32, - nnYLen: Int32, - input: UnsafeMutablePointer, - mask: UnsafeMutablePointer, - output: UnsafeMutablePointer -) { - GlobalPoolingResidualBlock.test( - descriptor: descriptor, - batchSize: batchSize as NSNumber, - nnXLen: nnXLen as NSNumber, - nnYLen: nnYLen as NSNumber, - input: input, - mask: mask, - output: output) -} - -/// A class that represents a nested bottleneck residual block -public class SWNestedBottleneckResidualBlockDesc: BlockDescriptor { - /// The batch normalization layer before the residual block. - let preBN: SWBatchNormLayerDesc - - /// The pre-activation function of the residual block. - let preActivation: ActivationKind - - /// The convolutional layer before the residual block. - let preConv: SWConvLayerDesc - - /// The list of blocks that make up the trunk - let blockDescriptors: [BlockDescriptor] - - /// The batch normalization layer after the residual block. - let postBN: SWBatchNormLayerDesc - - /// The activation function after the post batch normalization layer. - let postActivation: ActivationKind - - /// The convolutional layer after the post activation layer. - let postConv: SWConvLayerDesc - - /// Initialize a SWNestedBottleneckResidualBlockDesc object. - /// - Parameters: - /// - preBN: The batch normalization layer before the residual block. - /// - preActivation: The pre-activation function of the residual block. - /// - preConv: The convolutional layer before the residual block. - /// - postBN: The batch normalization layer after the residual block. - /// - postActivation: The activation function after the post batch normalization layer. - /// - postConv: The convolutional layer after the post activation layer. - init( - preBN: SWBatchNormLayerDesc, - preActivation: ActivationKind, - preConv: SWConvLayerDesc, - blockDescriptors: [BlockDescriptor], - postBN: SWBatchNormLayerDesc, - postActivation: ActivationKind, - postConv: SWConvLayerDesc - ) { - self.preBN = preBN - self.preActivation = preActivation - self.preConv = preConv - self.blockDescriptors = blockDescriptors - self.postBN = postBN - self.postActivation = postActivation - self.postConv = postConv - } -} - -public func createSWNestedBottleneckResidualBlockDesc( - preBN: SWBatchNormLayerDesc, - preActivation: ActivationKind, - preConv: SWConvLayerDesc, - blockDescriptors: [BlockDescriptor], - postBN: SWBatchNormLayerDesc, - postActivation: ActivationKind, - postConv: SWConvLayerDesc -) -> SWNestedBottleneckResidualBlockDesc { - return SWNestedBottleneckResidualBlockDesc( - preBN: preBN, - preActivation: preActivation, - preConv: preConv, - blockDescriptors: blockDescriptors, - postBN: postBN, - postActivation: postActivation, - postConv: postConv) -} - -public class BlockDescriptor { -} - -public class BlockDescriptorBuilder { - public var blockDescriptors: [BlockDescriptor] = [] - - public func enque(with descriptor: BlockDescriptor) { - blockDescriptors.append(descriptor) - } -} - -public func createBlockDescriptorBuilder() -> BlockDescriptorBuilder { - return BlockDescriptorBuilder() -} - -/// A structure that represents a block stack -struct BlockStack { - /// The resulting tensor after processing the block stack - let resultTensor: MPSGraphTensor - - /// Process block descriptors - /// - Parameters: - /// - graph: The MPSGraph - /// - sourceTensor: The input tensor - /// - maskTensor: The mask tensor - /// - maskSumTensor: The sum of the mask tensor - /// - maskSumSqrtS14M01Tensor: The square root of the sum of the mask tensor - /// - blockDescriptors: The block descriptors - /// - index: The index of the block descriptor - /// - nnXLen: X length - /// - nnYLen: Y length - /// - Returns: The result tensor - static func processBlockDescriptors( - _ graph: MPSGraph, - _ sourceTensor: MPSGraphTensor, - _ maskTensor: MPSGraphTensor, - _ maskSumTensor: MPSGraphTensor, - _ maskSumSqrtS14M01Tensor: MPSGraphTensor, - _ blockDescriptors: [BlockDescriptor], - _ index: Int, - _ nnXLen: NSNumber, - _ nnYLen: NSNumber - ) -> MPSGraphTensor { - guard index < blockDescriptors.count else { - return sourceTensor - } - - let blockDescriptor = blockDescriptors[index] - let blockInput: MPSGraphTensor - - switch blockDescriptor { - case let globalPoolingDescriptor as SWGlobalPoolingResidualBlockDesc: - let globalPooling = GlobalPoolingResidualBlock( - graph: graph, - sourceTensor: sourceTensor, - maskTensor: maskTensor, - maskSumTensor: maskSumTensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, - descriptor: globalPoolingDescriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - blockInput = globalPooling.resultTensor - case let nestedBottleneckDescriptor as SWNestedBottleneckResidualBlockDesc: - let nestedBottleneck = NestedBottleneckResidualBlock( - graph: graph, - sourceTensor: sourceTensor, - maskTensor: maskTensor, - maskSumTensor: maskSumTensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, - descriptor: nestedBottleneckDescriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - blockInput = nestedBottleneck.resultTensor - case let residualBlockDescriptor as SWResidualBlockDesc: - let ordinary = ResidualBlock( - graph: graph, - sourceTensor: sourceTensor, - maskTensor: maskTensor, - descriptor: residualBlockDescriptor, - nnXLen: nnXLen, - nnYLen: nnYLen) - - blockInput = ordinary.resultTensor - default: - blockInput = sourceTensor - } - - return processBlockDescriptors( - graph, - blockInput, - maskTensor, - maskSumTensor, - maskSumSqrtS14M01Tensor, - blockDescriptors, - index + 1, - nnXLen, - nnYLen) - } - - /// Initialize a BlockStack object - /// - Parameters: - /// - graph: The MPSGraph - /// - sourceTensor: The input tensor - /// - maskTensor: The mask tensor - /// - maskSumTensor: The sum of the mask tensor - /// - maskSumSqrtS14M01Tensor: The square root of the sum of the mask tensor - /// - blockDescriptors: The block descriptors - /// - nnXLen: X length - /// - nnYLen: Y length - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor, - blockDescriptors: [BlockDescriptor], - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - resultTensor = BlockStack.processBlockDescriptors( - graph, - sourceTensor, - maskTensor, - maskSumTensor, - maskSumSqrtS14M01Tensor, - blockDescriptors, - 0, - nnXLen, - nnYLen) - } -} - -/// A structure that represents a nested bottleneck residual block -struct NestedBottleneckResidualBlock { - /// The resulting tensor after processing the nested bottleneck residual block - let resultTensor: MPSGraphTensor - - /// Initialize a ResidualBlock object - /// - /// - Parameters: - /// - graph: The MPSGraph - /// - sourceTensor: The input tensor - /// - maskTensor: The mask tensor - /// - maskSumTensor: The sum of the mask tensor - /// - maskSumSqrtS14M01Tensor: The square root of the sum of the mask tensor - /// - descriptor: The nested bottleneck residual block descriptor - /// - nnXLen: X length - /// - nnYLen: Y length - init( - graph: MPSGraph, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor, - descriptor: SWNestedBottleneckResidualBlockDesc, - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - - let preBN = BatchNormLayer( - graph: graph, - sourceTensor: sourceTensor, - maskTensor: maskTensor, - descriptor: descriptor.preBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let preActivation = ActivationLayer( - graph: graph, - sourceTensor: preBN.resultTensor, - activationKind: descriptor.preActivation) - - let preConv = ConvLayer( - graph: graph, - sourceTensor: preActivation.resultTensor, - descriptor: descriptor.preConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let blocks = BlockStack( - graph: graph, - sourceTensor: preConv.resultTensor, - maskTensor: maskTensor, - maskSumTensor: maskSumTensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, - blockDescriptors: descriptor.blockDescriptors, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let postBN = BatchNormLayer( - graph: graph, - sourceTensor: blocks.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.postBN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let postActivation = ActivationLayer( - graph: graph, - sourceTensor: postBN.resultTensor, - activationKind: descriptor.postActivation) - - let postConv = ConvLayer( - graph: graph, - sourceTensor: postActivation.resultTensor, - descriptor: descriptor.postConv, - nnXLen: nnXLen, - nnYLen: nnYLen) +// NOTE: Model caching and conversion are now handled in C++ using the native katagocoreml library. +// The Python-based CoreMLConverter and ModelCacheManager have been removed to eliminate Python dependency. - resultTensor = graph.addition( - sourceTensor, - postConv.resultTensor, - name: nil) +/// Context storing board dimensions and settings +public class MetalComputeContext { + public let nnXLen: Int32 + public let nnYLen: Int32 + public let useFP16: Bool - assert(resultTensor.shape?.count == 4) + init(nnXLen: Int32, nnYLen: Int32, useFP16: Bool) { + self.nnXLen = nnXLen + self.nnYLen = nnYLen + self.useFP16 = useFP16 } } -/// Class representing the description of the SGF Metadata Encoder. -/// -/// This encoder consists of three matrix multiplication layers, each followed by a bias and an activation function. -public class SWSGFMetadataEncoderDesc { - /// Version of the SGF Metadata Encoder. - let version: Int +/// Create a Metal compute context +public func createMetalComputeContext( + nnXLen: Int32, + nnYLen: Int32, + useFP16: Bool +) -> MetalComputeContext { + return MetalComputeContext(nnXLen: nnXLen, nnYLen: nnYLen, useFP16: useFP16) +} - /// Number of input metadata channels. +/// Handle that wraps the loaded MLModel for inference +public class CoreMLComputeHandle { + let model: MLModel + let nnXLen: Int32 + let nnYLen: Int32 + let optimizeIdentityMask: Bool + let numInputChannels: Int + let numInputGlobalChannels: Int let numInputMetaChannels: Int + let numPolicyChannels: Int + let numValueChannels: Int + let numScoreValueChannels: Int + let numOwnershipChannels: Int + + /// Model input/output names matching KataGoCoremltools output + struct IONames { + static let spatialInput = "spatial_input" + static let globalInput = "global_input" + static let inputMask = "input_mask" + static let metaInput = "meta_input" + + static let policyOutput = "policy_p2_conv" + static let policyPassOutput = "policy_pass" + static let valueOutput = "value_v3_bias" + static let ownershipOutput = "value_ownership_conv" + static let scoreValueOutput = "value_sv3_bias" + } - /// Description of the first multiplication layer. - let mul1: SWMatMulLayerDesc - - /// Description of the bias for the first layer. - let bias1: SWMatBiasLayerDesc - - /// Activation kind for the first layer. - let act1: ActivationKind - - /// Description of the second multiplication layer. - let mul2: SWMatMulLayerDesc - - /// Description of the bias for the second layer. - let bias2: SWMatBiasLayerDesc - - /// Activation kind for the second layer. - let act2: ActivationKind - - /// Description of the third multiplication layer. - let mul3: SWMatMulLayerDesc - - /// Initializes a new instance of the `SWSGFMetadataEncoderDesc` class. - /// - /// - Parameters: - /// - version: The version of the SGF Metadata Encoder. - /// - numInputMetaChannels: The number of input metadata channels. - /// - mul1: Description of the first multiplication layer. - /// - bias1: Description of the bias for the first layer. - /// - act1: Activation kind for the first layer. - /// - mul2: Description of the second multiplication layer. - /// - bias2: Description of the bias for the second layer. - /// - act2: Activation kind for the second layer. - /// - mul3: Description of the third multiplication layer. - init( - version: Int, - numInputMetaChannels: Int, - mul1: SWMatMulLayerDesc, - bias1: SWMatBiasLayerDesc, - act1: ActivationKind, - mul2: SWMatMulLayerDesc, - bias2: SWMatBiasLayerDesc, - act2: ActivationKind, - mul3: SWMatMulLayerDesc - ) { - self.version = version + init(model: MLModel, nnXLen: Int32, nnYLen: Int32, + optimizeIdentityMask: Bool, + numInputChannels: Int, + numInputGlobalChannels: Int, + numInputMetaChannels: Int, + numPolicyChannels: Int, + numValueChannels: Int, + numScoreValueChannels: Int, + numOwnershipChannels: Int) { + self.model = model + self.nnXLen = nnXLen + self.nnYLen = nnYLen + self.optimizeIdentityMask = optimizeIdentityMask + self.numInputChannels = numInputChannels + self.numInputGlobalChannels = numInputGlobalChannels self.numInputMetaChannels = numInputMetaChannels - self.mul1 = mul1 - self.bias1 = bias1 - self.act1 = act1 - self.mul2 = mul2 - self.bias2 = bias2 - self.act2 = act2 - self.mul3 = mul3 + self.numPolicyChannels = numPolicyChannels + self.numValueChannels = numValueChannels + self.numScoreValueChannels = numScoreValueChannels + self.numOwnershipChannels = numOwnershipChannels } -} - -/// Creates an instance of `SWSGFMetadataEncoderDesc` using the specified parameters. -/// -/// - Parameters: -/// - version: An `Int32` representing the version of the encoder descriptor. -/// - numInputMetaChannels: An `Int32` specifying the number of input metadata channels. -/// - mul1: A `SWMatMulLayerDesc` representing the description of the first matrix multiplication layer. -/// - bias1: A `SWMatBiasLayerDesc` representing the description of the bias for the first layer. -/// - act1: An `ActivationKind` specifying the activation function applied after the first layer. -/// - mul2: A `SWMatMulLayerDesc` representing the description of the second matrix multiplication layer. -/// - bias2: A `SWMatBiasLayerDesc` representing the description of the bias for the second layer. -/// - act2: An `ActivationKind` specifying the activation function applied after the second layer. -/// - mul3: A `SWMatMulLayerDesc` representing the description of the third matrix multiplication layer. -/// -/// - Returns: -/// An instance of `SWSGFMetadataEncoderDesc` initialized with the provided parameters. -public func createSWSGFMetadataEncoderDesc( - version: Int32, - numInputMetaChannels: Int32, - mul1: SWMatMulLayerDesc, - bias1: SWMatBiasLayerDesc, - act1: ActivationKind, - mul2: SWMatMulLayerDesc, - bias2: SWMatBiasLayerDesc, - act2: ActivationKind, - mul3: SWMatMulLayerDesc -) -> SWSGFMetadataEncoderDesc? { - return SWSGFMetadataEncoderDesc( - version: Int(version), - numInputMetaChannels: Int(numInputMetaChannels), - mul1: mul1, - bias1: bias1, - act1: act1, - mul2: mul2, - bias2: bias2, - act2: act2, - mul3: mul3) -} -/// A class that describes SGF metadata encoder. -/// SGFMetadataEncoder takes a graph, a descriptor object defining various parameters for the encoding process, -/// and an input tensor, and performs a sequence of matrix multiplications, bias additions, and activation functions -/// to produce a final encoded tensor. -class SGFMetadataEncoder { - /// The resulting tensor after encoding the metadata. - let resultTensor: MPSGraphTensor - - /// Initializes an `SGFMetadataEncoder` instance and performs the encoding process. - /// - /// - Parameters: - /// - graph: The computational graph object used to define and manage tensor operations. - /// - descriptor: An object holding all the required parameters, including matrix multiplication, biases, - /// and activation functions for each layer. - /// - sourceTensor: The initial input tensor containing the metadata to be encoded. - init( - graph: MPSGraph, - descriptor: SWSGFMetadataEncoderDesc, - sourceTensor: MPSGraphTensor + /// Run inference on a batch of inputs + public func apply( + spatialInput: UnsafeMutablePointer, + globalInput: UnsafeMutablePointer, + metaInput: UnsafeMutablePointer, + maskInput: UnsafeMutablePointer, + policy: UnsafeMutablePointer, + policyPass: UnsafeMutablePointer, + value: UnsafeMutablePointer, + scoreValue: UnsafeMutablePointer, + ownership: UnsafeMutablePointer, + batchSize: Int ) { + // Process batch elements in parallel using Grand Central Dispatch + // Each inference is independent, reading/writing to different buffer offsets + DispatchQueue.concurrentPerform(iterations: batchSize) { b in + autoreleasepool { + do { + try runSingleInference( + batchIndex: b, + spatialInput: spatialInput, + globalInput: globalInput, + metaInput: metaInput, + maskInput: maskInput, + policy: policy, + policyPass: policyPass, + value: value, + scoreValue: scoreValue, + ownership: ownership + ) + } catch { + printError("Metal backend: CoreML inference error: \(error)") + } + } + } + } - // First matrix multiplication layer. - let mul1 = MatMulLayer( - graph: graph, - descriptor: descriptor.mul1, - sourceTensor: sourceTensor) - - // Adding bias to the result of the first matrix multiplication. - let bias1 = MatBiasLayer( - graph: graph, - descriptor: descriptor.bias1, - sourceTensor: mul1.resultTensor) - - // Applying the first activation function to the biased tensor. - let act1 = ActivationLayer( - graph: graph, - sourceTensor: bias1.resultTensor, - activationKind: descriptor.act1) - - // Second matrix multiplication layer taking the output of the first activation layer. - let mul2 = MatMulLayer( - graph: graph, - descriptor: descriptor.mul2, - sourceTensor: act1.resultTensor) - - // Adding bias to the result of the second matrix multiplication. - let bias2 = MatBiasLayer( - graph: graph, - descriptor: descriptor.bias2, - sourceTensor: mul2.resultTensor) - - // Applying the second activation function to the biased tensor. - let act2 = ActivationLayer( - graph: graph, - sourceTensor: bias2.resultTensor, - activationKind: descriptor.act2) - - // Third and final matrix multiplication layer taking the output of the second activation layer. - let mul3 = MatMulLayer( - graph: graph, - descriptor: descriptor.mul3, - sourceTensor: act2.resultTensor) + private func runSingleInference( + batchIndex: Int, + spatialInput: UnsafeMutablePointer, + globalInput: UnsafeMutablePointer, + metaInput: UnsafeMutablePointer, + maskInput: UnsafeMutablePointer, + policy: UnsafeMutablePointer, + policyPass: UnsafeMutablePointer, + value: UnsafeMutablePointer, + scoreValue: UnsafeMutablePointer, + ownership: UnsafeMutablePointer + ) throws { + let spatialSize = Int(nnXLen) * Int(nnYLen) * numInputChannels + let spatialOffset = batchIndex * spatialSize + + // Create MLMultiArray for spatial input (1, C, H, W) + let spatialArray = try MLMultiArray( + shape: [1, NSNumber(value: numInputChannels), + NSNumber(value: nnYLen), NSNumber(value: nnXLen)], + dataType: .float32) + + // Copy spatial data using fast memcpy + let spatialPtr = spatialArray.dataPointer.assumingMemoryBound(to: Float32.self) + memcpy(spatialPtr, spatialInput.advanced(by: spatialOffset), spatialSize * MemoryLayout.size) + + // Create global input array (1, C) - rank 2 as expected by converter + let globalArray = try MLMultiArray( + shape: [1, NSNumber(value: numInputGlobalChannels)], + dataType: .float32) + let globalPtr = globalArray.dataPointer.assumingMemoryBound(to: Float32.self) + let globalOffset = batchIndex * numInputGlobalChannels + memcpy(globalPtr, globalInput.advanced(by: globalOffset), numInputGlobalChannels * MemoryLayout.size) + + // Build feature provider dictionary + var inputDict: [String: MLFeatureValue] = [ + IONames.spatialInput: MLFeatureValue(multiArray: spatialArray), + IONames.globalInput: MLFeatureValue(multiArray: globalArray) + ] - // Setting the final result tensor to the output of the last matrix multiplication layer. - resultTensor = mul3.resultTensor + // Add mask input (always required, even with optimize_identity_mask=True) + // When optimize_identity_mask=True, the mask is still required as input but + // internal mask operations are optimized away for ~6.5% speedup + let maskArray = try MLMultiArray( + shape: [1, 1, NSNumber(value: nnYLen), NSNumber(value: nnXLen)], + dataType: .float32) + let maskPtr = maskArray.dataPointer.assumingMemoryBound(to: Float32.self) + let maskSize = Int(nnXLen) * Int(nnYLen) + let maskOffset = batchIndex * maskSize + memcpy(maskPtr, maskInput.advanced(by: maskOffset), maskSize * MemoryLayout.size) + inputDict[IONames.inputMask] = MLFeatureValue(multiArray: maskArray) + + // Add meta input if model has it + if numInputMetaChannels > 0 { + let metaArray = try MLMultiArray( + shape: [1, NSNumber(value: numInputMetaChannels)], + dataType: .float32) + let metaPtr = metaArray.dataPointer.assumingMemoryBound(to: Float32.self) + let metaOffset = batchIndex * numInputMetaChannels + memcpy(metaPtr, metaInput.advanced(by: metaOffset), numInputMetaChannels * MemoryLayout.size) + inputDict[IONames.metaInput] = MLFeatureValue(multiArray: metaArray) + } - assert(resultTensor.shape?.count == 2) + // Run prediction + let featureProvider = try MLDictionaryFeatureProvider(dictionary: inputDict) + let prediction = try model.prediction(from: featureProvider) + + // Extract outputs and copy to output buffers + extractOutputs( + prediction: prediction, + batchIndex: batchIndex, + policy: policy, + policyPass: policyPass, + value: value, + scoreValue: scoreValue, + ownership: ownership + ) } -} -/// A class that describes a trunk for a neural network -public class SWTrunkDesc { - /// The version of the ResNet trunk - let version: Int - /// Number of channels for the trunk - let trunkNumChannels: NSNumber - /// Number of channels for the mid section - let midNumChannels: NSNumber - /// Number of channels for the regular section - let regularNumChannels: NSNumber - /// Number of channels for the global pooling section - let gpoolNumChannels: NSNumber - /// The description of the initial convolutional layer - let initialConv: SWConvLayerDesc - /// The description of the initial matrix multiplication layer - let initialMatMul: SWMatMulLayerDesc - /// The description of the SGF metadata encoder - let sgfMetadataEncoder: SWSGFMetadataEncoderDesc? - /// The list of blocks that make up the trunk - let blockDescriptors: [BlockDescriptor] - /// The description of the batch normalization layer that is applied at the end of the trunk - let trunkTipBN: SWBatchNormLayerDesc - /// The activation function that is applied at the end of the trunk - let trunkTipActivation: ActivationKind - - /// Initializes a SWTrunkDesc object - /// - Parameters: - /// - version: The version of the ResNet trunk - /// - trunkNumChannels: Number of channels for the trunk - /// - midNumChannels: Number of channels for the mid section - /// - regularNumChannels: Number of channels for the regular section - /// - gpoolNumChannels: Number of channels for the global pooling section - /// - initialConv: The description of the initial convolutional layer - /// - initialMatMul: The description of the initial matrix multiplication layer - /// - sgfMetadataEncoder: The description of the SGF metadata encoder - /// - blockDescriptors: The list of blocks that make up the trunk - /// - trunkTipBN: The description of the batch normalization layer that is applied at the end of the trunk - /// - trunkTipActivation: The activation function that is applied at the end of the trunk - init( - version: Int, - trunkNumChannels: NSNumber, - midNumChannels: NSNumber, - regularNumChannels: NSNumber, - gpoolNumChannels: NSNumber, - initialConv: SWConvLayerDesc, - initialMatMul: SWMatMulLayerDesc, - sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, - blockDescriptors: [BlockDescriptor], - trunkTipBN: SWBatchNormLayerDesc, - trunkTipActivation: ActivationKind + /// Copy MLMultiArray data to destination buffer, respecting strides. + /// Core ML may return non-contiguous arrays, especially for spatial outputs after GPU computation. + private func copyMultiArray( + _ array: MLMultiArray, + to dest: UnsafeMutablePointer, + destOffset: Int ) { - self.version = version - self.trunkNumChannels = trunkNumChannels - self.midNumChannels = midNumChannels - self.regularNumChannels = regularNumChannels - self.gpoolNumChannels = gpoolNumChannels - self.initialConv = initialConv - self.initialMatMul = initialMatMul - self.sgfMetadataEncoder = sgfMetadataEncoder - self.blockDescriptors = blockDescriptors - self.trunkTipBN = trunkTipBN - self.trunkTipActivation = trunkTipActivation - } -} - -public func createSWTrunkDesc( - version: Int32, - trunkNumChannels: Int32, - midNumChannels: Int32, - regularNumChannels: Int32, - gpoolNumChannels: Int32, - initialConv: SWConvLayerDesc, - initialMatMul: SWMatMulLayerDesc, - sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, - blockDescriptors: [BlockDescriptor], - trunkTipBN: SWBatchNormLayerDesc, - trunkTipActivation: ActivationKind -) -> SWTrunkDesc { - return SWTrunkDesc( - version: Int(version), - trunkNumChannels: trunkNumChannels as NSNumber, - midNumChannels: midNumChannels as NSNumber, - regularNumChannels: regularNumChannels as NSNumber, - gpoolNumChannels: gpoolNumChannels as NSNumber, - initialConv: initialConv, - initialMatMul: initialMatMul, - sgfMetadataEncoder: sgfMetadataEncoder, - blockDescriptors: blockDescriptors, - trunkTipBN: trunkTipBN, - trunkTipActivation: trunkTipActivation) -} - -/// A structure representing a ResNet trunk for a neural network -struct Trunk { - /// The resulting tensor after processing the trunk - let resultTensor: MPSGraphTensor - - /// Returns the block source tensor by processing the input meta tensor, if available, and adding a bias term. - /// - /// - Parameters: - /// - graph: The Metal Performance Shaders (MPS) graph. - /// - descriptor: The SGF metadata encoder descriptor. - /// - initialAdd: The initial add operation result tensor. - /// - inputMetaTensor: The input meta tensor. - /// - nnXLen: The X length of the neural network (NN). - /// - nnYLen: The Y length of the neural network (NN). - /// - numChannels: The number of channels of the initial add operation result tensor. - /// - /// - Returns: - /// - blockSourceTensor: The processed block source tensor. - /// - /// This function is used to get the block source tensor by processing the input meta tensor, if available. - /// If the input meta tensor is not available, it returns the result tensor from the initial add operation. - /// The function uses SGF metadata encoder and AddNCBiasLayer to process the input meta tensor. - static func getBlockSourceTensor( - graph: MPSGraph, - descriptor: SWSGFMetadataEncoderDesc?, - initialAdd: AddNCBiasLayer, - inputMetaTensor: MPSGraphTensor?, - nnXLen: NSNumber, - nnYLen: NSNumber, - numChannels: NSNumber - ) -> MPSGraphTensor { - var blockSourceTensor: MPSGraphTensor - - if let inputMetaTensor, - let descriptor, descriptor.numInputMetaChannels > 0 - { - let encoded = SGFMetadataEncoder( - graph: graph, - descriptor: descriptor, - sourceTensor: inputMetaTensor) + let shape = array.shape.map { $0.intValue } + let strides = array.strides.map { $0.intValue } + let ptr = array.dataPointer.assumingMemoryBound(to: Float32.self) + let totalElements = shape.reduce(1, *) + + // Check if contiguous (strides match expected for row-major C-order) + var isContiguous = true + var expectedStride = 1 + for i in (0...size) + } else { + // Slow path: copy with strides (handles non-contiguous layouts) + copyWithStrides( + from: ptr, + to: dest, + destOffset: destOffset, + shape: shape, + strides: strides, + dim: 0, + srcOffset: 0, + destIdx: 0 + ) + } + } - blockSourceTensor = encodedAdd.resultTensor + /// Recursively copy array elements respecting strides (NCHW order) + @discardableResult + private func copyWithStrides( + from src: UnsafePointer, + to dest: UnsafeMutablePointer, + destOffset: Int, + shape: [Int], + strides: [Int], + dim: Int, + srcOffset: Int, + destIdx: Int + ) -> Int { + var currentDestIdx = destIdx + + if dim == shape.count - 1 { + // Innermost dimension: copy elements + for i in 0.., + policyPass: UnsafeMutablePointer, + value: UnsafeMutablePointer, + scoreValue: UnsafeMutablePointer, + ownership: UnsafeMutablePointer ) { + // Extract policy output (1, policyChannels, H, W) + // Must use stride-aware copy as Core ML may return non-contiguous arrays + if let policyArray = prediction.featureValue(for: IONames.policyOutput)?.multiArrayValue { + let policyOffset = batchIndex * Int(nnXLen) * Int(nnYLen) * numPolicyChannels + copyMultiArray(policyArray, to: policy, destOffset: policyOffset) + } - let initialConv = ConvLayer( - graph: graph, - sourceTensor: inputTensor, - descriptor: descriptor.initialConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let initialMatMul = MatMulLayer( - graph: graph, - descriptor: descriptor.initialMatMul, - sourceTensor: inputGlobalTensor) - - let initialAdd = AddNCBiasLayer( - graph: graph, - sourceTensor: initialConv.resultTensor, - biasTensor: initialMatMul.resultTensor, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.initialMatMul.outChannels) - - let blockSourceTensor = Trunk.getBlockSourceTensor( - graph: graph, - descriptor: descriptor.sgfMetadataEncoder, - initialAdd: initialAdd, - inputMetaTensor: inputMetaTensor, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.initialMatMul.outChannels) - - let blocks = BlockStack( - graph: graph, - sourceTensor: blockSourceTensor, - maskTensor: maskTensor, - maskSumTensor: maskSumTensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, - blockDescriptors: descriptor.blockDescriptors, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let trunkTipBN = BatchNormLayer( - graph: graph, - sourceTensor: blocks.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.trunkTipBN, - nnXLen: nnXLen, - nnYLen: nnYLen) + // Extract policy pass output (1, numPolicyChannels) + if let passArray = prediction.featureValue(for: IONames.policyPassOutput)?.multiArrayValue { + let passOffset = batchIndex * numPolicyChannels + copyMultiArray(passArray, to: policyPass, destOffset: passOffset) + } - let trunkTipActivation = ActivationLayer( - graph: graph, - sourceTensor: trunkTipBN.resultTensor, - activationKind: descriptor.trunkTipActivation) + // Extract value output (1, 3) + if let valueArray = prediction.featureValue(for: IONames.valueOutput)?.multiArrayValue { + let valueOffset = batchIndex * numValueChannels + copyMultiArray(valueArray, to: value, destOffset: valueOffset) + } - resultTensor = trunkTipActivation.resultTensor + // Extract score value output (1, numScoreValueChannels) + if let svArray = prediction.featureValue(for: IONames.scoreValueOutput)?.multiArrayValue { + let svOffset = batchIndex * numScoreValueChannels + copyMultiArray(svArray, to: scoreValue, destOffset: svOffset) + } - assert(resultTensor.shape?.count == 4) + // Extract ownership output (1, 1, H, W) + // Must use stride-aware copy as Core ML may return non-contiguous arrays + if let ownArray = prediction.featureValue(for: IONames.ownershipOutput)?.multiArrayValue { + let ownOffset = batchIndex * Int(nnXLen) * Int(nnYLen) * numOwnershipChannels + copyMultiArray(ownArray, to: ownership, destOffset: ownOffset) + } } } -/// A class that describes a policy head for a neural network, responsible for predicting -/// the best moves for the current player and the opposing player on the subsequent turn. -public struct SWPolicyHeadDesc { - /// The version of the policy head - let version: Int - /// The 1x1 convolution layer for P - let p1Conv: SWConvLayerDesc - /// The 1x1 convolution layer for G - let g1Conv: SWConvLayerDesc - /// The batch normalization layer for G - let g1BN: SWBatchNormLayerDesc - /// The activation function for G - let g1Activation: ActivationKind - /// The global pooling bias structure that pools the output of G to bias the output of P - let gpoolToBiasMul: SWMatMulLayerDesc - /// The batch normalization layer for P - let p1BN: SWBatchNormLayerDesc - /// The activation function for P - let p1Activation: ActivationKind - /// The 1x1 convolution layer with 2 channels for outputting two policy distributions - let p2Conv: SWConvLayerDesc - /// The fully connected linear layer for outputting logits for the pass move - let gpoolToPassMul: SWMatMulLayerDesc - /// The description of the bias layer that is applied to the output of the matrix multiplication layer for model version >= 15 - let gpoolToPassBias: SWMatBiasLayerDesc? - /// The activation function for the bias layer in model version >= 15 - let passActivation: ActivationKind? - /// The fully connected linear layer for outputting logits for the pass move in model version >= 15 - let gpoolToPassMul2: SWMatMulLayerDesc? - - /// Initializes a SWPolicyHeadDesc object with the given parameters - /// - Parameters: - /// - version: The version of the policy head - /// - p1Conv: The 1x1 convolution layer for P - /// - g1Conv: The 1x1 convolution layer for G - /// - g1BN: The batch normalization layer for G - /// - g1Activation: The activation function for G - /// - gpoolToBiasMul: The global pooling bias structure that pools the output of G to bias the output of P - /// - p1BN: The batch normalization layer for P - /// - p1Activation: The activation function for P - /// - p2Conv: The 1x1 convolution layer with 2 channels for outputting two policy distributions - /// - gpoolToPassMul: The fully connected linear layer for outputting logits for the pass move - init( - version: Int, - p1Conv: SWConvLayerDesc, - g1Conv: SWConvLayerDesc, - g1BN: SWBatchNormLayerDesc, - g1Activation: ActivationKind, - gpoolToBiasMul: SWMatMulLayerDesc, - p1BN: SWBatchNormLayerDesc, - p1Activation: ActivationKind, - p2Conv: SWConvLayerDesc, - gpoolToPassMul: SWMatMulLayerDesc, - gpoolToPassBias: SWMatBiasLayerDesc?, - passActivation: ActivationKind?, - gpoolToPassMul2: SWMatMulLayerDesc? - ) { - self.version = version - self.p1Conv = p1Conv - self.g1Conv = g1Conv - self.g1BN = g1BN - self.g1Activation = g1Activation - self.gpoolToBiasMul = gpoolToBiasMul - self.p1BN = p1BN - self.p1Activation = p1Activation - self.p2Conv = p2Conv - self.gpoolToPassMul = gpoolToPassMul - self.gpoolToPassBias = gpoolToPassBias - self.passActivation = passActivation - self.gpoolToPassMul2 = gpoolToPassMul2 - - assert( - (version >= 15) - || ((gpoolToPassBias == nil) && (passActivation == nil) && (gpoolToPassMul2 == nil)) - ) - assert( - (version < 15) - || ((gpoolToPassBias != nil) && (passActivation != nil) && (gpoolToPassMul2 != nil)) - ) +/// Delete the source .mlpackage after compilation +/// CoreML caches the compiled model, so the source is no longer needed +private func deleteSourceModel(at url: URL, serverThreadIdx: Int) { + do { + try FileManager.default.removeItem(at: url) + printError("Metal backend \(serverThreadIdx): Deleted temp model") + } catch { + printError("Metal backend \(serverThreadIdx): Warning: Failed to delete temp model: \(error)") } } -public func createSWPolicyHeadDesc( - version: Int32, - p1Conv: SWConvLayerDesc, - g1Conv: SWConvLayerDesc, - g1BN: SWBatchNormLayerDesc, - g1Activation: ActivationKind, - gpoolToBiasMul: SWMatMulLayerDesc, - p1BN: SWBatchNormLayerDesc, - p1Activation: ActivationKind, - p2Conv: SWConvLayerDesc, - gpoolToPassMul: SWMatMulLayerDesc, - gpoolToPassBias: SWMatBiasLayerDesc, - passActivation: ActivationKind, - gpoolToPassMul2: SWMatMulLayerDesc -) -> SWPolicyHeadDesc { - if version >= 15 { - return SWPolicyHeadDesc( - version: Int(version), - p1Conv: p1Conv, - g1Conv: g1Conv, - g1BN: g1BN, - g1Activation: g1Activation, - gpoolToBiasMul: gpoolToBiasMul, - p1BN: p1BN, - p1Activation: p1Activation, - p2Conv: p2Conv, - gpoolToPassMul: gpoolToPassMul, - gpoolToPassBias: gpoolToPassBias, - passActivation: passActivation, - gpoolToPassMul2: gpoolToPassMul2) - } else { - return SWPolicyHeadDesc( - version: Int(version), - p1Conv: p1Conv, - g1Conv: g1Conv, - g1BN: g1BN, - g1Activation: g1Activation, - gpoolToBiasMul: gpoolToBiasMul, - p1BN: p1BN, - p1Activation: p1Activation, - p2Conv: p2Conv, - gpoolToPassMul: gpoolToPassMul, - gpoolToPassBias: nil, - passActivation: nil, - gpoolToPassMul2: nil) +/// Create compute handle - loads pre-converted Core ML model +/// Model conversion is now handled in C++ using the native katagocoreml library +public func createCoreMLComputeHandle( + coremlModelPath: String, + serverThreadIdx: Int, + requireExactNNLen: Bool, + numInputChannels: Int32, + numInputGlobalChannels: Int32, + numInputMetaChannels: Int32, + numPolicyChannels: Int32, + numValueChannels: Int32, + numScoreValueChannels: Int32, + numOwnershipChannels: Int32, + context: MetalComputeContext +) -> CoreMLComputeHandle? { + + let optimizeMask = requireExactNNLen // When true: skips internal mask operations (~6.5% speedup) + let mlpackagePath = URL(fileURLWithPath: coremlModelPath) + + // Ensure temp file is deleted regardless of success/failure + defer { deleteSourceModel(at: mlpackagePath, serverThreadIdx: serverThreadIdx) } + + // Load Core ML model (already converted by C++ katagocoreml library) + do { + let config = MLModelConfiguration() + config.computeUnits = .cpuAndNeuralEngine // Exclude GPU for hybrid mode + + printError("Metal backend \(serverThreadIdx): Compiling model...") + let compiledURL = try MLModel.compileModel(at: mlpackagePath) + + printError("Metal backend \(serverThreadIdx): Loading compiled model...") + let model = try MLModel(contentsOf: compiledURL, configuration: config) + + printError("Metal backend \(serverThreadIdx): Model loaded successfully, \(context.nnXLen)x\(context.nnYLen)") + + return CoreMLComputeHandle( + model: model, + nnXLen: context.nnXLen, + nnYLen: context.nnYLen, + optimizeIdentityMask: optimizeMask, + numInputChannels: Int(numInputChannels), + numInputGlobalChannels: Int(numInputGlobalChannels), + numInputMetaChannels: Int(numInputMetaChannels), + numPolicyChannels: Int(numPolicyChannels), + numValueChannels: Int(numValueChannels), + numScoreValueChannels: Int(numScoreValueChannels), + numOwnershipChannels: Int(numOwnershipChannels) + ) + } catch { + printError("Metal backend: Failed to load model: \(error)") + return nil } } -/// A structure that represents a policy head of a neural network. -struct PolicyHead { - /// The tensor that holds the policy prediction of the neural network - let policyTensor: MPSGraphTensor - /// The tensor that holds the policy pass of the neural network - let policyPassTensor: MPSGraphTensor - - /// Initializes a PolicyHead object - /// - Parameters: - /// - graph: The MPSGraph object to which the policy head is added - /// - descriptor: The description of the policy head - /// - sourceTensor: The input tensor to the policy head - /// - maskTensor: The mask tensor for the input tensor - /// - maskSumTensor: The sum of the mask tensor - /// - maskSumSqrtS14M01Tensor: The square root of the sum of the mask tensor and a small epsilon - /// - nnXLen: The number of X pixels in the input tensor - /// - nnYLen: The number of Y pixels in the input tensor - init( - graph: MPSGraph, - descriptor: SWPolicyHeadDesc, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor, - nnXLen: NSNumber, - nnYLen: NSNumber - ) { - - let p1Conv = ConvLayer( - graph: graph, - sourceTensor: sourceTensor, - descriptor: descriptor.p1Conv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let g1Conv = ConvLayer( - graph: graph, - sourceTensor: sourceTensor, - descriptor: descriptor.g1Conv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let g1BN = BatchNormLayer( - graph: graph, - sourceTensor: g1Conv.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.g1BN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let g1Activation = ActivationLayer( - graph: graph, - sourceTensor: g1BN.resultTensor, - activationKind: descriptor.g1Activation) - - let g1Concat = GlobalPoolingLayer( - graph: graph, - sourceTensor: g1Activation.resultTensor, - maskTensor: maskTensor, - maskSumTensor: maskSumTensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor) - - assert(g1Concat.resultTensor.shape?[1] == descriptor.gpoolToBiasMul.inChannels) - - let gpoolToBiasMul = MatMulLayer( - graph: graph, - descriptor: descriptor.gpoolToBiasMul, - sourceTensor: g1Concat.resultTensor) - - let added = AddNCBiasLayer( - graph: graph, - sourceTensor: p1Conv.resultTensor, - biasTensor: gpoolToBiasMul.resultTensor, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.gpoolToBiasMul.outChannels) - - let p1BN = BatchNormLayer( - graph: graph, - sourceTensor: added.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.p1BN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let p1Activation = ActivationLayer( - graph: graph, - sourceTensor: p1BN.resultTensor, - activationKind: descriptor.p1Activation) - - let p2Conv = ConvLayer( - graph: graph, - sourceTensor: p1Activation.resultTensor, - descriptor: descriptor.p2Conv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - policyTensor = p2Conv.resultTensor - - assert(g1Concat.resultTensor.shape?[1] == descriptor.gpoolToPassMul.inChannels) +/// Print available Core ML compute units +public func printMetalDevices() { + printError("Metal backend: Hybrid mode - CoreML (CPU+ANE) + MPSGraph (GPU)") +} - let gpoolToPassMul = MatMulLayer( - graph: graph, - descriptor: descriptor.gpoolToPassMul, - sourceTensor: g1Concat.resultTensor) +// MARK: - Throughput Tracker for Adaptive Batch Sizing - if let gpoolToPassBias = descriptor.gpoolToPassBias, - let passActivation = descriptor.passActivation, - let gpoolToPassMul2 = descriptor.gpoolToPassMul2 - { - assert(descriptor.version >= 15) +/// Tracks throughput for CoreML and MPSGraph paths to adaptively adjust batch split ratio. +/// +/// # Thread Safety +/// +/// This class is thread-safe by design without requiring explicit locks: +/// +/// 1. **Single-Owner Access**: Each server thread owns its own `ComputeHandle` → +/// `HybridComputeHandle` → `ThroughputTracker` instance. There is no sharing +/// of `ThroughputTracker` instances between server threads. +/// +/// 2. **Disjoint Field Access**: Within a single `HybridComputeHandle.apply()` call, +/// concurrent dispatch queues access disjoint fields: +/// - `coremlQueue.async` calls `updateCoreML()` → writes `coreMLSamplesPerSec`, `totalCoreMLSamples` +/// - `mpsGraphQueue.async` calls `updateMPSGraph()` → writes `mpsGraphSamplesPerSec`, `totalMPSGraphSamples` +/// +/// Both read `warmupComplete`, `stableAlpha`, and `warmupAlpha`, but these are either +/// `let` constants or only written sequentially after `group.wait()`. +/// +/// 3. **Sequential Barrier**: `group.wait()` in `apply()` ensures all concurrent throughput +/// updates complete before `recordBatch()`, `shouldLogAndMark()`, or `getDiagnosticStats()` +/// are called. These methods run sequentially on the calling thread. +/// +/// Because of these invariants, no locks are needed. Removing `NSLock` was intentional +/// as it was unnecessary overhead given the access patterns above. +public class ThroughputTracker { + private var coreMLSamplesPerSec: Double = 0.9 // Warm-start: initial ratio ~0.47 (closer to optimal ~0.45) + private var mpsGraphSamplesPerSec: Double = 1.0 + + // Diagnostic fields + private var batchCount: Int = 0 + private var totalCoreMLSamples: Int = 0 + private var totalMPSGraphSamples: Int = 0 + private var ratioHistory: [Float] = [] + private let maxHistorySize = 100 // Keep last 100 ratios for analysis + private var lastLogBatchCount: Int = 0 + private let logInterval: Int = 50 // Log every N batches + + // Adaptive alpha parameters + private var warmupComplete: Bool = false + private let warmupAlpha: Double = 0.25 // Faster adaptation during warmup + private let stableAlpha: Double = 0.10 // Slower adaptation after convergence + private let warmupBatches: Int = 100 // Min batches before checking warmup transition + private let warmupVarianceThreshold: Double = 0.005 // Variance threshold for warmup completion + + /// Update CoreML throughput measurement with adaptive alpha + public func updateCoreML(samples: Int, duration: TimeInterval) { + guard duration > 0, samples > 0 else { return } + let newRate = Double(samples) / duration + let effectiveAlpha = warmupComplete ? stableAlpha : warmupAlpha + coreMLSamplesPerSec = effectiveAlpha * newRate + (1 - effectiveAlpha) * coreMLSamplesPerSec + totalCoreMLSamples += samples + } - let gpoolToPassBiasLayer = MatBiasLayer( - graph: graph, - descriptor: gpoolToPassBias, - sourceTensor: gpoolToPassMul.resultTensor) + /// Update MPSGraph throughput measurement with adaptive alpha + public func updateMPSGraph(samples: Int, duration: TimeInterval) { + guard duration > 0, samples > 0 else { return } + let newRate = Double(samples) / duration + let effectiveAlpha = warmupComplete ? stableAlpha : warmupAlpha + mpsGraphSamplesPerSec = effectiveAlpha * newRate + (1 - effectiveAlpha) * mpsGraphSamplesPerSec + totalMPSGraphSamples += samples + } - let passActivationLayer = ActivationLayer( - graph: graph, - sourceTensor: gpoolToPassBiasLayer.resultTensor, - activationKind: passActivation) + /// Get optimal CoreML ratio (0.0 to 1.0) based on measured throughput + public func getOptimalCoreMLRatio() -> Float { + let total = coreMLSamplesPerSec + mpsGraphSamplesPerSec + return total > 0 ? Float(coreMLSamplesPerSec / total) : 0.5 + } - let gpoolToPassMul2Layer = MatMulLayer( - graph: graph, - descriptor: gpoolToPassMul2, - sourceTensor: passActivationLayer.resultTensor) + /// Get current throughput stats for logging + public func getStats() -> (coreML: Double, mpsGraph: Double, ratio: Float) { + return (coreMLSamplesPerSec, mpsGraphSamplesPerSec, getOptimalCoreMLRatio()) + } - policyPassTensor = gpoolToPassMul2Layer.resultTensor - } else { - assert(descriptor.version < 15) - policyPassTensor = gpoolToPassMul.resultTensor + /// Record a batch for diagnostics (call after each apply) + public func recordBatch(ratio: Float) { + batchCount += 1 + if ratioHistory.count >= maxHistorySize { + ratioHistory.removeFirst() + } + ratioHistory.append(ratio) + // Check warmup transition + if !warmupComplete && batchCount >= warmupBatches && computeRatioVariance() < Float(warmupVarianceThreshold) { + warmupComplete = true } - - assert(policyTensor.shape?.count == 4) - assert(policyPassTensor.shape?.count == 2) } -} -/// A struct that describes the value head of a neural network -public struct SWValueHeadDesc { - /// The version of the value head - let version: Int - /// The description of the first convolutional layer in the value head - let v1Conv: SWConvLayerDesc - /// The description of the batch normalization layer after the first convolutional layer in the value head - let v1BN: SWBatchNormLayerDesc - /// The activation function that is applied after the first batch normalization layer in the value head - let v1Activation: ActivationKind - /// The description of the matrix multiplication layer that is applied to the output of the first convolutional layer in the value head - let v2Mul: SWMatMulLayerDesc - /// The description of the bias layer that is applied to the output of the matrix multiplication layer in the value head - let v2Bias: SWMatBiasLayerDesc - /// The activation function that is applied after the bias layer in the value head - let v2Activation: ActivationKind - /// The description of the matrix multiplication layer that is applied to the output of the bias layer in the value head - let v3Mul: SWMatMulLayerDesc - /// The description of the bias layer that is applied to the output of the matrix multiplication layer in the value head - let v3Bias: SWMatBiasLayerDesc - /// The description of the matrix multiplication layer that is applied to the output of the third bias layer in the value head - let sv3Mul: SWMatMulLayerDesc - /// The description of the bias layer that is applied to the output of the matrix multiplication layer in the value head - let sv3Bias: SWMatBiasLayerDesc - /// The description of the convolutional layer that is applied to the board ownership map in the value head - let vOwnershipConv: SWConvLayerDesc - - /// Initializes a SWValueHeadDesc object - /// - Parameters: - /// - version: The version of the value head - /// - v1Conv: The description of the first convolutional layer in the value head - /// - v1BN: The description of the batch normalization layer after the first convolutional layer in the value head - /// - v1Activation: The activation function that is applied after the first batch normalization layer in the value head - /// - v2Mul: The description of the matrix multiplication layer that is applied to the output of the first convolutional layer in the value head - /// - v2Bias: The description of the bias layer that is applied to the output of the matrix multiplication layer in the value head - /// - v2Activation: The activation function that is applied after the bias layer in the value head - /// - v3Mul: The description of the matrix multiplication layer that is applied to the output of the bias layer in the value head - /// - v3Bias: The description of the bias layer that is applied to the output of the matrix multiplication layer in the value head - /// - sv3Mul: The description of the matrix multiplication layer that is applied to the output of the third bias layer in the value head - /// - sv3Bias: The description of the bias layer that is applied to the output of the matrix multiplication layer in the value head - /// - vOwnershipConv: The description of the convolutional layer that is applied to the board ownership map in the value head - init( - version: Int, - v1Conv: SWConvLayerDesc, - v1BN: SWBatchNormLayerDesc, - v1Activation: ActivationKind, - v2Mul: SWMatMulLayerDesc, - v2Bias: SWMatBiasLayerDesc, - v2Activation: ActivationKind, - v3Mul: SWMatMulLayerDesc, - v3Bias: SWMatBiasLayerDesc, - sv3Mul: SWMatMulLayerDesc, - sv3Bias: SWMatBiasLayerDesc, - vOwnershipConv: SWConvLayerDesc - ) { - self.version = version - self.v1Conv = v1Conv - self.v1BN = v1BN - self.v1Activation = v1Activation - self.v2Mul = v2Mul - self.v2Bias = v2Bias - self.v2Activation = v2Activation - self.v3Mul = v3Mul - self.v3Bias = v3Bias - self.sv3Mul = sv3Mul - self.sv3Bias = sv3Bias - self.vOwnershipConv = vOwnershipConv + /// Check if logging should occur this batch, and if so, mark as logged + /// Returns true if logging should occur (atomically checks and marks) + public func shouldLogAndMark() -> Bool { + if batchCount - lastLogBatchCount >= logInterval { + lastLogBatchCount = batchCount + return true + } + return false } -} - -public func createSWValueHeadDesc( - version: Int32, - v1Conv: SWConvLayerDesc, - v1BN: SWBatchNormLayerDesc, - v1Activation: ActivationKind, - v2Mul: SWMatMulLayerDesc, - v2Bias: SWMatBiasLayerDesc, - v2Activation: ActivationKind, - v3Mul: SWMatMulLayerDesc, - v3Bias: SWMatBiasLayerDesc, - sv3Mul: SWMatMulLayerDesc, - sv3Bias: SWMatBiasLayerDesc, - vOwnershipConv: SWConvLayerDesc -) -> SWValueHeadDesc { - return SWValueHeadDesc( - version: Int(version), - v1Conv: v1Conv, - v1BN: v1BN, - v1Activation: v1Activation, - v2Mul: v2Mul, - v2Bias: v2Bias, - v2Activation: v2Activation, - v3Mul: v3Mul, - v3Bias: v3Bias, - sv3Mul: sv3Mul, - sv3Bias: sv3Bias, - vOwnershipConv: vOwnershipConv) -} -/// A structure that creates a value head for the neural network, which produces the value, score value, and ownership tensors. -struct ValueHead { - /// The tensor that represents the value of the board - let valueTensor: MPSGraphTensor - /// The tensor that represents the score value of the board - let scoreValueTensor: MPSGraphTensor - /// The tensor that represents the ownership of the board - let ownershipTensor: MPSGraphTensor - - /// Initializes the value head using a graph, a descriptor, a source tensor, and other relevant tensors. - /// - Parameters: - /// - graph: The graph used to perform calculations on tensors - /// - descriptor: The SWValueHeadDesc object that describes the value head - /// - sourceTensor: The tensor used to source data to the neural network - /// - maskTensor: The tensor used to mask out invalid moves - /// - maskSumTensor: The tensor used to sum up the mask tensor values - /// - maskSumSqrtS14M01Tensor: The tensor used to calculate a square root value - /// - maskSumSqrtS14M01SquareS01Tensor: The tensor used to calculate a square value - /// - nnXLen: The x-axis length of the neural network - /// - nnYLen: The y-axis length of the neural network - init( - graph: MPSGraph, - descriptor: SWValueHeadDesc, - sourceTensor: MPSGraphTensor, - maskTensor: MPSGraphTensor, - maskSumTensor: MPSGraphTensor, - maskSumSqrtS14M01Tensor: MPSGraphTensor, - maskSumSqrtS14M01SquareS01Tensor: MPSGraphTensor, - nnXLen: NSNumber, - nnYLen: NSNumber + /// Get diagnostic stats for logging + public func getDiagnosticStats() -> ( + batchCount: Int, + coreMLSamplesPerSec: Double, + mpsGraphSamplesPerSec: Double, + ratio: Float, + totalCoreMLSamples: Int, + totalMPSGraphSamples: Int, + ratioVariance: Float ) { + return ( + batchCount, + coreMLSamplesPerSec, + mpsGraphSamplesPerSec, + getOptimalCoreMLRatio(), + totalCoreMLSamples, + totalMPSGraphSamples, + computeRatioVariance() + ) + } - let v1Conv = ConvLayer( - graph: graph, - sourceTensor: sourceTensor, - descriptor: descriptor.v1Conv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let v1BN = BatchNormLayer( - graph: graph, - sourceTensor: v1Conv.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.v1BN, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let v1Activation = ActivationLayer( - graph: graph, - sourceTensor: v1BN.resultTensor, - activationKind: descriptor.v1Activation) - - let v1Mean = - GlobalPoolingValueLayer( - graph: graph, - sourceTensor: v1Activation.resultTensor, - maskSumTensor: maskSumTensor, - maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, - maskSumSqrtS14M01SquareS01Tensor: maskSumSqrtS14M01SquareS01Tensor) - - assert(v1Mean.resultTensor.shape?[1] == descriptor.v2Mul.inChannels) - - let v2Mul = MatMulLayer( - graph: graph, - descriptor: descriptor.v2Mul, - sourceTensor: v1Mean.resultTensor) - - let v2Bias = MatBiasLayer( - graph: graph, - descriptor: descriptor.v2Bias, - sourceTensor: v2Mul.resultTensor) - - let v2Activation = ActivationLayer( - graph: graph, - sourceTensor: v2Bias.resultTensor, - activationKind: descriptor.v2Activation) - - let v3Mul = MatMulLayer( - graph: graph, - descriptor: descriptor.v3Mul, - sourceTensor: v2Activation.resultTensor) - - let v3Bias = MatBiasLayer( - graph: graph, - descriptor: descriptor.v3Bias, - sourceTensor: v3Mul.resultTensor) - - let sv3Mul = MatMulLayer( - graph: graph, - descriptor: descriptor.sv3Mul, - sourceTensor: v2Activation.resultTensor) - - let sv3Bias = MatBiasLayer( - graph: graph, - descriptor: descriptor.sv3Bias, - sourceTensor: sv3Mul.resultTensor) - - let vOwnershipConv = ConvLayer( - graph: graph, - sourceTensor: v1Activation.resultTensor, - descriptor: descriptor.vOwnershipConv, - nnXLen: nnXLen, - nnYLen: nnYLen) - - valueTensor = v3Bias.resultTensor - scoreValueTensor = sv3Bias.resultTensor - ownershipTensor = vOwnershipConv.resultTensor - - assert(valueTensor.shape?.count == 2) - assert(scoreValueTensor.shape?.count == 2) - assert(ownershipTensor.shape?.count == 4) + /// Compute variance of recent ratios + private func computeRatioVariance() -> Float { + guard ratioHistory.count >= 10 else { return 0.0 } + let recentRatios = Array(ratioHistory.suffix(20)) + let mean = recentRatios.reduce(0.0, +) / Float(recentRatios.count) + let variance = recentRatios.map { ($0 - mean) * ($0 - mean) }.reduce(0.0, +) / Float(recentRatios.count) + return variance } -} -/// A struct that describes a neural network model used for playing the game of Go. -public struct SWModelDesc { - /// The version of the model. - let version: Int - /// The name of the model. - let name: String - /// Number of channels for input features. - let numInputChannels: NSNumber - /// Number of channels for global input features. - let numInputGlobalChannels: NSNumber - /// Number of channels for meta input features. - let numInputMetaChannels: NSNumber - /// Number of channels for the value head output. - let numValueChannels: NSNumber - /// Number of channels for the score value head output. - let numScoreValueChannels: NSNumber - /// Number of channels for the ownership head output. - let numOwnershipChannels: NSNumber - /// The description of the trunk that makes up the backbone of the model. - let trunk: SWTrunkDesc - /// The description of the policy head that predicts the probability of playing at a particular position. - let policyHead: SWPolicyHeadDesc - /// The description of the value head that predicts the expected outcome of a game state. - let valueHead: SWValueHeadDesc - - /// Initializes an SWModelDesc object. - /// - Parameters: - /// - version: The version of the model. - /// - name: The name of the model. - /// - numInputChannels: Number of channels for input features. - /// - numInputGlobalChannels: Number of channels for global input features. - /// - numInputMetaChannels: Number of channels for meta input features. - /// - numValueChannels: Number of channels for the value head output. - /// - numScoreValueChannels: Number of channels for the score value head output. - /// - numOwnershipChannels: Number of channels for the ownership head output. - /// - trunk: The description of the trunk that makes up the backbone of the model. - /// - policyHead: The description of the policy head that predicts the probability of playing at a particular position. - /// - valueHead: The description of the value head that predicts the expected outcome of a game state. - init( - version: Int, - name: String, - numInputChannels: NSNumber, - numInputGlobalChannels: NSNumber, - numInputMetaChannels: NSNumber, - numValueChannels: NSNumber, - numScoreValueChannels: NSNumber, - numOwnershipChannels: NSNumber, - trunk: SWTrunkDesc, - policyHead: SWPolicyHeadDesc, - valueHead: SWValueHeadDesc - ) { - self.version = version - self.name = name - self.numInputChannels = numInputChannels - self.numInputGlobalChannels = numInputGlobalChannels - self.numInputMetaChannels = numInputMetaChannels - self.numValueChannels = numValueChannels - self.numScoreValueChannels = numScoreValueChannels - self.numOwnershipChannels = numOwnershipChannels - self.trunk = trunk - self.policyHead = policyHead - self.valueHead = valueHead + /// Check if ratio has converged (variance < threshold) + public func hasConverged(threshold: Float = 0.001) -> Bool { + let variance = computeRatioVariance() + return ratioHistory.count >= 20 && variance < threshold } } -public func createSWModelDesc( - version: Int32, - name: String, - numInputChannels: Int32, - numInputGlobalChannels: Int32, - numInputMetaChannels: Int32, - numValueChannels: Int32, - numScoreValueChannels: Int32, - numOwnershipChannels: Int32, - trunk: SWTrunkDesc, - policyHead: SWPolicyHeadDesc, - valueHead: SWValueHeadDesc -) -> SWModelDesc { - return SWModelDesc( - version: Int(version), - name: name, - numInputChannels: numInputChannels as NSNumber, - numInputGlobalChannels: numInputGlobalChannels as NSNumber, - numInputMetaChannels: numInputMetaChannels as NSNumber, - numValueChannels: numValueChannels as NSNumber, - numScoreValueChannels: numScoreValueChannels as NSNumber, - numOwnershipChannels: numOwnershipChannels as NSNumber, - trunk: trunk, - policyHead: policyHead, - valueHead: valueHead) -} +// MARK: - MPSGraph-based Model for GPU Inference -/// A structure representing a neural network model for processing Go game states. -struct Model { - /// The Metal device +/// GPU-based model using MPSGraph for inference +public class MPSGraphModelHandle { let device: MTLDevice - /// The command queue used to execute the graph on the GPU let commandQueue: MTLCommandQueue - /// The Metal Performance Shaders graph object used for building and executing the graph let graph: MPSGraph - /// The length of the neural network input in the x dimension - let nnXLen: NSNumber - /// The length of the neural network input in the y dimension - let nnYLen: NSNumber - /// The version of the model - let version: Int - /// The number of channels in the value output layer - let numValueChannels: NSNumber - /// The number of channels in the score value output layer - let numScoreValueChannels: NSNumber - /// The number of channels in the ownership output layer - let numOwnershipChannels: NSNumber - /// The input layer of the neural network + let nnXLen: Int32 + let nnYLen: Int32 + let numInputChannels: Int + let numInputGlobalChannels: Int + let numInputMetaChannels: Int + let numPolicyChannels: Int + let numValueChannels: Int + let numScoreValueChannels: Int + let numOwnershipChannels: Int + + // Layers let input: InputLayer - /// The global input layer of the neural network let inputGlobal: InputGlobalLayer - /// The meta input layer of the neural network let inputMeta: InputMetaLayer - /// The mask layer of the neural network let mask: MaskLayer - /// The trunk of the neural network let trunk: Trunk - /// The policy head of the neural network let policyHead: PolicyHead - /// The value head of the neural network let valueHead: ValueHead - /// The dictionary that maps the output tensors to the tensor data let targetTensors: [MPSGraphTensor] - /// Initializes a Model object. - /// - Parameters: - /// - device: The Metal device to use for computations. - /// - graph: The Metal Performance Shaders graph object used for building and executing the graph. - /// - descriptor: The description of the model. - /// - nnXLen: The length of the neural network input in the x dimension. - /// - nnYLen: The length of the neural network input in the y dimension. - init( - device: MTLDevice, - graph: MPSGraph, - descriptor: SWModelDesc, - nnXLen: NSNumber, - nnYLen: NSNumber + public init?( + modelDesc: SWModelDesc, + nnXLen: Int32, + nnYLen: Int32, + optimizeIdentityMask: Bool = false ) { + guard let device = MTLCreateSystemDefaultDevice() else { + printError("Metal backend: Failed to create Metal device") + return nil + } + self.device = device - self.commandQueue = device.makeCommandQueue()! - self.graph = graph + guard let queue = device.makeCommandQueue() else { + printError("Metal backend: Failed to create command queue") + return nil + } + self.commandQueue = queue + self.graph = MPSGraph() self.nnXLen = nnXLen self.nnYLen = nnYLen - self.version = descriptor.version - self.numValueChannels = descriptor.numValueChannels - self.numScoreValueChannels = descriptor.numScoreValueChannels - self.numOwnershipChannels = descriptor.numOwnershipChannels + self.numInputChannels = modelDesc.numInputChannels.intValue + self.numInputGlobalChannels = modelDesc.numInputGlobalChannels.intValue + self.numInputMetaChannels = modelDesc.numInputMetaChannels.intValue + self.numPolicyChannels = modelDesc.numPolicyChannels.intValue + self.numValueChannels = modelDesc.numValueChannels.intValue + self.numScoreValueChannels = modelDesc.numScoreValueChannels.intValue + self.numOwnershipChannels = modelDesc.numOwnershipChannels.intValue + + let nnXLenNS = nnXLen as NSNumber + let nnYLenNS = nnYLen as NSNumber input = InputLayer( graph: graph, - nnXLen: nnXLen, - nnYLen: nnYLen, - numChannels: descriptor.numInputChannels) + nnXLen: nnXLenNS, + nnYLen: nnYLenNS, + numChannels: modelDesc.numInputChannels) inputGlobal = InputGlobalLayer( graph: graph, - numGlobalFeatures: descriptor.numInputGlobalChannels) + numGlobalFeatures: modelDesc.numInputGlobalChannels) inputMeta = InputMetaLayer( graph: graph, - numMetaFeatures: descriptor.numInputMetaChannels) + numMetaFeatures: modelDesc.numInputMetaChannels) mask = MaskLayer( graph: graph, - nnXLen: nnXLen, - nnYLen: nnYLen) - - let maskSum = MaskSumLayer( - graph: graph, - maskTensor: mask.tensor) + nnXLen: nnXLenNS, + nnYLen: nnYLenNS) - let maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer( - graph: graph, - maskSum: maskSum) + // Use constant tensors when mask is all 1s (requireExactNNLen=true) + let maskSum: MaskSumLayer + let maskSumSqrtS14M01: MaskSumSqrtS14M01Layer + let maskSumSqrtS14M01SquareS01: MaskSumSqrtS14M01SquareS01Layer - let maskSumSqrtS14M01SquareS01 = MaskSumSqrtS14M01SquareS01Layer( - graph: graph, - maskSumSqrtS14M01: maskSumSqrtS14M01) + if optimizeIdentityMask { + maskSum = MaskSumLayer( + graph: graph, + nnXLen: nnXLenNS, + nnYLen: nnYLenNS) + maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer( + graph: graph, + nnXLen: nnXLenNS, + nnYLen: nnYLenNS) + maskSumSqrtS14M01SquareS01 = MaskSumSqrtS14M01SquareS01Layer( + graph: graph, + nnXLen: nnXLenNS, + nnYLen: nnYLenNS) + } else { + maskSum = MaskSumLayer( + graph: graph, + maskTensor: mask.tensor) + maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer( + graph: graph, + maskSum: maskSum) + maskSumSqrtS14M01SquareS01 = MaskSumSqrtS14M01SquareS01Layer( + graph: graph, + maskSumSqrtS14M01: maskSumSqrtS14M01) + } trunk = Trunk( graph: graph, - descriptor: descriptor.trunk, + descriptor: modelDesc.trunk, inputTensor: input.tensor, inputGlobalTensor: inputGlobal.tensor, inputMetaTensor: inputMeta.tensor, maskTensor: mask.tensor, maskSumTensor: maskSum.tensor, maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, - nnXLen: nnXLen, - nnYLen: nnYLen) + nnXLen: nnXLenNS, + nnYLen: nnYLenNS, + optimizeIdentityMask: optimizeIdentityMask) policyHead = PolicyHead( graph: graph, - descriptor: descriptor.policyHead, + descriptor: modelDesc.policyHead, sourceTensor: trunk.resultTensor, maskTensor: mask.tensor, maskSumTensor: maskSum.tensor, maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, - nnXLen: nnXLen, - nnYLen: nnYLen) + nnXLen: nnXLenNS, + nnYLen: nnYLenNS, + optimizeIdentityMask: optimizeIdentityMask) valueHead = ValueHead( graph: graph, - descriptor: descriptor.valueHead, + descriptor: modelDesc.valueHead, sourceTensor: trunk.resultTensor, maskTensor: mask.tensor, maskSumTensor: maskSum.tensor, maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, maskSumSqrtS14M01SquareS01Tensor: maskSumSqrtS14M01SquareS01.tensor, - nnXLen: nnXLen, - nnYLen: nnYLen) + nnXLen: nnXLenNS, + nnYLen: nnYLenNS, + optimizeIdentityMask: optimizeIdentityMask) targetTensors = [ policyHead.policyTensor, @@ -3035,20 +684,12 @@ struct Model { valueHead.scoreValueTensor, valueHead.ownershipTensor, ] + + printError("Metal backend: MPSGraph initialized on \(device.name)\(optimizeIdentityMask ? " (mask optimized)" : "")") } - /// Applies the model to the given input data, and generates predictions for policy, value and ownership - /// - Parameters: - /// - inputPointer: UnsafeMutablePointer to a flattened 2D array of floats representing the input state - /// - inputGlobalPointer: UnsafeMutablePointer to a flattened array of floats representing global state features - /// - inputMetaPointer: UnsafeMutablePointer to a flattened array of floats representing the metadata - /// - policy: UnsafeMutablePointer to a flattened 2D array of floats representing predicted policy - /// - policyPass: UnsafeMutablePointer to a flattened array of floats representing predicted probability of passing - /// - value: UnsafeMutablePointer to a flattened array of floats representing predicted value - /// - scoreValue: UnsafeMutablePointer to a flattened array of floats representing predicted score value - /// - ownership: UnsafeMutablePointer to a flattened 2D array of floats representing predicted ownership - /// - batchSize: The batch size - func apply( + /// Run inference on a batch using MPSGraph (GPU) + public func apply( input inputPointer: UnsafeMutablePointer, inputGlobal inputGlobalPointer: UnsafeMutablePointer, inputMeta inputMetaPointer: UnsafeMutablePointer, @@ -3059,15 +700,16 @@ struct Model { ownership: UnsafeMutablePointer, batchSize: Int ) { - let channelAxis = InputShape.getChannelAxis() let numInputChannels = input.shape[channelAxis] + let nnXLenNS = nnXLen as NSNumber + let nnYLenNS = nnYLen as NSNumber let inputShape = InputShape.create( batchSize: batchSize as NSNumber, numChannels: numInputChannels, - nnYLen: nnYLen, - nnXLen: nnXLen) + nnYLen: nnYLenNS, + nnXLen: nnXLenNS) let inputDescriptor = MPSNDArrayDescriptor( dataType: input.tensor.dataType, @@ -3118,8 +760,8 @@ struct Model { let maskShape = InputShape.create( batchSize: batchSize as NSNumber, numChannels: 1, - nnYLen: nnYLen, - nnXLen: nnXLen) + nnYLen: nnYLenNS, + nnXLen: nnXLenNS) let maskDescriptor = MPSNDArrayDescriptor( dataType: mask.tensor.dataType, @@ -3129,12 +771,12 @@ struct Model { device: device, descriptor: maskDescriptor) + // Extract mask from first channel of spatial input var maskStrideArray = [ MemoryLayout.size, - nnXLen.intValue * MemoryLayout.size, - nnYLen.intValue * nnXLen.intValue * MemoryLayout.size, - numInputChannels.intValue * nnYLen.intValue * nnXLen.intValue - * MemoryLayout.size, + Int(nnXLen) * MemoryLayout.size, + Int(nnYLen) * Int(nnXLen) * MemoryLayout.size, + numInputChannels.intValue * Int(nnYLen) * Int(nnXLen) * MemoryLayout.size, ] maskArray.writeBytes(inputPointer, strideBytes: &maskStrideArray) @@ -3152,12 +794,6 @@ struct Model { targetTensors: targetTensors, targetOperations: nil) - assert(fetch[policyHead.policyTensor] != nil) - assert(fetch[policyHead.policyPassTensor] != nil) - assert(fetch[valueHead.valueTensor] != nil) - assert(fetch[valueHead.scoreValueTensor] != nil) - assert(fetch[valueHead.ownershipTensor] != nil) - fetch[policyHead.policyTensor]?.mpsndarray().readBytes(policy) fetch[policyHead.policyPassTensor]?.mpsndarray().readBytes(policyPass) fetch[valueHead.valueTensor]?.mpsndarray().readBytes(value) @@ -3166,52 +802,48 @@ struct Model { } } -// A enum to represent enabled/disabled/auto option of a feature. -public enum SWEnable { - case False - case True - case Auto -} - -/// A class that represents context of GPU devices. -public class MetalComputeContext { - public let nnXLen: Int32 - public let nnYLen: Int32 +// MARK: - Hybrid Compute Handle - /// Initialize a context. - /// - Parameters: - /// - nnXLen: The width of the input tensor. - /// - nnYLen: The height of the input tensor. - init( - nnXLen: Int32, - nnYLen: Int32 - ) { - self.nnXLen = nnXLen - self.nnYLen = nnYLen +/// Global flag to enable/disable diagnostic logging (set via environment variable) +private let diagnosticLoggingEnabled: Bool = { + if let envValue = ProcessInfo.processInfo.environment["KATAGO_HYBRID_DIAG"] { + return envValue.lowercased() == "1" || envValue.lowercased() == "true" } -} - -public func createMetalComputeContext( - nnXLen: Int32, - nnYLen: Int32 -) -> MetalComputeContext { - return MetalComputeContext( - nnXLen: nnXLen, - nnYLen: nnYLen) -} - -/// A class that represents a handle of GPU device. -public class MetalComputeHandle { - let model: Model - - init(model: Model) { - self.model = model + return false +}() + +/// Hybrid compute handle that dispatches to both CoreML (CPU+ANE) and MPSGraph (GPU) +public class HybridComputeHandle { + let coremlHandle: CoreMLComputeHandle + let mpsGraphHandle: MPSGraphModelHandle + let throughputTracker: ThroughputTracker + let coremlQueue: DispatchQueue + let mpsGraphQueue: DispatchQueue + let nnXLen: Int32 + let nnYLen: Int32 + let serverThreadIdx: Int + + public init( + coremlHandle: CoreMLComputeHandle, + mpsGraphHandle: MPSGraphModelHandle, + serverThreadIdx: Int = 0 + ) { + self.coremlHandle = coremlHandle + self.mpsGraphHandle = mpsGraphHandle + self.serverThreadIdx = serverThreadIdx + self.throughputTracker = ThroughputTracker() + self.coremlQueue = DispatchQueue(label: "com.katago.coreml", qos: .userInitiated) + self.mpsGraphQueue = DispatchQueue(label: "com.katago.mpsgraph", qos: .userInitiated) + self.nnXLen = coremlHandle.nnXLen + self.nnYLen = coremlHandle.nnYLen } + /// Run hybrid inference - splits batch between CoreML and MPSGraph public func apply( - input inputPointer: UnsafeMutablePointer, - inputGlobal inputGlobalPointer: UnsafeMutablePointer, - inputMeta inputMetaPointer: UnsafeMutablePointer, + spatialInput: UnsafeMutablePointer, + globalInput: UnsafeMutablePointer, + metaInput: UnsafeMutablePointer, + maskInput: UnsafeMutablePointer, policy: UnsafeMutablePointer, policyPass: UnsafeMutablePointer, value: UnsafeMutablePointer, @@ -3219,48 +851,196 @@ public class MetalComputeHandle { ownership: UnsafeMutablePointer, batchSize: Int ) { - autoreleasepool { - model.apply( - input: inputPointer, - inputGlobal: inputGlobalPointer, - inputMeta: inputMetaPointer, - policy: policy, - policyPass: policyPass, - value: value, - scoreValue: scoreValue, - ownership: ownership, - batchSize: batchSize) + // Get optimal split ratio based on throughput + let ratio = throughputTracker.getOptimalCoreMLRatio() + // Prefer MPSGraph over CoreML for batch size 1, as MPSGraph is more stable + let coreMLBatchSize = max(0, min(batchSize - 1, Int(Float(batchSize) * ratio))) + let mpsGraphBatchSize = batchSize - coreMLBatchSize + + // Calculate buffer offsets + let spatialSize = Int(nnXLen) * Int(nnYLen) * coremlHandle.numInputChannels + let globalSize = coremlHandle.numInputGlobalChannels + let metaSize = coremlHandle.numInputMetaChannels + let policySize = Int(nnXLen) * Int(nnYLen) * coremlHandle.numPolicyChannels + let policyPassSize = coremlHandle.numPolicyChannels // Non-spatial pass output + let valueSize = coremlHandle.numValueChannels + let scoreValueSize = coremlHandle.numScoreValueChannels + let ownershipSize = Int(nnXLen) * Int(nnYLen) * coremlHandle.numOwnershipChannels + + #if DEBUG + // Verify batch split ensures non-overlapping buffer access + // CoreML writes [0, coreMLBatchSize), MPSGraph writes [coreMLBatchSize, batchSize) + assert(coreMLBatchSize >= 0 && mpsGraphBatchSize >= 0, "Batch sizes must be non-negative") + assert(coreMLBatchSize + mpsGraphBatchSize == batchSize, "Batch split must sum to total") + #endif + + let group = DispatchGroup() + + // CoreML path (CPU + ANE) + if coreMLBatchSize > 0 { + group.enter() + coremlQueue.async { [self] in + let start = CFAbsoluteTimeGetCurrent() + + autoreleasepool { + coremlHandle.apply( + spatialInput: spatialInput, + globalInput: globalInput, + metaInput: metaInput, + maskInput: maskInput, + policy: policy, + policyPass: policyPass, + value: value, + scoreValue: scoreValue, + ownership: ownership, + batchSize: coreMLBatchSize + ) + } + + let duration = CFAbsoluteTimeGetCurrent() - start + throughputTracker.updateCoreML(samples: coreMLBatchSize, duration: duration) + group.leave() + } + } + + // MPSGraph path (GPU) + if mpsGraphBatchSize > 0 { + group.enter() + mpsGraphQueue.async { [self] in + let start = CFAbsoluteTimeGetCurrent() + + // Offset pointers for MPSGraph batch portion + let spatialOffset = coreMLBatchSize * spatialSize + let globalOffset = coreMLBatchSize * globalSize + let metaOffset = coreMLBatchSize * metaSize + let policyOffset = coreMLBatchSize * policySize + let policyPassOffset = coreMLBatchSize * policyPassSize + let valueOffset = coreMLBatchSize * valueSize + let scoreValueOffset = coreMLBatchSize * scoreValueSize + let ownershipOffset = coreMLBatchSize * ownershipSize + + autoreleasepool { + mpsGraphHandle.apply( + input: spatialInput.advanced(by: spatialOffset), + inputGlobal: globalInput.advanced(by: globalOffset), + inputMeta: metaInput.advanced(by: metaOffset), + policy: policy.advanced(by: policyOffset), + policyPass: policyPass.advanced(by: policyPassOffset), + value: value.advanced(by: valueOffset), + scoreValue: scoreValue.advanced(by: scoreValueOffset), + ownership: ownership.advanced(by: ownershipOffset), + batchSize: mpsGraphBatchSize + ) + } + + let duration = CFAbsoluteTimeGetCurrent() - start + throughputTracker.updateMPSGraph(samples: mpsGraphBatchSize, duration: duration) + group.leave() + } + } + + // Wait for both paths to complete + group.wait() + + // Record batch for diagnostics + throughputTracker.recordBatch(ratio: ratio) + + // Periodic diagnostic logging + if diagnosticLoggingEnabled && throughputTracker.shouldLogAndMark() { + let stats = throughputTracker.getDiagnosticStats() + let converged = throughputTracker.hasConverged() + print(String(format: "[HybridDiag T%d] batch=%d ratio=%.3f coreml=%.1f/s mps=%.1f/s total=%d/%d var=%.5f conv=%@", + serverThreadIdx, + stats.batchCount, + stats.ratio, + stats.coreMLSamplesPerSec, + stats.mpsGraphSamplesPerSec, + stats.totalCoreMLSamples, + stats.totalMPSGraphSamples, + stats.ratioVariance, + converged ? "yes" : "no")) } } } -public func maybeCreateMetalComputeHandle( - condition: Bool, - serverThreadIdx: Int = 0, - descriptor: SWModelDesc, +/// Create a hybrid compute handle +public func createHybridComputeHandle( + coremlModelPath: String, + modelDesc: SWModelDesc, + serverThreadIdx: Int, + requireExactNNLen: Bool, + numInputChannels: Int32, + numInputGlobalChannels: Int32, + numInputMetaChannels: Int32, + numPolicyChannels: Int32, + numValueChannels: Int32, + numScoreValueChannels: Int32, + numOwnershipChannels: Int32, context: MetalComputeContext -) -> MetalComputeHandle? { - guard condition else { return nil } +) -> HybridComputeHandle? { + + // Create CoreML handle (CPU + ANE) + guard let coremlHandle = createCoreMLComputeHandle( + coremlModelPath: coremlModelPath, + serverThreadIdx: serverThreadIdx, + requireExactNNLen: requireExactNNLen, + numInputChannels: numInputChannels, + numInputGlobalChannels: numInputGlobalChannels, + numInputMetaChannels: numInputMetaChannels, + numPolicyChannels: numPolicyChannels, + numValueChannels: numValueChannels, + numScoreValueChannels: numScoreValueChannels, + numOwnershipChannels: numOwnershipChannels, + context: context + ) else { + printError("Metal backend \(serverThreadIdx): Failed to create CoreML handle") + return nil + } - let device = MTLCreateSystemDefaultDevice()! + // Create MPSGraph handle (GPU) + guard let mpsGraphHandle = MPSGraphModelHandle( + modelDesc: modelDesc, + nnXLen: context.nnXLen, + nnYLen: context.nnYLen, + optimizeIdentityMask: requireExactNNLen + ) else { + printError("Metal backend \(serverThreadIdx): Failed to create MPSGraph handle") + printError("Metal backend \(serverThreadIdx): CoreML handle will be released") + return nil + } - let model = Model( - device: device, - graph: MPSGraph(), - descriptor: descriptor, - nnXLen: context.nnXLen as NSNumber, - nnYLen: context.nnYLen as NSNumber) + printError("Metal backend \(serverThreadIdx): Initialized CoreML (CPU+ANE) + MPSGraph (GPU)") - let handle = MetalComputeHandle(model: model) + // Log if diagnostic mode is enabled + if diagnosticLoggingEnabled { + printError("Metal backend \(serverThreadIdx): Diagnostic logging enabled (KATAGO_HYBRID_DIAG=1)") + } - printError( - "Metal backend \(serverThreadIdx): \(device.name), Model version \(descriptor.version) \(descriptor.name), \(context.nnXLen)x\(context.nnYLen)" + return HybridComputeHandle( + coremlHandle: coremlHandle, + mpsGraphHandle: mpsGraphHandle, + serverThreadIdx: serverThreadIdx ) - - return handle } -public func printMetalDevices() { - let device = MTLCreateSystemDefaultDevice()! - printError("Found Metal Device: \(device.name)") +/// Create a GPU-only compute handle using MPSGraph +/// Used when useFP16=false to avoid slow FP32 CoreML execution on CPU+ANE +public func createMPSGraphOnlyHandle( + modelDesc: SWModelDesc, + serverThreadIdx: Int, + requireExactNNLen: Bool, + context: MetalComputeContext +) -> MPSGraphModelHandle? { + guard let mpsGraphHandle = MPSGraphModelHandle( + modelDesc: modelDesc, + nnXLen: context.nnXLen, + nnYLen: context.nnYLen, + optimizeIdentityMask: requireExactNNLen + ) else { + printError("Metal backend \(serverThreadIdx): Failed to create MPSGraph handle") + return nil + } + + printError("Metal backend \(serverThreadIdx): Initialized MPSGraph GPU-only mode") + return mpsGraphHandle } diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift new file mode 100644 index 000000000..9e3327c32 --- /dev/null +++ b/cpp/neuralnet/metallayers.swift @@ -0,0 +1,2914 @@ +// MPSGraph layer implementations shared between Metal and CoreML backends +// Extracted from metalbackend.swift to enable hybrid CoreML + MPSGraph execution + +import Foundation +import MetalPerformanceShaders +import MetalPerformanceShadersGraph + +// MARK: - Helper Extensions + +/// An extension to the Data struct for handling float data with optional FP16 conversion. +extension Data { + /// Initializes a new Data instance using an UnsafeMutablePointer, with optional conversion to FP16 format. + init( + floatsNoCopy: UnsafeMutablePointer, + shape: [NSNumber] + ) { + self.init( + bytesNoCopy: floatsNoCopy, + count: shape.countBytesOfFloat32(), + deallocator: .none) + } +} + +/// Extension to MPSNDArray to convert from MPSGraphTensor, and to read/write bytes from/to UnsafeMutableRawPointer +extension MPSNDArray { + /// Read bytes from the buffer + func readBytes(_ buffer: UnsafeMutableRawPointer) { + self.readBytes(buffer, strideBytes: nil) + } + + /// Write bytes to the buffer + func writeBytes(_ buffer: UnsafeMutableRawPointer) { + self.writeBytes(buffer, strideBytes: nil) + } +} + +/// Extension to Array to count number of elements and bytes +extension Array where Element == NSNumber { + /// Count number of elements + func countElements() -> Int { + return reduce(1, { $0 * $1.intValue }) + } + + /// Count number of bytes + func countBytesOfFloat32() -> Int { + return countElements() * MemoryLayout.size + } +} + +/// Extension to MPSGraph to the mish activation function +extension MPSGraph { + /// Mish activation: x * tanh(softplus(x)) + func mish(tensor: MPSGraphTensor) -> MPSGraphTensor { + assert(tensor.dataType == .float32) + + let one = 1.0 + let threshold = 20.0 + let thresholdTensor = constant(threshold, dataType: tensor.dataType) + let minimumTensor = minimum(tensor, thresholdTensor, name: nil) + let expTensor = exponent(with: minimumTensor, name: nil) + let oneTensor = constant(one, dataType: tensor.dataType) + let addTensor = addition(expTensor, oneTensor, name: nil) + let logTensor = logarithm(with: addTensor, name: nil) + let lessTensor = lessThan(tensor, thresholdTensor, name: nil) + let selectTensor = select( + predicate: lessTensor, trueTensor: logTensor, falseTensor: tensor, name: nil) + let tanhTensor = tanh(with: selectTensor, name: nil) + let mulTensor = multiplication(tensor, tanhTensor, name: nil) + + return mulTensor + } +} + +// MARK: - Input Shape Utilities + +/// A structure that represents the input shape (internal - not exposed to C++) +struct InputShape { + /// Create a shape for the input tensor + static func create( + batchSize: NSNumber, + numChannels: NSNumber, + nnYLen: NSNumber, + nnXLen: NSNumber + ) -> [NSNumber] { + return [batchSize, numChannels, nnYLen, nnXLen] + } + + /// Get the channel axis + static func getChannelAxis() -> Int { + return 1 + } + + /// Get the HW axes + static func getHWAxes() -> [NSNumber] { + return [2, 3] as [NSNumber] + } +} + +// MARK: - Input Layers + +/// A structure that represents the input layer +struct InputLayer { + let tensor: MPSGraphTensor + let shape: [NSNumber] + + init( + graph: MPSGraph, + nnXLen: NSNumber, + nnYLen: NSNumber, + numChannels: NSNumber, + dataType: MPSDataType = .float32 + ) { + shape = InputShape.create( + batchSize: -1, + numChannels: numChannels, + nnYLen: nnYLen, + nnXLen: nnXLen) + + self.tensor = graph.placeholder( + shape: shape, + dataType: dataType, + name: nil) + + assert(self.tensor.shape?.count == 4) + } +} + +/// A structure that represents an input global layer for a neural network model. +struct InputGlobalLayer { + let tensor: MPSGraphTensor + let shape: [NSNumber] + + init( + graph: MPSGraph, + numGlobalFeatures: NSNumber, + dataType: MPSDataType = .float32 + ) { + shape = InputShape.create( + batchSize: -1, + numChannels: numGlobalFeatures, + nnYLen: 1, + nnXLen: 1) + + self.tensor = graph.placeholder( + shape: shape, + dataType: dataType, + name: nil) + + assert(self.tensor.shape?.count == 4) + } +} + +/// A structure representing the input meta layer for a neural network graph. +struct InputMetaLayer { + let tensor: MPSGraphTensor + let shape: [NSNumber] + + init( + graph: MPSGraph, + numMetaFeatures: NSNumber, + dataType: MPSDataType = .float32 + ) { + shape = InputShape.create( + batchSize: -1, + numChannels: numMetaFeatures, + nnYLen: 1, + nnXLen: 1) + + self.tensor = graph.placeholder( + shape: shape, + dataType: dataType, + name: nil) + } +} + +/// A structure that represents a mask layer for a neural network model. +struct MaskLayer { + let tensor: MPSGraphTensor + let shape: [NSNumber] + + init( + graph: MPSGraph, + nnXLen: NSNumber, + nnYLen: NSNumber, + dataType: MPSDataType = .float32 + ) { + shape = InputShape.create( + batchSize: -1, + numChannels: 1, + nnYLen: nnYLen, + nnXLen: nnXLen) + + self.tensor = graph.placeholder( + shape: shape, + dataType: dataType, + name: nil) + + assert(self.tensor.shape?.count == 4) + } +} + +// MARK: - Mask Processing Layers + +/// A structure that represents a layer which performs the summation operation on a mask layer. +struct MaskSumLayer { + let tensor: MPSGraphTensor + + init(tensor: MPSGraphTensor) { + self.tensor = tensor + assert(self.tensor.shape?.count == 4) + } + + init( + graph: MPSGraph, + maskTensor: MPSGraphTensor + ) { + let hwAxes = InputShape.getHWAxes() + + self.tensor = graph.reductionSum( + with: maskTensor, + axes: hwAxes, + name: nil) + + assert(self.tensor.shape?.count == 4) + } + + /// Optimized init for when mask is all 1s (requireExactNNLen=true) + /// Returns constant tensor with boardSize value + init( + graph: MPSGraph, + nnXLen: NSNumber, + nnYLen: NSNumber, + dataType: MPSDataType = .float32 + ) { + let boardSize = Double(nnXLen.intValue * nnYLen.intValue) + self.tensor = graph.constant( + boardSize, + shape: [1, 1, 1, 1], + dataType: dataType) + + assert(self.tensor.shape?.count == 4) + } +} + +/// A structure that represents sqrt(maskSum) * 0.1 - 1.4 +struct MaskSumSqrtS14M01Layer { + let tensor: MPSGraphTensor + + init(tensor: MPSGraphTensor) { + self.tensor = tensor + assert(self.tensor.shape?.count == 4) + } + + init( + graph: MPSGraph, + maskSum: MaskSumLayer + ) { + let sqrtMaskSum = graph.squareRoot(with: maskSum.tensor, name: nil) + + let fourTeen = graph.constant( + 14.0, + shape: [1], + dataType: maskSum.tensor.dataType) + + let subtracted = graph.subtraction(sqrtMaskSum, fourTeen, name: nil) + + let zeroPointone = graph.constant( + 0.1, + shape: [1], + dataType: maskSum.tensor.dataType) + + self.tensor = graph.multiplication( + subtracted, + zeroPointone, + name: nil) + + assert(self.tensor.shape?.count == 4) + } + + /// Optimized init for when mask is all 1s (requireExactNNLen=true) + /// Returns constant tensor: (sqrt(boardSize) - 14) * 0.1 + init( + graph: MPSGraph, + nnXLen: NSNumber, + nnYLen: NSNumber, + dataType: MPSDataType = .float32 + ) { + let boardSize = Double(nnXLen.intValue * nnYLen.intValue) + let value = (sqrt(boardSize) - 14.0) * 0.1 + self.tensor = graph.constant( + value, + shape: [1, 1, 1, 1], + dataType: dataType) + + assert(self.tensor.shape?.count == 4) + } +} + +/// A structure for (sqrt(maskSum) * 0.1 - 1.4)^2 - 0.1 +struct MaskSumSqrtS14M01SquareS01Layer { + let tensor: MPSGraphTensor + + init(tensor: MPSGraphTensor) { + self.tensor = tensor + assert(self.tensor.shape?.count == 4) + } + + init( + graph: MPSGraph, + maskSumSqrtS14M01: MaskSumSqrtS14M01Layer + ) { + let squared = graph.square(with: maskSumSqrtS14M01.tensor, name: nil) + + let zeroPointone = graph.constant( + 0.1, + shape: [1], + dataType: maskSumSqrtS14M01.tensor.dataType) + + self.tensor = graph.subtraction( + squared, + zeroPointone, + name: nil) + + assert(self.tensor.shape?.count == 4) + } + + /// Optimized init for when mask is all 1s (requireExactNNLen=true) + /// Returns constant tensor: ((sqrt(boardSize) - 14) * 0.1)^2 - 0.1 + init( + graph: MPSGraph, + nnXLen: NSNumber, + nnYLen: NSNumber, + dataType: MPSDataType = .float32 + ) { + let boardSize = Double(nnXLen.intValue * nnYLen.intValue) + let sqrtS14M01 = (sqrt(boardSize) - 14.0) * 0.1 + let value = sqrtS14M01 * sqrtS14M01 - 0.1 + self.tensor = graph.constant( + value, + shape: [1, 1, 1, 1], + dataType: dataType) + + assert(self.tensor.shape?.count == 4) + } +} + +// MARK: - Layer Descriptors + +/// An enumeration of the different kinds of activation function. +public enum ActivationKind { + case identity + case relu + case mish +} + +/// A struct that represents a description of convolutional layer. +public struct SWConvLayerDesc { + let convYSize: NSNumber + let convXSize: NSNumber + let inChannels: NSNumber + let outChannels: NSNumber + let dilationY: Int + let dilationX: Int + let weights: UnsafeMutablePointer + + init( + convYSize: NSNumber, + convXSize: NSNumber, + inChannels: NSNumber, + outChannels: NSNumber, + dilationY: Int, + dilationX: Int, + weights: UnsafeMutablePointer + ) { + self.convYSize = convYSize + self.convXSize = convXSize + self.inChannels = inChannels + self.outChannels = outChannels + self.dilationY = dilationY + self.dilationX = dilationX + self.weights = weights + } +} + +public func createSWConvLayerDesc( + convYSize: Int32, + convXSize: Int32, + inChannels: Int32, + outChannels: Int32, + dilationY: Int32, + dilationX: Int32, + weights: UnsafeMutablePointer +) -> SWConvLayerDesc { + return SWConvLayerDesc( + convYSize: convYSize as NSNumber, + convXSize: convXSize as NSNumber, + inChannels: inChannels as NSNumber, + outChannels: outChannels as NSNumber, + dilationY: Int(dilationY), + dilationX: Int(dilationX), + weights: weights) +} + +/// A struct that represents a description of a batch normalization layer. +public struct SWBatchNormLayerDesc { + let numChannels: NSNumber + let mergedScale: UnsafeMutablePointer + let mergedBias: UnsafeMutablePointer + + init( + numChannels: NSNumber, + mergedScale: UnsafeMutablePointer, + mergedBias: UnsafeMutablePointer + ) { + self.numChannels = numChannels + self.mergedScale = mergedScale + self.mergedBias = mergedBias + } +} + +public func createSWBatchNormLayerDesc( + numChannels: Int32, + mergedScale: UnsafeMutablePointer, + mergedBias: UnsafeMutablePointer +) -> SWBatchNormLayerDesc { + return SWBatchNormLayerDesc( + numChannels: numChannels as NSNumber, + mergedScale: mergedScale, + mergedBias: mergedBias) +} + +/// A struct that represents a matrix multiplication layer descriptor +public struct SWMatMulLayerDesc { + let inChannels: NSNumber + let outChannels: NSNumber + let weights: UnsafeMutablePointer + + init( + inChannels: NSNumber, + outChannels: NSNumber, + weights: UnsafeMutablePointer + ) { + self.inChannels = inChannels + self.outChannels = outChannels + self.weights = weights + } +} + +public func createSWMatMulLayerDesc( + inChannels: Int32, + outChannels: Int32, + weights: UnsafeMutablePointer +) -> SWMatMulLayerDesc { + return SWMatMulLayerDesc( + inChannels: inChannels as NSNumber, + outChannels: outChannels as NSNumber, + weights: weights) +} + +/// A struct that represents the bias layer description. +public struct SWMatBiasLayerDesc { + let numChannels: NSNumber + let weights: UnsafeMutablePointer + + init( + numChannels: NSNumber, + weights: UnsafeMutablePointer + ) { + self.numChannels = numChannels + self.weights = weights + } +} + +public func createSWMatBiasLayerDesc( + numChannels: Int32, + weights: UnsafeMutablePointer +) -> SWMatBiasLayerDesc { + return SWMatBiasLayerDesc( + numChannels: numChannels as NSNumber, + weights: weights) +} + +// MARK: - Core Layers + +/// A class that represents a convolutional layer using MPSGraph +class ConvLayer { + let resultTensor: MPSGraphTensor + let convDescriptor = MPSGraphConvolution2DOpDescriptor( + strideInX: 1, + strideInY: 1, + dilationRateInX: 1, + dilationRateInY: 1, + groups: 1, + paddingStyle: .TF_SAME, + dataLayout: .NCHW, + weightsLayout: .OIHW)! + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + descriptor: SWConvLayerDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let weightsShape = [ + descriptor.outChannels, + descriptor.inChannels, + descriptor.convYSize, + descriptor.convXSize, + ] + + let weightsData = Data( + floatsNoCopy: descriptor.weights, + shape: weightsShape) + + let weightsTensor = graph.constant( + weightsData, + shape: weightsShape, + dataType: sourceTensor.dataType) + + resultTensor = graph.convolution2D( + sourceTensor, + weights: weightsTensor, + descriptor: convDescriptor, + name: nil) + + assert(resultTensor.shape?.count == 4) + } +} + +/// A class that represents a batch normalization layer. +class BatchNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWBatchNormLayerDesc, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let scaleBiasShape = InputShape.create( + batchSize: 1, + numChannels: descriptor.numChannels, + nnYLen: 1, + nnXLen: 1) + + let mergedScaleData = Data( + floatsNoCopy: descriptor.mergedScale, + shape: scaleBiasShape) + + let mergedBiasData = Data( + floatsNoCopy: descriptor.mergedBias, + shape: scaleBiasShape) + + let scaleTensor = graph.constant( + mergedScaleData, + shape: scaleBiasShape, + dataType: sourceTensor.dataType) + + let biasTensor = graph.constant( + mergedBiasData, + shape: scaleBiasShape, + dataType: sourceTensor.dataType) + + let scaled = graph.multiplication( + sourceTensor, + scaleTensor, + name: nil) + + let normalized = graph.addition( + scaled, + biasTensor, + name: nil) + + // Skip mask multiplication when all mask values are 1 + if optimizeIdentityMask { + resultTensor = normalized + } else { + resultTensor = graph.multiplication( + normalized, + maskTensor, + name: nil) + } + + assert(resultTensor.shape?.count == 4) + } +} + +/// A structure that represents an activation layer +struct ActivationLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + activationKind: ActivationKind + ) { + switch activationKind { + case .relu: + resultTensor = graph.reLU(with: sourceTensor, name: nil) + case .mish: + resultTensor = graph.mish(tensor: sourceTensor) + default: + resultTensor = sourceTensor + } + + assert(resultTensor.shape == sourceTensor.shape) + } +} + +/// A structure representing a matrix multiplication layer. +struct MatMulLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + descriptor: SWMatMulLayerDesc, + sourceTensor: MPSGraphTensor + ) { + assert( + (sourceTensor.shape?.count == 4) || (sourceTensor.shape?[1] == descriptor.inChannels)) + assert( + (sourceTensor.shape?.count == 2) || (sourceTensor.shape?[1] == descriptor.inChannels)) + + let weightsShape = [ + descriptor.inChannels, + descriptor.outChannels, + ] + + let weightsData = Data( + floatsNoCopy: descriptor.weights, + shape: weightsShape) + + let weightsTensor = graph.constant( + weightsData, + shape: weightsShape, + dataType: sourceTensor.dataType) + + let shape = [-1, descriptor.inChannels] + + let reshapedSource = graph.reshape( + sourceTensor, + shape: shape, + name: nil) + + resultTensor = graph.matrixMultiplication( + primary: reshapedSource, + secondary: weightsTensor, + name: nil) + + assert(resultTensor.shape?.count == 2) + } +} + +/// A structure that performs matrix bias operations +struct MatBiasLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + descriptor: SWMatBiasLayerDesc, + sourceTensor: MPSGraphTensor + ) { + assert( + (sourceTensor.shape?.count == 2) && (sourceTensor.shape?[1] == descriptor.numChannels)) + + let weightsShape = [1, descriptor.numChannels] + + let weightsData = Data( + floatsNoCopy: descriptor.weights, + shape: weightsShape) + + let weightsTensor = graph.constant( + weightsData, + shape: weightsShape, + dataType: sourceTensor.dataType) + + resultTensor = graph.addition( + sourceTensor, + weightsTensor, + name: nil) + } +} + +/// A structure that performs bias operations in NC coordinates. +struct AddNCBiasLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + biasTensor: MPSGraphTensor, + nnXLen: NSNumber, + nnYLen: NSNumber, + numChannels: NSNumber + ) { + let shape = InputShape.create( + batchSize: -1, + numChannels: numChannels, + nnYLen: 1, + nnXLen: 1) + + assert(biasTensor.shape?[1] == shape[1]) + + let reshaped = graph.reshape(biasTensor, shape: shape, name: nil) + resultTensor = graph.addition(sourceTensor, reshaped, name: nil) + + assert(resultTensor.shape?.count == 4) + assert(resultTensor.shape?[2] == nnYLen) + assert(resultTensor.shape?[3] == nnXLen) + } +} + +// MARK: - Pooling Layers + +/// A structure that represents a global pooling layer +struct GlobalPoolingLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + optimizeIdentityMask: Bool = false + ) { + let hwAxes = InputShape.getHWAxes() + let channelAxis = InputShape.getChannelAxis() + + let sumTensor = graph.reductionSum( + with: sourceTensor, + axes: hwAxes, + name: nil) + + let meanTensor = graph.division(sumTensor, maskSumTensor, name: nil) + + let meanMaskTensor = graph.multiplication( + meanTensor, + maskSumSqrtS14M01Tensor, + name: nil) + + let maxTensor: MPSGraphTensor + if optimizeIdentityMask { + // When all mask values are 1, directly compute max without mask adjustment + maxTensor = graph.reductionMaximum( + with: sourceTensor, + axes: hwAxes, + name: nil) + } else { + // Mask out invalid positions by subtracting 1 (making them very negative) + let oneTensor = graph.constant(1.0, dataType: sourceTensor.dataType) + let maskM1Tensor = graph.subtraction(maskTensor, oneTensor, name: nil) + let addition = graph.addition(sourceTensor, maskM1Tensor, name: nil) + + maxTensor = graph.reductionMaximum( + with: addition, + axes: hwAxes, + name: nil) + } + + resultTensor = graph.concatTensors( + [ + meanTensor, + meanMaskTensor, + maxTensor, + ], + dimension: channelAxis, + name: nil) + + assert(resultTensor.shape?.count == 4) + assert(resultTensor.shape?[2] == 1) + assert(resultTensor.shape?[3] == 1) + } +} + +/// A structure that represents a layer that performs global pooling on the input tensor +struct GlobalPoolingValueLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + maskSumSqrtS14M01SquareS01Tensor: MPSGraphTensor + ) { + let hwAxes = InputShape.getHWAxes() + let channelAxis = InputShape.getChannelAxis() + + let sumTensor = graph.reductionSum( + with: sourceTensor, + axes: hwAxes, + name: nil) + + let meanTensor = graph.division(sumTensor, maskSumTensor, name: nil) + + let meanMaskTensor = graph.multiplication( + meanTensor, + maskSumSqrtS14M01Tensor, + name: nil) + + let meanMaskSquareTensor = graph.multiplication( + meanTensor, + maskSumSqrtS14M01SquareS01Tensor, + name: nil) + + resultTensor = graph.concatTensors( + [ + meanTensor, + meanMaskTensor, + meanMaskSquareTensor, + ], + dimension: channelAxis, + name: nil) + + assert(resultTensor.shape?.count == 4) + assert(resultTensor.shape?[2] == 1) + assert(resultTensor.shape?[3] == 1) + } +} + +// MARK: - Block Descriptors + +/// Base class for block descriptors +public class BlockDescriptor { +} + +/// A class that represents a residual block. +public class SWResidualBlockDesc: BlockDescriptor { + let preBN: SWBatchNormLayerDesc + let preActivation: ActivationKind + let regularConv: SWConvLayerDesc + let midBN: SWBatchNormLayerDesc + let midActivation: ActivationKind + let finalConv: SWConvLayerDesc + + init( + preBN: SWBatchNormLayerDesc, + preActivation: ActivationKind, + regularConv: SWConvLayerDesc, + midBN: SWBatchNormLayerDesc, + midActivation: ActivationKind, + finalConv: SWConvLayerDesc + ) { + self.preBN = preBN + self.preActivation = preActivation + self.regularConv = regularConv + self.midBN = midBN + self.midActivation = midActivation + self.finalConv = finalConv + } +} + +public func createSWResidualBlockDesc( + preBN: SWBatchNormLayerDesc, + preActivation: ActivationKind, + regularConv: SWConvLayerDesc, + midBN: SWBatchNormLayerDesc, + midActivation: ActivationKind, + finalConv: SWConvLayerDesc +) -> SWResidualBlockDesc { + return SWResidualBlockDesc( + preBN: preBN, + preActivation: preActivation, + regularConv: regularConv, + midBN: midBN, + midActivation: midActivation, + finalConv: finalConv) +} + +/// A class that represents a residual block with global pooling. +public class SWGlobalPoolingResidualBlockDesc: BlockDescriptor { + let preBN: SWBatchNormLayerDesc + let preActivation: ActivationKind + let regularConv: SWConvLayerDesc + let gpoolConv: SWConvLayerDesc + let gpoolBN: SWBatchNormLayerDesc + let gpoolActivation: ActivationKind + let gpoolToBiasMul: SWMatMulLayerDesc + let midBN: SWBatchNormLayerDesc + let midActivation: ActivationKind + let finalConv: SWConvLayerDesc + + init( + preBN: SWBatchNormLayerDesc, + preActivation: ActivationKind, + regularConv: SWConvLayerDesc, + gpoolConv: SWConvLayerDesc, + gpoolBN: SWBatchNormLayerDesc, + gpoolActivation: ActivationKind, + gpoolToBiasMul: SWMatMulLayerDesc, + midBN: SWBatchNormLayerDesc, + midActivation: ActivationKind, + finalConv: SWConvLayerDesc + ) { + self.preBN = preBN + self.preActivation = preActivation + self.regularConv = regularConv + self.gpoolConv = gpoolConv + self.gpoolBN = gpoolBN + self.gpoolActivation = gpoolActivation + self.gpoolToBiasMul = gpoolToBiasMul + self.midBN = midBN + self.midActivation = midActivation + self.finalConv = finalConv + } +} + +public func createSWGlobalPoolingResidualBlockDesc( + preBN: SWBatchNormLayerDesc, + preActivation: ActivationKind, + regularConv: SWConvLayerDesc, + gpoolConv: SWConvLayerDesc, + gpoolBN: SWBatchNormLayerDesc, + gpoolActivation: ActivationKind, + gpoolToBiasMul: SWMatMulLayerDesc, + midBN: SWBatchNormLayerDesc, + midActivation: ActivationKind, + finalConv: SWConvLayerDesc +) -> SWGlobalPoolingResidualBlockDesc { + return SWGlobalPoolingResidualBlockDesc( + preBN: preBN, + preActivation: preActivation, + regularConv: regularConv, + gpoolConv: gpoolConv, + gpoolBN: gpoolBN, + gpoolActivation: gpoolActivation, + gpoolToBiasMul: gpoolToBiasMul, + midBN: midBN, + midActivation: midActivation, + finalConv: finalConv) +} + +/// A class that represents a nested bottleneck residual block +public class SWNestedBottleneckResidualBlockDesc: BlockDescriptor { + let preBN: SWBatchNormLayerDesc + let preActivation: ActivationKind + let preConv: SWConvLayerDesc + let blockDescriptors: [BlockDescriptor] + let postBN: SWBatchNormLayerDesc + let postActivation: ActivationKind + let postConv: SWConvLayerDesc + + init( + preBN: SWBatchNormLayerDesc, + preActivation: ActivationKind, + preConv: SWConvLayerDesc, + blockDescriptors: [BlockDescriptor], + postBN: SWBatchNormLayerDesc, + postActivation: ActivationKind, + postConv: SWConvLayerDesc + ) { + self.preBN = preBN + self.preActivation = preActivation + self.preConv = preConv + self.blockDescriptors = blockDescriptors + self.postBN = postBN + self.postActivation = postActivation + self.postConv = postConv + } +} + +public func createSWNestedBottleneckResidualBlockDesc( + preBN: SWBatchNormLayerDesc, + preActivation: ActivationKind, + preConv: SWConvLayerDesc, + blockDescriptors: [BlockDescriptor], + postBN: SWBatchNormLayerDesc, + postActivation: ActivationKind, + postConv: SWConvLayerDesc +) -> SWNestedBottleneckResidualBlockDesc { + return SWNestedBottleneckResidualBlockDesc( + preBN: preBN, + preActivation: preActivation, + preConv: preConv, + blockDescriptors: blockDescriptors, + postBN: postBN, + postActivation: postActivation, + postConv: postConv) +} + +public class BlockDescriptorBuilder { + public var blockDescriptors: [BlockDescriptor] = [] + + init() {} + + public func enque(with descriptor: BlockDescriptor) { + blockDescriptors.append(descriptor) + } +} + +public func createBlockDescriptorBuilder() -> BlockDescriptorBuilder { + return BlockDescriptorBuilder() +} + +// MARK: - Block Implementations + +/// A class that represents a Residual Block layer +class ResidualBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWResidualBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let preBN = BatchNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let preActivation = ActivationLayer( + graph: graph, + sourceTensor: preBN.resultTensor, + activationKind: descriptor.preActivation) + + let regularConv = ConvLayer( + graph: graph, + sourceTensor: preActivation.resultTensor, + descriptor: descriptor.regularConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let midBN = BatchNormLayer( + graph: graph, + sourceTensor: regularConv.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.midBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let midActivation = ActivationLayer( + graph: graph, + sourceTensor: midBN.resultTensor, + activationKind: descriptor.midActivation) + + let finalConv = ConvLayer( + graph: graph, + sourceTensor: midActivation.resultTensor, + descriptor: descriptor.finalConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + resultTensor = graph.addition( + sourceTensor, + finalConv.resultTensor, + name: nil) + + assert(resultTensor.shape?.count == 4) + } +} + +/// A class representing a residual block with global pooling +class GlobalPoolingResidualBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + descriptor: SWGlobalPoolingResidualBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let maskSum = MaskSumLayer(tensor: maskSumTensor) + let maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer(tensor: maskSumSqrtS14M01Tensor) + + let preBN = BatchNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let preActivation = ActivationLayer( + graph: graph, + sourceTensor: preBN.resultTensor, + activationKind: descriptor.preActivation) + + let regularConv = ConvLayer( + graph: graph, + sourceTensor: preActivation.resultTensor, + descriptor: descriptor.regularConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let gpoolConv = ConvLayer( + graph: graph, + sourceTensor: preActivation.resultTensor, + descriptor: descriptor.gpoolConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let gpoolBN = BatchNormLayer( + graph: graph, + sourceTensor: gpoolConv.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.gpoolBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let gpoolActivation = ActivationLayer( + graph: graph, + sourceTensor: gpoolBN.resultTensor, + activationKind: descriptor.gpoolActivation) + + let gpoolConcat = GlobalPoolingLayer( + graph: graph, + sourceTensor: gpoolActivation.resultTensor, + maskTensor: maskTensor, + maskSumTensor: maskSum.tensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, + optimizeIdentityMask: optimizeIdentityMask) + + assert(gpoolConcat.resultTensor.shape?[1] == descriptor.gpoolToBiasMul.inChannels) + + let gpoolToBiasMul = MatMulLayer( + graph: graph, + descriptor: descriptor.gpoolToBiasMul, + sourceTensor: gpoolConcat.resultTensor) + + let added = AddNCBiasLayer( + graph: graph, + sourceTensor: regularConv.resultTensor, + biasTensor: gpoolToBiasMul.resultTensor, + nnXLen: nnXLen, + nnYLen: nnYLen, + numChannels: descriptor.gpoolToBiasMul.outChannels) + + let midBN = BatchNormLayer( + graph: graph, + sourceTensor: added.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.midBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let midActivation = ActivationLayer( + graph: graph, + sourceTensor: midBN.resultTensor, + activationKind: descriptor.midActivation) + + let finalConv = ConvLayer( + graph: graph, + sourceTensor: midActivation.resultTensor, + descriptor: descriptor.finalConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + resultTensor = graph.addition( + sourceTensor, + finalConv.resultTensor, + name: nil) + + assert(resultTensor.shape?.count == 4) + } +} + +/// A structure that represents a block stack +struct BlockStack { + let resultTensor: MPSGraphTensor + + static func processBlockDescriptors( + _ graph: MPSGraph, + _ sourceTensor: MPSGraphTensor, + _ maskTensor: MPSGraphTensor, + _ maskSumTensor: MPSGraphTensor, + _ maskSumSqrtS14M01Tensor: MPSGraphTensor, + _ blockDescriptors: [BlockDescriptor], + _ index: Int, + _ nnXLen: NSNumber, + _ nnYLen: NSNumber, + _ optimizeIdentityMask: Bool + ) -> MPSGraphTensor { + guard index < blockDescriptors.count else { + return sourceTensor + } + + let blockDescriptor = blockDescriptors[index] + let blockInput: MPSGraphTensor + + switch blockDescriptor { + case let globalPoolingDescriptor as SWGlobalPoolingResidualBlockDesc: + let globalPooling = GlobalPoolingResidualBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + maskSumTensor: maskSumTensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, + descriptor: globalPoolingDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + blockInput = globalPooling.resultTensor + case let nestedBottleneckDescriptor as SWNestedBottleneckResidualBlockDesc: + let nestedBottleneck = NestedBottleneckResidualBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + maskSumTensor: maskSumTensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, + descriptor: nestedBottleneckDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + blockInput = nestedBottleneck.resultTensor + case let residualBlockDescriptor as SWResidualBlockDesc: + let ordinary = ResidualBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: residualBlockDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + blockInput = ordinary.resultTensor + default: + blockInput = sourceTensor + } + + return processBlockDescriptors( + graph, + blockInput, + maskTensor, + maskSumTensor, + maskSumSqrtS14M01Tensor, + blockDescriptors, + index + 1, + nnXLen, + nnYLen, + optimizeIdentityMask) + } + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + blockDescriptors: [BlockDescriptor], + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + resultTensor = BlockStack.processBlockDescriptors( + graph, + sourceTensor, + maskTensor, + maskSumTensor, + maskSumSqrtS14M01Tensor, + blockDescriptors, + 0, + nnXLen, + nnYLen, + optimizeIdentityMask) + } +} + +/// A structure that represents a nested bottleneck residual block +struct NestedBottleneckResidualBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + descriptor: SWNestedBottleneckResidualBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let preBN = BatchNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let preActivation = ActivationLayer( + graph: graph, + sourceTensor: preBN.resultTensor, + activationKind: descriptor.preActivation) + + let preConv = ConvLayer( + graph: graph, + sourceTensor: preActivation.resultTensor, + descriptor: descriptor.preConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let blocks = BlockStack( + graph: graph, + sourceTensor: preConv.resultTensor, + maskTensor: maskTensor, + maskSumTensor: maskSumTensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, + blockDescriptors: descriptor.blockDescriptors, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let postBN = BatchNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.postBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let postActivation = ActivationLayer( + graph: graph, + sourceTensor: postBN.resultTensor, + activationKind: descriptor.postActivation) + + let postConv = ConvLayer( + graph: graph, + sourceTensor: postActivation.resultTensor, + descriptor: descriptor.postConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + resultTensor = graph.addition( + sourceTensor, + postConv.resultTensor, + name: nil) + + assert(resultTensor.shape?.count == 4) + } +} + +// MARK: - SGF Metadata Encoder + +/// Class representing the description of the SGF Metadata Encoder. +public class SWSGFMetadataEncoderDesc { + let version: Int + let numInputMetaChannels: Int + let mul1: SWMatMulLayerDesc + let bias1: SWMatBiasLayerDesc + let act1: ActivationKind + let mul2: SWMatMulLayerDesc + let bias2: SWMatBiasLayerDesc + let act2: ActivationKind + let mul3: SWMatMulLayerDesc + + init( + version: Int, + numInputMetaChannels: Int, + mul1: SWMatMulLayerDesc, + bias1: SWMatBiasLayerDesc, + act1: ActivationKind, + mul2: SWMatMulLayerDesc, + bias2: SWMatBiasLayerDesc, + act2: ActivationKind, + mul3: SWMatMulLayerDesc + ) { + self.version = version + self.numInputMetaChannels = numInputMetaChannels + self.mul1 = mul1 + self.bias1 = bias1 + self.act1 = act1 + self.mul2 = mul2 + self.bias2 = bias2 + self.act2 = act2 + self.mul3 = mul3 + } +} + +public func createSWSGFMetadataEncoderDesc( + version: Int32, + numInputMetaChannels: Int32, + mul1: SWMatMulLayerDesc, + bias1: SWMatBiasLayerDesc, + act1: ActivationKind, + mul2: SWMatMulLayerDesc, + bias2: SWMatBiasLayerDesc, + act2: ActivationKind, + mul3: SWMatMulLayerDesc +) -> SWSGFMetadataEncoderDesc? { + return SWSGFMetadataEncoderDesc( + version: Int(version), + numInputMetaChannels: Int(numInputMetaChannels), + mul1: mul1, + bias1: bias1, + act1: act1, + mul2: mul2, + bias2: bias2, + act2: act2, + mul3: mul3) +} + +/// A class that encodes SGF metadata. +class SGFMetadataEncoder { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + descriptor: SWSGFMetadataEncoderDesc, + sourceTensor: MPSGraphTensor + ) { + let mul1 = MatMulLayer( + graph: graph, + descriptor: descriptor.mul1, + sourceTensor: sourceTensor) + + let bias1 = MatBiasLayer( + graph: graph, + descriptor: descriptor.bias1, + sourceTensor: mul1.resultTensor) + + let act1 = ActivationLayer( + graph: graph, + sourceTensor: bias1.resultTensor, + activationKind: descriptor.act1) + + let mul2 = MatMulLayer( + graph: graph, + descriptor: descriptor.mul2, + sourceTensor: act1.resultTensor) + + let bias2 = MatBiasLayer( + graph: graph, + descriptor: descriptor.bias2, + sourceTensor: mul2.resultTensor) + + let act2 = ActivationLayer( + graph: graph, + sourceTensor: bias2.resultTensor, + activationKind: descriptor.act2) + + let mul3 = MatMulLayer( + graph: graph, + descriptor: descriptor.mul3, + sourceTensor: act2.resultTensor) + + resultTensor = mul3.resultTensor + + assert(resultTensor.shape?.count == 2) + } +} + +// MARK: - Trunk + +/// A class that describes a trunk for a neural network +public class SWTrunkDesc { + let version: Int + let trunkNumChannels: NSNumber + let midNumChannels: NSNumber + let regularNumChannels: NSNumber + let gpoolNumChannels: NSNumber + let initialConv: SWConvLayerDesc + let initialMatMul: SWMatMulLayerDesc + let sgfMetadataEncoder: SWSGFMetadataEncoderDesc? + let blockDescriptors: [BlockDescriptor] + let trunkTipBN: SWBatchNormLayerDesc + let trunkTipActivation: ActivationKind + + init( + version: Int, + trunkNumChannels: NSNumber, + midNumChannels: NSNumber, + regularNumChannels: NSNumber, + gpoolNumChannels: NSNumber, + initialConv: SWConvLayerDesc, + initialMatMul: SWMatMulLayerDesc, + sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, + blockDescriptors: [BlockDescriptor], + trunkTipBN: SWBatchNormLayerDesc, + trunkTipActivation: ActivationKind + ) { + self.version = version + self.trunkNumChannels = trunkNumChannels + self.midNumChannels = midNumChannels + self.regularNumChannels = regularNumChannels + self.gpoolNumChannels = gpoolNumChannels + self.initialConv = initialConv + self.initialMatMul = initialMatMul + self.sgfMetadataEncoder = sgfMetadataEncoder + self.blockDescriptors = blockDescriptors + self.trunkTipBN = trunkTipBN + self.trunkTipActivation = trunkTipActivation + } +} + +public func createSWTrunkDesc( + version: Int32, + trunkNumChannels: Int32, + midNumChannels: Int32, + regularNumChannels: Int32, + gpoolNumChannels: Int32, + initialConv: SWConvLayerDesc, + initialMatMul: SWMatMulLayerDesc, + sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, + blockDescriptors: [BlockDescriptor], + trunkTipBN: SWBatchNormLayerDesc, + trunkTipActivation: ActivationKind +) -> SWTrunkDesc { + return SWTrunkDesc( + version: Int(version), + trunkNumChannels: trunkNumChannels as NSNumber, + midNumChannels: midNumChannels as NSNumber, + regularNumChannels: regularNumChannels as NSNumber, + gpoolNumChannels: gpoolNumChannels as NSNumber, + initialConv: initialConv, + initialMatMul: initialMatMul, + sgfMetadataEncoder: sgfMetadataEncoder, + blockDescriptors: blockDescriptors, + trunkTipBN: trunkTipBN, + trunkTipActivation: trunkTipActivation) +} + +/// A structure representing a ResNet trunk for a neural network +struct Trunk { + let resultTensor: MPSGraphTensor + + static func getBlockSourceTensor( + graph: MPSGraph, + descriptor: SWSGFMetadataEncoderDesc?, + initialAdd: AddNCBiasLayer, + inputMetaTensor: MPSGraphTensor?, + nnXLen: NSNumber, + nnYLen: NSNumber, + numChannels: NSNumber + ) -> MPSGraphTensor { + var blockSourceTensor: MPSGraphTensor + + if let inputMetaTensor, + let descriptor, descriptor.numInputMetaChannels > 0 + { + let encoded = SGFMetadataEncoder( + graph: graph, + descriptor: descriptor, + sourceTensor: inputMetaTensor) + + let encodedAdd = AddNCBiasLayer( + graph: graph, + sourceTensor: initialAdd.resultTensor, + biasTensor: encoded.resultTensor, + nnXLen: nnXLen, + nnYLen: nnYLen, + numChannels: numChannels) + + blockSourceTensor = encodedAdd.resultTensor + } else { + blockSourceTensor = initialAdd.resultTensor + } + + return blockSourceTensor + } + + init( + graph: MPSGraph, + descriptor: SWTrunkDesc, + inputTensor: MPSGraphTensor, + inputGlobalTensor: MPSGraphTensor, + inputMetaTensor: MPSGraphTensor?, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let initialConv = ConvLayer( + graph: graph, + sourceTensor: inputTensor, + descriptor: descriptor.initialConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let initialMatMul = MatMulLayer( + graph: graph, + descriptor: descriptor.initialMatMul, + sourceTensor: inputGlobalTensor) + + let initialAdd = AddNCBiasLayer( + graph: graph, + sourceTensor: initialConv.resultTensor, + biasTensor: initialMatMul.resultTensor, + nnXLen: nnXLen, + nnYLen: nnYLen, + numChannels: descriptor.initialMatMul.outChannels) + + let blockSourceTensor = Trunk.getBlockSourceTensor( + graph: graph, + descriptor: descriptor.sgfMetadataEncoder, + initialAdd: initialAdd, + inputMetaTensor: inputMetaTensor, + nnXLen: nnXLen, + nnYLen: nnYLen, + numChannels: descriptor.initialMatMul.outChannels) + + let blocks = BlockStack( + graph: graph, + sourceTensor: blockSourceTensor, + maskTensor: maskTensor, + maskSumTensor: maskSumTensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, + blockDescriptors: descriptor.blockDescriptors, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let trunkTipBN = BatchNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let trunkTipActivation = ActivationLayer( + graph: graph, + sourceTensor: trunkTipBN.resultTensor, + activationKind: descriptor.trunkTipActivation) + + resultTensor = trunkTipActivation.resultTensor + + assert(resultTensor.shape?.count == 4) + } +} + +// MARK: - Policy Head + +/// A class that describes a policy head for a neural network +public struct SWPolicyHeadDesc { + let version: Int + let p1Conv: SWConvLayerDesc + let g1Conv: SWConvLayerDesc + let g1BN: SWBatchNormLayerDesc + let g1Activation: ActivationKind + let gpoolToBiasMul: SWMatMulLayerDesc + let p1BN: SWBatchNormLayerDesc + let p1Activation: ActivationKind + let p2Conv: SWConvLayerDesc + let gpoolToPassMul: SWMatMulLayerDesc + let gpoolToPassBias: SWMatBiasLayerDesc? + let passActivation: ActivationKind? + let gpoolToPassMul2: SWMatMulLayerDesc? + + init( + version: Int, + p1Conv: SWConvLayerDesc, + g1Conv: SWConvLayerDesc, + g1BN: SWBatchNormLayerDesc, + g1Activation: ActivationKind, + gpoolToBiasMul: SWMatMulLayerDesc, + p1BN: SWBatchNormLayerDesc, + p1Activation: ActivationKind, + p2Conv: SWConvLayerDesc, + gpoolToPassMul: SWMatMulLayerDesc, + gpoolToPassBias: SWMatBiasLayerDesc?, + passActivation: ActivationKind?, + gpoolToPassMul2: SWMatMulLayerDesc? + ) { + self.version = version + self.p1Conv = p1Conv + self.g1Conv = g1Conv + self.g1BN = g1BN + self.g1Activation = g1Activation + self.gpoolToBiasMul = gpoolToBiasMul + self.p1BN = p1BN + self.p1Activation = p1Activation + self.p2Conv = p2Conv + self.gpoolToPassMul = gpoolToPassMul + self.gpoolToPassBias = gpoolToPassBias + self.passActivation = passActivation + self.gpoolToPassMul2 = gpoolToPassMul2 + + assert( + (version >= 15) + || ((gpoolToPassBias == nil) && (passActivation == nil) && (gpoolToPassMul2 == nil)) + ) + assert( + (version < 15) + || ((gpoolToPassBias != nil) && (passActivation != nil) && (gpoolToPassMul2 != nil)) + ) + } +} + +public func createSWPolicyHeadDesc( + version: Int32, + p1Conv: SWConvLayerDesc, + g1Conv: SWConvLayerDesc, + g1BN: SWBatchNormLayerDesc, + g1Activation: ActivationKind, + gpoolToBiasMul: SWMatMulLayerDesc, + p1BN: SWBatchNormLayerDesc, + p1Activation: ActivationKind, + p2Conv: SWConvLayerDesc, + gpoolToPassMul: SWMatMulLayerDesc, + gpoolToPassBias: SWMatBiasLayerDesc, + passActivation: ActivationKind, + gpoolToPassMul2: SWMatMulLayerDesc +) -> SWPolicyHeadDesc { + if version >= 15 { + return SWPolicyHeadDesc( + version: Int(version), + p1Conv: p1Conv, + g1Conv: g1Conv, + g1BN: g1BN, + g1Activation: g1Activation, + gpoolToBiasMul: gpoolToBiasMul, + p1BN: p1BN, + p1Activation: p1Activation, + p2Conv: p2Conv, + gpoolToPassMul: gpoolToPassMul, + gpoolToPassBias: gpoolToPassBias, + passActivation: passActivation, + gpoolToPassMul2: gpoolToPassMul2) + } else { + return SWPolicyHeadDesc( + version: Int(version), + p1Conv: p1Conv, + g1Conv: g1Conv, + g1BN: g1BN, + g1Activation: g1Activation, + gpoolToBiasMul: gpoolToBiasMul, + p1BN: p1BN, + p1Activation: p1Activation, + p2Conv: p2Conv, + gpoolToPassMul: gpoolToPassMul, + gpoolToPassBias: nil, + passActivation: nil, + gpoolToPassMul2: nil) + } +} + +/// A structure that represents a policy head of a neural network. +struct PolicyHead { + let policyTensor: MPSGraphTensor + let policyPassTensor: MPSGraphTensor + + init( + graph: MPSGraph, + descriptor: SWPolicyHeadDesc, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let p1Conv = ConvLayer( + graph: graph, + sourceTensor: sourceTensor, + descriptor: descriptor.p1Conv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let g1Conv = ConvLayer( + graph: graph, + sourceTensor: sourceTensor, + descriptor: descriptor.g1Conv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let g1BN = BatchNormLayer( + graph: graph, + sourceTensor: g1Conv.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.g1BN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let g1Activation = ActivationLayer( + graph: graph, + sourceTensor: g1BN.resultTensor, + activationKind: descriptor.g1Activation) + + let g1Concat = GlobalPoolingLayer( + graph: graph, + sourceTensor: g1Activation.resultTensor, + maskTensor: maskTensor, + maskSumTensor: maskSumTensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, + optimizeIdentityMask: optimizeIdentityMask) + + assert(g1Concat.resultTensor.shape?[1] == descriptor.gpoolToBiasMul.inChannels) + + let gpoolToBiasMul = MatMulLayer( + graph: graph, + descriptor: descriptor.gpoolToBiasMul, + sourceTensor: g1Concat.resultTensor) + + let added = AddNCBiasLayer( + graph: graph, + sourceTensor: p1Conv.resultTensor, + biasTensor: gpoolToBiasMul.resultTensor, + nnXLen: nnXLen, + nnYLen: nnYLen, + numChannels: descriptor.gpoolToBiasMul.outChannels) + + let p1BN = BatchNormLayer( + graph: graph, + sourceTensor: added.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.p1BN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let p1Activation = ActivationLayer( + graph: graph, + sourceTensor: p1BN.resultTensor, + activationKind: descriptor.p1Activation) + + let p2Conv = ConvLayer( + graph: graph, + sourceTensor: p1Activation.resultTensor, + descriptor: descriptor.p2Conv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + policyTensor = p2Conv.resultTensor + + assert(g1Concat.resultTensor.shape?[1] == descriptor.gpoolToPassMul.inChannels) + + let gpoolToPassMul = MatMulLayer( + graph: graph, + descriptor: descriptor.gpoolToPassMul, + sourceTensor: g1Concat.resultTensor) + + if let gpoolToPassBias = descriptor.gpoolToPassBias, + let passActivation = descriptor.passActivation, + let gpoolToPassMul2 = descriptor.gpoolToPassMul2 + { + assert(descriptor.version >= 15) + + let gpoolToPassBiasLayer = MatBiasLayer( + graph: graph, + descriptor: gpoolToPassBias, + sourceTensor: gpoolToPassMul.resultTensor) + + let passActivationLayer = ActivationLayer( + graph: graph, + sourceTensor: gpoolToPassBiasLayer.resultTensor, + activationKind: passActivation) + + let gpoolToPassMul2Layer = MatMulLayer( + graph: graph, + descriptor: gpoolToPassMul2, + sourceTensor: passActivationLayer.resultTensor) + + policyPassTensor = gpoolToPassMul2Layer.resultTensor + } else { + assert(descriptor.version < 15) + policyPassTensor = gpoolToPassMul.resultTensor + } + + assert(policyTensor.shape?.count == 4) + assert(policyPassTensor.shape?.count == 2) + } +} + +// MARK: - Value Head + +/// A struct that describes the value head of a neural network +public struct SWValueHeadDesc { + let version: Int + let v1Conv: SWConvLayerDesc + let v1BN: SWBatchNormLayerDesc + let v1Activation: ActivationKind + let v2Mul: SWMatMulLayerDesc + let v2Bias: SWMatBiasLayerDesc + let v2Activation: ActivationKind + let v3Mul: SWMatMulLayerDesc + let v3Bias: SWMatBiasLayerDesc + let sv3Mul: SWMatMulLayerDesc + let sv3Bias: SWMatBiasLayerDesc + let vOwnershipConv: SWConvLayerDesc + + init( + version: Int, + v1Conv: SWConvLayerDesc, + v1BN: SWBatchNormLayerDesc, + v1Activation: ActivationKind, + v2Mul: SWMatMulLayerDesc, + v2Bias: SWMatBiasLayerDesc, + v2Activation: ActivationKind, + v3Mul: SWMatMulLayerDesc, + v3Bias: SWMatBiasLayerDesc, + sv3Mul: SWMatMulLayerDesc, + sv3Bias: SWMatBiasLayerDesc, + vOwnershipConv: SWConvLayerDesc + ) { + self.version = version + self.v1Conv = v1Conv + self.v1BN = v1BN + self.v1Activation = v1Activation + self.v2Mul = v2Mul + self.v2Bias = v2Bias + self.v2Activation = v2Activation + self.v3Mul = v3Mul + self.v3Bias = v3Bias + self.sv3Mul = sv3Mul + self.sv3Bias = sv3Bias + self.vOwnershipConv = vOwnershipConv + } +} + +public func createSWValueHeadDesc( + version: Int32, + v1Conv: SWConvLayerDesc, + v1BN: SWBatchNormLayerDesc, + v1Activation: ActivationKind, + v2Mul: SWMatMulLayerDesc, + v2Bias: SWMatBiasLayerDesc, + v2Activation: ActivationKind, + v3Mul: SWMatMulLayerDesc, + v3Bias: SWMatBiasLayerDesc, + sv3Mul: SWMatMulLayerDesc, + sv3Bias: SWMatBiasLayerDesc, + vOwnershipConv: SWConvLayerDesc +) -> SWValueHeadDesc { + return SWValueHeadDesc( + version: Int(version), + v1Conv: v1Conv, + v1BN: v1BN, + v1Activation: v1Activation, + v2Mul: v2Mul, + v2Bias: v2Bias, + v2Activation: v2Activation, + v3Mul: v3Mul, + v3Bias: v3Bias, + sv3Mul: sv3Mul, + sv3Bias: sv3Bias, + vOwnershipConv: vOwnershipConv) +} + +/// A structure that creates a value head for the neural network +struct ValueHead { + let valueTensor: MPSGraphTensor + let scoreValueTensor: MPSGraphTensor + let ownershipTensor: MPSGraphTensor + + init( + graph: MPSGraph, + descriptor: SWValueHeadDesc, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + maskSumTensor: MPSGraphTensor, + maskSumSqrtS14M01Tensor: MPSGraphTensor, + maskSumSqrtS14M01SquareS01Tensor: MPSGraphTensor, + nnXLen: NSNumber, + nnYLen: NSNumber, + optimizeIdentityMask: Bool = false + ) { + let v1Conv = ConvLayer( + graph: graph, + sourceTensor: sourceTensor, + descriptor: descriptor.v1Conv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let v1BN = BatchNormLayer( + graph: graph, + sourceTensor: v1Conv.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.v1BN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let v1Activation = ActivationLayer( + graph: graph, + sourceTensor: v1BN.resultTensor, + activationKind: descriptor.v1Activation) + + let v1Mean = + GlobalPoolingValueLayer( + graph: graph, + sourceTensor: v1Activation.resultTensor, + maskSumTensor: maskSumTensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01Tensor, + maskSumSqrtS14M01SquareS01Tensor: maskSumSqrtS14M01SquareS01Tensor) + + assert(v1Mean.resultTensor.shape?[1] == descriptor.v2Mul.inChannels) + + let v2Mul = MatMulLayer( + graph: graph, + descriptor: descriptor.v2Mul, + sourceTensor: v1Mean.resultTensor) + + let v2Bias = MatBiasLayer( + graph: graph, + descriptor: descriptor.v2Bias, + sourceTensor: v2Mul.resultTensor) + + let v2Activation = ActivationLayer( + graph: graph, + sourceTensor: v2Bias.resultTensor, + activationKind: descriptor.v2Activation) + + let v3Mul = MatMulLayer( + graph: graph, + descriptor: descriptor.v3Mul, + sourceTensor: v2Activation.resultTensor) + + let v3Bias = MatBiasLayer( + graph: graph, + descriptor: descriptor.v3Bias, + sourceTensor: v3Mul.resultTensor) + + let sv3Mul = MatMulLayer( + graph: graph, + descriptor: descriptor.sv3Mul, + sourceTensor: v2Activation.resultTensor) + + let sv3Bias = MatBiasLayer( + graph: graph, + descriptor: descriptor.sv3Bias, + sourceTensor: sv3Mul.resultTensor) + + let vOwnershipConv = ConvLayer( + graph: graph, + sourceTensor: v1Activation.resultTensor, + descriptor: descriptor.vOwnershipConv, + nnXLen: nnXLen, + nnYLen: nnYLen) + + valueTensor = v3Bias.resultTensor + scoreValueTensor = sv3Bias.resultTensor + ownershipTensor = vOwnershipConv.resultTensor + + assert(valueTensor.shape?.count == 2) + assert(scoreValueTensor.shape?.count == 2) + assert(ownershipTensor.shape?.count == 4) + } +} + +// MARK: - Model Descriptor + +/// A struct that describes a neural network model used for playing the game of Go. +public struct SWModelDesc { + let version: Int + let name: String + let numInputChannels: NSNumber + let numInputGlobalChannels: NSNumber + let numInputMetaChannels: NSNumber + let numValueChannels: NSNumber + let numScoreValueChannels: NSNumber + let numOwnershipChannels: NSNumber + let numPolicyChannels: NSNumber + let trunk: SWTrunkDesc + let policyHead: SWPolicyHeadDesc + let valueHead: SWValueHeadDesc + + init( + version: Int, + name: String, + numInputChannels: NSNumber, + numInputGlobalChannels: NSNumber, + numInputMetaChannels: NSNumber, + numValueChannels: NSNumber, + numScoreValueChannels: NSNumber, + numOwnershipChannels: NSNumber, + numPolicyChannels: NSNumber, + trunk: SWTrunkDesc, + policyHead: SWPolicyHeadDesc, + valueHead: SWValueHeadDesc + ) { + self.version = version + self.name = name + self.numInputChannels = numInputChannels + self.numInputGlobalChannels = numInputGlobalChannels + self.numInputMetaChannels = numInputMetaChannels + self.numValueChannels = numValueChannels + self.numScoreValueChannels = numScoreValueChannels + self.numOwnershipChannels = numOwnershipChannels + self.numPolicyChannels = numPolicyChannels + self.trunk = trunk + self.policyHead = policyHead + self.valueHead = valueHead + } +} + +public func createSWModelDesc( + version: Int32, + name: String, + numInputChannels: Int32, + numInputGlobalChannels: Int32, + numInputMetaChannels: Int32, + numValueChannels: Int32, + numScoreValueChannels: Int32, + numOwnershipChannels: Int32, + numPolicyChannels: Int32, + trunk: SWTrunkDesc, + policyHead: SWPolicyHeadDesc, + valueHead: SWValueHeadDesc +) -> SWModelDesc { + return SWModelDesc( + version: Int(version), + name: name, + numInputChannels: numInputChannels as NSNumber, + numInputGlobalChannels: numInputGlobalChannels as NSNumber, + numInputMetaChannels: numInputMetaChannels as NSNumber, + numValueChannels: numValueChannels as NSNumber, + numScoreValueChannels: numScoreValueChannels as NSNumber, + numOwnershipChannels: numOwnershipChannels as NSNumber, + numPolicyChannels: numPolicyChannels as NSNumber, + trunk: trunk, + policyHead: policyHead, + valueHead: valueHead) +} + +// MARK: - MPSGraph Model (for GPU inference) + +/// A structure representing a neural network model for processing Go game states using MPSGraph. +struct MPSGraphModel { + let device: MTLDevice + let commandQueue: MTLCommandQueue + let graph: MPSGraph + let nnXLen: NSNumber + let nnYLen: NSNumber + let version: Int + let numValueChannels: NSNumber + let numScoreValueChannels: NSNumber + let numOwnershipChannels: NSNumber + let input: InputLayer + let inputGlobal: InputGlobalLayer + let inputMeta: InputMetaLayer + let mask: MaskLayer + let trunk: Trunk + let policyHead: PolicyHead + let valueHead: ValueHead + let targetTensors: [MPSGraphTensor] + + init( + device: MTLDevice, + graph: MPSGraph, + descriptor: SWModelDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + self.device = device + self.commandQueue = device.makeCommandQueue()! + self.graph = graph + self.nnXLen = nnXLen + self.nnYLen = nnYLen + self.version = descriptor.version + self.numValueChannels = descriptor.numValueChannels + self.numScoreValueChannels = descriptor.numScoreValueChannels + self.numOwnershipChannels = descriptor.numOwnershipChannels + + input = InputLayer( + graph: graph, + nnXLen: nnXLen, + nnYLen: nnYLen, + numChannels: descriptor.numInputChannels) + + inputGlobal = InputGlobalLayer( + graph: graph, + numGlobalFeatures: descriptor.numInputGlobalChannels) + + inputMeta = InputMetaLayer( + graph: graph, + numMetaFeatures: descriptor.numInputMetaChannels) + + mask = MaskLayer( + graph: graph, + nnXLen: nnXLen, + nnYLen: nnYLen) + + let maskSum = MaskSumLayer( + graph: graph, + maskTensor: mask.tensor) + + let maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer( + graph: graph, + maskSum: maskSum) + + let maskSumSqrtS14M01SquareS01 = MaskSumSqrtS14M01SquareS01Layer( + graph: graph, + maskSumSqrtS14M01: maskSumSqrtS14M01) + + trunk = Trunk( + graph: graph, + descriptor: descriptor.trunk, + inputTensor: input.tensor, + inputGlobalTensor: inputGlobal.tensor, + inputMetaTensor: inputMeta.tensor, + maskTensor: mask.tensor, + maskSumTensor: maskSum.tensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + policyHead = PolicyHead( + graph: graph, + descriptor: descriptor.policyHead, + sourceTensor: trunk.resultTensor, + maskTensor: mask.tensor, + maskSumTensor: maskSum.tensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + valueHead = ValueHead( + graph: graph, + descriptor: descriptor.valueHead, + sourceTensor: trunk.resultTensor, + maskTensor: mask.tensor, + maskSumTensor: maskSum.tensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, + maskSumSqrtS14M01SquareS01Tensor: maskSumSqrtS14M01SquareS01.tensor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + targetTensors = [ + policyHead.policyTensor, + policyHead.policyPassTensor, + valueHead.valueTensor, + valueHead.scoreValueTensor, + valueHead.ownershipTensor, + ] + } + + /// Applies the model to the given input data + public func apply( + input inputPointer: UnsafeMutablePointer, + inputGlobal inputGlobalPointer: UnsafeMutablePointer, + inputMeta inputMetaPointer: UnsafeMutablePointer, + policy: UnsafeMutablePointer, + policyPass: UnsafeMutablePointer, + value: UnsafeMutablePointer, + scoreValue: UnsafeMutablePointer, + ownership: UnsafeMutablePointer, + batchSize: Int + ) { + let channelAxis = InputShape.getChannelAxis() + let numInputChannels = input.shape[channelAxis] + + let inputShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: numInputChannels, + nnYLen: nnYLen, + nnXLen: nnXLen) + + let inputDescriptor = MPSNDArrayDescriptor( + dataType: input.tensor.dataType, + shape: inputShape) + + let inputArray = MPSNDArray( + device: device, + descriptor: inputDescriptor) + + inputArray.writeBytes(inputPointer) + + let numInputGlobalChannels = inputGlobal.shape[channelAxis] + + let inputGlobalShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: numInputGlobalChannels, + nnYLen: 1, + nnXLen: 1) + + let inputGlobalDescriptor = MPSNDArrayDescriptor( + dataType: inputGlobal.tensor.dataType, + shape: inputGlobalShape) + + let inputGlobalArray = MPSNDArray( + device: device, + descriptor: inputGlobalDescriptor) + + inputGlobalArray.writeBytes(inputGlobalPointer) + + let numInputMetaChannels = inputMeta.shape[channelAxis] + + let inputMetaShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: numInputMetaChannels, + nnYLen: 1, + nnXLen: 1) + + let inputMetaDescriptor = MPSNDArrayDescriptor( + dataType: inputMeta.tensor.dataType, + shape: inputMetaShape) + + let inputMetaArray = MPSNDArray( + device: device, + descriptor: inputMetaDescriptor) + + inputMetaArray.writeBytes(inputMetaPointer) + + let maskShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: 1, + nnYLen: nnYLen, + nnXLen: nnXLen) + + let maskDescriptor = MPSNDArrayDescriptor( + dataType: mask.tensor.dataType, + shape: maskShape) + + let maskArray = MPSNDArray( + device: device, + descriptor: maskDescriptor) + + var maskStrideArray = [ + MemoryLayout.size, + nnXLen.intValue * MemoryLayout.size, + nnYLen.intValue * nnXLen.intValue * MemoryLayout.size, + numInputChannels.intValue * nnYLen.intValue * nnXLen.intValue + * MemoryLayout.size, + ] + + maskArray.writeBytes(inputPointer, strideBytes: &maskStrideArray) + + let feeds = [ + input.tensor: MPSGraphTensorData(inputArray), + inputGlobal.tensor: MPSGraphTensorData(inputGlobalArray), + inputMeta.tensor: MPSGraphTensorData(inputMetaArray), + mask.tensor: MPSGraphTensorData(maskArray), + ] + + let fetch = graph.run( + with: commandQueue, + feeds: feeds, + targetTensors: targetTensors, + targetOperations: nil) + + assert(fetch[policyHead.policyTensor] != nil) + assert(fetch[policyHead.policyPassTensor] != nil) + assert(fetch[valueHead.valueTensor] != nil) + assert(fetch[valueHead.scoreValueTensor] != nil) + assert(fetch[valueHead.ownershipTensor] != nil) + + fetch[policyHead.policyTensor]?.mpsndarray().readBytes(policy) + fetch[policyHead.policyPassTensor]?.mpsndarray().readBytes(policyPass) + fetch[valueHead.valueTensor]?.mpsndarray().readBytes(value) + fetch[valueHead.scoreValueTensor]?.mpsndarray().readBytes(scoreValue) + fetch[valueHead.ownershipTensor]?.mpsndarray().readBytes(ownership) + } +} + +// MARK: - Test Infrastructure + +/// Helper struct for testing individual network layers using MPSGraph +struct NetworkTester { + let device: MTLDevice + let commandQueue: MTLCommandQueue + let graph: MPSGraph + let inputTensor: MPSGraphTensor + let maskTensor: MPSGraphTensor + let outputTensor: MPSGraphTensor + let inputShape: [NSNumber] + let maskShape: [NSNumber] + let outputShape: [NSNumber] + + /// Initialize a network tester for testing a single layer + init( + device: MTLDevice, + graph: MPSGraph, + inputTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + outputTensor: MPSGraphTensor, + batchSize: NSNumber, + nnXLen: NSNumber, + nnYLen: NSNumber, + inChannels: NSNumber, + outChannels: NSNumber + ) { + self.device = device + self.commandQueue = device.makeCommandQueue()! + self.graph = graph + self.inputTensor = inputTensor + self.maskTensor = maskTensor + self.outputTensor = outputTensor + self.inputShape = InputShape.create( + batchSize: batchSize, + numChannels: inChannels, + nnYLen: nnYLen, + nnXLen: nnXLen) + self.maskShape = InputShape.create( + batchSize: batchSize, + numChannels: 1, + nnYLen: nnYLen, + nnXLen: nnXLen) + self.outputShape = InputShape.create( + batchSize: batchSize, + numChannels: outChannels, + nnYLen: nnYLen, + nnXLen: nnXLen) + } + + /// Run the test with given input and mask data, writing results to output + func run( + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer + ) { + let inputDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: inputShape) + + let inputArray = MPSNDArray( + device: device, + descriptor: inputDescriptor) + + inputArray.writeBytes(UnsafeMutableRawPointer(mutating: inputPointer)) + + let maskDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: maskShape) + + let maskArray = MPSNDArray( + device: device, + descriptor: maskDescriptor) + + maskArray.writeBytes(UnsafeMutableRawPointer(mutating: maskPointer)) + + let feeds = [ + inputTensor: MPSGraphTensorData(inputArray), + maskTensor: MPSGraphTensorData(maskArray), + ] + + let fetch = graph.run( + with: commandQueue, + feeds: feeds, + targetTensors: [outputTensor], + targetOperations: nil) + + fetch[outputTensor]?.mpsndarray().readBytes(outputPointer) + } +} + +// MARK: - ConvLayer Test Extension + +extension ConvLayer { + /// Test the convolution layer with given parameters + static func test( + descriptor: SWConvLayerDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer + ) -> Bool { + guard let device = MTLCreateSystemDefaultDevice() else { + return false + } + + let graph = MPSGraph() + + let inputShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: descriptor.inChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputTensor = graph.placeholder( + shape: inputShape, + dataType: .float32, + name: nil) + + let convLayer = ConvLayer( + graph: graph, + sourceTensor: inputTensor, + descriptor: descriptor, + nnXLen: nnXLen as NSNumber, + nnYLen: nnYLen as NSNumber) + + // Run the graph + let commandQueue = device.makeCommandQueue()! + + let actualInputShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: descriptor.inChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualInputShape) + + let inputArray = MPSNDArray( + device: device, + descriptor: inputDescriptor) + + inputArray.writeBytes(UnsafeMutableRawPointer(mutating: inputPointer)) + + let feeds = [inputTensor: MPSGraphTensorData(inputArray)] + + let fetch = graph.run( + with: commandQueue, + feeds: feeds, + targetTensors: [convLayer.resultTensor], + targetOperations: nil) + + fetch[convLayer.resultTensor]?.mpsndarray().readBytes(outputPointer) + + return true + } +} + +// MARK: - BatchNormLayer Test Extension + +extension BatchNormLayer { + /// Test the batch normalization layer with given parameters + static func test( + descriptor: SWBatchNormLayerDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer + ) -> Bool { + guard let device = MTLCreateSystemDefaultDevice() else { + return false + } + + let graph = MPSGraph() + + let inputShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: descriptor.numChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputTensor = graph.placeholder( + shape: inputShape, + dataType: .float32, + name: nil) + + let maskShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: 1, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let maskTensor = graph.placeholder( + shape: maskShape, + dataType: .float32, + name: nil) + + let bnLayer = BatchNormLayer( + graph: graph, + sourceTensor: inputTensor, + maskTensor: maskTensor, + descriptor: descriptor, + nnXLen: nnXLen as NSNumber, + nnYLen: nnYLen as NSNumber) + + // Run the graph + let commandQueue = device.makeCommandQueue()! + + let actualInputShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: descriptor.numChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualInputShape) + + let inputArray = MPSNDArray( + device: device, + descriptor: inputDescriptor) + + inputArray.writeBytes(UnsafeMutableRawPointer(mutating: inputPointer)) + + let actualMaskShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: 1, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let maskDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualMaskShape) + + let maskArray = MPSNDArray( + device: device, + descriptor: maskDescriptor) + + maskArray.writeBytes(UnsafeMutableRawPointer(mutating: maskPointer)) + + let feeds = [ + inputTensor: MPSGraphTensorData(inputArray), + maskTensor: MPSGraphTensorData(maskArray), + ] + + let fetch = graph.run( + with: commandQueue, + feeds: feeds, + targetTensors: [bnLayer.resultTensor], + targetOperations: nil) + + fetch[bnLayer.resultTensor]?.mpsndarray().readBytes(outputPointer) + + return true + } +} + +// MARK: - ResidualBlock Test Extension + +extension ResidualBlock { + /// Test the residual block with given parameters + static func test( + descriptor: SWResidualBlockDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer + ) -> Bool { + guard let device = MTLCreateSystemDefaultDevice() else { + return false + } + + let graph = MPSGraph() + + let inputShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: descriptor.preBN.numChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputTensor = graph.placeholder( + shape: inputShape, + dataType: .float32, + name: nil) + + let maskShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: 1, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let maskTensor = graph.placeholder( + shape: maskShape, + dataType: .float32, + name: nil) + + let resBlock = ResidualBlock( + graph: graph, + sourceTensor: inputTensor, + maskTensor: maskTensor, + descriptor: descriptor, + nnXLen: nnXLen as NSNumber, + nnYLen: nnYLen as NSNumber) + + // Run the graph + let commandQueue = device.makeCommandQueue()! + + let actualInputShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: descriptor.preBN.numChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualInputShape) + + let inputArray = MPSNDArray( + device: device, + descriptor: inputDescriptor) + + inputArray.writeBytes(UnsafeMutableRawPointer(mutating: inputPointer)) + + let actualMaskShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: 1, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let maskDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualMaskShape) + + let maskArray = MPSNDArray( + device: device, + descriptor: maskDescriptor) + + maskArray.writeBytes(UnsafeMutableRawPointer(mutating: maskPointer)) + + let feeds = [ + inputTensor: MPSGraphTensorData(inputArray), + maskTensor: MPSGraphTensorData(maskArray), + ] + + let fetch = graph.run( + with: commandQueue, + feeds: feeds, + targetTensors: [resBlock.resultTensor], + targetOperations: nil) + + fetch[resBlock.resultTensor]?.mpsndarray().readBytes(outputPointer) + + return true + } +} + +// MARK: - GlobalPoolingResidualBlock Test Extension + +extension GlobalPoolingResidualBlock { + /// Test the global pooling residual block with given parameters + static func test( + descriptor: SWGlobalPoolingResidualBlockDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer + ) -> Bool { + guard let device = MTLCreateSystemDefaultDevice() else { + return false + } + + let graph = MPSGraph() + + let inputShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: descriptor.preBN.numChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputTensor = graph.placeholder( + shape: inputShape, + dataType: .float32, + name: nil) + + let maskShape = InputShape.create( + batchSize: -1 as NSNumber, + numChannels: 1, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let maskTensor = graph.placeholder( + shape: maskShape, + dataType: .float32, + name: nil) + + // Compute mask sum and related tensors from mask + let maskSum = MaskSumLayer(graph: graph, maskTensor: maskTensor) + let maskSumSqrtS14M01 = MaskSumSqrtS14M01Layer(graph: graph, maskSum: maskSum) + + let gpoolBlock = GlobalPoolingResidualBlock( + graph: graph, + sourceTensor: inputTensor, + maskTensor: maskTensor, + maskSumTensor: maskSum.tensor, + maskSumSqrtS14M01Tensor: maskSumSqrtS14M01.tensor, + descriptor: descriptor, + nnXLen: nnXLen as NSNumber, + nnYLen: nnYLen as NSNumber) + + // Run the graph + let commandQueue = device.makeCommandQueue()! + + let actualInputShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: descriptor.preBN.numChannels, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let inputDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualInputShape) + + let inputArray = MPSNDArray( + device: device, + descriptor: inputDescriptor) + + inputArray.writeBytes(UnsafeMutableRawPointer(mutating: inputPointer)) + + let actualMaskShape = InputShape.create( + batchSize: batchSize as NSNumber, + numChannels: 1, + nnYLen: nnYLen as NSNumber, + nnXLen: nnXLen as NSNumber) + + let maskDescriptor = MPSNDArrayDescriptor( + dataType: .float32, + shape: actualMaskShape) + + let maskArray = MPSNDArray( + device: device, + descriptor: maskDescriptor) + + maskArray.writeBytes(UnsafeMutableRawPointer(mutating: maskPointer)) + + let feeds = [ + inputTensor: MPSGraphTensorData(inputArray), + maskTensor: MPSGraphTensorData(maskArray), + ] + + let fetch = graph.run( + with: commandQueue, + feeds: feeds, + targetTensors: [gpoolBlock.resultTensor], + targetOperations: nil) + + fetch[gpoolBlock.resultTensor]?.mpsndarray().readBytes(outputPointer) + + return true + } +} + +// MARK: - Public Test Functions (callable from C++) + +/// Test the convolution layer +public func testConvLayer( + descriptor: SWConvLayerDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer +) -> Bool { + return ConvLayer.test( + descriptor: descriptor, + batchSize: batchSize, + nnXLen: nnXLen, + nnYLen: nnYLen, + inputPointer: inputPointer, + outputPointer: outputPointer) +} + +/// Test the batch normalization layer +public func testBatchNormLayer( + descriptor: SWBatchNormLayerDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer +) -> Bool { + return BatchNormLayer.test( + descriptor: descriptor, + batchSize: batchSize, + nnXLen: nnXLen, + nnYLen: nnYLen, + inputPointer: inputPointer, + maskPointer: maskPointer, + outputPointer: outputPointer) +} + +/// Test the residual block +public func testResidualBlock( + descriptor: SWResidualBlockDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer +) -> Bool { + return ResidualBlock.test( + descriptor: descriptor, + batchSize: batchSize, + nnXLen: nnXLen, + nnYLen: nnYLen, + inputPointer: inputPointer, + maskPointer: maskPointer, + outputPointer: outputPointer) +} + +/// Test the global pooling residual block +public func testGlobalPoolingResidualBlock( + descriptor: SWGlobalPoolingResidualBlockDesc, + batchSize: Int32, + nnXLen: Int32, + nnYLen: Int32, + inputPointer: UnsafePointer, + maskPointer: UnsafePointer, + outputPointer: UnsafeMutablePointer +) -> Bool { + return GlobalPoolingResidualBlock.test( + descriptor: descriptor, + batchSize: batchSize, + nnXLen: nnXLen, + nnYLen: nnYLen, + inputPointer: inputPointer, + maskPointer: maskPointer, + outputPointer: outputPointer) +}