From 6656158d23141b96f7abb712c90192382d19dc49 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 14:57:46 -0400 Subject: [PATCH 1/5] chore(O): remove SIMD/NEON CPU backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete cpu_buffer.h, compute_backend.mm (abandoned Obj-C++ draft), and the simd/ directory. Remove the misleadingly-named simd_graph() helper. Apple Silicon → MLX only. Linux/Windows → CUDA/ROCm (future phases). No CPU SIMD fallback path will ever be needed for LLM-scale inference. --- .../src/compute/backends/simd/cpu_buffer.h | 62 ------ compute/src/compute/core/compute_backend.mm | 189 ------------------ compute/src/compute/core/graph.h | 4 - 3 files changed, 255 deletions(-) delete mode 100644 compute/src/compute/backends/simd/cpu_buffer.h delete mode 100644 compute/src/compute/core/compute_backend.mm diff --git a/compute/src/compute/backends/simd/cpu_buffer.h b/compute/src/compute/backends/simd/cpu_buffer.h deleted file mode 100644 index 7399eab..0000000 --- a/compute/src/compute/backends/simd/cpu_buffer.h +++ /dev/null @@ -1,62 +0,0 @@ -#pragma once - -#include "../../core/tensor.h" -#include - -namespace compute { - -/** - * CPU-based buffer implementation for SIMD backends - * Simple wrapper around std::vector for CPU memory - */ -class CpuBuffer : public BackendBuffer { -private: - std::vector data_; - BackendType backend_type_; - -public: - /** - * Create buffer with data (copy from input) - */ - CpuBuffer(std::span data, BackendType backend_type) - : data_(data.begin(), data.end()), backend_type_(backend_type) {} - - /** - * Create uninitialized buffer - */ - CpuBuffer(size_t size, BackendType backend_type) - : data_(size), backend_type_(backend_type) {} - - /** - * Create buffer from existing vector (move) - */ - CpuBuffer(std::vector data, BackendType backend_type) - : data_(std::move(data)), backend_type_(backend_type) {} - - void* get_data() override { - return data_.data(); - } - - size_t get_size() const override { - return data_.size() * sizeof(float); - } - - BackendType get_backend_type() const override { - return backend_type_; - } - - void evaluate() override { - // CPU operations are eager - nothing to evaluate - } - - // CPU-specific access - std::vector& cpu_data() { return data_; } - const std::vector& cpu_data() const { return data_; } - - float* data_ptr() { return data_.data(); } - const float* data_ptr() const { return data_.data(); } - - size_t num_elements() const { return data_.size(); } -}; - -} // namespace compute \ No newline at end of file diff --git a/compute/src/compute/core/compute_backend.mm b/compute/src/compute/core/compute_backend.mm deleted file mode 100644 index 4dbf229..0000000 --- a/compute/src/compute/core/compute_backend.mm +++ /dev/null @@ -1,189 +0,0 @@ -#include "compute_backend.h" -#include "../backends/simd/neon_backend.h" -#if defined(__APPLE__) && defined(__aarch64__) -#include "../backends/metal/metal_backend.h" -#if defined(MLX_BACKEND_ENABLED) -#include "../backends/mlx/mlx_backend.h" -#endif -#endif -#include -#include - -namespace compute { - -// BackendFactory implementation -Result> BackendFactory::create(BackendType type) { - switch (type) { - case BackendType::SimdNeon: -#ifdef __ARM_NEON - return std::make_unique(); -#else - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "NEON backend only available on ARM platforms with NEON support"}); -#endif - - case BackendType::Auto: { - // Try Metal first (GPU acceleration), then NEON (CPU SIMD) -#if defined(__APPLE__) && defined(__aarch64__) - auto metal_backend = std::make_unique(); - if (metal_backend->is_available()) { - return std::move(metal_backend); - } -#endif -#ifdef __ARM_NEON - auto neon_backend = std::make_unique(); - if (neon_backend->is_available()) { - auto init_result = neon_backend->initialize(); - if (init_result) { - return std::move(neon_backend); - } - } -#endif - // No viable backend found - return std::unexpected(Error{ErrorCode::BackendNotAvailable, - "No compute backend available. LLM inference requires SIMD support (NEON/Metal on this platform)"}); - } - - case BackendType::Metal: -#if defined(__APPLE__) && defined(__aarch64__) - { - auto metal_backend = std::make_unique(); - if (metal_backend->is_available()) { - auto init_result = metal_backend->initialize(); - if (init_result) { - return std::move(metal_backend); - } - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Metal backend failed to initialize"}); - } - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Metal device not available"}); - } -#else - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Metal backend only available on Apple Silicon"}); -#endif - - case BackendType::MLX: -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - { - auto mlx_backend = std::make_unique(); - if (mlx_backend->is_available()) { - auto init_result = mlx_backend->initialize(); - if (init_result) { - return std::move(mlx_backend); - } - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "MLX backend failed to initialize"}); - } - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "MLX not available"}); - } -#else - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "MLX backend only available on Apple Silicon with MLX enabled"}); -#endif - - default: - return std::unexpected(Error{ErrorCode::UnknownError, "Unknown backend type"}); - } -} - -std::vector BackendFactory::available_backends() { - std::vector backends; - - // Check MLX availability (Apple Silicon with MLX enabled) -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - backends.push_back(BackendType::MLX); -#endif - - // Check Metal availability (Apple Silicon only) -#if defined(__APPLE__) && defined(__aarch64__) - backends.push_back(BackendType::Metal); -#endif - - // Check NEON availability at compile time -#ifdef __ARM_NEON - backends.push_back(BackendType::SimdNeon); -#endif - - return backends; -} - -BackendType BackendFactory::best_available_backend() { - // For now, just return the first available - auto available = available_backends(); - return available.empty() ? BackendType::Auto : available[0]; -} - -// BackendManager implementation -BackendManager& BackendManager::instance() { - static BackendManager instance; - return instance; -} - -Result BackendManager::initialize() { - if (initialized_) { - return {}; // Already initialized - } - - // Create available backends - create_available_backends(); - - // Set default backend - if (!backends_.empty()) { - default_backend_ = backends_[0].get(); - } - - initialized_ = true; - return {}; -} - -void BackendManager::cleanup() { - if (!initialized_) return; - - // Clean up all backends - for (auto& backend : backends_) { - backend->cleanup(); - } - - backends_.clear(); - default_backend_ = nullptr; - initialized_ = false; -} - -ComputeBackend* BackendManager::get_backend(BackendType type) { - if (!initialized_) { - auto init_result = initialize(); - if (!init_result) return nullptr; - } - - // Find backend of requested type - for (auto& backend : backends_) { - if (backend->type() == type) { - return backend.get(); - } - } - - return nullptr; -} - -ComputeBackend* BackendManager::get_default_backend() { - if (!initialized_) { - auto init_result = initialize(); - if (!init_result) return nullptr; - } - - return default_backend_; -} - -void BackendManager::create_available_backends() { - // Create all available backends - auto available_types = BackendFactory::available_backends(); - - for (auto type : available_types) { - auto backend_result = BackendFactory::create(type); - if (backend_result) { - auto& backend = *backend_result; - auto init_result = backend->initialize(); - if (init_result) { - backends_.push_back(std::move(backend)); - } - } - } -} - -} // namespace compute diff --git a/compute/src/compute/core/graph.h b/compute/src/compute/core/graph.h index 24d18ba..621a572 100644 --- a/compute/src/compute/core/graph.h +++ b/compute/src/compute/core/graph.h @@ -206,10 +206,6 @@ inline ComputeGraphBuilder graph(BackendType backend = BackendType::Auto) { return ComputeGraphBuilder(backend); } -inline ComputeGraphBuilder simd_graph() { - return ComputeGraphBuilder(BackendType::MLX); -} - inline ComputeGraphBuilder metal_graph() { return ComputeGraphBuilder(BackendType::Metal); } From 15858a985dda58d44fb578824775d00a3ace9c44 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 15:17:36 -0400 Subject: [PATCH 2/5] chore(O): remove SIMD/NEON backend and BackendType::Metal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete cpu_buffer.h, compute_backend.mm (abandoned Obj-C++ draft), and the simd/ directory. Remove simd_graph() and metal_graph() helpers, and drop BackendType::Metal from the enum. Apple Silicon → MLX only. Linux/Windows → CUDA/ROCm (future phases). No CPU SIMD or bare Metal backend will be added for LLM-scale inference. --- compute/src/compute/core/compute_backend.cpp | 3 --- compute/src/compute/core/compute_types.h | 1 - compute/src/compute/core/graph.h | 4 ---- compute/tests/compute/test_model_loader.cpp | 2 +- service/src/neurons_service.cpp | 4 ---- 5 files changed, 1 insertion(+), 13 deletions(-) diff --git a/compute/src/compute/core/compute_backend.cpp b/compute/src/compute/core/compute_backend.cpp index 1da3c68..80da442 100644 --- a/compute/src/compute/core/compute_backend.cpp +++ b/compute/src/compute/core/compute_backend.cpp @@ -26,9 +26,6 @@ Result> BackendFactory::create(BackendType type) return std::unexpected(Error{ErrorCode::BackendNotAvailable, "MLX backend only available on Apple Silicon with MLX enabled"}); #endif - case BackendType::Metal: - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Metal backend not yet implemented; use MLX"}); - default: return std::unexpected(Error{ErrorCode::UnknownError, "Unknown backend type"}); } diff --git a/compute/src/compute/core/compute_types.h b/compute/src/compute/core/compute_types.h index 4f0c692..6e01a0c 100644 --- a/compute/src/compute/core/compute_types.h +++ b/compute/src/compute/core/compute_types.h @@ -34,7 +34,6 @@ using Result = std::expected; // Backend types enum class BackendType { - Metal, MLX, // Apple MLX framework (Metal-accelerated on Apple Silicon) Auto // Let system choose best available }; diff --git a/compute/src/compute/core/graph.h b/compute/src/compute/core/graph.h index 621a572..6264f8c 100644 --- a/compute/src/compute/core/graph.h +++ b/compute/src/compute/core/graph.h @@ -206,8 +206,4 @@ inline ComputeGraphBuilder graph(BackendType backend = BackendType::Auto) { return ComputeGraphBuilder(backend); } -inline ComputeGraphBuilder metal_graph() { - return ComputeGraphBuilder(BackendType::Metal); -} - } // namespace compute diff --git a/compute/tests/compute/test_model_loader.cpp b/compute/tests/compute/test_model_loader.cpp index 4568321..91a0427 100644 --- a/compute/tests/compute/test_model_loader.cpp +++ b/compute/tests/compute/test_model_loader.cpp @@ -205,7 +205,7 @@ TEST_F(ModelLoaderTest, FindSafetensorsFiles) { // We'll use a mock backend that claims to be something other than MLX class NonMLXMockBackend : public ComputeBackend { public: - BackendType type() const override { return BackendType::Metal; } // Not MLX + BackendType type() const override { return BackendType::Auto; } // Not MLX std::string name() const override { return "NonMLXMock"; } bool is_available() const override { return false; } Result initialize() override { return {}; } diff --git a/service/src/neurons_service.cpp b/service/src/neurons_service.cpp index 316a88f..c336f02 100644 --- a/service/src/neurons_service.cpp +++ b/service/src/neurons_service.cpp @@ -140,10 +140,6 @@ grpc::Status NeuronsServiceImpl::GetStatus(grpc::ServerContext* /*ctx*/, backend_name = "mlx"; gpu_name = "Apple Silicon"; break; - case compute::BackendType::Metal: - backend_name = "metal"; - gpu_name = "Apple GPU"; - break; default: break; } } From 1f78bd18c298e12d41195317d8e9716702c04094 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 16:04:18 -0400 Subject: [PATCH 3/5] chore: remove Tensor/ComputeGraph abstractions and non-MLX model fallbacks Phase O cleanup: the Tensor/BackendBuffer and ComputeGraph/ComputeGraphBuilder abstractions were bypassed entirely on Apple Silicon (all three model families used mlx_weights_ directly). Removing them closes ~7 100 lines of dead code and leaves ComputeBackend as a thin lifecycle handle only. Removed: - core/tensor.{h,cpp}, core/graph.{h,cpp} - backends/mlx/mlx_buffer.h, mlx_utils.h - model/kv_cache.h - model/gemma_model{,_base}.{h,cpp} - model/qwen3_moe_model{,_base}.{h,cpp} - tests/compute/test_symbolic_api.cpp, test_mlx_backend.cpp Simplified: - ComputeBackend: 5 lifecycle methods only (type/name/is_available/initialize/cleanup) - MlxBackend: implements those 5 methods; ~730 lines of Tensor ops deleted - LlamaModel, GemmaModelMLX, Qwen3MoeModelMLX: removed inheritance from base Tensor-path classes; MLX classes own config_ and tokenizer_ directly - ModelLoader: load_model()/load_all_safetensors() removed; load_model_mlx() kept - language_model.cpp: Gemma/Qwen3MoE dispatch is now MLX-only - BackendType::Metal removed (vestigial, never instantiated) - Tests updated to remove calls to deleted APIs (forward(), attention_layer(), wrap_native_tensor(), load_model(backend)) --- compute/CMakeLists.txt | 16 - compute/src/compute/backends/mlx/mlx_buffer.h | 121 -- compute/src/compute/backends/mlx/mlx_utils.h | 128 -- compute/src/compute/core/graph.cpp | 232 --- compute/src/compute/core/graph.h | 209 -- compute/src/compute/core/tensor.cpp | 18 - compute/src/compute/core/tensor.h | 164 -- compute/src/compute/model/gemma_model.cpp | 497 ----- compute/src/compute/model/gemma_model.h | 90 - .../src/compute/model/gemma_model_base.cpp | 27 - compute/src/compute/model/gemma_model_base.h | 29 - compute/src/compute/model/kv_cache.h | 36 - compute/src/compute/model/llama_model.cpp | 481 +---- compute/src/compute/model/qwen3_moe_model.cpp | 901 -------- compute/src/compute/model/qwen3_moe_model.h | 102 - .../compute/model/qwen3_moe_model_base.cpp | 37 - .../src/compute/model/qwen3_moe_model_base.h | 30 - .../compute/test_attention_qkv_trace.cpp | 147 +- compute/tests/compute/test_forward_pass.cpp | 32 - .../compute/test_mistral_integration.cpp | 40 - compute/tests/compute/test_mlx_backend.cpp | 1817 ----------------- compute/tests/compute/test_model_loader.cpp | 378 +--- compute/tests/compute/test_symbolic_api.cpp | 169 -- 23 files changed, 14 insertions(+), 5687 deletions(-) delete mode 100644 compute/src/compute/backends/mlx/mlx_buffer.h delete mode 100644 compute/src/compute/backends/mlx/mlx_utils.h delete mode 100644 compute/src/compute/core/graph.cpp delete mode 100644 compute/src/compute/core/graph.h delete mode 100644 compute/src/compute/core/tensor.cpp delete mode 100644 compute/src/compute/core/tensor.h delete mode 100644 compute/src/compute/model/gemma_model.cpp delete mode 100644 compute/src/compute/model/gemma_model.h delete mode 100644 compute/src/compute/model/gemma_model_base.cpp delete mode 100644 compute/src/compute/model/gemma_model_base.h delete mode 100644 compute/src/compute/model/kv_cache.h delete mode 100644 compute/src/compute/model/qwen3_moe_model.cpp delete mode 100644 compute/src/compute/model/qwen3_moe_model.h delete mode 100644 compute/src/compute/model/qwen3_moe_model_base.cpp delete mode 100644 compute/src/compute/model/qwen3_moe_model_base.h delete mode 100644 compute/tests/compute/test_mlx_backend.cpp delete mode 100644 compute/tests/compute/test_symbolic_api.cpp diff --git a/compute/CMakeLists.txt b/compute/CMakeLists.txt index 56e1848..83bfdf6 100644 --- a/compute/CMakeLists.txt +++ b/compute/CMakeLists.txt @@ -50,8 +50,6 @@ endif() # ── Library ─────────────────────────────────────────────────────────────────── add_library(compute_backend src/compute/core/compute_backend.cpp - src/compute/core/tensor.cpp - src/compute/core/graph.cpp src/compute/model/model_config.cpp src/compute/model/model_loader.cpp src/compute/model/tokenizer_config.cpp @@ -60,10 +58,6 @@ add_library(compute_backend src/compute/model/chat_template.cpp src/compute/model/language_model.cpp src/compute/model/llama_model.cpp - src/compute/model/gemma_model_base.cpp - src/compute/model/gemma_model.cpp - src/compute/model/qwen3_moe_model_base.cpp - src/compute/model/qwen3_moe_model.cpp ${BACKEND_SOURCES} ) @@ -111,16 +105,7 @@ if(BUILD_TESTING) set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) - # Platform-specific test-only sources (backend/model sources come via compute_backend) - set(TEST_PLATFORM_SOURCES "") - if(MLX_AVAILABLE AND TARGET mlx) - list(APPEND TEST_PLATFORM_SOURCES - tests/compute/test_mlx_backend.cpp - ) - endif() - add_executable(compute_tests - tests/compute/test_symbolic_api.cpp tests/compute/test_model_config.cpp tests/compute/test_model_loader.cpp tests/compute/test_tokenizer_config.cpp @@ -134,7 +119,6 @@ if(BUILD_TESTING) tests/compute/test_gemma_integration.cpp tests/compute/test_qwen3_moe_integration.cpp tests/compute/test_chat_template.cpp - ${TEST_PLATFORM_SOURCES} ) target_link_libraries(compute_tests PRIVATE diff --git a/compute/src/compute/backends/mlx/mlx_buffer.h b/compute/src/compute/backends/mlx/mlx_buffer.h deleted file mode 100644 index 95d6744..0000000 --- a/compute/src/compute/backends/mlx/mlx_buffer.h +++ /dev/null @@ -1,121 +0,0 @@ -#pragma once - -#include "../../core/tensor.h" - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) -#include -namespace mx = mlx::core; - -namespace compute { - -/** - * MLX-based buffer implementation using mx::array - * Wraps MLX lazy-evaluated arrays for tensor operations - */ -class MLXBuffer : public BackendBuffer { -private: - mx::array mlx_array_; - mutable bool evaluated_ = false; - -public: - /** - * Create buffer wrapping existing MLX array - */ - explicit MLXBuffer(mx::array array) - : mlx_array_(std::move(array)), evaluated_(false) {} - - /** - * Create buffer from data (creates MLX array) - */ - MLXBuffer(std::span data, const std::vector& shape) - : mlx_array_(create_array_from_data(data, shape)), evaluated_(false) { - } - - /** - * Create uninitialized buffer with given shape - */ - MLXBuffer(const std::vector& shape) - : mlx_array_(create_zeros_array(shape)), evaluated_(false) { - } - - ~MLXBuffer() override = default; - - void* get_data() override { - evaluate(); - // Convert to float32 if needed (e.g., from bfloat16) - // MLX .data() only works if the array dtype matches T exactly - if (mlx_array_.dtype() != mx::float32) { - mlx_array_ = mx::astype(mlx_array_, mx::float32); - mx::eval(mlx_array_); - evaluated_ = true; // Re-mark as evaluated after conversion - } - return mlx_array_.data(); - } - - size_t get_size() const override { - return mlx_array_.nbytes(); - } - - BackendType get_backend_type() const override { - return BackendType::MLX; - } - - void evaluate() override { - if (!evaluated_) { - mx::eval(mlx_array_); // Force MLX computation - evaluated_ = true; - } - } - - void evaluate() const { - if (!evaluated_) { - mx::eval(mlx_array_); // Force MLX computation - evaluated_ = true; - } - } - - // MLX-specific access - const mx::array& mlx_array() const { return mlx_array_; } - mx::array& mlx_array() { return mlx_array_; } - - float* data_ptr() { - evaluate(); - return mlx_array_.data(); - } - - const float* data_ptr() const { - evaluate(); - return mlx_array_.data(); - } - - size_t num_elements() const { return mlx_array_.size(); } - -private: - // Helper methods to create MLX arrays - static mx::array create_array_from_data(std::span data, const std::vector& shape) { - // Convert shape to MLX format (Shape is SmallVector) - mx::Shape mlx_shape; - mlx_shape.reserve(shape.size()); - for (size_t dim : shape) { - mlx_shape.push_back(static_cast(dim)); - } - - // Create MLX array using iterator constructor - return mx::array(data.begin(), mlx_shape, mx::float32); - } - - static mx::array create_zeros_array(const std::vector& shape) { - // Convert shape to MLX format (Shape is SmallVector) - mx::Shape mlx_shape; - mlx_shape.reserve(shape.size()); - for (size_t dim : shape) { - mlx_shape.push_back(static_cast(dim)); - } - - return mx::zeros(mlx_shape, mx::float32); - } -}; - -} // namespace compute - -#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) \ No newline at end of file diff --git a/compute/src/compute/backends/mlx/mlx_utils.h b/compute/src/compute/backends/mlx/mlx_utils.h deleted file mode 100644 index 824f39d..0000000 --- a/compute/src/compute/backends/mlx/mlx_utils.h +++ /dev/null @@ -1,128 +0,0 @@ -#pragma once - -#include "../../core/compute_types.h" -#include "../../core/tensor.h" -#include "mlx_buffer.h" - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) -#include -#include -#include - -namespace mx = mlx::core; - -namespace compute::mlx_utils { - -// Helper to validate a single tensor -constexpr auto validate_single_mlx_tensor(const Tensor& tensor) -> std::optional { - if (tensor.backend_type() != BackendType::MLX) { - return Error{ErrorCode::InvalidInput, "Tensor must be from MLX backend"}; - } - return std::nullopt; -} - -// Variadic validation using fold expression (C++17) -template -constexpr auto validate_mlx_tensors(const Tensors&... tensors) -> std::optional { - std::optional error = std::nullopt; - ((error = validate_single_mlx_tensor(tensors), error.has_value()) || ...); - return error; -} - -// Macro for validating one or more MLX tensors -#define VALIDATE_MLX_TENSOR(...) \ - do { \ - if (auto error = compute::mlx_utils::validate_mlx_tensors(__VA_ARGS__)) { \ - return std::unexpected(*error); \ - } \ - } while(0) - -// Concept to detect types that can convert to mx::array -template -concept MLXConvertible = requires(const T& t) { - { t.to_mlx() } -> std::convertible_to; -}; - -// Automatic conversion function -template -constexpr auto to_mlx_auto(const T& tensor) -> mx::array { - return tensor.to_mlx(); -} - -// Overload for mx::array (identity) -constexpr auto to_mlx_auto(const mx::array& array) -> const mx::array& { - return array; -} - -// Generic wrapper for MLX operations that may throw -template -auto mlx_safe(Func&& func, Args&&... args) noexcept - -> std::expected, Error> { - try { - return std::forward(func)(std::forward(args)...); - } catch (const std::invalid_argument& e) { - return std::unexpected(Error{ErrorCode::InvalidInput, - std::string("MLX invalid argument: ") + e.what()}); - } catch (const std::runtime_error& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX runtime error: ") + e.what()}); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX error: ") + e.what()}); - } catch (...) { - return std::unexpected(Error{ErrorCode::ComputeError, - "MLX unknown error"}); - } -} - -// Specialized wrapper for operations returning mx::array -> Tensor -// This overload automatically converts arguments using ADL and concepts -template -auto mlx_tensor_op(const std::vector& expected_shape, Func&& func, const Args&... args) noexcept - -> Result { - auto result = mlx_safe([&]() { - return std::forward(func)(to_mlx_auto(args)...); - }); - - if (!result) { - return std::unexpected(result.error()); - } - - // Convert mx::array to Tensor - auto mlx_array = *result; - auto mlx_shape = mlx_array.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(mlx_array); - return Tensor(result_buffer, result_shape); -} - -// Fallback for lambda-based operations (original implementation) -template -auto mlx_tensor_op(const std::vector& expected_shape, Func&& func) noexcept - -> Result { - auto result = mlx_safe(std::forward(func)); - if (!result) { - return std::unexpected(result.error()); - } - - // Convert mx::array to Tensor - auto mlx_array = *result; - auto mlx_shape = mlx_array.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(mlx_array); - return Tensor(result_buffer, result_shape); -} - -// Helper to compute broadcast shape for binary operations -inline std::vector broadcast_shape(const std::vector& a_shape, - const std::vector& b_shape) { - // For now, return the larger shape - MLX will handle broadcasting - // This is a simplified implementation; MLX does the actual broadcast computation - return a_shape.size() >= b_shape.size() ? a_shape : b_shape; -} - -} // namespace compute::mlx_utils - -#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) diff --git a/compute/src/compute/core/graph.cpp b/compute/src/compute/core/graph.cpp deleted file mode 100644 index 835e58f..0000000 --- a/compute/src/compute/core/graph.cpp +++ /dev/null @@ -1,232 +0,0 @@ -#include "graph.h" -#include "compute_backend.h" -#include -#include - -namespace compute { - -// ComputeGraph implementations -NodeId ComputeGraph::add_node(GraphNode node) { - NodeId id = next_node_id_++; - node.id = id; // Ensure node has correct ID - nodes_.push_back(std::move(node)); - return id; -} - -const GraphNode* ComputeGraph::get_node(NodeId id) const { - auto it = std::find_if(nodes_.begin(), nodes_.end(), - [id](const GraphNode& node) { return node.id == id; }); - return it != nodes_.end() ? &(*it) : nullptr; -} - -Result ComputeGraph::execute() { - // Get compute backend - auto backend_result = get_backend(); - if (!backend_result) { - return std::unexpected(backend_result.error()); - } - ComputeBackend* backend = *backend_result; - - // Clear intermediate results from previous executions - intermediate_scalars_.clear(); - intermediate_tensors_.clear(); - - // Get execution order via topological sort - auto execution_order = topological_sort(); - - // Execute nodes in dependency order - for (NodeId node_id : execution_order) { - const GraphNode* node = get_node(node_id); - if (!node) { - return std::unexpected(Error{ErrorCode::UnknownError, "Node not found during execution"}); - } - - auto exec_result = execute_node(*node, backend); - if (!exec_result) { - return std::unexpected(exec_result.error()); - } - } - - // For now, return empty result - in the future this could return final outputs - std::vector result_data = {0.0f}; - return ComputeResult(Result>(std::move(result_data)), {1}); -} - -Result ComputeGraph::get_backend() { - auto& manager = BackendManager::instance(); - auto init_result = manager.initialize(); - if (!init_result) { - return std::unexpected(init_result.error()); - } - - ComputeBackend* backend = nullptr; - if (backend_type_ == BackendType::Auto) { - backend = manager.get_default_backend(); - } else { - backend = manager.get_backend(backend_type_); - } - - if (!backend) { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Requested backend not available"}); - } - - return backend; -} - -std::vector ComputeGraph::topological_sort() const { - std::vector result; - std::unordered_map in_degree; - std::queue ready_queue; - - // Initialize in-degree count for all nodes - for (const auto& node : nodes_) { - in_degree[node.id] = node.dependencies.size(); - - // If node has no dependencies, it's ready to execute - if (node.dependencies.empty()) { - ready_queue.push(node.id); - } - } - - // Process nodes in topological order - while (!ready_queue.empty()) { - NodeId current = ready_queue.front(); - ready_queue.pop(); - result.push_back(current); - - // Find all nodes that depend on the current node - for (const auto& node : nodes_) { - // Check if this node depends on the current node - auto it = std::find(node.dependencies.begin(), node.dependencies.end(), current); - if (it != node.dependencies.end()) { - // Reduce in-degree - in_degree[node.id]--; - - // If in-degree becomes 0, node is ready to execute - if (in_degree[node.id] == 0) { - ready_queue.push(node.id); - } - } - } - } - - // Check for cycles (if we didn't process all nodes) - if (result.size() != nodes_.size()) { - // There's a cycle in the dependency graph - // For now, just return the partial result - in the future we should handle this error - // TODO: Return error result indicating cyclic dependency - } - - return result; -} - -Result ComputeGraph::execute_node(const GraphNode& node, ComputeBackend* backend) { - return std::visit([this, &node, backend](const auto& params) -> Result { - using ParamType = std::decay_t; - - if constexpr (std::is_same_v) { - // Resolve tensor inputs (could be immediate or symbolic) - auto tensor_a_result = resolve_tensor_input(params.input_a); - if (!tensor_a_result) return std::unexpected(tensor_a_result.error()); - - auto tensor_b_result = resolve_tensor_input(params.input_b); - if (!tensor_b_result) return std::unexpected(tensor_b_result.error()); - - // Execute dot product - auto result_tensor = backend->dot_product(*tensor_a_result, *tensor_b_result); - - // Extract result to output buffer - auto extract_result = backend->extract(result_tensor, params.output); - if (!extract_result) return std::unexpected(extract_result.error()); - - // Store intermediate result for potential use by other nodes - if (params.output.size() == 1) { - intermediate_scalars_[node.id] = {params.output[0]}; - } - - return {}; - - } else if constexpr (std::is_same_v) { - // Resolve tensor input - auto tensor_result = resolve_tensor_input(params.input_tensor); - if (!tensor_result) return std::unexpected(tensor_result.error()); - - // Resolve scalar input (could be immediate value or symbolic) - auto scalar_result = resolve_scalar_input(params.scalar); - if (!scalar_result) return std::unexpected(scalar_result.error()); - - // Execute matrix scalar add - auto result_tensor = backend->matrix_scalar_add(*tensor_result, *scalar_result); - - // Extract result to output buffer - auto extract_result = backend->extract(result_tensor, params.output); - if (!extract_result) return std::unexpected(extract_result.error()); - - // Store intermediate result for potential use by other nodes - std::vector result_data(params.output.begin(), params.output.end()); - intermediate_tensors_[node.id] = std::move(result_data); - - return {}; - - } else { - return std::unexpected(Error{ErrorCode::UnknownError, "Unknown operation type"}); - } - }, node.params); -} - -Result ComputeGraph::resolve_tensor_input(const TensorInput& input) { - return std::visit([this](const auto& value) -> Result { - using ValueType = std::decay_t; - - if constexpr (std::is_same_v) { - // Immediate tensor value - return value; - - } else if constexpr (std::is_same_v) { - // Symbolic reference - need to get from intermediate results - auto it = intermediate_tensors_.find(value.node_id()); - if (it == intermediate_tensors_.end()) { - return std::unexpected(Error{ErrorCode::ComputeError, - "Symbolic tensor result not found - dependency not satisfied"}); - } - - // For now, we need a backend to create a tensor from the stored data - // This is a limitation of the current design - we should improve this - return std::unexpected(Error{ErrorCode::ComputeError, - "Symbolic tensor resolution not fully implemented yet"}); - - } else { - return std::unexpected(Error{ErrorCode::UnknownError, "Unknown tensor input type"}); - } - }, input); -} - -Result ComputeGraph::resolve_scalar_input(const ScalarInput& input) { - return std::visit([this](const auto& value) -> Result { - using ValueType = std::decay_t; - - if constexpr (std::is_same_v) { - // Immediate scalar value - return value; - - } else if constexpr (std::is_same_v) { - // Symbolic reference - need to get from intermediate results - auto it = intermediate_scalars_.find(value.node_id()); - if (it == intermediate_scalars_.end()) { - return std::unexpected(Error{ErrorCode::ComputeError, - "Symbolic scalar result not found - dependency not satisfied"}); - } - - if (it->second.empty()) { - return std::unexpected(Error{ErrorCode::ComputeError, "Scalar result is empty"}); - } - - return it->second[0]; - - } else { - return std::unexpected(Error{ErrorCode::UnknownError, "Unknown scalar input type"}); - } - }, input); -} - -} // namespace compute \ No newline at end of file diff --git a/compute/src/compute/core/graph.h b/compute/src/compute/core/graph.h deleted file mode 100644 index 6264f8c..0000000 --- a/compute/src/compute/core/graph.h +++ /dev/null @@ -1,209 +0,0 @@ -#pragma once - -#include "compute_types.h" -#include "compute_backend.h" -#include -#include -#include -#include - -namespace compute { - -// Operation types for the computation graph -enum class OpType { - DotProduct, - MatrixScalarAdd -}; - -// Input types for operations (can be immediate values or symbolic references) -using ScalarInput = std::variant; -using TensorInput = std::variant; - -// Parameters for different operations -struct DotProductParams { - TensorInput input_a; - TensorInput input_b; - std::span output; // Where to write the scalar result -}; - -struct MatrixScalarAddParams { - TensorInput input_tensor; - ScalarInput scalar; - std::span output; // Where to write the result data - std::vector output_shape; -}; - -using OpParams = std::variant; - -// Represents a single operation in the computation graph -struct GraphNode { - NodeId id; - OpType type; - OpParams params; - std::vector dependencies; // Which nodes must execute before this one - - GraphNode(NodeId node_id, OpType op_type, OpParams op_params) - : id(node_id), type(op_type), params(std::move(op_params)) {} -}; - -// Forward declaration for fluent interface -class ComputeGraphBuilder; - -// Execution result of a computation graph -class ComputeResult { -public: - ComputeResult(Result> data, std::vector shape) - : data_(std::move(data)), shape_(std::move(shape)) {} - - // Get the result data - Result> data() const { - if (!data_) return std::unexpected(data_.error()); - return std::span(data_->data(), data_->size()); - } - - // Get result shape - const std::vector& shape() const { return shape_; } - - // Convenience for scalar results (dot product) - Result scalar() const { - if (!data_) return std::unexpected(data_.error()); - if (data_->size() != 1) { - return std::unexpected(Error{ErrorCode::InvalidInput, "Result is not a scalar"}); - } - return (*data_)[0]; - } - -private: - Result> data_; - std::vector shape_; -}; - -// The computation graph that can be executed -class ComputeGraph { -public: - explicit ComputeGraph(BackendType backend_type = BackendType::Auto) - : backend_type_(backend_type), next_node_id_(0) {} - - // Execute the computation graph with dependency resolution - Result execute(); - - // For building the graph (used by ComputeGraphBuilder) - NodeId add_node(GraphNode node); - - // Get node information - const GraphNode* get_node(NodeId id) const; - - // Get next node ID (for ComputeGraphBuilder) - NodeId get_next_node_id() const { return next_node_id_; } - -private: - std::vector nodes_; - BackendType backend_type_; - NodeId next_node_id_; - - // Intermediate results storage during execution - std::unordered_map> intermediate_scalars_; - std::unordered_map> intermediate_tensors_; - - Result get_backend(); - std::vector topological_sort() const; - Result execute_node(const GraphNode& node, ComputeBackend* backend); - - // Helper methods for resolving inputs - Result resolve_tensor_input(const TensorInput& input); - Result resolve_scalar_input(const ScalarInput& input); -}; - -// Builder for fluently constructing computation graphs -class ComputeGraphBuilder { -public: - explicit ComputeGraphBuilder(BackendType backend = BackendType::Auto) - : graph_(std::make_unique(backend)) {} - - // Dot product operation - returns symbolic scalar - SymbolicScalar dot_product(TensorInput a, TensorInput b, std::span output) { - DotProductParams params{std::move(a), std::move(b), output}; - GraphNode node(graph_->get_next_node_id(), OpType::DotProduct, std::move(params)); - - // Add dependencies based on inputs - add_dependencies(node); - - NodeId node_id = graph_->add_node(std::move(node)); - return SymbolicScalar(node_id); - } - - // Convenience overload for immediate tensor values - SymbolicScalar dot_product(Tensor a, Tensor b, std::span output) { - return dot_product(TensorInput{a}, TensorInput{b}, output); - } - - // Matrix scalar addition operation - writes result to output buffer - void matrix_scalar_add(TensorInput tensor, ScalarInput scalar, std::span output, std::vector output_shape) { - MatrixScalarAddParams params{std::move(tensor), std::move(scalar), output, std::move(output_shape)}; - GraphNode node(graph_->get_next_node_id(), OpType::MatrixScalarAdd, std::move(params)); - - // Add dependencies based on inputs - add_dependencies(node); - - graph_->add_node(std::move(node)); - } - - // Convenience overloads for immediate values - void matrix_scalar_add(Tensor tensor, float scalar, std::span output, std::vector output_shape) { - matrix_scalar_add(TensorInput{tensor}, ScalarInput{scalar}, output, std::move(output_shape)); - } - - void matrix_scalar_add(Tensor tensor, SymbolicScalar scalar, std::span output, std::vector output_shape) { - matrix_scalar_add(TensorInput{tensor}, ScalarInput{scalar}, output, std::move(output_shape)); - } - - // Execute the built graph - Result execute() { - return graph_->execute(); - } - - // Get the built graph (transfers ownership) - std::unique_ptr build() { - return std::move(graph_); - } - -private: - std::unique_ptr graph_; - - // Helper to add dependencies based on symbolic inputs - void add_dependencies(GraphNode& node) { - // Extract dependencies from OpParams - std::visit([&node, this](const auto& params) { - using ParamType = std::decay_t; - - if constexpr (std::is_same_v) { - add_tensor_dependency(node, params.input_a); - add_tensor_dependency(node, params.input_b); - } else if constexpr (std::is_same_v) { - add_tensor_dependency(node, params.input_tensor); - add_scalar_dependency(node, params.scalar); - } - }, node.params); - } - - void add_tensor_dependency(GraphNode& node, const TensorInput& input) { - if (std::holds_alternative(input)) { - const auto& symbolic = std::get(input); - node.dependencies.push_back(symbolic.node_id()); - } - } - - void add_scalar_dependency(GraphNode& node, const ScalarInput& input) { - if (std::holds_alternative(input)) { - const auto& symbolic = std::get(input); - node.dependencies.push_back(symbolic.node_id()); - } - } -}; - -// Convenience factory functions -inline ComputeGraphBuilder graph(BackendType backend = BackendType::Auto) { - return ComputeGraphBuilder(backend); -} - -} // namespace compute diff --git a/compute/src/compute/core/tensor.cpp b/compute/src/compute/core/tensor.cpp deleted file mode 100644 index cd816bd..0000000 --- a/compute/src/compute/core/tensor.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "tensor.h" - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) -#include "../backends/mlx/mlx_buffer.h" - -namespace compute { - -mx::array& Tensor::to_mlx() { - return static_cast(buffer_.get())->mlx_array(); -} - -const mx::array& Tensor::to_mlx() const { - return static_cast(buffer_.get())->mlx_array(); -} - -} // namespace compute - -#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) \ No newline at end of file diff --git a/compute/src/compute/core/tensor.h b/compute/src/compute/core/tensor.h deleted file mode 100644 index 25592cc..0000000 --- a/compute/src/compute/core/tensor.h +++ /dev/null @@ -1,164 +0,0 @@ -#pragma once - -#include "compute_types.h" -#include -#include -#include - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) -#include -namespace mx = mlx::core; -#endif - -namespace compute { - -// Forward declarations -class BackendBuffer; -class ComputeBackend; - -/** - * Backend-managed memory buffer abstraction - * Each backend implements this to manage memory in their optimal format - */ -class BackendBuffer { -public: - virtual ~BackendBuffer() = default; - - /** - * Get raw pointer to data (may trigger computation for lazy backends) - * @return Pointer to buffer data - */ - virtual void* get_data() = 0; - - /** - * Get buffer size in bytes - */ - virtual size_t get_size() const = 0; - - /** - * Get backend type that owns this buffer - */ - virtual BackendType get_backend_type() const = 0; - - /** - * Force evaluation of any pending computations (for lazy backends) - */ - virtual void evaluate() = 0; -}; - -/** - * Unified tensor that lives in backend-native format - * Eliminates TensorView abstraction - tensors are created directly in backend format - */ -class Tensor { -private: - std::shared_ptr buffer_; - std::vector shape_; - -public: - /** - * Construct tensor with backend buffer and shape - */ - Tensor(std::shared_ptr buffer, std::vector shape) - : buffer_(std::move(buffer)), shape_(std::move(shape)) {} - - // Move constructor and assignment - Tensor(Tensor&&) = default; - Tensor& operator=(Tensor&&) = default; - - // Copy constructor and assignment - Tensor(const Tensor&) = default; - Tensor& operator=(const Tensor&) = default; - - /** - * Get typed pointer to tensor data - * May trigger computation for lazy backends - */ - template - T* data() { - buffer_->evaluate(); // Ensure computation is complete - return static_cast(buffer_->get_data()); - } - - template - const T* data() const { - const_cast(buffer_.get())->evaluate(); - return static_cast(buffer_->get_data()); - } - - /** - * Get raw pointer (convenience for float tensors) - */ - float* data_f32() { return data(); } - const float* data_f32() const { return data(); } - - /** - * Get backend type - */ - BackendType backend_type() const { - return buffer_->get_backend_type(); - } - - /** - * Get tensor shape - */ - const std::vector& shape() const { return shape_; } - - /** - * Get total number of elements - */ - size_t size() const { - size_t total = 1; - for (size_t dim : shape_) { - total *= dim; - } - return total; - } - - /** - * Get size in bytes (assumes float32) - */ - size_t byte_size() const { - return size() * sizeof(float); - } - - // Convenience shape accessors - bool is_scalar() const { return shape_.empty() || (shape_.size() == 1 && shape_[0] == 1); } - bool is_vector() const { return shape_.size() == 1; } - bool is_matrix() const { return shape_.size() == 2; } - - size_t length() const { - return is_vector() ? shape_[0] : 0; - } - - size_t rows() const { - return is_matrix() ? shape_[0] : 0; - } - - size_t cols() const { - return is_matrix() ? shape_[1] : 0; - } - - /** - * Get underlying buffer (for backend-specific operations) - */ - std::shared_ptr buffer() const { return buffer_; } - - /** - * Force evaluation of any pending computations - */ - void evaluate() const { - buffer_->evaluate(); - } - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - /** - * Get MLX array for MLX backend tensors - * Throws if tensor is not from MLX backend - */ - mx::array& to_mlx(); - const mx::array& to_mlx() const; -#endif -}; - -} // namespace compute \ No newline at end of file diff --git a/compute/src/compute/model/gemma_model.cpp b/compute/src/compute/model/gemma_model.cpp deleted file mode 100644 index e63c128..0000000 --- a/compute/src/compute/model/gemma_model.cpp +++ /dev/null @@ -1,497 +0,0 @@ -#include "gemma_model.h" -#include "model_loader.h" -#include "sampler.h" -#include "../core/compute_backend.h" -#include -#include -#include -#include - -namespace compute { - -// ── Private constructor ─────────────────────────────────────────────────────── - -GemmaModel::GemmaModel( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights, - ComputeBackend* backend) - : GemmaModelBase(std::move(config), std::move(tokenizer), std::move(weights)) - , backend_(backend) -{} - -// ── Factory ─────────────────────────────────────────────────────────────────── - -Result GemmaModel::from_model_dir( - const std::filesystem::path& model_dir, - ComputeBackend* backend) -{ - if (!backend) { - return std::unexpected(Error{ErrorCode::InvalidArgument, "Backend cannot be null"}); - } - - auto model_result = ModelLoader::load_model(model_dir, backend); - if (!model_result) return std::unexpected(model_result.error()); - - auto& [config, weights] = *model_result; - - auto tokenizer_result = SimpleBpeTokenizer::from_model_dir(model_dir); - if (!tokenizer_result) return std::unexpected(tokenizer_result.error()); - - if (!config.is_valid()) { - return std::unexpected(Error{ErrorCode::InvalidModel, "Invalid model configuration"}); - } - if (!config.is_gemma_architecture()) { - return std::unexpected(Error{ErrorCode::InvalidModel, - "Not a Gemma architecture: " + config.model_type}); - } - - // Multimodal models (Gemma3ForConditionalGeneration) prefix all text weights - // with "language_model." — strip it so GemmaModel sees standard "model." keys. - { - const std::string prefix = "language_model."; - const bool has_prefixed = weights.count(prefix + "model.embed_tokens.weight") > 0; - const bool has_plain = weights.count("model.embed_tokens.weight") > 0; - if (has_prefixed && !has_plain) { - std::unordered_map remapped; - remapped.reserve(weights.size()); - for (auto& [key, val] : weights) { - if (key.size() > prefix.size() && - key.compare(0, prefix.size(), prefix) == 0) { - remapped.emplace(key.substr(prefix.size()), std::move(val)); - } else { - remapped.emplace(key, std::move(val)); - } - } - weights = std::move(remapped); - } - } - - // Gemma always stops on (ID 106) regardless of config. - // Some model variants (e.g. gemma-3-1b-it-4bit) only list eos_token_id=1 - // in config.json and omit 106, but the chat template uses - // as the actual conversation boundary token. - { - constexpr int kEndOfTurn = 106; - if (!config.eos_token_ids.has_value()) { - config.eos_token_ids = std::vector{config.primary_eos_token_id(), kEndOfTurn}; - } else { - auto& ids = *config.eos_token_ids; - if (std::find(ids.begin(), ids.end(), kEndOfTurn) == ids.end()) { - ids.push_back(kEndOfTurn); - } - } - } - - return GemmaModel( - std::move(config), - std::move(*tokenizer_result), - std::move(weights), - backend); -} - -// ── Linear projection (quantized or unquantized) ───────────────────────────── - -Result GemmaModel::linear(const Tensor& input, const std::string& weight_key) { - auto w = get_weight(weight_key + ".weight"); - if (!w) return std::unexpected(w.error()); - - auto s_it = weights_.find(weight_key + ".scales"); - if (s_it != weights_.end()) { - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto b_it = weights_.find(weight_key + ".biases"); - const Tensor* biases = (b_it != weights_.end()) ? &b_it->second : nullptr; - return backend_->quantized_matmul(input, *w, s_it->second, biases, true, gs, bits); - } - - // Unquantized: weight stored [out, in] — transpose before matmul - auto w_t = backend_->swapaxes(*w, 0, 1); - if (!w_t) return std::unexpected(w_t.error()); - return backend_->matmul(input, *w_t); -} - -// ── Embedding ───────────────────────────────────────────────────────────────── - -Result GemmaModel::embedding(const std::vector& token_ids) { - if (token_ids.empty()) { - return std::unexpected(Error{ErrorCode::InvalidInput, "Empty token_ids"}); - } - - // Dequantize and cache the embedding table on first use. - if (!dequantized_embed_tokens_.has_value()) { - auto w = get_weight("model.embed_tokens.weight"); - if (!w) return std::unexpected(w.error()); - - auto scales = get_weight("model.embed_tokens.scales"); - if (scales) { - auto biases = get_weight("model.embed_tokens.biases"); - if (!biases) return std::unexpected(biases.error()); - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto deq = backend_->dequantize(*w, *scales, *biases, gs, bits); - if (!deq) return std::unexpected(deq.error()); - dequantized_embed_tokens_ = std::move(*deq); - } else { - dequantized_embed_tokens_ = std::move(*w); - } - } - const auto& embed_weight = *dequantized_embed_tokens_; - - if (embed_weight.shape().size() != 2) { - return std::unexpected(Error{ErrorCode::InvalidModel, - "Embedding weight must be 2D"}); - } - - const size_t vocab_size = embed_weight.shape()[0]; - const size_t hidden_size = embed_weight.shape()[1]; - - for (int id : token_ids) { - if (id < 0 || static_cast(id) >= vocab_size) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Token ID " + std::to_string(id) + " out of range [0, " + - std::to_string(vocab_size) + ")"}); - } - } - - std::vector rows; - rows.reserve(token_ids.size()); - for (int id : token_ids) { - auto row = backend_->slice(embed_weight, id, id + 1, 0); - if (!row) return std::unexpected(row.error()); - auto vec = backend_->reshape(*row, {1, hidden_size}); - if (!vec) return std::unexpected(vec.error()); - rows.push_back(*vec); - } - - auto embedded = backend_->concatenate(rows, 0); - if (!embedded) return std::unexpected(embedded.error()); - - // Gemma-specific: scale embeddings by sqrt(hidden_size) - const float embed_scale = std::sqrt(static_cast(hidden_size)); - auto scale_tensor = backend_->create_tensor( - std::span(&embed_scale, 1), {1}); - auto scaled = backend_->multiply(*embedded, scale_tensor); - if (!scaled) return std::unexpected(scaled.error()); - - return *scaled; -} - -// ── RMSNorm helper ──────────────────────────────────────────────────────────── -// -// Gemma uses (1 + weight) as the effective scale factor — NOT just weight. -// The stored weights are deviations from 0 (HuggingFace initializes them to -// zeros; mlx-lm applies `mx.fast.rms_norm(x, 1.0 + self.weight, eps)`). -// Our backend's rms_norm computes: x / rms(x) * weight. -// So we pass (1 + weight) to get the correct Gemma semantics. - -Result GemmaModel::rms_norm(const Tensor& input, const Tensor& weight) { - auto effective_weight = backend_->matrix_scalar_add(weight, 1.0f); - return backend_->rms_norm(input, effective_weight, config_.rms_norm_eps); -} - -// ── Attention layer ─────────────────────────────────────────────────────────── - -Result GemmaModel::attention_layer( - const Tensor& input, - int layer_idx, - int position_offset, - GemmaLayerKVCache* cache) -{ - if (input.shape().size() != 2) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Attention input must be 2D [seq_len, hidden_size]"}); - } - - const size_t seq_len = input.shape()[0]; - const size_t n_heads = config_.num_attention_heads; - const size_t n_kv_heads = config_.num_key_value_heads; - const size_t head_dim = config_.effective_head_dim(); - const float scale = config_.effective_attention_scale(); - - // RoPE theta: local layers use rope_local_base_freq, global layers use rope_theta - const bool is_local = config_.is_local_layer(layer_idx); - const float rope_theta = (is_local && config_.rope_local_base_freq.has_value()) - ? *config_.rope_local_base_freq - : config_.rope_theta; - - const std::string prefix = "model.layers." + std::to_string(layer_idx) + ".self_attn."; - - // ── QKV projections ─────────────────────────────────────────────────────── - auto queries_flat = linear(input, prefix + "q_proj"); - if (!queries_flat) return std::unexpected(queries_flat.error()); - - auto keys_flat = linear(input, prefix + "k_proj"); - if (!keys_flat) return std::unexpected(keys_flat.error()); - - auto values_flat = linear(input, prefix + "v_proj"); - if (!values_flat) return std::unexpected(values_flat.error()); - - // ── Reshape → [heads, seq, head_dim] ────────────────────────────────────── - auto q3 = backend_->reshape(*queries_flat, {seq_len, n_heads, head_dim}); - if (!q3) return std::unexpected(q3.error()); - auto qt = backend_->swapaxes(*q3, 0, 1); - if (!qt) return std::unexpected(qt.error()); - - auto k3 = backend_->reshape(*keys_flat, {seq_len, n_kv_heads, head_dim}); - if (!k3) return std::unexpected(k3.error()); - auto kt = backend_->swapaxes(*k3, 0, 1); - if (!kt) return std::unexpected(kt.error()); - - auto v3 = backend_->reshape(*values_flat, {seq_len, n_kv_heads, head_dim}); - if (!v3) return std::unexpected(v3.error()); - auto vt = backend_->swapaxes(*v3, 0, 1); - if (!vt) return std::unexpected(vt.error()); - - // ── Q/K norm (Gemma3-specific, before RoPE) ─────────────────────────────── - // Weights are [head_dim]; MLX rms_norm broadcasts over leading dims. - auto q_norm_w = get_weight(prefix + "q_norm.weight"); - if (q_norm_w) { - auto qn = rms_norm(*qt, *q_norm_w); - if (!qn) return std::unexpected(qn.error()); - qt = std::move(qn); - } - - auto k_norm_w = get_weight(prefix + "k_norm.weight"); - if (k_norm_w) { - auto kn = rms_norm(*kt, *k_norm_w); - if (!kn) return std::unexpected(kn.error()); - kt = std::move(kn); - } - - // ── RoPE ────────────────────────────────────────────────────────────────── - auto q_rope = backend_->rope(*qt, static_cast(head_dim), rope_theta, position_offset); - if (!q_rope) return std::unexpected(q_rope.error()); - auto k_rope = backend_->rope(*kt, static_cast(head_dim), rope_theta, position_offset); - if (!k_rope) return std::unexpected(k_rope.error()); - - // ── KV cache update ─────────────────────────────────────────────────────── - Tensor full_k = *k_rope; - Tensor full_v = *vt; - std::string attn_mask = "causal"; - - if (cache) { - if (!cache->valid) { - cache->keys = *k_rope; - cache->values = *vt; - cache->valid = true; - } else { - auto cat_k = backend_->concatenate({*cache->keys, *k_rope}, 1); - if (!cat_k) return std::unexpected(cat_k.error()); - auto cat_v = backend_->concatenate({*cache->values, *vt}, 1); - if (!cat_v) return std::unexpected(cat_v.error()); - cache->keys = *cat_k; - cache->values = *cat_v; - full_k = *cat_k; - full_v = *cat_v; - } - if (seq_len == 1) attn_mask = ""; - } - - // ── SDPA ────────────────────────────────────────────────────────────────── - auto attn_out = backend_->scaled_dot_product_attention( - *q_rope, full_k, full_v, scale, attn_mask); - if (!attn_out) return std::unexpected(attn_out.error()); - - // ── Reshape output → [seq, n_heads * head_dim] ──────────────────────────── - auto attn_t = backend_->swapaxes(*attn_out, 0, 1); - if (!attn_t) return std::unexpected(attn_t.error()); - auto attn_flat = backend_->reshape(*attn_t, {seq_len, n_heads * head_dim}); - if (!attn_flat) return std::unexpected(attn_flat.error()); - - // ── Output projection ───────────────────────────────────────────────────── - return linear(*attn_flat, prefix + "o_proj"); -} - -// ── MLP layer (GeGLU: gelu(gate) * up, projected by down) ──────────────────── - -Result GemmaModel::mlp_layer(const Tensor& input, int layer_idx) { - const std::string prefix = "model.layers." + std::to_string(layer_idx) + ".mlp."; - - auto gate = linear(input, prefix + "gate_proj"); - if (!gate) return std::unexpected(gate.error()); - - auto up = linear(input, prefix + "up_proj"); - if (!up) return std::unexpected(up.error()); - - // GeGLU: gelu(gate) * up - auto activated = backend_->gelu(*gate); - if (!activated) return std::unexpected(activated.error()); - auto hidden = backend_->multiply(*activated, *up); - if (!hidden) return std::unexpected(hidden.error()); - - return linear(*hidden, prefix + "down_proj"); -} - -// ── Transformer block (Gemma3: 4 norms, post-norm residuals) ───────────────── -// -// Forward pass per layer: -// residual = x -// x = input_layernorm(x) -// x = attention(x) -// x = post_attention_layernorm(x) ← applied to attn OUTPUT -// x = residual + x -// residual = x -// x = pre_feedforward_layernorm(x) -// x = mlp(x) -// x = post_feedforward_layernorm(x) ← applied to FFN OUTPUT -// x = residual + x - -Result GemmaModel::transformer_block( - const Tensor& input, - int layer_idx, - int position_offset, - GemmaLayerKVCache* cache) -{ - const std::string prefix = "model.layers." + std::to_string(layer_idx) + "."; - - // ── Attention sub-block ─────────────────────────────────────────────────── - auto input_norm_w = get_weight(prefix + "input_layernorm.weight"); - if (!input_norm_w) return std::unexpected(input_norm_w.error()); - auto normed_in = rms_norm(input, *input_norm_w); - if (!normed_in) return std::unexpected(normed_in.error()); - - auto attn_out = attention_layer(*normed_in, layer_idx, position_offset, cache); - if (!attn_out) return std::unexpected(attn_out.error()); - - auto post_attn_norm_w = get_weight(prefix + "post_attention_layernorm.weight"); - if (!post_attn_norm_w) return std::unexpected(post_attn_norm_w.error()); - auto normed_attn = rms_norm(*attn_out, *post_attn_norm_w); - if (!normed_attn) return std::unexpected(normed_attn.error()); - - auto residual1 = backend_->add(input, *normed_attn); - if (!residual1) return std::unexpected(residual1.error()); - - // ── FFN sub-block ───────────────────────────────────────────────────────── - auto pre_ffn_norm_w = get_weight(prefix + "pre_feedforward_layernorm.weight"); - if (!pre_ffn_norm_w) return std::unexpected(pre_ffn_norm_w.error()); - auto normed_ffn_in = rms_norm(*residual1, *pre_ffn_norm_w); - if (!normed_ffn_in) return std::unexpected(normed_ffn_in.error()); - - auto ffn_out = mlp_layer(*normed_ffn_in, layer_idx); - if (!ffn_out) return std::unexpected(ffn_out.error()); - - auto post_ffn_norm_w = get_weight(prefix + "post_feedforward_layernorm.weight"); - if (!post_ffn_norm_w) return std::unexpected(post_ffn_norm_w.error()); - auto normed_ffn = rms_norm(*ffn_out, *post_ffn_norm_w); - if (!normed_ffn) return std::unexpected(normed_ffn.error()); - - return backend_->add(*residual1, *normed_ffn); -} - -// ── Internal forward ───────────────────────────────────────────────────────── - -Result> GemmaModel::forward_impl( - const std::vector& input_ids, - int position_offset, - std::vector* cache_vec) -{ - auto hidden = embedding(input_ids); - if (!hidden) return std::unexpected(hidden.error()); - - for (int i = 0; i < static_cast(config_.num_hidden_layers); ++i) { - GemmaLayerKVCache* layer_cache = cache_vec ? &(*cache_vec)[i] : nullptr; - auto block_out = transformer_block(*hidden, i, position_offset, layer_cache); - if (!block_out) return std::unexpected(block_out.error()); - hidden = std::move(block_out); - } - - // Final layer norm - auto norm_w = get_weight("model.norm.weight"); - if (!norm_w) return std::unexpected(norm_w.error()); - auto normed = rms_norm(*hidden, *norm_w); - if (!normed) return std::unexpected(normed.error()); - - // LM head — Gemma3 1B has a separate lm_head; multimodal variants (4B+) - // tie word embeddings (no lm_head.weight, reuse embed_tokens transposed). - Result logits_tensor{std::unexpected(Error{ErrorCode::TensorNotFound, "unset"})}; - if (weights_.count("lm_head.weight")) { - logits_tensor = linear(*normed, "lm_head"); - } else if (config_.tie_word_embeddings) { - if (!dequantized_embed_tokens_.has_value()) { - // Trigger embedding table dequantization without a real token - auto ew = get_weight("model.embed_tokens.weight"); - if (!ew) return std::unexpected(ew.error()); - auto es = get_weight("model.embed_tokens.scales"); - if (es) { - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto eb = get_weight("model.embed_tokens.biases"); - if (!eb) return std::unexpected(eb.error()); - auto deq = backend_->dequantize(*ew, *es, *eb, gs, bits); - if (!deq) return std::unexpected(deq.error()); - dequantized_embed_tokens_ = std::move(*deq); - } else { - dequantized_embed_tokens_ = std::move(*ew); - } - } - // normed: [seq, hidden] × embed_tokens^T: [hidden, vocab] → [seq, vocab] - auto embed_t = backend_->swapaxes(*dequantized_embed_tokens_, 0, 1); - if (!embed_t) return std::unexpected(embed_t.error()); - logits_tensor = backend_->matmul(*normed, *embed_t); - } else { - return std::unexpected(Error{ErrorCode::TensorNotFound, - "lm_head.weight not found and tie_word_embeddings is false"}); - } - if (!logits_tensor) return std::unexpected(logits_tensor.error()); - - const size_t seq_len = input_ids.size(); - const size_t vocab_size = config_.vocab_size; - - // Extract last-token logits - auto last = backend_->slice(*logits_tensor, - static_cast(seq_len - 1), - static_cast(seq_len), 0); - if (!last) return std::unexpected(last.error()); - auto flat = backend_->reshape(*last, {vocab_size}); - if (!flat) return std::unexpected(flat.error()); - - std::vector result(vocab_size); - auto extract = backend_->extract(*flat, result); - if (!extract) return std::unexpected(extract.error()); - return result; -} - -// ── LanguageModel interface ─────────────────────────────────────────────────── - -Result> GemmaModel::generate( - const std::vector& input_ids, - size_t max_new_tokens, - SamplingParams params, - std::function on_token) -{ - return GenerateHelper::run( - input_ids, max_new_tokens, params, on_token, config_, - [this](const std::vector& ids) { return prefill(ids); }, - [this](int tok) { return decode(tok); }); -} - -// ── KV-cache steps ──────────────────────────────────────────────────────────── - -void GemmaModel::reset_cache() { - kv_cache_.clear(); - cache_position_ = 0; -} - -Result> GemmaModel::prefill(const std::vector& prompt_ids) { - if (prompt_ids.empty()) { - return std::unexpected(Error{ErrorCode::InvalidInput, "prefill: prompt_ids cannot be empty"}); - } - reset_cache(); - kv_cache_.assign(config_.num_hidden_layers, GemmaLayerKVCache{}); - auto result = forward_impl(prompt_ids, 0, &kv_cache_); - if (result) cache_position_ = prompt_ids.size(); - return result; -} - -Result> GemmaModel::decode(int token_id) { - if (cache_position_ == 0) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "decode: must call prefill() before decode()"}); - } - auto result = forward_impl({token_id}, static_cast(cache_position_), &kv_cache_); - if (result) ++cache_position_; - return result; -} - -} // namespace compute diff --git a/compute/src/compute/model/gemma_model.h b/compute/src/compute/model/gemma_model.h deleted file mode 100644 index a5378ab..0000000 --- a/compute/src/compute/model/gemma_model.h +++ /dev/null @@ -1,90 +0,0 @@ -#pragma once - -#include "language_model.h" -#include "gemma_model_base.h" -#include "kv_cache.h" -#include -#include - -namespace compute { - -/** - * ComputeBackend-path implementation of Gemma/Gemma2/Gemma3. - * - * Handles model_type: "gemma", "gemma2", "gemma3_text" - * On Apple Silicon the factory dispatches to GemmaModelMLX instead. - * - * Key differences from LlamaModel: - * - Embedding scale: multiply by sqrt(hidden_size) after lookup - * - 4 norms per block: input_layernorm, post_attention_layernorm (on attn out), - * pre_feedforward_layernorm, post_feedforward_layernorm (on FFN out) - * - Q/K norm inside each attention layer (Gemma3) - * - GeGLU FFN: down(gelu(gate) * up) instead of down(silu(gate) * up) - * - Layer-specific RoPE theta (local vs global layers in Gemma3) - */ -class GemmaModel final : public GemmaModelBase, public LanguageModel { -public: - static Result from_model_dir( - const std::filesystem::path& model_dir, - ComputeBackend* backend); - - // ── LanguageModel interface ─────────────────────────────────────────────── - - Result> generate( - const std::vector& input_ids, - size_t max_new_tokens = 4096, - SamplingParams params = {}, - std::function on_token = nullptr) override; - - const ModelConfig& config() const override { return config_; } - const std::string& model_type() const override { return config_.model_type; } - const SimpleBpeTokenizer& tokenizer() const override { return tokenizer_; } - size_t num_parameters() const override { return GemmaModelBase::num_parameters(); } - -private: - GemmaModel( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights, - ComputeBackend* backend); - - // ── KV-cache steps (private — used by generate() only) ─────────────────── - - Result> prefill(const std::vector& prompt_ids); - Result> decode(int token_id); - void reset_cache(); - - // ── Layer implementations ───────────────────────────────────────────────── - - Result linear(const Tensor& input, const std::string& weight_key); - Result embedding(const std::vector& token_ids); - Result rms_norm(const Tensor& input, const Tensor& weight); - - Result attention_layer( - const Tensor& input, - int layer_idx, - int position_offset, - GemmaLayerKVCache* cache); - - Result mlp_layer(const Tensor& input, int layer_idx); - - Result transformer_block( - const Tensor& input, - int layer_idx, - int position_offset, - GemmaLayerKVCache* cache); - - Result> forward_impl( - const std::vector& input_ids, - int position_offset, - std::vector* cache_vec); - - // ── State ───────────────────────────────────────────────────────────────── - - ComputeBackend* backend_; - std::vector kv_cache_; - size_t cache_position_ = 0; - mutable std::optional dequantized_embed_tokens_; -}; - -} // namespace compute diff --git a/compute/src/compute/model/gemma_model_base.cpp b/compute/src/compute/model/gemma_model_base.cpp deleted file mode 100644 index 6cc391b..0000000 --- a/compute/src/compute/model/gemma_model_base.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "gemma_model_base.h" - -namespace compute { - -GemmaModelBase::GemmaModelBase( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights) - : config_(std::move(config)) - , tokenizer_(std::move(tokenizer)) - , weights_(std::move(weights)) -{} - -size_t GemmaModelBase::num_parameters() const { - size_t total = 0; - for (const auto& [name, tensor] : weights_) total += tensor.size(); - return total; -} - -Result GemmaModelBase::get_weight(const std::string& name) const { - auto it = weights_.find(name); - if (it == weights_.end()) - return std::unexpected(Error{ErrorCode::TensorNotFound, "Weight not found: " + name}); - return it->second; -} - -} // namespace compute diff --git a/compute/src/compute/model/gemma_model_base.h b/compute/src/compute/model/gemma_model_base.h deleted file mode 100644 index b019665..0000000 --- a/compute/src/compute/model/gemma_model_base.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include "model_config.h" -#include "simple_bpe_tokenizer.h" -#include "../core/tensor.h" -#include "../core/compute_types.h" -#include -#include - -namespace compute { - -class GemmaModelBase { -public: - size_t num_parameters() const; - -protected: - GemmaModelBase( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights); - - Result get_weight(const std::string& name) const; - - ModelConfig config_; - SimpleBpeTokenizer tokenizer_; - std::unordered_map weights_; -}; - -} // namespace compute diff --git a/compute/src/compute/model/kv_cache.h b/compute/src/compute/model/kv_cache.h deleted file mode 100644 index d84559f..0000000 --- a/compute/src/compute/model/kv_cache.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include "../core/tensor.h" -#include - -namespace compute { - -/** - * Per-layer KV cache for efficient autoregressive decoding. - * Keys/values are stored post-RoPE so they can be directly concatenated - * with new tokens during decode without re-applying positional encoding. - * - * Shape when valid: - * keys: [n_kv_heads, seq_so_far, head_dim] - * values: [n_kv_heads, seq_so_far, head_dim] - */ -struct LayerKVCache { - std::optional keys; - std::optional values; - bool valid = false; -}; - -/** - * Per-layer KV cache for Gemma (same shape convention as LayerKVCache). - * - * Gemma3 has local (sliding-window) and global layers. - * For the first implementation we store all keys/values without - * window truncation (correct for short sequences, conservative for long ones). - */ -struct GemmaLayerKVCache { - std::optional keys; - std::optional values; - bool valid = false; -}; - -} // namespace compute diff --git a/compute/src/compute/model/llama_model.cpp b/compute/src/compute/model/llama_model.cpp index fe0789f..662d162 100644 --- a/compute/src/compute/model/llama_model.cpp +++ b/compute/src/compute/model/llama_model.cpp @@ -1,9 +1,6 @@ #include "llama_model.h" #include "model_loader.h" #include "sampler.h" -#include "../core/compute_backend.h" -#include -#include #include #include @@ -18,13 +15,11 @@ namespace compute { // ── Private constructor ─────────────────────────────────────────────────────── LlamaModel::LlamaModel( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights, - ComputeBackend* backend) + ModelConfig config, + SimpleBpeTokenizer tokenizer, + ComputeBackend* backend) : config_(std::move(config)) , tokenizer_(std::move(tokenizer)) - , weights_(std::move(weights)) , backend_(backend) , tool_family_(detect_tool_family(tokenizer_, config_)) #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) @@ -76,408 +71,18 @@ Result LlamaModel::from_model_dir( }(); mx::eval(embed_mat); - LlamaModel model(std::move(config), std::move(*tokenizer_result), {}, backend); + LlamaModel model(std::move(config), std::move(*tokenizer_result), backend); model.mlx_setup(std::move(mlx_weights), std::move(embed_mat), context_size); return model; - #else - if (!backend) - return std::unexpected(Error{ErrorCode::InvalidArgument, "Backend cannot be null"}); - - auto model_result = ModelLoader::load_model(model_dir, backend); - if (!model_result) return std::unexpected(model_result.error()); - - auto& [config, weights] = *model_result; - - auto tokenizer_result = SimpleBpeTokenizer::from_model_dir(model_dir); - if (!tokenizer_result) return std::unexpected(tokenizer_result.error()); - - if (!config.is_valid()) - return std::unexpected(Error{ErrorCode::InvalidModel, "Invalid model configuration"}); - if (!config.is_supported_architecture()) - return std::unexpected(Error{ErrorCode::InvalidModel, - "Unsupported model architecture: " + config.model_type}); - - return LlamaModel(std::move(config), std::move(*tokenizer_result), std::move(weights), backend); + (void)model_dir; (void)backend; (void)context_size; + return std::unexpected(Error{ErrorCode::BackendNotAvailable, "LlamaModel requires MLX backend"}); #endif } -// ── Weight lookup ───────────────────────────────────────────────────────────── - -Result LlamaModel::get_weight(const std::string& name) const { - auto it = weights_.find(name); - if (it == weights_.end()) { - return std::unexpected(Error{ErrorCode::TensorNotFound, "Weight not found: " + name}); - } - return it->second; -} - -// ── Linear projection (quantized or unquantized) ───────────────────────────── - -Result LlamaModel::linear(const Tensor& input, const std::string& weight_key) { - auto w = get_weight(weight_key + ".weight"); - if (!w) return std::unexpected(w.error()); - - auto s_it = weights_.find(weight_key + ".scales"); - if (s_it != weights_.end()) { - // Quantized path (mlx-community int4/int8): use quantized_matmul - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto b_it = weights_.find(weight_key + ".biases"); - const Tensor* biases = (b_it != weights_.end()) ? &b_it->second : nullptr; - return backend_->quantized_matmul(input, *w, s_it->second, biases, true, gs, bits); - } - - // Unquantized path (fp16/bf16): weight is stored [out, in] — transpose before matmul - auto w_t = backend_->swapaxes(*w, 0, 1); - if (!w_t) return std::unexpected(w_t.error()); - return backend_->matmul(input, *w_t); -} - -// ── Embedding ───────────────────────────────────────────────────────────────── - -Result LlamaModel::embedding(const std::vector& token_ids) { - if (token_ids.empty()) { - return std::unexpected(Error{ErrorCode::InvalidInput, "Empty token_ids"}); - } - - // Dequantize and cache the embedding table on first use. - if (!dequantized_embed_tokens_.has_value()) { - auto w = get_weight("model.embed_tokens.weight"); - if (!w) return std::unexpected(w.error()); - - auto scales = get_weight("model.embed_tokens.scales"); - if (scales) { - auto biases = get_weight("model.embed_tokens.biases"); - if (!biases) return std::unexpected(biases.error()); - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto deq = backend_->dequantize(*w, *scales, *biases, gs, bits); - if (!deq) return std::unexpected(deq.error()); - dequantized_embed_tokens_ = std::move(*deq); - } else { - dequantized_embed_tokens_ = std::move(*w); - } - } - const auto& embed_weight = *dequantized_embed_tokens_; - - if (embed_weight.shape().size() != 2) { - return std::unexpected(Error{ErrorCode::InvalidModel, - "Embedding weight must be 2D"}); - } - - const size_t vocab_size = embed_weight.shape()[0]; - const size_t hidden_size = embed_weight.shape()[1]; - - for (int id : token_ids) { - if (id < 0 || static_cast(id) >= vocab_size) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Token ID " + std::to_string(id) + " out of range [0, " + - std::to_string(vocab_size) + ")"}); - } - } - - std::vector rows; - rows.reserve(token_ids.size()); - for (int id : token_ids) { - auto row = backend_->slice(embed_weight, id, id + 1, 0); - if (!row) return std::unexpected(row.error()); - auto vec = backend_->reshape(*row, {1, hidden_size}); - if (!vec) return std::unexpected(vec.error()); - rows.push_back(*vec); - } - - auto result = backend_->concatenate(rows, 0); - if (!result) return std::unexpected(result.error()); - - const auto& shape = result->shape(); - if (shape.size() != 2 || shape[0] != token_ids.size() || shape[1] != hidden_size) { - return std::unexpected(Error{ErrorCode::ComputeError, "Embedding shape mismatch"}); - } - return *result; -} - -// ── RMSNorm ─────────────────────────────────────────────────────────────────── - -Result LlamaModel::rms_norm(const Tensor& input, const Tensor& weight, float eps) { - return backend_->rms_norm(input, weight, eps); -} - -// ── Attention layer ─────────────────────────────────────────────────────────── - -// Public no-cache shim used by tests -Result LlamaModel::attention_layer(const Tensor& input, int layer_idx) { - return attention_layer(input, layer_idx, 0, nullptr); -} - -Result LlamaModel::attention_layer( - const Tensor& input, - int layer_idx, - int position_offset, - LayerKVCache* cache) -{ - if (input.shape().size() != 2) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Attention input must be 2D [seq_len, hidden_size]"}); - } - - const size_t seq_len = input.shape()[0]; - const size_t hidden_size = input.shape()[1]; - const size_t n_heads = config_.num_attention_heads; - const size_t n_kv_heads = config_.num_key_value_heads; - const size_t head_dim = hidden_size / n_heads; - const float scale = 1.0f / std::sqrt(static_cast(head_dim)); - - const std::string prefix = "model.layers." + std::to_string(layer_idx) + ".self_attn."; - - // ── QKV projections ─────────────────────────────────────────────────────── - // linear() dispatches to quantized_matmul or matmul based on whether - // {proj}.scales exists — works for both int4 and fp16/bf16 weights. - // Separately, Qwen2 adds a learned attention bias (*q_proj.bias*) after the - // projection. We always probe for it — absent in Llama/Mistral, so no-op there. - auto queries_flat = linear(input, prefix + "q_proj"); - if (!queries_flat) return std::unexpected(queries_flat.error()); - { - auto it = weights_.find(prefix + "q_proj.bias"); - if (it != weights_.end()) { - queries_flat = backend_->add(*queries_flat, it->second); - if (!queries_flat) return std::unexpected(queries_flat.error()); - } - } - - auto keys_flat = linear(input, prefix + "k_proj"); - if (!keys_flat) return std::unexpected(keys_flat.error()); - { - auto it = weights_.find(prefix + "k_proj.bias"); - if (it != weights_.end()) { - keys_flat = backend_->add(*keys_flat, it->second); - if (!keys_flat) return std::unexpected(keys_flat.error()); - } - } - - auto values_flat = linear(input, prefix + "v_proj"); - if (!values_flat) return std::unexpected(values_flat.error()); - { - auto it = weights_.find(prefix + "v_proj.bias"); - if (it != weights_.end()) { - values_flat = backend_->add(*values_flat, it->second); - if (!values_flat) return std::unexpected(values_flat.error()); - } - } - - // ── Reshape → [heads, seq, head_dim] ────────────────────────────────────── - auto q3 = backend_->reshape(*queries_flat, {seq_len, n_heads, head_dim}); - if (!q3) return std::unexpected(q3.error()); - auto qt = backend_->swapaxes(*q3, 0, 1); - if (!qt) return std::unexpected(qt.error()); - - auto k3 = backend_->reshape(*keys_flat, {seq_len, n_kv_heads, head_dim}); - if (!k3) return std::unexpected(k3.error()); - auto kt = backend_->swapaxes(*k3, 0, 1); - if (!kt) return std::unexpected(kt.error()); - - auto v3 = backend_->reshape(*values_flat, {seq_len, n_kv_heads, head_dim}); - if (!v3) return std::unexpected(v3.error()); - auto vt = backend_->swapaxes(*v3, 0, 1); - if (!vt) return std::unexpected(vt.error()); - - // ── Optional per-head QK normalization (Qwen3) ─────────────────────────── - // Qwen3 adds learned RMSNorm per head-dimension before RoPE. Weights are - // absent in Qwen2/Llama/Mistral, so probing is a no-op for those families. - { - auto it = weights_.find(prefix + "q_norm.weight"); - if (it != weights_.end()) { - auto flat = backend_->reshape(*qt, {n_heads * seq_len, head_dim}); - if (!flat) return std::unexpected(flat.error()); - auto normed = backend_->rms_norm(*flat, it->second, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - qt = backend_->reshape(*normed, {n_heads, seq_len, head_dim}); - if (!qt) return std::unexpected(qt.error()); - } - } - { - auto it = weights_.find(prefix + "k_norm.weight"); - if (it != weights_.end()) { - auto flat = backend_->reshape(*kt, {n_kv_heads * seq_len, head_dim}); - if (!flat) return std::unexpected(flat.error()); - auto normed = backend_->rms_norm(*flat, it->second, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - kt = backend_->reshape(*normed, {n_kv_heads, seq_len, head_dim}); - if (!kt) return std::unexpected(kt.error()); - } - } - - // ── RoPE ────────────────────────────────────────────────────────────────── - auto q_rope = backend_->rope(*qt, head_dim, config_.rope_theta, position_offset); - if (!q_rope) return std::unexpected(q_rope.error()); - auto k_rope = backend_->rope(*kt, head_dim, config_.rope_theta, position_offset); - if (!k_rope) return std::unexpected(k_rope.error()); - - // ── KV cache update ─────────────────────────────────────────────────────── - Tensor full_k = *k_rope; - Tensor full_v = *vt; - std::string attn_mask = "causal"; - - if (cache) { - if (!cache->valid) { - cache->keys = *k_rope; - cache->values = *vt; - cache->valid = true; - } else { - auto cat_k = backend_->concatenate({*cache->keys, *k_rope}, 1); - if (!cat_k) return std::unexpected(cat_k.error()); - auto cat_v = backend_->concatenate({*cache->values, *vt}, 1); - if (!cat_v) return std::unexpected(cat_v.error()); - cache->keys = *cat_k; - cache->values = *cat_v; - full_k = *cat_k; - full_v = *cat_v; - } - if (seq_len == 1) attn_mask = ""; - } - - // ── SDPA ────────────────────────────────────────────────────────────────── - auto attn_out = backend_->scaled_dot_product_attention(*q_rope, full_k, full_v, scale, attn_mask); - if (!attn_out) return std::unexpected(attn_out.error()); - - // ── Reshape output → [seq, hidden_size] ─────────────────────────────────── - auto attn_t = backend_->swapaxes(*attn_out, 0, 1); - if (!attn_t) return std::unexpected(attn_t.error()); - auto attn_flat = backend_->reshape(*attn_t, {seq_len, hidden_size}); - if (!attn_flat) return std::unexpected(attn_flat.error()); - - // ── Output projection ───────────────────────────────────────────────────── - return linear(*attn_flat, prefix + "o_proj"); -} - -// ── MLP layer ───────────────────────────────────────────────────────────────── - -Result LlamaModel::mlp_layer(const Tensor& input, int layer_idx) { - const std::string prefix = "model.layers." + std::to_string(layer_idx) + ".mlp."; - - auto gate = linear(input, prefix + "gate_proj"); - if (!gate) return std::unexpected(gate.error()); - - auto up = linear(input, prefix + "up_proj"); - if (!up) return std::unexpected(up.error()); - - auto activated = backend_->silu(*gate); - if (!activated) return std::unexpected(activated.error()); - auto hidden = backend_->multiply(*activated, *up); - if (!hidden) return std::unexpected(hidden.error()); - - return linear(*hidden, prefix + "down_proj"); -} - -// ── Transformer block ───────────────────────────────────────────────────────── - -Result LlamaModel::transformer_block( - const Tensor& input, int layer_idx, - int position_offset, LayerKVCache* cache) -{ - const std::string prefix = "model.layers." + std::to_string(layer_idx) + "."; - - auto pre_norm_w = get_weight(prefix + "input_layernorm.weight"); - if (!pre_norm_w) return std::unexpected(pre_norm_w.error()); - auto normed = rms_norm(input, *pre_norm_w, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - - auto attn_out = attention_layer(*normed, layer_idx, position_offset, cache); - if (!attn_out) return std::unexpected(attn_out.error()); - - auto residual1 = backend_->add(input, *attn_out); - if (!residual1) return std::unexpected(residual1.error()); - - auto post_norm_w = get_weight(prefix + "post_attention_layernorm.weight"); - if (!post_norm_w) return std::unexpected(post_norm_w.error()); - auto normed2 = rms_norm(*residual1, *post_norm_w, config_.rms_norm_eps); - if (!normed2) return std::unexpected(normed2.error()); - - auto mlp_out = mlp_layer(*normed2, layer_idx); - if (!mlp_out) return std::unexpected(mlp_out.error()); - - return backend_->add(*residual1, *mlp_out); -} - -// ── Internal forward ───────────────────────────────────────────────────────── - -Result> LlamaModel::forward_impl( - const std::vector& input_ids, - int position_offset, - std::vector* cache_vec) -{ - auto hidden = embedding(input_ids); - if (!hidden) return std::unexpected(hidden.error()); - - for (int i = 0; i < static_cast(config_.num_hidden_layers); ++i) { - LayerKVCache* layer_cache = cache_vec ? &(*cache_vec)[i] : nullptr; - auto block_out = transformer_block(*hidden, i, position_offset, layer_cache); - if (!block_out) return std::unexpected(block_out.error()); - hidden = std::move(block_out); - } - - auto norm_w = get_weight("model.norm.weight"); - if (!norm_w) return std::unexpected(norm_w.error()); - auto normed = rms_norm(*hidden, *norm_w, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - - // ── LM head projection ─────────────────────────────────────────────────── - // When tie_word_embeddings=true (e.g. Llama-3), there is no separate lm_head — - // the model reuses the embedding table transposed. Fall back to the dequantized - // embed_tokens weight and an unquantized matmul in that case. - Result logits_tensor{std::unexpected(Error{ErrorCode::TensorNotFound, "unset"})}; - if (weights_.count("lm_head.weight")) { - logits_tensor = linear(*normed, "lm_head"); - } else if (config_.tie_word_embeddings) { - // Ensure embedding table is dequantized (embedding() is normally called first, - // but forward_impl may be reached via forward_logits which doesn't go through it). - if (!dequantized_embed_tokens_.has_value()) { - auto ew = get_weight("model.embed_tokens.weight"); - if (!ew) return std::unexpected(ew.error()); - auto es = get_weight("model.embed_tokens.scales"); - if (es) { - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto eb = get_weight("model.embed_tokens.biases"); - if (!eb) return std::unexpected(eb.error()); - auto deq = backend_->dequantize(*ew, *es, *eb, gs, bits); - if (!deq) return std::unexpected(deq.error()); - dequantized_embed_tokens_ = std::move(*deq); - } else { - dequantized_embed_tokens_ = std::move(*ew); - } - } - // normed: [seq, hidden] × embed_tokens^T: [hidden, vocab] → [seq, vocab] - auto embed_t = backend_->swapaxes(*dequantized_embed_tokens_, 0, 1); - if (!embed_t) return std::unexpected(embed_t.error()); - logits_tensor = backend_->matmul(*normed, *embed_t); - } else { - return std::unexpected(Error{ErrorCode::TensorNotFound, - "lm_head.weight not found and tie_word_embeddings is false"}); - } - if (!logits_tensor) return std::unexpected(logits_tensor.error()); - - const size_t seq_len = input_ids.size(); - const size_t vocab_size = config_.vocab_size; - - auto last = backend_->slice(*logits_tensor, - static_cast(seq_len - 1), - static_cast(seq_len), 0); - if (!last) return std::unexpected(last.error()); - auto flat = backend_->reshape(*last, {vocab_size}); - if (!flat) return std::unexpected(flat.error()); - - std::vector result(vocab_size); - auto extract = backend_->extract(*flat, result); - if (!extract) return std::unexpected(extract.error()); - return result; -} - // ── KV-cache public API ─────────────────────────────────────────────────────── void LlamaModel::reset_cache() { - kv_cache_.clear(); - cache_position_ = 0; #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) mlx_state_.reset(); mlx_pos_ = 0; @@ -511,91 +116,25 @@ Result> LlamaModel::generate( Result> LlamaModel::prefill(const std::vector& prompt_ids) { if (prompt_ids.empty()) return std::unexpected(Error{ErrorCode::InvalidInput, "prefill: prompt_ids cannot be empty"}); - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) reset_cache(); - mlx_state_.emplace(); // create state struct; KV filled by prefill_batch - mlx_build_decode_fn(); // build decode lambda for subsequent decode() calls + mlx_state_.emplace(); + mlx_build_decode_fn(); return mlx_prefill_batch(prompt_ids); -#else - reset_cache(); - kv_cache_.assign(config_.num_hidden_layers, LayerKVCache{}); - auto result = forward_impl(prompt_ids, 0, &kv_cache_); - if (result) cache_position_ = prompt_ids.size(); - return result; -#endif } Result> LlamaModel::decode(int token_id) { -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) if (mlx_pos_ == 0) return std::unexpected(Error{ErrorCode::InvalidInput, "decode: must call prefill() before decode()"}); return mlx_run_step(token_id); -#else - if (cache_position_ == 0) - return std::unexpected(Error{ErrorCode::InvalidInput, - "decode: must call prefill() before decode()"}); - auto result = forward_impl({token_id}, static_cast(cache_position_), &kv_cache_); - if (result) ++cache_position_; - return result; -#endif -} - -// ── No-cache forward (test interface) ──────────────────────────────────────── - -Result> LlamaModel::forward(const std::vector& input_ids) { - return forward_impl(input_ids, 0, nullptr); -} - -Result LlamaModel::forward_logits(const std::vector& input_ids) { - auto hidden = embedding(input_ids); - if (!hidden) return std::unexpected(hidden.error()); - - for (int i = 0; i < static_cast(config_.num_hidden_layers); ++i) { - auto block_out = transformer_block(*hidden, i, 0, nullptr); - if (!block_out) return std::unexpected(block_out.error()); - hidden = std::move(block_out); - } - - auto norm_w = get_weight("model.norm.weight"); - if (!norm_w) return std::unexpected(norm_w.error()); - auto normed = rms_norm(*hidden, *norm_w, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - - if (weights_.count("lm_head.weight")) { - return linear(*normed, "lm_head"); - } else if (config_.tie_word_embeddings) { - if (!dequantized_embed_tokens_.has_value()) { - auto ew = get_weight("model.embed_tokens.weight"); - if (!ew) return std::unexpected(ew.error()); - auto es = get_weight("model.embed_tokens.scales"); - if (es) { - const int gs = config_.quantization ? config_.quantization->group_size : 64; - const int bits = config_.quantization ? config_.quantization->bits : 4; - auto eb = get_weight("model.embed_tokens.biases"); - if (!eb) return std::unexpected(eb.error()); - auto deq = backend_->dequantize(*ew, *es, *eb, gs, bits); - if (!deq) return std::unexpected(deq.error()); - dequantized_embed_tokens_ = std::move(*deq); - } else { - dequantized_embed_tokens_ = std::move(*ew); - } - } - auto embed_t = backend_->swapaxes(*dequantized_embed_tokens_, 0, 1); - if (!embed_t) return std::unexpected(embed_t.error()); - return backend_->matmul(*normed, *embed_t); - } - return std::unexpected(Error{ErrorCode::TensorNotFound, - "lm_head.weight not found and tie_word_embeddings is false"}); } // ── Metadata ────────────────────────────────────────────────────────────────── size_t LlamaModel::num_parameters() const { size_t total = 0; - for (const auto& [name, tensor] : weights_) - total += tensor.size(); + for (const auto& [name, w] : mlx_weights_) + total += static_cast(w.size()); return total; } diff --git a/compute/src/compute/model/qwen3_moe_model.cpp b/compute/src/compute/model/qwen3_moe_model.cpp deleted file mode 100644 index 861dcc3..0000000 --- a/compute/src/compute/model/qwen3_moe_model.cpp +++ /dev/null @@ -1,901 +0,0 @@ -#include "qwen3_moe_model.h" -#include "model_loader.h" -#include "sampler.h" -#include "simple_bpe_tokenizer.h" -#include "../core/compute_backend.h" -#include -#include -#include - -namespace compute { - -// ── Factory ────────────────────────────────────────────────────────────────── - -Result Qwen3MoeModel::from_model_dir( - const std::filesystem::path& model_dir, - ComputeBackend* backend) -{ - auto model_result = ModelLoader::load_model(model_dir, backend); - if (!model_result) return std::unexpected(model_result.error()); - - auto& [config, weights] = *model_result; - - auto tokenizer_result = SimpleBpeTokenizer::from_model_dir(model_dir); - if (!tokenizer_result) return std::unexpected(tokenizer_result.error()); - - return Qwen3MoeModel( - std::move(config), - std::move(*tokenizer_result), - std::move(weights), - backend); -} - -Qwen3MoeModel::Qwen3MoeModel( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights, - ComputeBackend* backend) - : Qwen3MoeModelBase(std::move(config), std::move(tokenizer), std::move(weights)) - , backend_(backend) -{} - -// ── LanguageModel interface ─────────────────────────────────────────────────── - -Result> Qwen3MoeModel::generate( - const std::vector& input_ids, - size_t max_new_tokens, - SamplingParams params, - std::function on_token) -{ - return GenerateHelper::run( - input_ids, max_new_tokens, params, on_token, config_, - [this](const std::vector& ids) { return prefill(ids); }, - [this](int tok) { return decode(tok); }); -} - -// ── KV-cache steps ──────────────────────────────────────────────────────────── - -Result> Qwen3MoeModel::prefill(const std::vector& prompt_ids) { - if (prompt_ids.empty()) - return std::unexpected(Error{ErrorCode::InvalidInput, "prefill: prompt_ids cannot be empty"}); - reset_cache(); - kv_cache_.assign(config_.num_hidden_layers, LayerKVCache{}); - ssm_cache_.assign(config_.num_hidden_layers, SsmState{}); - auto result = forward_impl(prompt_ids, 0, &kv_cache_); - if (result) cache_position_ = prompt_ids.size(); - return result; -} - -Result> Qwen3MoeModel::decode(int token_id) { - if (cache_position_ == 0) - return std::unexpected(Error{ErrorCode::InvalidInput, "decode: must call prefill() first"}); - - auto result = forward_impl({token_id}, static_cast(cache_position_), &kv_cache_); - if (result) ++cache_position_; - return result; -} - -void Qwen3MoeModel::reset_cache() { - kv_cache_.clear(); - ssm_cache_.clear(); - cache_position_ = 0; -} - -// ── Embedding ───────────────────────────────────────────────────────────────── - -Result Qwen3MoeModel::embedding(const std::vector& token_ids) { - if (token_ids.empty()) - return std::unexpected(Error{ErrorCode::InvalidInput, "Empty token_ids"}); - - if (!dequantized_embed_tokens_.has_value()) { - auto w = get_weight("language_model.model.embed_tokens.weight"); - if (!w) return std::unexpected(w.error()); - - auto s_it = weights_.find("language_model.model.embed_tokens.scales"); - if (s_it != weights_.end()) { - auto b_it = weights_.find("language_model.model.embed_tokens.biases"); - if (b_it == weights_.end()) - return std::unexpected(Error{ErrorCode::InvalidModel, "embed_tokens.biases missing"}); - int gs = config_.quantization ? config_.quantization->group_size : 64; - int bits = infer_quant_bits(*w, s_it->second); - auto deq = backend_->dequantize(*w, s_it->second, b_it->second, gs, bits); - if (!deq) return std::unexpected(deq.error()); - dequantized_embed_tokens_ = std::move(*deq); - } else { - dequantized_embed_tokens_ = std::move(*w); - } - } - - const auto& ew = *dequantized_embed_tokens_; - const size_t vocab_size = ew.shape()[0]; - const size_t hidden_size = ew.shape()[1]; - - std::vector rows; - rows.reserve(token_ids.size()); - for (int id : token_ids) { - if (id < 0 || static_cast(id) >= vocab_size) - return std::unexpected(Error{ErrorCode::InvalidInput, - "Token ID " + std::to_string(id) + " out of range"}); - auto row = backend_->slice(ew, id, id + 1, 0); - if (!row) return std::unexpected(row.error()); - auto vec = backend_->reshape(*row, {1, hidden_size}); - if (!vec) return std::unexpected(vec.error()); - rows.push_back(*vec); - } - - return backend_->concatenate(rows, 0); -} - -// ── Linear projection ───────────────────────────────────────────────────────── - -Result Qwen3MoeModel::linear(const Tensor& input, const std::string& weight_key) { - auto w = get_weight(weight_key + ".weight"); - if (!w) return std::unexpected(w.error()); - - auto s_it = weights_.find(weight_key + ".scales"); - if (s_it != weights_.end()) { - int gs = config_.quantization ? config_.quantization->group_size : 64; - int bits = infer_quant_bits(*w, s_it->second); - auto b_it = weights_.find(weight_key + ".biases"); - const Tensor* biases = (b_it != weights_.end()) ? &b_it->second : nullptr; - return backend_->quantized_matmul(input, *w, s_it->second, biases, true, gs, bits, "affine"); - } - - auto w_t = backend_->swapaxes(*w, 0, 1); - if (!w_t) return std::unexpected(w_t.error()); - return backend_->matmul(input, *w_t); -} - -// ── Expert linear (slice from 3-D weight bank) ─────────────────────────────── - -Result Qwen3MoeModel::expert_linear( - const Tensor& input, const std::string& weight_key, int expert_idx) -{ - auto w3d = get_weight(weight_key + ".weight"); - if (!w3d) return std::unexpected(w3d.error()); - - // Slice expert e: [E, out, in_packed] → [1, out, in_packed] → [out, in_packed] - auto w_e3 = backend_->slice(*w3d, expert_idx, expert_idx + 1, 0); - if (!w_e3) return std::unexpected(w_e3.error()); - auto w_e = backend_->reshape(*w_e3, {w3d->shape()[1], w3d->shape()[2]}); - if (!w_e) return std::unexpected(w_e.error()); - - auto s_it = weights_.find(weight_key + ".scales"); - if (s_it != weights_.end()) { - auto s_e3 = backend_->slice(s_it->second, expert_idx, expert_idx + 1, 0); - if (!s_e3) return std::unexpected(s_e3.error()); - auto s_e = backend_->reshape(*s_e3, - {s_it->second.shape()[1], s_it->second.shape()[2]}); - if (!s_e) return std::unexpected(s_e.error()); - - int gs = config_.quantization ? config_.quantization->group_size : 64; - int bits = infer_quant_bits(*w_e, *s_e); - - auto b_it = weights_.find(weight_key + ".biases"); - if (b_it != weights_.end()) { - auto b_e3 = backend_->slice(b_it->second, expert_idx, expert_idx + 1, 0); - if (!b_e3) return std::unexpected(b_e3.error()); - auto b_e = backend_->reshape(*b_e3, - {b_it->second.shape()[1], b_it->second.shape()[2]}); - if (!b_e) return std::unexpected(b_e.error()); - return backend_->quantized_matmul(input, *w_e, *s_e, &*b_e, true, gs, bits, "affine"); - } - return backend_->quantized_matmul(input, *w_e, *s_e, nullptr, true, gs, bits, "affine"); - } - - auto w_t = backend_->swapaxes(*w_e, 0, 1); - if (!w_t) return std::unexpected(w_t.error()); - return backend_->matmul(input, *w_t); -} - -// ── MoE MLP ─────────────────────────────────────────────────────────────────── - -Result Qwen3MoeModel::moe_mlp(const Tensor& input, int layer_idx) { - const std::string pfx = "language_model.model.layers." + std::to_string(layer_idx) + ".mlp."; - const std::string sw = pfx + "switch_mlp."; - - size_t num_experts = config_.num_experts.value_or(256); - size_t top_k = config_.num_experts_per_tok.value_or(8); - size_t seq = input.shape()[0]; - size_t hidden = input.shape()[1]; - int gs = config_.quantization ? config_.quantization->group_size : 64; - - // Router gate logits [seq, num_experts] — computed once, shared by both paths - auto gate_logits = linear(input, pfx + "gate"); - if (!gate_logits) return std::unexpected(gate_logits.error()); - - // ── Switch expert computation ───────────────────────────────────────────── - auto switch_out = [&]() -> Result { - if (seq == 1) { - // ── GPU-side decode routing (zero CPU-GPU syncs) ────────────────────── - // All routing and expert computation stays lazy on GPU. - // topk_indices on 1-D logit vector → no extract() needed. - - auto logits_1d = backend_->reshape(*gate_logits, {num_experts}); - if (!logits_1d) return std::unexpected(logits_1d.error()); - - // Top-k expert indices as GPU tensor [top_k] - auto topk_idx = backend_->topk_indices(*logits_1d, (int)top_k, 0); - if (!topk_idx) return std::unexpected(topk_idx.error()); - - // Routing scores: softmax of top-k logits (= renormalized top-k probs) - auto topk_logits = backend_->take(*logits_1d, *topk_idx, 0); - if (!topk_logits) return std::unexpected(topk_logits.error()); - auto scores_1d = backend_->softmax(*topk_logits, -1); - if (!scores_1d) return std::unexpected(scores_1d.error()); - auto normed_scores = backend_->reshape(*scores_1d, {1, top_k}); // [1, top_k] - if (!normed_scores) return std::unexpected(normed_scores.error()); - - const size_t k = top_k; - - auto gw3 = get_weight(sw + "gate_proj.weight"); if (!gw3) return std::unexpected(gw3.error()); - auto gs3 = get_weight(sw + "gate_proj.scales"); if (!gs3) return std::unexpected(gs3.error()); - auto uw3 = get_weight(sw + "up_proj.weight"); if (!uw3) return std::unexpected(uw3.error()); - auto us3 = get_weight(sw + "up_proj.scales"); if (!us3) return std::unexpected(us3.error()); - auto dw3 = get_weight(sw + "down_proj.weight"); if (!dw3) return std::unexpected(dw3.error()); - auto ds3 = get_weight(sw + "down_proj.scales"); if (!ds3) return std::unexpected(ds3.error()); - - auto gb_it = weights_.find(sw + "gate_proj.biases"); - auto ub_it = weights_.find(sw + "up_proj.biases"); - auto db_it = weights_.find(sw + "down_proj.biases"); - - int bits = infer_quant_bits(*gw3, *gs3); - - // Gather k expert weight slices via GPU-tensor indices (one take per tensor) - auto gw_k = backend_->take(*gw3, *topk_idx, 0); if (!gw_k) return std::unexpected(gw_k.error()); - auto gs_k = backend_->take(*gs3, *topk_idx, 0); if (!gs_k) return std::unexpected(gs_k.error()); - auto uw_k = backend_->take(*uw3, *topk_idx, 0); if (!uw_k) return std::unexpected(uw_k.error()); - auto us_k = backend_->take(*us3, *topk_idx, 0); if (!us_k) return std::unexpected(us_k.error()); - auto dw_k = backend_->take(*dw3, *topk_idx, 0); if (!dw_k) return std::unexpected(dw_k.error()); - auto ds_k = backend_->take(*ds3, *topk_idx, 0); if (!ds_k) return std::unexpected(ds_k.error()); - - // Stack gate + up weights: [k, out, in_p] → [k*out, in_p] - const size_t out_e = gw3->shape()[1]; - const size_t in_p = gw3->shape()[2]; - const size_t gss = gs3->shape()[2]; - - auto gw2 = backend_->reshape(*gw_k, {k * out_e, in_p}); if (!gw2) return std::unexpected(gw2.error()); - auto gs2 = backend_->reshape(*gs_k, {k * out_e, gss}); if (!gs2) return std::unexpected(gs2.error()); - auto uw2 = backend_->reshape(*uw_k, {k * out_e, in_p}); if (!uw2) return std::unexpected(uw2.error()); - auto us2 = backend_->reshape(*us_k, {k * out_e, gss}); if (!us2) return std::unexpected(us2.error()); - - std::optional gb2, ub2; - if (gb_it != weights_.end()) { - auto gb_k = backend_->take(gb_it->second, *topk_idx, 0); if (!gb_k) return std::unexpected(gb_k.error()); - auto r = backend_->reshape(*gb_k, {k * out_e, gb_it->second.shape()[2]}); if (!r) return std::unexpected(r.error()); - gb2 = std::move(*r); - } - if (ub_it != weights_.end()) { - auto ub_k = backend_->take(ub_it->second, *topk_idx, 0); if (!ub_k) return std::unexpected(ub_k.error()); - auto r = backend_->reshape(*ub_k, {k * out_e, ub_it->second.shape()[2]}); if (!r) return std::unexpected(r.error()); - ub2 = std::move(*r); - } - - // Single quantized_matmul for all k gate projections: [1,hidden] × [k*out,in_p]^T → [1, k*out] - auto gate_k = backend_->quantized_matmul(input, *gw2, *gs2, gb2 ? &*gb2 : nullptr, true, gs, bits, "affine"); - if (!gate_k) return std::unexpected(gate_k.error()); - auto up_k = backend_->quantized_matmul(input, *uw2, *us2, ub2 ? &*ub2 : nullptr, true, gs, bits, "affine"); - if (!up_k) return std::unexpected(up_k.error()); - - // Reshape [1, k*out] → [k, out] for per-expert activation - auto gate_ke = backend_->reshape(*gate_k, {k, out_e}); if (!gate_ke) return std::unexpected(gate_ke.error()); - auto up_ke = backend_->reshape(*up_k, {k, out_e}); if (!up_ke) return std::unexpected(up_ke.error()); - - auto act = backend_->silu(*gate_ke); if (!act) return std::unexpected(act.error()); - auto h_ke = backend_->multiply(*act, *up_ke); if (!h_ke) return std::unexpected(h_ke.error()); // [k, out] - - // Down projection: k matmuls (each expert has different h_ke[i]). - // Weights already gathered contiguously in dw_k/ds_k → sequential slices = cache-friendly. - const size_t down_out = dw3->shape()[1]; - const size_t int_p = dw3->shape()[2]; - const size_t dss = ds3->shape()[2]; - - std::optional db_k; - if (db_it != weights_.end()) { - auto r = backend_->take(db_it->second, *topk_idx, 0); - if (!r) return std::unexpected(r.error()); - db_k = std::move(*r); - } - - std::vector down_outs; - down_outs.reserve(k); - for (size_t i = 0; i < k; ++i) { - auto h_i = backend_->slice(*h_ke, (int)i, (int)i + 1, 0); if (!h_i) return std::unexpected(h_i.error()); - - auto dw_i3 = backend_->slice(*dw_k, (int)i, (int)i + 1, 0); if (!dw_i3) return std::unexpected(dw_i3.error()); - auto dw_i = backend_->reshape(*dw_i3, {down_out, int_p}); if (!dw_i) return std::unexpected(dw_i.error()); - - auto dsi_3 = backend_->slice(*ds_k, (int)i, (int)i + 1, 0); if (!dsi_3) return std::unexpected(dsi_3.error()); - auto dsi = backend_->reshape(*dsi_3, {down_out, dss}); if (!dsi) return std::unexpected(dsi.error()); - - std::optional db_i; - if (db_k) { - auto dbi_3 = backend_->slice(*db_k, (int)i, (int)i + 1, 0); if (!dbi_3) return std::unexpected(dbi_3.error()); - auto dbi = backend_->reshape(*dbi_3, {down_out, db_it->second.shape()[2]}); if (!dbi) return std::unexpected(dbi.error()); - db_i = std::move(*dbi); - } - - auto out_i = backend_->quantized_matmul( - *h_i, *dw_i, *dsi, db_i ? &*db_i : nullptr, true, gs, bits, "affine"); - if (!out_i) return std::unexpected(out_i.error()); - down_outs.push_back(std::move(*out_i)); - } - - // Concatenate [k × [1, hidden]] → [k, hidden], then score-weighted sum via matmul - auto down_all = backend_->concatenate(down_outs, 0); // [k, hidden] - if (!down_all) return std::unexpected(down_all.error()); - - // normed_scores [1, top_k] @ down_all [top_k, hidden] → [1, hidden] (all GPU) - return backend_->matmul(*normed_scores, *down_all); - - } else { - // ── CPU-side prefill routing (seq > 1) ──────────────────────────────── - // Extract + CPU top-k is acceptable for prefill (one-time cost per prompt). - auto gate_probs = backend_->softmax(*gate_logits, -1); - if (!gate_probs) return std::unexpected(gate_probs.error()); - - std::vector probs_cpu(seq * num_experts); - { - auto flat = backend_->reshape(*gate_probs, {seq * num_experts}); - if (!flat) return std::unexpected(flat.error()); - auto ex = backend_->extract(*flat, probs_cpu); - if (!ex) return std::unexpected(ex.error()); - } - - struct ES { int idx; float score; }; - std::vector> selected(seq); - for (size_t t = 0; t < seq; ++t) { - float* row = probs_cpu.data() + t * num_experts; - std::vector cands(num_experts); - for (size_t e = 0; e < num_experts; ++e) cands[e] = {(int)e, row[e]}; - std::partial_sort(cands.begin(), cands.begin() + top_k, cands.end(), - [](const ES& a, const ES& b){ return a.score > b.score; }); - selected[t].assign(cands.begin(), cands.begin() + top_k); - float sum = 0.0f; - for (auto& es : selected[t]) sum += es.score; - if (sum > 1e-8f) - for (auto& es : selected[t]) es.score /= sum; - } - - // Sequential per-token expert loop - std::vector token_outs; - token_outs.reserve(seq); - for (size_t t = 0; t < seq; ++t) { - auto tok3 = backend_->slice(input, (int)t, (int)t + 1, 0); - if (!tok3) return std::unexpected(tok3.error()); - auto tok = backend_->reshape(*tok3, {1, hidden}); - if (!tok) return std::unexpected(tok.error()); - - std::optional acc; - for (auto& es : selected[t]) { - auto g_out = expert_linear(*tok, sw + "gate_proj", es.idx); - if (!g_out) return std::unexpected(g_out.error()); - auto u_out = expert_linear(*tok, sw + "up_proj", es.idx); - if (!u_out) return std::unexpected(u_out.error()); - auto act = backend_->silu(*g_out); - if (!act) return std::unexpected(act.error()); - auto h = backend_->multiply(*act, *u_out); - if (!h) return std::unexpected(h.error()); - auto out = expert_linear(*h, sw + "down_proj", es.idx); - if (!out) return std::unexpected(out.error()); - - Tensor score_t = backend_->create_tensor( - std::span(&es.score, 1), {1}); - auto scaled = backend_->multiply(*out, score_t); - if (!scaled) return std::unexpected(scaled.error()); - - if (!acc) { acc = std::move(*scaled); } - else { - auto added = backend_->add(*acc, *scaled); - if (!added) return std::unexpected(added.error()); - acc = std::move(*added); - } - } - if (!acc) { - std::vector z(hidden, 0.0f); - acc = backend_->create_tensor(z, {1, hidden}); - } - token_outs.push_back(std::move(*acc)); - } - - return (token_outs.size() == 1) - ? Result{token_outs[0]} - : backend_->concatenate(token_outs, 0); - } - }(); - - if (!switch_out) return std::unexpected(switch_out.error()); - - // Shared expert: standard SwiGLU MLP - auto se_g = linear(input, pfx + "shared_expert.gate_proj"); - if (!se_g) return std::unexpected(se_g.error()); - auto se_u = linear(input, pfx + "shared_expert.up_proj"); - if (!se_u) return std::unexpected(se_u.error()); - auto se_act = backend_->silu(*se_g); - if (!se_act) return std::unexpected(se_act.error()); - auto se_h = backend_->multiply(*se_act, *se_u); - if (!se_h) return std::unexpected(se_h.error()); - auto se_out = linear(*se_h, pfx + "shared_expert.down_proj"); - if (!se_out) return std::unexpected(se_out.error()); - - // Shared expert gate (scalar sigmoid): [seq, 1] * [seq, hidden] - auto seg = linear(input, pfx + "shared_expert_gate"); - if (!seg) return std::unexpected(seg.error()); - auto seg_sig = backend_->sigmoid(*seg); - if (!seg_sig) return std::unexpected(seg_sig.error()); - auto se_gated = backend_->multiply(*se_out, *seg_sig); - if (!se_gated) return std::unexpected(se_gated.error()); - - return backend_->add(*switch_out, *se_gated); -} - -// ── Full attention block ────────────────────────────────────────────────────── - -Result Qwen3MoeModel::full_attention_block( - const Tensor& input, int layer_idx, int position_offset, LayerKVCache* cache) -{ - const std::string pfx = "language_model.model.layers." + std::to_string(layer_idx) + ".self_attn."; - - const size_t seq = input.shape()[0]; - const size_t n_heads = config_.num_attention_heads; - const size_t n_kv = config_.num_key_value_heads; - const size_t head_dim = config_.effective_head_dim(); - const float scale = 1.0f / std::sqrt(static_cast(head_dim)); - - // Partial RoPE dims - float prf = config_.partial_rotary_factor.value_or(0.25f); - int rope_dims = static_cast(head_dim * prf); - - // q_proj: [seq, n_heads * head_dim * 2] (includes output gate) - auto q_raw = linear(input, pfx + "q_proj"); - if (!q_raw) return std::unexpected(q_raw.error()); - - // Split q_proj output into q and gate (along last axis, each head has head_dim*2) - // Reshape to [seq, n_heads, head_dim*2] then slice - auto q_4d = backend_->reshape(*q_raw, {seq, n_heads, head_dim * 2}); - if (!q_4d) return std::unexpected(q_4d.error()); - auto q_h = backend_->slice(*q_4d, 0, (int)head_dim, 2); - if (!q_h) return std::unexpected(q_h.error()); - auto gate_h = backend_->slice(*q_4d, (int)head_dim, (int)(head_dim * 2), 2); - if (!gate_h) return std::unexpected(gate_h.error()); - // gate: [seq, n_heads, head_dim] → flatten to [seq, n_heads*head_dim] - auto gate = backend_->reshape(*gate_h, {seq, n_heads * head_dim}); - if (!gate) return std::unexpected(gate.error()); - // q: [seq, n_heads, head_dim] → swapaxes → [n_heads, seq, head_dim] - auto qt = backend_->swapaxes(*q_h, 0, 1); - if (!qt) return std::unexpected(qt.error()); - - // k, v: standard projections - auto k_raw = linear(input, pfx + "k_proj"); - if (!k_raw) return std::unexpected(k_raw.error()); - auto v_raw = linear(input, pfx + "v_proj"); - if (!v_raw) return std::unexpected(v_raw.error()); - - auto k3 = backend_->reshape(*k_raw, {seq, n_kv, head_dim}); - if (!k3) return std::unexpected(k3.error()); - auto kt = backend_->swapaxes(*k3, 0, 1); - if (!kt) return std::unexpected(kt.error()); - - auto v3 = backend_->reshape(*v_raw, {seq, n_kv, head_dim}); - if (!v3) return std::unexpected(v3.error()); - auto vt = backend_->swapaxes(*v3, 0, 1); - if (!vt) return std::unexpected(vt.error()); - - // QK norms - { - auto w = weights_.find(pfx + "q_norm.weight"); - if (w != weights_.end()) { - auto f = backend_->reshape(*qt, {n_heads * seq, head_dim}); - if (!f) return std::unexpected(f.error()); - auto n = backend_->rms_norm(*f, w->second, config_.rms_norm_eps); - if (!n) return std::unexpected(n.error()); - qt = backend_->reshape(*n, {n_heads, seq, head_dim}); - if (!qt) return std::unexpected(qt.error()); - } - } - { - auto w = weights_.find(pfx + "k_norm.weight"); - if (w != weights_.end()) { - auto f = backend_->reshape(*kt, {n_kv * seq, head_dim}); - if (!f) return std::unexpected(f.error()); - auto n = backend_->rms_norm(*f, w->second, config_.rms_norm_eps); - if (!n) return std::unexpected(n.error()); - kt = backend_->reshape(*n, {n_kv, seq, head_dim}); - if (!kt) return std::unexpected(kt.error()); - } - } - - // RoPE (partial) - auto q_rope = backend_->rope(*qt, rope_dims, config_.rope_theta, position_offset); - if (!q_rope) return std::unexpected(q_rope.error()); - auto k_rope = backend_->rope(*kt, rope_dims, config_.rope_theta, position_offset); - if (!k_rope) return std::unexpected(k_rope.error()); - - // KV cache - Tensor full_k = *k_rope; - Tensor full_v = *vt; - std::string attn_mask = "causal"; - - if (cache) { - if (!cache->valid) { - cache->keys = *k_rope; - cache->values = *vt; - cache->valid = true; - } else { - auto cat_k = backend_->concatenate({*cache->keys, *k_rope}, 1); - if (!cat_k) return std::unexpected(cat_k.error()); - auto cat_v = backend_->concatenate({*cache->values, *vt}, 1); - if (!cat_v) return std::unexpected(cat_v.error()); - cache->keys = *cat_k; - cache->values = *cat_v; - full_k = *cat_k; - full_v = *cat_v; - } - if (seq == 1) attn_mask = ""; - } - - // SDPA - auto attn_out = backend_->scaled_dot_product_attention( - *q_rope, full_k, full_v, scale, attn_mask); - if (!attn_out) return std::unexpected(attn_out.error()); - - // Reshape output: [n_heads, seq, head_dim] → [seq, n_heads*head_dim] - auto attn_t = backend_->swapaxes(*attn_out, 0, 1); - if (!attn_t) return std::unexpected(attn_t.error()); - auto attn_flat = backend_->reshape(*attn_t, {seq, n_heads * head_dim}); - if (!attn_flat) return std::unexpected(attn_flat.error()); - - // Output gate: attn_flat * sigmoid(gate) - auto gate_sig = backend_->sigmoid(*gate); - if (!gate_sig) return std::unexpected(gate_sig.error()); - auto gated = backend_->multiply(*attn_flat, *gate_sig); - if (!gated) return std::unexpected(gated.error()); - - return linear(*gated, pfx + "o_proj"); -} - -// ── GatedDeltaNet SSM block ─────────────────────────────────────────────────── - -Result Qwen3MoeModel::linear_attention_block(const Tensor& input, int layer_idx) { - const std::string pfx = - "language_model.model.layers." + std::to_string(layer_idx) + ".linear_attn."; - - // SSM dimensions from config (with sensible defaults) - const size_t Hv = config_.linear_num_value_heads.value_or(32); - const size_t Dv = config_.linear_value_head_dim.value_or(128); - const size_t Hk = config_.linear_num_key_heads.value_or(Hv); - const size_t Dk = config_.linear_key_head_dim.value_or(128); - const size_t kernel_size = config_.linear_conv_kernel_dim.value_or(4); - const size_t key_dim = Hk * Dk; - const size_t value_dim = Hv * Dv; - const size_t conv_dim = key_dim * 2 + value_dim; - - const size_t seq = input.shape()[0]; - const size_t hidden = input.shape()[1]; - - // Linear projections - auto qkv = linear(input, pfx + "in_proj_qkv"); // [seq, conv_dim] - if (!qkv) return std::unexpected(qkv.error()); - auto z = linear(input, pfx + "in_proj_z"); // [seq, Hv*Dv] - if (!z) return std::unexpected(z.error()); - auto b = linear(input, pfx + "in_proj_b"); // [seq, Hv] - if (!b) return std::unexpected(b.error()); - auto a = linear(input, pfx + "in_proj_a"); // [seq, Hv] - if (!a) return std::unexpected(a.error()); - - // Model weights (plain floats) - auto A_log_w = get_weight(pfx + "A_log"); // [Hv] - if (!A_log_w) return std::unexpected(A_log_w.error()); - auto dt_bias_w = get_weight(pfx + "dt_bias"); // [Hv] - if (!dt_bias_w) return std::unexpected(dt_bias_w.error()); - auto conv1d_w = get_weight(pfx + "conv1d.weight"); // [conv_dim, kernel_size, 1] - if (!conv1d_w) return std::unexpected(conv1d_w.error()); - auto norm_w = get_weight(pfx + "norm.weight"); // [Dv] - if (!norm_w) return std::unexpected(norm_w.error()); - - // Initialize SSM state if needed - SsmState& state = ssm_cache_[layer_idx]; - if (!state.valid) { - std::vector czeros((kernel_size - 1) * conv_dim, 0.0f); - state.conv_state = backend_->create_tensor(czeros, {kernel_size - 1, conv_dim}); - std::vector rzeros(Hv * Dv * Dk, 0.0f); - state.rec_state = backend_->create_tensor(rzeros, {Hv, Dv, Dk}); - state.valid = true; - } - - // Causal conv: concat [conv_state, qkv] along time axis, then apply depthwise conv1d - // conv_state: [kernel-1, conv_dim], qkv: [seq, conv_dim] - auto conv_input = backend_->concatenate({*state.conv_state, *qkv}, 0); - if (!conv_input) return std::unexpected(conv_input.error()); - - // Save updated conv state (last kernel-1 rows of conv_input) - { - size_t total = kernel_size - 1 + seq; - auto ns = backend_->slice(*conv_input, (int)(total - (kernel_size - 1)), (int)total, 0); - if (!ns) return std::unexpected(ns.error()); - state.conv_state = *ns; - } - - // Depthwise conv1d then silu: output [seq, conv_dim] - auto conv_raw = backend_->conv1d(*conv_input, *conv1d_w, 1, 0, (int)conv_dim); - if (!conv_raw) return std::unexpected(conv_raw.error()); - auto conv_out = backend_->silu(*conv_raw); - if (!conv_out) return std::unexpected(conv_out.error()); - - // Split conv output into q, k, v along feature axis - auto q_flat = backend_->slice(*conv_out, 0, (int)key_dim, 1); - if (!q_flat) return std::unexpected(q_flat.error()); - auto k_flat = backend_->slice(*conv_out, (int)key_dim, (int)(2 * key_dim), 1); - if (!k_flat) return std::unexpected(k_flat.error()); - auto v_flat = backend_->slice(*conv_out, (int)(2*key_dim), (int)conv_dim, 1); - if (!v_flat) return std::unexpected(v_flat.error()); - - // Reshape to head format - auto q_3d = backend_->reshape(*q_flat, {seq, Hk, Dk}); - if (!q_3d) return std::unexpected(q_3d.error()); - auto k_3d = backend_->reshape(*k_flat, {seq, Hk, Dk}); - if (!k_3d) return std::unexpected(k_3d.error()); - auto v_3d = backend_->reshape(*v_flat, {seq, Hv, Dv}); - if (!v_3d) return std::unexpected(v_3d.error()); - auto z_3d = backend_->reshape(*z, {seq, Hv, Dv}); - if (!z_3d) return std::unexpected(z_3d.error()); - - // RMSNorm q and k (unit weight) then scale by 1/Dk and 1/sqrt(Dk) - std::vector ones_Dk_v(Dk, 1.0f); - Tensor ones_Dk = backend_->create_tensor(ones_Dk_v, {Dk}); - - { - auto f = backend_->reshape(*q_3d, {seq * Hk, Dk}); - if (!f) return std::unexpected(f.error()); - auto n = backend_->rms_norm(*f, ones_Dk, 1e-6f); - if (!n) return std::unexpected(n.error()); - q_3d = backend_->reshape(*n, {seq, Hk, Dk}); - if (!q_3d) return std::unexpected(q_3d.error()); - } - { - auto f = backend_->reshape(*k_3d, {seq * Hk, Dk}); - if (!f) return std::unexpected(f.error()); - auto n = backend_->rms_norm(*f, ones_Dk, 1e-6f); - if (!n) return std::unexpected(n.error()); - k_3d = backend_->reshape(*n, {seq, Hk, Dk}); - if (!k_3d) return std::unexpected(k_3d.error()); - } - - // Scale: q *= 1/Dk, k *= 1/sqrt(Dk) - float inv_scale = 1.0f / std::sqrt(static_cast(Dk)); - float q_scale_val = inv_scale * inv_scale; - Tensor q_scale_t = backend_->create_tensor( - std::span(&q_scale_val, 1), {1}); - Tensor k_scale_t = backend_->create_tensor( - std::span(&inv_scale, 1), {1}); - - auto q_sc = backend_->multiply(*q_3d, q_scale_t); - if (!q_sc) return std::unexpected(q_sc.error()); - auto k_sc = backend_->multiply(*k_3d, k_scale_t); - if (!k_sc) return std::unexpected(k_sc.error()); - - // Repeat q and k along head axis if Hv > Hk - if (Hv > Hk) { - int rf = static_cast(Hv / Hk); - auto qr = backend_->repeat(*q_sc, rf, 1); - if (!qr) return std::unexpected(qr.error()); - q_sc = qr; - auto kr = backend_->repeat(*k_sc, rf, 1); - if (!kr) return std::unexpected(kr.error()); - k_sc = kr; - } - // q_sc, k_sc: [seq, Hv, Dk] - - // Compute per-position decay (g) and input gate (beta) for all positions at once - // g = exp(-exp(A_log) * softplus(a + dt_bias)) → [seq, Hv] - // beta = sigmoid(b) → [seq, Hv] - auto a_biased = backend_->add(*a, *dt_bias_w); // [seq, Hv] + [Hv] → [seq, Hv] - if (!a_biased) return std::unexpected(a_biased.error()); - auto sp = backend_->softplus(*a_biased); // [seq, Hv] - if (!sp) return std::unexpected(sp.error()); - auto exp_A = backend_->exp(*A_log_w); // [Hv] - if (!exp_A) return std::unexpected(exp_A.error()); - auto neg_pre = backend_->multiply(*exp_A, *sp); // [seq, Hv] - if (!neg_pre) return std::unexpected(neg_pre.error()); - { - float zero = 0.0f; - Tensor zt = backend_->create_tensor(std::span(&zero, 1), {1}); - auto neg = backend_->subtract(zt, *neg_pre); // [seq, Hv] (negate) - if (!neg) return std::unexpected(neg.error()); - neg_pre = neg; - } - auto g_seq = backend_->exp(*neg_pre); // [seq, Hv] - if (!g_seq) return std::unexpected(g_seq.error()); - auto beta_seq = backend_->sigmoid(*b); // [seq, Hv] - if (!beta_seq) return std::unexpected(beta_seq.error()); - - // Sequential SSM recurrence (one step per time position) - std::vector ys; - ys.reserve(seq); - - for (size_t t = 0; t < seq; ++t) { - // Extract position t from [seq, Hv, Dk] tensors - auto qt3 = backend_->slice(*q_sc, (int)t, (int)t+1, 0); - if (!qt3) return std::unexpected(qt3.error()); - auto q_t = backend_->reshape(*qt3, {Hv, Dk}); - if (!q_t) return std::unexpected(q_t.error()); - - auto kt3 = backend_->slice(*k_sc, (int)t, (int)t+1, 0); - if (!kt3) return std::unexpected(kt3.error()); - auto k_t = backend_->reshape(*kt3, {Hv, Dk}); - if (!k_t) return std::unexpected(k_t.error()); - - auto vt3 = backend_->slice(*v_3d, (int)t, (int)t+1, 0); - if (!vt3) return std::unexpected(vt3.error()); - auto v_t = backend_->reshape(*vt3, {Hv, Dv}); - if (!v_t) return std::unexpected(v_t.error()); - - auto gt3 = backend_->slice(*g_seq, (int)t, (int)t+1, 0); - if (!gt3) return std::unexpected(gt3.error()); - auto g_t = backend_->reshape(*gt3, {Hv}); - if (!g_t) return std::unexpected(g_t.error()); - - auto bt3 = backend_->slice(*beta_seq, (int)t, (int)t+1, 0); - if (!bt3) return std::unexpected(bt3.error()); - auto beta_t = backend_->reshape(*bt3, {Hv}); - if (!beta_t) return std::unexpected(beta_t.error()); - - // SSM step on state: [Hv, Dv, Dk] - Tensor& h = *state.rec_state; - - // 1. Decay: h = h * g_t (g_t broadcast [Hv] → [Hv,1,1]) - auto g_exp = backend_->reshape(*g_t, {Hv, 1, 1}); - if (!g_exp) return std::unexpected(g_exp.error()); - auto h_dec = backend_->multiply(h, *g_exp); // [Hv, Dv, Dk] - if (!h_dec) return std::unexpected(h_dec.error()); - - // 2. kv_mem = h_dec @ k_t → [Hv, Dv] - auto k_col = backend_->reshape(*k_t, {Hv, Dk, 1}); - if (!k_col) return std::unexpected(k_col.error()); - auto kv3d = backend_->matmul(*h_dec, *k_col); // [Hv, Dv, 1] - if (!kv3d) return std::unexpected(kv3d.error()); - auto kv_mem = backend_->reshape(*kv3d, {Hv, Dv}); - if (!kv_mem) return std::unexpected(kv_mem.error()); - - // 3. delta = (v_t - kv_mem) * beta_t → [Hv, Dv] - auto err = backend_->subtract(*v_t, *kv_mem); - if (!err) return std::unexpected(err.error()); - auto beta_exp = backend_->reshape(*beta_t, {Hv, 1}); - if (!beta_exp) return std::unexpected(beta_exp.error()); - auto delta = backend_->multiply(*err, *beta_exp); // [Hv, Dv] - if (!delta) return std::unexpected(delta.error()); - - // 4. state += outer(delta, k_t) → [Hv, Dv, Dk] - auto k_row = backend_->reshape(*k_t, {Hv, 1, Dk}); - if (!k_row) return std::unexpected(k_row.error()); - auto d_col = backend_->reshape(*delta, {Hv, Dv, 1}); - if (!d_col) return std::unexpected(d_col.error()); - auto outer = backend_->multiply(*d_col, *k_row); // [Hv, Dv, Dk] - if (!outer) return std::unexpected(outer.error()); - auto h_new = backend_->add(*h_dec, *outer); - if (!h_new) return std::unexpected(h_new.error()); - - // 5. y = h_new @ q_t → [Hv, Dv] - auto q_col = backend_->reshape(*q_t, {Hv, Dk, 1}); - if (!q_col) return std::unexpected(q_col.error()); - auto y3d = backend_->matmul(*h_new, *q_col); // [Hv, Dv, 1] - if (!y3d) return std::unexpected(y3d.error()); - auto y_t = backend_->reshape(*y3d, {Hv, Dv}); - if (!y_t) return std::unexpected(y_t.error()); - - state.rec_state = *h_new; - ys.push_back(*y_t); - } - - // Stack outputs: [seq, Hv, Dv] - std::vector ys_batched; - ys_batched.reserve(seq); - for (auto& y : ys) { - auto yr = backend_->reshape(y, {1, Hv, Dv}); - if (!yr) return std::unexpected(yr.error()); - ys_batched.push_back(*yr); - } - Result out_3d = (ys_batched.size() == 1) - ? Result{ys_batched[0]} - : backend_->concatenate(ys_batched, 0); - if (!out_3d) return std::unexpected(out_3d.error()); - - // RMSNormGated: norm(out, norm.weight, eps) * silu(z) - // Flatten [seq, Hv, Dv] → [seq*Hv, Dv] for rms_norm - auto out_flat = backend_->reshape(*out_3d, {seq * Hv, Dv}); - if (!out_flat) return std::unexpected(out_flat.error()); - auto out_normed = backend_->rms_norm(*out_flat, *norm_w, config_.rms_norm_eps); - if (!out_normed) return std::unexpected(out_normed.error()); - auto out_n3d = backend_->reshape(*out_normed, {seq, Hv, Dv}); - if (!out_n3d) return std::unexpected(out_n3d.error()); - - auto z_silu = backend_->silu(*z_3d); - if (!z_silu) return std::unexpected(z_silu.error()); - auto gated = backend_->multiply(*out_n3d, *z_silu); // [seq, Hv, Dv] - if (!gated) return std::unexpected(gated.error()); - - // Reshape and out_proj: [seq, Hv*Dv] → [seq, hidden] - auto gated_flat = backend_->reshape(*gated, {seq, Hv * Dv}); - if (!gated_flat) return std::unexpected(gated_flat.error()); - - return linear(*gated_flat, pfx + "out_proj"); -} - -// ── Forward pass ────────────────────────────────────────────────────────────── - -Result> Qwen3MoeModel::forward_impl( - const std::vector& input_ids, - int position_offset, - std::vector* cache_vec) -{ - auto hidden = embedding(input_ids); - if (!hidden) return std::unexpected(hidden.error()); - - const int full_attn_interval = 4; // (layer_idx + 1) % 4 == 0 → full attention - - for (int i = 0; i < static_cast(config_.num_hidden_layers); ++i) { - const std::string lpfx = - "language_model.model.layers." + std::to_string(i) + "."; - - // Pre-attention RMSNorm - auto pre_norm_w = get_weight(lpfx + "input_layernorm.weight"); - if (!pre_norm_w) return std::unexpected(pre_norm_w.error()); - auto normed = backend_->rms_norm(*hidden, *pre_norm_w, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - - // Attention (SSM or full) - bool is_linear = (i + 1) % full_attn_interval != 0; - Result attn_out{std::unexpected(Error{ErrorCode::NotImplemented, "unset"})}; - if (is_linear) { - attn_out = linear_attention_block(*normed, i); - } else { - LayerKVCache* kvc = (cache_vec && i < (int)cache_vec->size()) - ? &(*cache_vec)[i] : nullptr; - attn_out = full_attention_block(*normed, i, position_offset, kvc); - } - if (!attn_out) return std::unexpected(attn_out.error()); - - auto residual1 = backend_->add(*hidden, *attn_out); - if (!residual1) return std::unexpected(residual1.error()); - - // Post-attention RMSNorm - auto post_norm_w = get_weight(lpfx + "post_attention_layernorm.weight"); - if (!post_norm_w) return std::unexpected(post_norm_w.error()); - auto normed2 = backend_->rms_norm(*residual1, *post_norm_w, config_.rms_norm_eps); - if (!normed2) return std::unexpected(normed2.error()); - - // MoE MLP (all layers) - auto mlp_out = moe_mlp(*normed2, i); - if (!mlp_out) return std::unexpected(mlp_out.error()); - - auto result = backend_->add(*residual1, *mlp_out); - if (!result) return std::unexpected(result.error()); - hidden = std::move(result); - } - - // Final RMSNorm - auto norm_w = get_weight("language_model.model.norm.weight"); - if (!norm_w) return std::unexpected(norm_w.error()); - auto normed = backend_->rms_norm(*hidden, *norm_w, config_.rms_norm_eps); - if (!normed) return std::unexpected(normed.error()); - - // LM head - auto logits = linear(*normed, "language_model.lm_head"); - if (!logits) return std::unexpected(logits.error()); - - // Extract last token's logits - const size_t seq_len = input_ids.size(); - const size_t vocab_size = config_.vocab_size; - - auto last = backend_->slice(*logits, (int)(seq_len - 1), (int)seq_len, 0); - if (!last) return std::unexpected(last.error()); - auto flat = backend_->reshape(*last, {vocab_size}); - if (!flat) return std::unexpected(flat.error()); - - std::vector result(vocab_size); - auto ex = backend_->extract(*flat, result); - if (!ex) return std::unexpected(ex.error()); - return result; -} - -} // namespace compute diff --git a/compute/src/compute/model/qwen3_moe_model.h b/compute/src/compute/model/qwen3_moe_model.h deleted file mode 100644 index ae1a8b0..0000000 --- a/compute/src/compute/model/qwen3_moe_model.h +++ /dev/null @@ -1,102 +0,0 @@ -#pragma once - -#include "language_model.h" -#include "qwen3_moe_model_base.h" -#include "kv_cache.h" -#include -#include -#include - -namespace compute { - -/** - * Qwen3.5 MoE — hybrid GatedDeltaNet SSM + GQA + MoE language model. - * - * model_type: "qwen3_5_moe" - * - * Architecture (40 layers): - * - Every layer: MoE MLP (switch_mlp batched experts + shared_expert + shared_expert_gate) - * - is_linear = (layer_idx + 1) % 4 != 0 → GatedDeltaNet SSM - * - is_linear = false (layers 3, 7, 11, …) → full GQA with q_norm/k_norm + output gate - * - * Weight prefix: language_model.model.layers.{i}.* - */ -class Qwen3MoeModel final : public Qwen3MoeModelBase, public LanguageModel { -public: - static Result from_model_dir( - const std::filesystem::path& model_dir, - ComputeBackend* backend); - - // ── LanguageModel interface ─────────────────────────────────────────────── - - Result> generate( - const std::vector& input_ids, - size_t max_new_tokens = 4096, - SamplingParams params = {}, - std::function on_token = nullptr) override; - - const ModelConfig& config() const override { return config_; } - const std::string& model_type() const override { return config_.model_type; } - const SimpleBpeTokenizer& tokenizer() const override { return tokenizer_; } - size_t num_parameters() const override { return Qwen3MoeModelBase::num_parameters(); } - -private: - // ── KV-cache steps (private — used by generate() only) ─────────────────── - - Result> prefill(const std::vector& prompt_ids); - Result> decode(int token_id); - void reset_cache(); - - // Per-SSM-layer state (GatedDeltaNet) - struct SsmState { - std::optional conv_state; // [kernel_size-1, conv_dim] - std::optional rec_state; // [Hv, Dv, Dk] - bool valid = false; - }; - - Qwen3MoeModel( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights, - ComputeBackend* backend); - - // ── Helpers ─────────────────────────────────────────────────────────────── - - Result embedding(const std::vector& token_ids); - - // Linear projection: dispatches to quantized_matmul or matmul. - // Uses infer_quant_bits so gate layers (8-bit) work correctly. - Result linear(const Tensor& input, const std::string& weight_key); - - // Linear projection for a single expert slice from a 3D weight bank. - // weight_key names a weight of shape [num_experts, out, in_packed]. - Result expert_linear(const Tensor& input, const std::string& weight_key, int expert_idx); - - // MoE MLP — used by every layer regardless of attention type. - Result moe_mlp(const Tensor& input, int layer_idx); - - // Full-attention transformer block (every 4th layer: 3, 7, 11, …). - Result full_attention_block( - const Tensor& input, int layer_idx, int position_offset, LayerKVCache* cache); - - // GatedDeltaNet SSM transformer block (all other layers). - Result linear_attention_block(const Tensor& input, int layer_idx); - - // Top-level forward: embedding → layer loop → norm → lm_head. - Result> forward_impl( - const std::vector& input_ids, - int position_offset, - std::vector* cache_vec); - - // ── State ───────────────────────────────────────────────────────────────── - - ComputeBackend* backend_; - - std::optional dequantized_embed_tokens_; - std::vector kv_cache_; - std::vector ssm_cache_; - size_t cache_position_ = 0; - -}; - -} // namespace compute diff --git a/compute/src/compute/model/qwen3_moe_model_base.cpp b/compute/src/compute/model/qwen3_moe_model_base.cpp deleted file mode 100644 index cc38a04..0000000 --- a/compute/src/compute/model/qwen3_moe_model_base.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "qwen3_moe_model_base.h" -#include - -namespace compute { - -Qwen3MoeModelBase::Qwen3MoeModelBase( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights) - : config_(std::move(config)) - , tokenizer_(std::move(tokenizer)) - , weights_(std::move(weights)) -{} - -size_t Qwen3MoeModelBase::num_parameters() const { - size_t total = 0; - for (const auto& [name, tensor] : weights_) total += tensor.size(); - return total; -} - -Result Qwen3MoeModelBase::get_weight(const std::string& name) const { - auto it = weights_.find(name); - if (it == weights_.end()) - return std::unexpected(Error{ErrorCode::InvalidModel, "Weight not found: " + name}); - return it->second; -} - -int Qwen3MoeModelBase::infer_quant_bits(const Tensor& w, const Tensor& scales) const { - size_t in_packed = w.shape().back(); - size_t groups = scales.shape().back(); - size_t gs = config_.quantization ? config_.quantization->group_size : 64; - double ratio = static_cast(in_packed) / static_cast(groups); - int bits = static_cast(std::round(32.0 * ratio / static_cast(gs))); - return (bits > 0) ? bits : 4; -} - -} // namespace compute diff --git a/compute/src/compute/model/qwen3_moe_model_base.h b/compute/src/compute/model/qwen3_moe_model_base.h deleted file mode 100644 index ae9baff..0000000 --- a/compute/src/compute/model/qwen3_moe_model_base.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include "model_config.h" -#include "simple_bpe_tokenizer.h" -#include "../core/tensor.h" -#include "../core/compute_types.h" -#include -#include - -namespace compute { - -class Qwen3MoeModelBase { -public: - size_t num_parameters() const; - -protected: - Qwen3MoeModelBase( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights); - - Result get_weight(const std::string& name) const; - int infer_quant_bits(const Tensor& w, const Tensor& scales) const; - - ModelConfig config_; - SimpleBpeTokenizer tokenizer_; - std::unordered_map weights_; -}; - -} // namespace compute diff --git a/compute/tests/compute/test_attention_qkv_trace.cpp b/compute/tests/compute/test_attention_qkv_trace.cpp index 9f641b3..41392d4 100644 --- a/compute/tests/compute/test_attention_qkv_trace.cpp +++ b/compute/tests/compute/test_attention_qkv_trace.cpp @@ -1,146 +1,3 @@ +// attention_layer() was removed in the Phase O Tensor-abstraction cleanup (O.6.2). +// Tests that relied on it have been deleted. This file is preserved as a placeholder. #include -#include "compute/model/tinyllama_inference.h" -#include "compute/core/compute_backend.h" -#include "test_config.h" -#include -#include -#include - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) -#include "compute/backends/mlx/mlx_backend.h" -#include "compute/backends/mlx/mlx_buffer.h" -#include -#endif - -namespace compute { - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - -class AttentionQKVTraceTest : public ::testing::Test { -protected: - static void SetUpTestSuite() { - model_dir_ = TINYLLAMA_MODEL_DIR; - baseline_dir_ = std::filesystem::path(TEST_RESOURCES_DIR).parent_path() / "baselines" / "output"; - - if (!std::filesystem::exists(model_dir_)) { - skip_reason_ = "Model not found: " + model_dir_.string(); - return; - } - - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { skip_reason_ = backend_result.error().message; return; } - backend_ = std::move(*backend_result); - if (!backend_->initialize()) { skip_reason_ = "Backend init failed"; return; } - - auto inference_result = TinyLlamaInference::from_model_dir(model_dir_, backend_.get()); - if (!inference_result) { skip_reason_ = inference_result.error().message; return; } - inference_ = std::make_unique(std::move(*inference_result)); - } - - static void TearDownTestSuite() { - inference_.reset(); - if (backend_) backend_->cleanup(); - backend_.reset(); - skip_reason_.clear(); - } - - void SetUp() override { - if (!skip_reason_.empty()) - GTEST_SKIP() << skip_reason_; - } - - void compare_arrays(const mx::array& cpp_array, const mx::array& python_array, - const std::string& name, float rtol = 1e-3f, float atol = 1e-4f) { - ASSERT_EQ(cpp_array.shape().size(), python_array.shape().size()) - << name << ": Shape dimension mismatch"; - for (size_t i = 0; i < cpp_array.shape().size(); ++i) { - ASSERT_EQ(cpp_array.shape()[i], python_array.shape()[i]) - << name << ": Shape mismatch at dimension " << i; - } - - auto cpp_f32 = mx::astype(cpp_array, mx::float32); - auto py_f32 = mx::astype(python_array, mx::float32); - mx::eval(cpp_f32, py_f32); - - auto diff = mx::abs(cpp_f32 - py_f32); - auto rel_diff = diff / (mx::abs(py_f32) + 1e-8f); - mx::eval(diff, rel_diff); - - auto max_abs_diff = mx::max(diff).item(); - auto max_rel_diff = mx::max(rel_diff).item(); - - EXPECT_LE(max_abs_diff, atol) - << name << ": Max absolute difference " << max_abs_diff << " exceeds tolerance " << atol; - EXPECT_LE(max_rel_diff, rtol) - << name << ": Max relative difference " << max_rel_diff << " exceeds tolerance " << rtol; - - if (max_abs_diff <= atol && max_rel_diff <= rtol) { - std::cout << "✓ " << name << " matches (max_abs_diff=" << max_abs_diff - << ", max_rel_diff=" << max_rel_diff << ")" << std::endl; - } - } - - static std::filesystem::path model_dir_; - static std::filesystem::path baseline_dir_; - static std::string skip_reason_; - static std::unique_ptr backend_; - static std::unique_ptr inference_; -}; - -std::filesystem::path AttentionQKVTraceTest::model_dir_; -std::filesystem::path AttentionQKVTraceTest::baseline_dir_; -std::string AttentionQKVTraceTest::skip_reason_; -std::unique_ptr AttentionQKVTraceTest::backend_; -std::unique_ptr AttentionQKVTraceTest::inference_; - -TEST_F(AttentionQKVTraceTest, CompareWithPythonBaseline) { - GTEST_SKIP() << "attention_layer() not available in MLX path (O.6.2)"; - - auto baseline_file = baseline_dir_ / "attention_full_baseline.safetensors"; - ASSERT_TRUE(std::filesystem::exists(baseline_file)) - << "Baseline file not found: " << baseline_file; - - std::cout << "Loading baseline from: " << baseline_file << std::endl; - auto baseline_data = mx::load_safetensors(baseline_file.string()); - auto& baseline_arrays = baseline_data.first; - - auto input_mlx = baseline_arrays.at("input"); - mx::eval(input_mlx); - - std::cout << "Input shape: [" << input_mlx.shape()[0] << ", " - << input_mlx.shape()[1] << "]" << std::endl; - std::cout << "Input dtype: " << input_mlx.dtype() << std::endl; - - auto input_bf16 = mx::astype(input_mlx, mx::bfloat16); - mx::eval(input_bf16); - auto* input_array_ptr = new mx::array(input_bf16); - auto input_tensor = backend_->wrap_native_tensor(input_array_ptr, {5, 2048}); - - const int layer_idx = 0; - std::cout << "\nRunning C++ attention_layer..." << std::endl; - auto attention_output = inference_->attention_layer(input_tensor, layer_idx); - - ASSERT_TRUE(attention_output.has_value()) - << "attention_layer failed: " << attention_output.error().message; - - std::cout << "✓ attention_layer completed successfully" << std::endl; - - auto cpp_output = const_cast(*attention_output).to_mlx(); - mx::eval(cpp_output); - - std::cout << "C++ output shape: [" << cpp_output.shape()[0] << ", " - << cpp_output.shape()[1] << "]" << std::endl; - - auto python_output = baseline_arrays.at("final_output"); - std::cout << "Python output shape: [" << python_output.shape()[0] << ", " - << python_output.shape()[1] << "]" << std::endl; - - std::cout << "\nComparing C++ vs Python outputs..." << std::endl; - compare_arrays(cpp_output, python_output, "final_output", 1e-2f, 1e-2f); - - std::cout << "\n✓ Full attention layer test passed!" << std::endl; -} - -#endif // MLX_BACKEND_ENABLED - -} // namespace compute diff --git a/compute/tests/compute/test_forward_pass.cpp b/compute/tests/compute/test_forward_pass.cpp index 40e5281..4d84a3a 100644 --- a/compute/tests/compute/test_forward_pass.cpp +++ b/compute/tests/compute/test_forward_pass.cpp @@ -58,38 +58,6 @@ std::string ForwardPassTest::skip_reason_; std::unique_ptr ForwardPassTest::backend_; std::unique_ptr ForwardPassTest::inference_; -TEST_F(ForwardPassTest, GreedyTokenMatchesPython) { - GTEST_SKIP() << "forward() not available in MLX path (O.6.2)"; - - std::vector token_ids = {1, 1724, 338, 278, 7483, 310, 3444, 29973}; - - std::cout << "Running forward_logits for " << token_ids.size() << " tokens..." << std::endl; - - auto logits_result = inference_->forward(token_ids); - ASSERT_TRUE(logits_result.has_value()) << logits_result.error().message; - - const auto& logits = *logits_result; - ASSERT_EQ(logits.size(), inference_->config().vocab_size); - - int greedy_token = static_cast( - std::max_element(logits.begin(), logits.end()) - logits.begin()); - - std::cout << "C++ greedy next token: " << greedy_token << std::endl; - - const int python_greedy = 2; - EXPECT_EQ(greedy_token, python_greedy) - << "Greedy token mismatch: C++ got " << greedy_token - << ", Python got " << python_greedy; - - float top_logit = logits[greedy_token]; - EXPECT_TRUE(std::isfinite(top_logit)); - EXPECT_GT(top_logit, -100.0f); - EXPECT_LT(top_logit, 100.0f); - - std::cout << "✓ Greedy token matches Python baseline (token=" << greedy_token - << ", logit=" << top_logit << ")" << std::endl; -} - TEST_F(ForwardPassTest, GenerateCoherentOutput) { const std::string prompt = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"; auto token_ids = inference_->tokenizer().encode(prompt, /*add_special_tokens=*/true); diff --git a/compute/tests/compute/test_mistral_integration.cpp b/compute/tests/compute/test_mistral_integration.cpp index 28ccd83..c019993 100644 --- a/compute/tests/compute/test_mistral_integration.cpp +++ b/compute/tests/compute/test_mistral_integration.cpp @@ -225,46 +225,6 @@ TEST_F(MistralIntegrationTest, DiagnosticTokenLogitTrace) { std::cout << "================================================\n\n"; } -// Diagnostic: compare no-cache forward vs prefill+decode -// Skipped on Apple Silicon: LlamaModel uses the MLX path which has no Tensor -// weights_ (O.6.2), so forward() is not available. -TEST_F(MistralIntegrationTest, DiagnosticNoCacheVsDecode) { - GTEST_SKIP() << "forward() not available in MLX path (O.6.2)"; - - const std::string prompt = "[INST] What is the capital of France? [/INST]"; - auto token_ids = inference_->tokenizer().encode(prompt, /*add_special_tokens=*/true); - ASSERT_FALSE(token_ids.empty()); - - const int tok_the = 1183; - - std::vector extended = token_ids; - extended.push_back(tok_the); - auto nc_result = inference_->forward(extended); - ASSERT_TRUE(nc_result.has_value()) << nc_result.error().message; - - auto pf_result = inference_->prefill(token_ids); - ASSERT_TRUE(pf_result.has_value()) << pf_result.error().message; - auto dc_result = inference_->decode(tok_the); - ASSERT_TRUE(dc_result.has_value()) << dc_result.error().message; - - auto print_top5 = [&](const std::string& label, const std::vector& logits) { - std::vector> top; - for (size_t i = 0; i < logits.size(); ++i) top.push_back({logits[i], (int)i}); - std::partial_sort(top.begin(), top.begin()+5, top.end(), - [](const auto& a, const auto& b){ return a.first > b.first; }); - std::cout << label << ": "; - for (int k = 0; k < 5; ++k) - std::cout << "[" << top[k].second << "=" << inference_->tokenizer().decode({top[k].second}) - << "(" << top[k].first << ")] "; - std::cout << "\n"; - }; - - std::cout << "\n=== No-cache vs decode comparison ===\n"; - std::cout << "Python reference: [6333=capital(28.4531)] [17821=Capital(16.8281)] ...\n"; - print_top5("No-cache forward", *nc_result); - print_top5("Prefill+decode ", *dc_result); -} - // Exercises the same sampling path as the app TEST_F(MistralIntegrationTest, GenerateCapitalOfFranceWithSampling) { const std::string prompt = "[INST] what is the capital of france? [/INST]"; diff --git a/compute/tests/compute/test_mlx_backend.cpp b/compute/tests/compute/test_mlx_backend.cpp deleted file mode 100644 index 4691250..0000000 --- a/compute/tests/compute/test_mlx_backend.cpp +++ /dev/null @@ -1,1817 +0,0 @@ -#include -#include -#include - -#include "compute/core/compute_types.h" -#include "compute/core/compute_backend.h" -#include "compute/core/graph.h" - -using namespace compute; - -TEST(MLXBackendTest, MLXBackendAvailability) { - auto backend_result = BackendFactory::create(BackendType::MLX); - - if (!backend_result) { - std::cerr << "MLX backend creation failed: " << backend_result.error().message << std::endl; - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - - ASSERT_TRUE(backend->is_available()); - - auto init_result = backend->initialize(); - if (!init_result) { - std::cerr << "MLX backend initialization failed: " << init_result.error().message << std::endl; - } - ASSERT_TRUE(init_result); -} - -TEST(MLXBackendTest, MLXBackendDirectDotProduct) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Simple vectors for testing - std::vector vec_a = {1.0f, 2.0f}; - std::vector vec_b = {3.0f, 4.0f}; - - // Create tensors using new API - auto tensor_a = backend->create_tensor(std::span(vec_a), {2}); - auto tensor_b = backend->create_tensor(std::span(vec_b), {2}); - - // Execute dot product - auto result_tensor = backend->dot_product(tensor_a, tensor_b); - - // Extract result - std::vector result_data(1); - auto extract_result = backend->extract(result_tensor, std::span(result_data)); - - if (!extract_result) { - std::cerr << "MLX dot product failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - EXPECT_NEAR(result_data[0], 11.0f, 1e-6); // 1*3 + 2*4 = 11 -} - -TEST(MLXBackendTest, MLXBackendDirectMatrixScalarAddition) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test matrix: add 5 to each element - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f}; - - // Create input tensor using new API - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 2}); - - // Execute scalar addition - auto result_tensor = backend->matrix_scalar_add(input_tensor, 5.0f); - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX matrix scalar add failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // Check results - EXPECT_NEAR(output_data[0], 6.0f, 1e-6); // 1 + 5 - EXPECT_NEAR(output_data[1], 7.0f, 1e-6); // 2 + 5 - EXPECT_NEAR(output_data[2], 8.0f, 1e-6); // 3 + 5 - EXPECT_NEAR(output_data[3], 9.0f, 1e-6); // 4 + 5 -} - -TEST(MLXBackendTest, MLXBackendDirectMatrixMultiplication) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test matrix multiplication: (2x3) x (3x2) -> (2x2) - // Matrix A: [[1, 2, 3], [4, 5, 6]] - // Matrix B: [[7, 8], [9, 10], [11, 12]] - // Expected result: [[58, 64], [139, 154]] - std::vector matrix_a = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - std::vector matrix_b = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; - - // Create tensors - auto tensor_a = backend->create_tensor(std::span(matrix_a), {2, 3}); - auto tensor_b = backend->create_tensor(std::span(matrix_b), {3, 2}); - - // Execute matrix multiplication - auto matmul_result = backend->matmul(tensor_a, tensor_b); - ASSERT_TRUE(matmul_result); - auto result_tensor = *matmul_result; - - // Verify result shape - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 2); - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX matmul failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // Check results: [1*7+2*9+3*11, 1*8+2*10+3*12, 4*7+5*9+6*11, 4*8+5*10+6*12] - EXPECT_NEAR(output_data[0], 58.0f, 1e-6); // 7 + 18 + 33 = 58 - EXPECT_NEAR(output_data[1], 64.0f, 1e-6); // 8 + 20 + 36 = 64 - EXPECT_NEAR(output_data[2], 139.0f, 1e-6); // 28 + 45 + 66 = 139 - EXPECT_NEAR(output_data[3], 154.0f, 1e-6); // 32 + 50 + 72 = 154 -} - -TEST(MLXBackendTest, MLXBackendDirectSoftmax) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test softmax on a simple 2x3 matrix - // Input: [[1, 2, 3], [4, 5, 6]] - // Softmax applied to last dimension (-1) - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - - // Create tensor - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - // Execute softmax (default: last dimension) - auto softmax_result = backend->softmax(input_tensor, -1); - ASSERT_TRUE(softmax_result); - auto result_tensor = *softmax_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX softmax failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // With the corrected implementation, softmax is applied to the last dimension (dim=1) - // So each row should sum to 1.0 - float row1_sum = output_data[0] + output_data[1] + output_data[2]; - float row2_sum = output_data[3] + output_data[4] + output_data[5]; - - EXPECT_NEAR(row1_sum, 1.0f, 1e-6); - EXPECT_NEAR(row2_sum, 1.0f, 1e-6); - - // Check that probabilities are positive and monotonic within each row - EXPECT_GT(output_data[0], 0.0f); - EXPECT_GT(output_data[1], output_data[0]); // e^2 > e^1 - EXPECT_GT(output_data[2], output_data[1]); // e^3 > e^2 - - EXPECT_GT(output_data[3], 0.0f); - EXPECT_GT(output_data[4], output_data[3]); // e^5 > e^4 - EXPECT_GT(output_data[5], output_data[4]); // e^6 > e^5 -} - -TEST(MLXBackendTest, MLXBackendDirectSoftmaxDimension0) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test softmax on dimension 0 (columns) of a 2x3 matrix - // Input: [[1, 2, 3], [4, 5, 6]] - // Softmax applied to dimension 0 (across rows for each column) - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - - // Create tensor - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - // Execute softmax on dimension 0 - auto softmax_result = backend->softmax(input_tensor, 0); - ASSERT_TRUE(softmax_result); - auto result_tensor = *softmax_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX softmax dim 0 failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // When applying softmax to dimension 0, each column should sum to 1.0 - // Column 0: softmax([1, 4]) -> values at indices [0, 3] - // Column 1: softmax([2, 5]) -> values at indices [1, 4] - // Column 2: softmax([3, 6]) -> values at indices [2, 5] - float col0_sum = output_data[0] + output_data[3]; // elements [0,0] and [1,0] - float col1_sum = output_data[1] + output_data[4]; // elements [0,1] and [1,1] - float col2_sum = output_data[2] + output_data[5]; // elements [0,2] and [1,2] - - EXPECT_NEAR(col0_sum, 1.0f, 1e-6); - EXPECT_NEAR(col1_sum, 1.0f, 1e-6); - EXPECT_NEAR(col2_sum, 1.0f, 1e-6); - - // Check that larger values get higher probabilities within each column - // For each column, the second row should have higher probability (larger input values) - EXPECT_GT(output_data[3], output_data[0]); // 4 > 1, so softmax(4) > softmax(1) - EXPECT_GT(output_data[4], output_data[1]); // 5 > 2, so softmax(5) > softmax(2) - EXPECT_GT(output_data[5], output_data[2]); // 6 > 3, so softmax(6) > softmax(3) - - // All probabilities should be positive - for (int i = 0; i < 6; i++) { - EXPECT_GT(output_data[i], 0.0f); - } -} - -TEST(MLXBackendTest, MLXBackendDirectTranspose) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test transpose on a 2x3 matrix: [[1, 2, 3], [4, 5, 6]] -> [[1, 4], [2, 5], [3, 6]] - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - - // Create tensor - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - // Execute transpose - auto transpose_result = backend->transpose(input_tensor); - ASSERT_TRUE(transpose_result); - auto result_tensor = *transpose_result; - - // Verify result shape (should be swapped: 2x3 -> 3x2) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 3); - EXPECT_EQ(result_tensor.shape()[1], 2); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX transpose failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // For now, let's just verify the shape is correct and all data is preserved - // We'll fix the exact ordering based on debug output - std::vector expected_input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - std::sort(expected_input.begin(), expected_input.end()); - std::vector actual_output(output_data.begin(), output_data.end()); - std::sort(actual_output.begin(), actual_output.end()); - - // Verify all original values are preserved (just reordered) - for (size_t i = 0; i < 6; i++) { - EXPECT_NEAR(actual_output[i], expected_input[i], 1e-6); - } -} - -TEST(MLXBackendTest, MLXBackendDirectSwapaxes) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test swapaxes(-2, -1) on a 2x3 matrix for attention mechanism - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - - // Create tensor - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - // Execute swapaxes (swap last two dimensions: -2, -1) - auto swapaxes_result = backend->swapaxes(input_tensor, -2, -1); - ASSERT_TRUE(swapaxes_result); - auto result_tensor = *swapaxes_result; - - // Verify result shape (should be swapped: 2x3 -> 3x2) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 3); - EXPECT_EQ(result_tensor.shape()[1], 2); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX swapaxes failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // For now, let's just verify the shape is correct and all data is preserved - std::vector expected_input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - std::sort(expected_input.begin(), expected_input.end()); - std::vector actual_output(output_data.begin(), output_data.end()); - std::sort(actual_output.begin(), actual_output.end()); - - // Verify all original values are preserved (just reordered) - for (size_t i = 0; i < 6; i++) { - EXPECT_NEAR(actual_output[i], expected_input[i], 1e-6); - } -} - -TEST(MLXBackendTest, MLXBackendDirectAdd) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test element-wise addition of two 2x2 matrices - std::vector matrix_a = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector matrix_b = {5.0f, 6.0f, 7.0f, 8.0f}; - - // Create tensors - auto tensor_a = backend->create_tensor(std::span(matrix_a), {2, 2}); - auto tensor_b = backend->create_tensor(std::span(matrix_b), {2, 2}); - - // Execute element-wise addition - auto add_result = backend->add(tensor_a, tensor_b); - ASSERT_TRUE(add_result); - auto result_tensor = *add_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 2); - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - - if (!extract_result) { - std::cerr << "MLX add failed: " << extract_result.error().message << std::endl; - } - - ASSERT_TRUE(extract_result); - - // Check results: element-wise addition [1+5, 2+6, 3+7, 4+8] = [6, 8, 10, 12] - EXPECT_NEAR(output_data[0], 6.0f, 1e-6); // 1 + 5 - EXPECT_NEAR(output_data[1], 8.0f, 1e-6); // 2 + 6 - EXPECT_NEAR(output_data[2], 10.0f, 1e-6); // 3 + 7 - EXPECT_NEAR(output_data[3], 12.0f, 1e-6); // 4 + 8 -} - -// Graph-based MLX tests (using the computation graph API) -TEST(MLXBackendTest, MLXBackendEndToEndDotProductExecution) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto builder = ComputeGraphBuilder(BackendType::MLX); - - // Test vectors: dot product should be 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32 - std::vector vec_a = {1.0f, 2.0f, 3.0f}; - std::vector vec_b = {4.0f, 5.0f, 6.0f}; - - // Get backend to create tensors - auto& backend = *backend_result; - auto tensor_a = backend->create_tensor(std::span(vec_a), {3}); - auto tensor_b = backend->create_tensor(std::span(vec_b), {3}); - - // Execute dot product - float result; - auto symbolic = builder.dot_product(tensor_a, tensor_b, std::span(&result, 1)); - - auto execution_result = builder.execute(); - if (!execution_result) { - std::cerr << "MLX graph execution failed: " << execution_result.error().message << std::endl; - } - ASSERT_TRUE(execution_result); - - // Check the result - EXPECT_NEAR(result, 32.0f, 1e-6); -} - -TEST(MLXBackendTest, MLXBackendEndToEndMatrixScalarAdditionExecution) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto builder = ComputeGraphBuilder(BackendType::MLX); - - // Test matrix: add 10 to each element - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector output_data(4); - - // Get backend to create tensors - auto& backend = *backend_result; - auto input_matrix = backend->create_tensor(std::span(input_data), {2, 2}); - - // Execute scalar addition - builder.matrix_scalar_add(input_matrix, 10.0f, std::span(output_data), {2, 2}); - - auto execution_result = builder.execute(); - if (!execution_result) { - std::cerr << "MLX graph execution failed: " << execution_result.error().message << std::endl; - } - ASSERT_TRUE(execution_result); - - // Check results - EXPECT_NEAR(output_data[0], 11.0f, 1e-6); - EXPECT_NEAR(output_data[1], 12.0f, 1e-6); - EXPECT_NEAR(output_data[2], 13.0f, 1e-6); - EXPECT_NEAR(output_data[3], 14.0f, 1e-6); -} - -TEST(MLXBackendTest, MLXBackendComplexComputationGraphWithDependencies) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto builder = ComputeGraphBuilder(BackendType::MLX); - - // Build a complex graph: - // 1. Compute dot product of two vectors - // 2. Use that result to add to a matrix - - std::vector vec_a = {1.0f, 2.0f}; // dot product = 1*3 + 2*4 = 11 - std::vector vec_b = {3.0f, 4.0f}; - std::vector matrix_data = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector output_data(4); - - // Get backend to create tensors - auto& backend = *backend_result; - auto tensor_a = backend->create_tensor(std::span(vec_a), {2}); - auto tensor_b = backend->create_tensor(std::span(vec_b), {2}); - auto input_matrix = backend->create_tensor(std::span(matrix_data), {2, 2}); - - // Build computation graph with dependency - float dot_result; - auto symbolic_scalar = builder.dot_product(tensor_a, tensor_b, std::span(&dot_result, 1)); - - // This operation depends on the dot product result - builder.matrix_scalar_add(input_matrix, symbolic_scalar, std::span(output_data), {2, 2}); - - auto execution_result = builder.execute(); - if (!execution_result) { - std::cerr << "MLX graph execution failed: " << execution_result.error().message << std::endl; - } - ASSERT_TRUE(execution_result); - - // Check intermediate result (dot product) - EXPECT_NEAR(dot_result, 11.0f, 1e-6); - - // Check final results (matrix + scalar) - EXPECT_NEAR(output_data[0], 12.0f, 1e-6); // 1 + 11 - EXPECT_NEAR(output_data[1], 13.0f, 1e-6); // 2 + 11 - EXPECT_NEAR(output_data[2], 14.0f, 1e-6); // 3 + 11 - EXPECT_NEAR(output_data[3], 15.0f, 1e-6); // 4 + 11 -} - -// Quantized Matrix Multiplication Tests -TEST(MLXBackendTest, MLXBackendQuantizedMatmulBasicAffine) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test quantized matmul with affine mode (default) - // MLX requires last dimension divisible by group_size (64) - // x: (2, 64) activation matrix - // w: (1, 64) quantized weight matrix (uint32 packed) - - // Create 64-element vectors for proper group size - std::vector x_data(2 * 64); - std::vector w_float(1 * 64); - - // Fill with simple test data - for (int i = 0; i < 128; i++) { - x_data[i] = (float)(i % 10 + 1); // Values 1-10 repeating - } - for (int i = 0; i < 64; i++) { - w_float[i] = 0.1f * (i % 5 + 1); // Values 0.1, 0.2, 0.3, 0.4, 0.5 repeating - } - - auto x_tensor = backend->create_tensor(std::span(x_data), {2, 64}); - auto w_float_tensor = backend->create_tensor(std::span(w_float), {1, 64}); - - // Use MLX quantize to create proper quantized weights, scales, and biases - // We'll need to call MLX directly for this test - try { - // Create MLX arrays for quantization - mx::array w_array = w_float_tensor.to_mlx(); - - // Quantize the weight matrix using MLX - auto quantized_result = mx::quantize(w_array, 64, 4, "affine"); - ASSERT_EQ(quantized_result.size(), 3); // [w_quantized, scales, biases] - - auto w_quantized = quantized_result[0]; - auto scales = quantized_result[1]; - auto biases = quantized_result[2]; - - // Wrap MLX arrays as Tensors - auto w_tensor = backend->wrap_native_tensor(&w_quantized, {w_quantized.shape().begin(), w_quantized.shape().end()}); - auto scales_tensor = backend->wrap_native_tensor(&scales, {scales.shape().begin(), scales.shape().end()}); - auto biases_tensor = backend->wrap_native_tensor(&biases, {biases.shape().begin(), biases.shape().end()}); - - // Execute quantized matrix multiplication - auto qmatmul_result = backend->quantized_matmul( - x_tensor, w_tensor, scales_tensor, &biases_tensor, - true, // transpose - 64, // group_size (default) - 4, // bits (default) - "affine" // mode (default) - ); - - ASSERT_TRUE(qmatmul_result) << "Quantized matmul failed: " << qmatmul_result.error().message; - auto result_tensor = *qmatmul_result; - - // Verify result has expected shape - // x: (2, 64) @ w.T: (64, 1) -> (2, 1) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 1); - - // Extract results - std::vector output_data(2); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Check that results are reasonable (quantized operations introduce some error) - // We mainly verify that the operation completes successfully and produces valid output - for (int i = 0; i < 2; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } - - } catch (const std::exception& e) { - GTEST_SKIP() << "MLX quantize function not available or failed: " << e.what(); - } -} - -TEST(MLXBackendTest, MLXBackendQuantizedMatmulMxfp4Mode) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test quantized matmul with mxfp4 mode (no biases) - // mxfp4 requires group_size = 32, so last dimension must be divisible by 32 - std::vector x_data(2 * 32); - std::vector w_float(1 * 32); - - // Fill with simple test data - for (int i = 0; i < 64; i++) { - x_data[i] = (float)(i % 8 + 1); // Values 1-8 repeating - } - for (int i = 0; i < 32; i++) { - w_float[i] = 0.1f * (i % 4 + 1); // Values 0.1, 0.2, 0.3, 0.4 repeating - } - - auto x_tensor = backend->create_tensor(std::span(x_data), {2, 32}); - auto w_float_tensor = backend->create_tensor(std::span(w_float), {1, 32}); - - try { - // Create MLX arrays for quantization - mx::array w_array = w_float_tensor.to_mlx(); - - // Quantize using mxfp4 mode (requires group_size = 32) - auto quantized_result = mx::quantize(w_array, 32, 4, "mxfp4"); - ASSERT_EQ(quantized_result.size(), 2); // [w_quantized, scales] (no biases for mxfp4) - - auto w_quantized = quantized_result[0]; - auto scales = quantized_result[1]; - - // Wrap MLX arrays as Tensors - auto w_tensor = backend->wrap_native_tensor(&w_quantized, {w_quantized.shape().begin(), w_quantized.shape().end()}); - auto scales_tensor = backend->wrap_native_tensor(&scales, {scales.shape().begin(), scales.shape().end()}); - - // Execute quantized matrix multiplication with mxfp4 mode - auto qmatmul_result = backend->quantized_matmul( - x_tensor, w_tensor, scales_tensor, nullptr, // no biases for mxfp4 - true, // transpose - 32, // group_size (required for mxfp4) - 4, // bits - "mxfp4" // mode - ); - - ASSERT_TRUE(qmatmul_result) << "Quantized matmul mxfp4 failed: " << qmatmul_result.error().message; - auto result_tensor = *qmatmul_result; - - // Verify result has expected shape: (2, 32) @ (32, 1) -> (2, 1) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 1); - - // Extract results - std::vector output_data(2); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Verify finite results - for (int i = 0; i < 2; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } - - } catch (const std::exception& e) { - GTEST_SKIP() << "MLX mxfp4 quantization not available or failed: " << e.what(); - } -} - -TEST(MLXBackendTest, MLXBackendQuantizedMatmulDifferentBits) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test quantized matmul with 8-bit quantization - // Need dimensions divisible by group_size (64) - std::vector x_data(1 * 64); - std::vector w_float(1 * 64); - - // Fill with simple test data - for (int i = 0; i < 64; i++) { - x_data[i] = (float)(i % 6 + 1); // Values 1-6 repeating - w_float[i] = 0.1f * (i % 3 + 1); // Values 0.1, 0.2, 0.3 repeating - } - - auto x_tensor = backend->create_tensor(std::span(x_data), {1, 64}); - auto w_float_tensor = backend->create_tensor(std::span(w_float), {1, 64}); - - try { - mx::array w_array = w_float_tensor.to_mlx(); - - // Test 8-bit quantization - auto quantized_result = mx::quantize(w_array, 64, 8, "affine"); // 8 bits - auto w_quantized = quantized_result[0]; - auto scales = quantized_result[1]; - auto biases = quantized_result[2]; - - auto w_tensor = backend->wrap_native_tensor(&w_quantized, {w_quantized.shape().begin(), w_quantized.shape().end()}); - auto scales_tensor = backend->wrap_native_tensor(&scales, {scales.shape().begin(), scales.shape().end()}); - auto biases_tensor = backend->wrap_native_tensor(&biases, {biases.shape().begin(), biases.shape().end()}); - - auto qmatmul_result = backend->quantized_matmul( - x_tensor, w_tensor, scales_tensor, &biases_tensor, - true, 64, 8, "affine" // 8 bits instead of 4 - ); - - ASSERT_TRUE(qmatmul_result) << "8-bit quantized matmul failed: " << qmatmul_result.error().message; - auto result_tensor = *qmatmul_result; - - EXPECT_EQ(result_tensor.shape()[0], 1); - EXPECT_EQ(result_tensor.shape()[1], 1); - - std::vector output_data(1); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - EXPECT_TRUE(std::isfinite(output_data[0])); - - } catch (const std::exception& e) { - GTEST_SKIP() << "8-bit quantization test failed: " << e.what(); - } -} - -TEST(MLXBackendTest, MLXBackendQuantizedMatmulErrorHandling) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Create simple test tensors - std::vector x_data = {1.0f, 2.0f}; - std::vector dummy_data = {1.0f}; - - auto x_tensor = backend->create_tensor(std::span(x_data), {1, 2}); - auto dummy_tensor = backend->create_tensor(std::span(dummy_data), {1}); - - // Test with incompatible tensor backends (MLX validates this) - // For this test, we'll create a CPU tensor if other backends are available - // Since we only have MLX in this test, we'll test invalid mode instead - - // Test invalid quantization mode - auto invalid_mode_result = backend->quantized_matmul( - x_tensor, dummy_tensor, dummy_tensor, nullptr, - true, 64, 4, "invalid_mode" // Invalid mode - ); - - EXPECT_FALSE(invalid_mode_result) << "Expected failure with invalid quantization mode"; -} - -TEST(MLXBackendTest, MLXBackendDirectSiLU) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test SiLU (Swish) activation: f(x) = x * sigmoid(x) - // For input [0, 1, -1, 2], expected approximately: - // SiLU(0) = 0 * sigmoid(0) = 0 * 0.5 = 0 - // SiLU(1) = 1 * sigmoid(1) = 1 * 0.731 ≈ 0.731 - // SiLU(-1) = -1 * sigmoid(-1) = -1 * 0.269 ≈ -0.269 - // SiLU(2) = 2 * sigmoid(2) = 2 * 0.881 ≈ 1.762 - std::vector input_data = {0.0f, 1.0f, -1.0f, 2.0f}; - - auto input_tensor = backend->create_tensor(std::span(input_data), {4}); - - // Execute SiLU activation - auto silu_result = backend->silu(input_tensor); - ASSERT_TRUE(silu_result) << "SiLU failed: " << silu_result.error().message; - auto result_tensor = *silu_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 4); - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Check SiLU results with reasonable tolerance - EXPECT_NEAR(output_data[0], 0.0f, 1e-6); // SiLU(0) = 0 - EXPECT_NEAR(output_data[1], 0.7311f, 1e-3); // SiLU(1) ≈ 0.731 - EXPECT_NEAR(output_data[2], -0.2689f, 1e-3); // SiLU(-1) ≈ -0.269 - EXPECT_NEAR(output_data[3], 1.7616f, 1e-3); // SiLU(2) ≈ 1.762 - - // Verify all results are finite - for (int i = 0; i < 4; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendSiLUMatrix) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test SiLU on a 2x3 matrix - std::vector input_data = { - 0.0f, 0.5f, 1.0f, // row 1 - -0.5f, -1.0f, 2.0f // row 2 - }; - - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - // Execute SiLU activation - auto silu_result = backend->silu(input_tensor); - ASSERT_TRUE(silu_result) << "Matrix SiLU failed: " << silu_result.error().message; - auto result_tensor = *silu_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Check that SiLU produces reasonable results for matrix input - // Mainly verify operation completes and produces finite values - for (int i = 0; i < 6; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Matrix output " << i << " is not finite: " << output_data[i]; - } - - // Verify SiLU(0) = 0 (first element) - EXPECT_NEAR(output_data[0], 0.0f, 1e-6); - - // Verify all positive inputs produce positive outputs (SiLU characteristic) - EXPECT_GT(output_data[1], 0.0f); // SiLU(0.5) > 0 - EXPECT_GT(output_data[2], 0.0f); // SiLU(1.0) > 0 - EXPECT_GT(output_data[5], 0.0f); // SiLU(2.0) > 0 - - // Verify negative inputs produce negative outputs (for x < 0, SiLU(x) < 0) - EXPECT_LT(output_data[3], 0.0f); // SiLU(-0.5) < 0 - EXPECT_LT(output_data[4], 0.0f); // SiLU(-1.0) < 0 -} - -TEST(MLXBackendTest, MLXBackendDirectMultiply) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test element-wise multiplication with same shape tensors - std::vector data_a = {2.0f, 3.0f, -1.0f, 4.0f}; - std::vector data_b = {1.5f, -2.0f, 3.0f, 0.5f}; - - auto tensor_a = backend->create_tensor(std::span(data_a), {4}); - auto tensor_b = backend->create_tensor(std::span(data_b), {4}); - - // Execute element-wise multiplication - auto multiply_result = backend->multiply(tensor_a, tensor_b); - ASSERT_TRUE(multiply_result) << "Multiply failed: " << multiply_result.error().message; - auto result_tensor = *multiply_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 4); - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [2*1.5, 3*(-2), (-1)*3, 4*0.5] = [3.0, -6.0, -3.0, 2.0] - EXPECT_NEAR(output_data[0], 3.0f, 1e-6); - EXPECT_NEAR(output_data[1], -6.0f, 1e-6); - EXPECT_NEAR(output_data[2], -3.0f, 1e-6); - EXPECT_NEAR(output_data[3], 2.0f, 1e-6); - - // Verify all results are finite - for (int i = 0; i < 4; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendMultiplyMatrix) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test element-wise multiplication on 2x3 matrices - std::vector data_a = { - 1.0f, 2.0f, 3.0f, // row 1 - 4.0f, 5.0f, 6.0f // row 2 - }; - std::vector data_b = { - 2.0f, 0.5f, -1.0f, // row 1 - 0.25f, 2.0f, 1.5f // row 2 - }; - - auto tensor_a = backend->create_tensor(std::span(data_a), {2, 3}); - auto tensor_b = backend->create_tensor(std::span(data_b), {2, 3}); - - // Execute element-wise multiplication - auto multiply_result = backend->multiply(tensor_a, tensor_b); - ASSERT_TRUE(multiply_result) << "Matrix multiply failed: " << multiply_result.error().message; - auto result_tensor = *multiply_result; - - // Verify result shape (should be unchanged) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [1*2, 2*0.5, 3*(-1), 4*0.25, 5*2, 6*1.5] = [2.0, 1.0, -3.0, 1.0, 10.0, 9.0] - EXPECT_NEAR(output_data[0], 2.0f, 1e-6); // 1 * 2 - EXPECT_NEAR(output_data[1], 1.0f, 1e-6); // 2 * 0.5 - EXPECT_NEAR(output_data[2], -3.0f, 1e-6); // 3 * (-1) - EXPECT_NEAR(output_data[3], 1.0f, 1e-6); // 4 * 0.25 - EXPECT_NEAR(output_data[4], 10.0f, 1e-6); // 5 * 2 - EXPECT_NEAR(output_data[5], 9.0f, 1e-6); // 6 * 1.5 - - // Verify all results are finite - for (int i = 0; i < 6; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Matrix output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendMultiplyBroadcasting) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test broadcasting: 2x3 matrix multiplied by scalar (1x1) - std::vector matrix_data = { - 1.0f, 2.0f, 3.0f, // row 1 - 4.0f, 5.0f, 6.0f // row 2 - }; - std::vector scalar_data = {2.0f}; - - auto matrix_tensor = backend->create_tensor(std::span(matrix_data), {2, 3}); - auto scalar_tensor = backend->create_tensor(std::span(scalar_data), {1}); - - // Execute broadcasting multiplication - auto multiply_result = backend->multiply(matrix_tensor, scalar_tensor); - ASSERT_TRUE(multiply_result) << "Broadcasting multiply failed: " << multiply_result.error().message; - auto result_tensor = *multiply_result; - - // Verify result shape (should match the larger tensor) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract results - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: all elements multiplied by 2.0 - std::vector expected = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; - for (int i = 0; i < 6; i++) { - EXPECT_NEAR(output_data[i], expected[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendMultiplyEdgeCases) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test with zeros and special values - std::vector data_a = {0.0f, 1.0f, -1.0f, 2.0f}; - std::vector data_b = {1.0f, 0.0f, -1.0f, 0.5f}; - - auto tensor_a = backend->create_tensor(std::span(data_a), {4}); - auto tensor_b = backend->create_tensor(std::span(data_b), {4}); - - // Execute multiplication - auto multiply_result = backend->multiply(tensor_a, tensor_b); - ASSERT_TRUE(multiply_result) << "Edge case multiply failed: " << multiply_result.error().message; - auto result_tensor = *multiply_result; - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [0*1, 1*0, (-1)*(-1), 2*0.5] = [0.0, 0.0, 1.0, 1.0] - EXPECT_NEAR(output_data[0], 0.0f, 1e-6); // 0 * 1 = 0 - EXPECT_NEAR(output_data[1], 0.0f, 1e-6); // 1 * 0 = 0 - EXPECT_NEAR(output_data[2], 1.0f, 1e-6); // (-1) * (-1) = 1 - EXPECT_NEAR(output_data[3], 1.0f, 1e-6); // 2 * 0.5 = 1 - - // Verify all results are finite - for (int i = 0; i < 4; i++) { - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendDirectReshape) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test basic reshape: 1D vector to 2D matrix - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {6}); - - // Reshape from [6] to [2, 3] - auto reshape_result = backend->reshape(input_tensor, {2, 3}); - ASSERT_TRUE(reshape_result) << "Reshape failed: " << reshape_result.error().message; - auto result_tensor = *reshape_result; - - // Verify result shape - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Verify total elements unchanged - EXPECT_EQ(result_tensor.size(), 6); - - // Extract and verify data (should be unchanged, just reshaped) - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Data should be identical to input (reshape doesn't change values) - for (int i = 0; i < 6; i++) { - EXPECT_NEAR(output_data[i], input_data[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendReshapeMatrix) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test matrix reshape: 2x6 to 3x4 - std::vector input_data = { - 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, // row 1 - 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f // row 2 - }; - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 6}); - - // Reshape from [2, 6] to [3, 4] - auto reshape_result = backend->reshape(input_tensor, {3, 4}); - ASSERT_TRUE(reshape_result) << "Matrix reshape failed: " << reshape_result.error().message; - auto result_tensor = *reshape_result; - - // Verify result shape - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 3); - EXPECT_EQ(result_tensor.shape()[1], 4); - - // Verify total elements unchanged - EXPECT_EQ(result_tensor.size(), 12); - - // Extract and verify data preserved - std::vector output_data(12); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // All data should be preserved (just different arrangement) - for (int i = 0; i < 12; i++) { - EXPECT_NEAR(output_data[i], input_data[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendReshapeMultiDimensional) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test multi-dimensional reshape: simulate attention head reshaping - // [batch=2, seq_len=4, hidden_size=8] -> [batch=2, seq_len=4, num_heads=4, head_dim=2] - std::vector input_data(2 * 4 * 8); // 64 elements - for (int i = 0; i < 64; i++) { - input_data[i] = static_cast(i + 1); // Fill with 1, 2, 3, ..., 64 - } - - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 4, 8}); - - // Reshape for multi-head attention: [2, 4, 8] -> [2, 4, 4, 2] - auto reshape_result = backend->reshape(input_tensor, {2, 4, 4, 2}); - ASSERT_TRUE(reshape_result) << "Multi-dimensional reshape failed: " << reshape_result.error().message; - auto result_tensor = *reshape_result; - - // Verify result shape - ASSERT_EQ(result_tensor.shape().size(), 4); - EXPECT_EQ(result_tensor.shape()[0], 2); // batch - EXPECT_EQ(result_tensor.shape()[1], 4); // seq_len - EXPECT_EQ(result_tensor.shape()[2], 4); // num_heads - EXPECT_EQ(result_tensor.shape()[3], 2); // head_dim - - // Verify total elements unchanged - EXPECT_EQ(result_tensor.size(), 64); - - // Extract and verify data preserved - std::vector output_data(64); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // All data should be preserved - for (int i = 0; i < 64; i++) { - EXPECT_NEAR(output_data[i], input_data[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendReshapeErrorCases) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {4}); - - // Test error case: incompatible total elements - auto invalid_reshape_result = backend->reshape(input_tensor, {2, 3}); // 4 -> 6 elements - EXPECT_FALSE(invalid_reshape_result) << "Expected failure for incompatible element count"; - - // Test error case: zero dimension - auto zero_dim_result = backend->reshape(input_tensor, {2, 0}); - EXPECT_FALSE(zero_dim_result) << "Expected failure for zero dimension"; - - // Test valid case for comparison - auto valid_reshape_result = backend->reshape(input_tensor, {2, 2}); // 4 -> 4 elements - EXPECT_TRUE(valid_reshape_result) << "Valid reshape should succeed"; -} - -TEST(MLXBackendTest, MLXBackendReshapeToVector) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test flattening: matrix to vector - std::vector input_data = { - 1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f - }; - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - // Flatten to 1D vector - auto reshape_result = backend->reshape(input_tensor, {6}); - ASSERT_TRUE(reshape_result) << "Flatten reshape failed: " << reshape_result.error().message; - auto result_tensor = *reshape_result; - - // Verify result shape - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 6); - - // Extract and verify data preserved in row-major order - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Should be flattened in row-major order - std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - for (int i = 0; i < 6; i++) { - EXPECT_NEAR(output_data[i], expected[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendDirectConcatenate) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test basic concatenation along axis 0 (rows) - std::vector data_a = {1.0f, 2.0f, 3.0f}; - std::vector data_b = {4.0f, 5.0f, 6.0f}; - - auto tensor_a = backend->create_tensor(std::span(data_a), {1, 3}); - auto tensor_b = backend->create_tensor(std::span(data_b), {1, 3}); - - // Concatenate along axis 0 (stack rows) - std::vector tensors = {tensor_a, tensor_b}; - auto concat_result = backend->concatenate(tensors, 0); - ASSERT_TRUE(concat_result) << "Concatenate failed: " << concat_result.error().message; - auto result_tensor = *concat_result; - - // Verify result shape: [1,3] + [1,3] -> [2,3] along axis 0 - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract and verify data - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [1, 2, 3, 4, 5, 6] (tensor_a rows first, then tensor_b rows) - std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - for (int i = 0; i < 6; i++) { - EXPECT_NEAR(output_data[i], expected[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendConcatenateAxis1) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test concatenation along axis 1 (columns) - std::vector data_a = {1.0f, 2.0f, 3.0f, 4.0f}; // 2x2 matrix - std::vector data_b = {5.0f, 6.0f}; // 2x1 matrix - - auto tensor_a = backend->create_tensor(std::span(data_a), {2, 2}); - auto tensor_b = backend->create_tensor(std::span(data_b), {2, 1}); - - // Concatenate along axis 1 (columns) - std::vector tensors = {tensor_a, tensor_b}; - auto concat_result = backend->concatenate(tensors, 1); - ASSERT_TRUE(concat_result) << "Concatenate axis 1 failed: " << concat_result.error().message; - auto result_tensor = *concat_result; - - // Verify result shape: [2,2] + [2,1] -> [2,3] along axis 1 - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract and verify data - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [[1, 2, 5], [3, 4, 6]] -> [1, 2, 5, 3, 4, 6] - std::vector expected = {1.0f, 2.0f, 5.0f, 3.0f, 4.0f, 6.0f}; - for (int i = 0; i < 6; i++) { - EXPECT_NEAR(output_data[i], expected[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendConcatenateMultipleTensors) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test concatenating multiple tensors (simulate attention head combination) - std::vector head1_data = {1.0f, 2.0f}; - std::vector head2_data = {3.0f, 4.0f}; - std::vector head3_data = {5.0f, 6.0f}; - std::vector head4_data = {7.0f, 8.0f}; - - auto head1 = backend->create_tensor(std::span(head1_data), {1, 2}); - auto head2 = backend->create_tensor(std::span(head2_data), {1, 2}); - auto head3 = backend->create_tensor(std::span(head3_data), {1, 2}); - auto head4 = backend->create_tensor(std::span(head4_data), {1, 2}); - - // Concatenate all heads along axis 1 (feature dimension) - std::vector heads = {head1, head2, head3, head4}; - auto concat_result = backend->concatenate(heads, 1); - ASSERT_TRUE(concat_result) << "Multi-tensor concatenate failed: " << concat_result.error().message; - auto result_tensor = *concat_result; - - // Verify result shape: 4 x [1,2] -> [1,8] along axis 1 - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 1); - EXPECT_EQ(result_tensor.shape()[1], 8); - - // Extract and verify data - std::vector output_data(8); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [1, 2, 3, 4, 5, 6, 7, 8] (all heads concatenated) - std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; - for (int i = 0; i < 8; i++) { - EXPECT_NEAR(output_data[i], expected[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendConcatenateNegativeAxis) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test negative axis indexing - std::vector data_a = {1.0f, 2.0f}; - std::vector data_b = {3.0f, 4.0f}; - - auto tensor_a = backend->create_tensor(std::span(data_a), {1, 2}); - auto tensor_b = backend->create_tensor(std::span(data_b), {1, 2}); - - // Concatenate along axis -1 (last axis, equivalent to axis 1 for 2D) - std::vector tensors = {tensor_a, tensor_b}; - auto concat_result = backend->concatenate(tensors, -1); - ASSERT_TRUE(concat_result) << "Negative axis concatenate failed: " << concat_result.error().message; - auto result_tensor = *concat_result; - - // Verify result shape: [1,2] + [1,2] -> [1,4] along axis -1 (axis 1) - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 1); - EXPECT_EQ(result_tensor.shape()[1], 4); - - // Extract and verify data - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Expected: [1, 2, 3, 4] - std::vector expected = {1.0f, 2.0f, 3.0f, 4.0f}; - for (int i = 0; i < 4; i++) { - EXPECT_NEAR(output_data[i], expected[i], 1e-6) << "Mismatch at index " << i; - EXPECT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " is not finite: " << output_data[i]; - } -} - -TEST(MLXBackendTest, MLXBackendConcatenateErrorCases) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - std::vector data_a = {1.0f, 2.0f}; - std::vector data_b = {3.0f, 4.0f, 5.0f}; // Different shape - - auto tensor_a = backend->create_tensor(std::span(data_a), {1, 2}); - auto tensor_b = backend->create_tensor(std::span(data_b), {1, 3}); - - // Test error case: incompatible shapes - std::vector incompatible_tensors = {tensor_a, tensor_b}; - auto invalid_concat_result = backend->concatenate(incompatible_tensors, 0); - EXPECT_FALSE(invalid_concat_result) << "Expected failure for incompatible shapes"; - - // Test error case: empty tensor list - std::vector empty_tensors; - auto empty_concat_result = backend->concatenate(empty_tensors, 0); - EXPECT_FALSE(empty_concat_result) << "Expected failure for empty tensor list"; - - // Test error case: out of bounds axis - std::vector valid_tensors = {tensor_a}; - auto invalid_axis_result = backend->concatenate(valid_tensors, 5); - EXPECT_FALSE(invalid_axis_result) << "Expected failure for out of bounds axis"; - - // Test valid case for comparison - std::vector single_tensor = {tensor_a}; - auto single_concat_result = backend->concatenate(single_tensor, 0); - EXPECT_TRUE(single_concat_result) << "Single tensor concatenate should succeed"; -} - -// ============================================================================ -// Tests for Phase 3 Transformer Operations -// ============================================================================ - -TEST(MLXBackendTest, MLXBackendMean) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test mean along last axis: [[1, 2, 3], [4, 5, 6]] - // Mean along axis -1 (columns): [2, 5] - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 3}); - - auto mean_result = backend->mean(input_tensor, -1, false); - ASSERT_TRUE(mean_result); - auto result_tensor = *mean_result; - - // Verify shape: keepdims=false should reduce dimension - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 2); - - // Extract and verify results - std::vector output_data(2); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - EXPECT_NEAR(output_data[0], 2.0f, 1e-6); // (1+2+3)/3 = 2 - EXPECT_NEAR(output_data[1], 5.0f, 1e-6); // (4+5+6)/3 = 5 -} - -TEST(MLXBackendTest, MLXBackendMeanKeepDims) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {2, 2}); - - auto mean_result = backend->mean(input_tensor, 0, true); // keepdims=true - ASSERT_TRUE(mean_result); - auto result_tensor = *mean_result; - - // With keepdims=true, shape should be [1, 2] - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 1); - EXPECT_EQ(result_tensor.shape()[1], 2); - - std::vector output_data(2); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - EXPECT_NEAR(output_data[0], 2.0f, 1e-6); // (1+3)/2 = 2.0 - EXPECT_NEAR(output_data[1], 3.0f, 1e-6); // (2+4)/2 = 3.0 -} - -TEST(MLXBackendTest, MLXBackendRsqrt) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test rsqrt (reciprocal square root): 1/sqrt(x) - std::vector input_data = {1.0f, 4.0f, 9.0f, 16.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {4}); - - auto rsqrt_result = backend->rsqrt(input_tensor); - ASSERT_TRUE(rsqrt_result); - auto result_tensor = *rsqrt_result; - - // Verify shape unchanged - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 4); - - // Extract and verify: 1/sqrt(1)=1, 1/sqrt(4)=0.5, 1/sqrt(9)≈0.333, 1/sqrt(16)=0.25 - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - EXPECT_NEAR(output_data[0], 1.0f, 1e-6); - EXPECT_NEAR(output_data[1], 0.5f, 1e-6); - EXPECT_NEAR(output_data[2], 0.333333f, 1e-5); - EXPECT_NEAR(output_data[3], 0.25f, 1e-6); -} - -TEST(MLXBackendTest, MLXBackendSlice) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test slice: extract middle portion [1:3] along axis 0 - // Input: [[0, 1], [2, 3], [4, 5], [6, 7]] - std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {4, 2}); - - auto slice_result = backend->slice(input_tensor, 1, 3, 0); // rows 1 and 2 - ASSERT_TRUE(slice_result); - auto result_tensor = *slice_result; - - // Verify shape: should be [2, 2] - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 2); - EXPECT_EQ(result_tensor.shape()[1], 2); - - // Extract and verify: [[2, 3], [4, 5]] - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - EXPECT_NEAR(output_data[0], 2.0f, 1e-6); - EXPECT_NEAR(output_data[1], 3.0f, 1e-6); - EXPECT_NEAR(output_data[2], 4.0f, 1e-6); - EXPECT_NEAR(output_data[3], 5.0f, 1e-6); -} - -TEST(MLXBackendTest, MLXBackendRepeat) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test repeat: repeat elements along axis - // Input: [1, 2, 3] - std::vector input_data = {1.0f, 2.0f, 3.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {3}); - - auto repeat_result = backend->repeat(input_tensor, 2, 0); // repeat each element 2 times - ASSERT_TRUE(repeat_result); - auto result_tensor = *repeat_result; - - // Verify shape: should be [6] (3 * 2) - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 6); - - // Extract and verify: [1, 1, 2, 2, 3, 3] - std::vector output_data(6); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - EXPECT_NEAR(output_data[0], 1.0f, 1e-6); - EXPECT_NEAR(output_data[1], 1.0f, 1e-6); - EXPECT_NEAR(output_data[2], 2.0f, 1e-6); - EXPECT_NEAR(output_data[3], 2.0f, 1e-6); - EXPECT_NEAR(output_data[4], 3.0f, 1e-6); - EXPECT_NEAR(output_data[5], 3.0f, 1e-6); -} - -TEST(MLXBackendTest, MLXBackendTriu) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test triu (upper triangular): [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - // Expected: [[1, 2, 3], [0, 5, 6], [0, 0, 9]] - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {3, 3}); - - auto triu_result = backend->triu(input_tensor, 0); // main diagonal (k=0) - ASSERT_TRUE(triu_result); - auto result_tensor = *triu_result; - - // Verify shape unchanged - ASSERT_EQ(result_tensor.shape().size(), 2); - EXPECT_EQ(result_tensor.shape()[0], 3); - EXPECT_EQ(result_tensor.shape()[1], 3); - - // Extract and verify - std::vector output_data(9); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // Upper triangle should be preserved, lower should be 0 - EXPECT_NEAR(output_data[0], 1.0f, 1e-6); // [0,0] - EXPECT_NEAR(output_data[1], 2.0f, 1e-6); // [0,1] - EXPECT_NEAR(output_data[2], 3.0f, 1e-6); // [0,2] - EXPECT_NEAR(output_data[3], 0.0f, 1e-6); // [1,0] - zeroed - EXPECT_NEAR(output_data[4], 5.0f, 1e-6); // [1,1] - EXPECT_NEAR(output_data[5], 6.0f, 1e-6); // [1,2] - EXPECT_NEAR(output_data[6], 0.0f, 1e-6); // [2,0] - zeroed - EXPECT_NEAR(output_data[7], 0.0f, 1e-6); // [2,1] - zeroed - EXPECT_NEAR(output_data[8], 9.0f, 1e-6); // [2,2] -} - -TEST(MLXBackendTest, MLXBackendRMSNorm) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test RMSNorm: normalize input and scale by weight - // Input: [1, 2, 3, 4] (will be normalized) - // Weight: [1, 1, 1, 1] (no scaling) - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector weight_data = {1.0f, 1.0f, 1.0f, 1.0f}; - - auto input_tensor = backend->create_tensor(std::span(input_data), {4}); - auto weight_tensor = backend->create_tensor(std::span(weight_data), {4}); - - auto rms_norm_result = backend->rms_norm(input_tensor, weight_tensor, 1e-5f); - ASSERT_TRUE(rms_norm_result); - auto result_tensor = *rms_norm_result; - - // Verify shape unchanged - ASSERT_EQ(result_tensor.shape().size(), 1); - EXPECT_EQ(result_tensor.shape()[0], 4); - - // Extract results - std::vector output_data(4); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // RMSNorm should normalize: x / sqrt(mean(x^2) + eps) - // RMS = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.739 - // Normalized: [1/2.739, 2/2.739, 3/2.739, 4/2.739] - float expected_rms = std::sqrt((1.0f + 4.0f + 9.0f + 16.0f) / 4.0f + 1e-5f); - EXPECT_NEAR(output_data[0], 1.0f / expected_rms, 1e-4); - EXPECT_NEAR(output_data[1], 2.0f / expected_rms, 1e-4); - EXPECT_NEAR(output_data[2], 3.0f / expected_rms, 1e-4); - EXPECT_NEAR(output_data[3], 4.0f / expected_rms, 1e-4); -} - -TEST(MLXBackendTest, MLXBackendRoPE) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test RoPE (Rotary Position Embedding) - // RoPE requires at least 3D input: [batch, seq_len, features] - // Input: [1, 2, 4] - batch=1, seq_len=2, features=4 - std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; - auto input_tensor = backend->create_tensor(std::span(input_data), {1, 2, 4}); - - auto rope_result = backend->rope(input_tensor, 4, 10000.0f, 0); - ASSERT_TRUE(rope_result) << "RoPE failed: " << rope_result.error().message; - auto result_tensor = *rope_result; - - // Verify shape unchanged (should remain [1, 2, 4]) - ASSERT_EQ(result_tensor.shape().size(), 3); - EXPECT_EQ(result_tensor.shape()[0], 1); - EXPECT_EQ(result_tensor.shape()[1], 2); - EXPECT_EQ(result_tensor.shape()[2], 4); - - // Extract results - just verify shape and that operation succeeded - std::vector output_data(8); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // RoPE applies rotations, so output should be different from input - bool values_changed = false; - for (size_t i = 0; i < input_data.size(); ++i) { - if (std::abs(output_data[i] - input_data[i]) > 1e-6) { - values_changed = true; - break; - } - } - EXPECT_TRUE(values_changed) << "RoPE should modify input values"; -} - -TEST(MLXBackendTest, MLXBackendScaledDotProductAttention) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "MLX backend not available on this platform"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test scaled dot-product attention with simple inputs - // Q, K, V: [batch=1, seq_len=2, heads=1, head_dim=4] - std::vector q_data = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; - std::vector k_data = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; - std::vector v_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; - - auto q_tensor = backend->create_tensor(std::span(q_data), {1, 2, 1, 4}); - auto k_tensor = backend->create_tensor(std::span(k_data), {1, 2, 1, 4}); - auto v_tensor = backend->create_tensor(std::span(v_data), {1, 2, 1, 4}); - - float scale = 1.0f / std::sqrt(4.0f); // 1/sqrt(head_dim) - - auto attn_result = backend->scaled_dot_product_attention(q_tensor, k_tensor, v_tensor, scale, ""); - ASSERT_TRUE(attn_result); - auto result_tensor = *attn_result; - - // Verify shape: should match V shape [1, 2, 1, 4] - ASSERT_EQ(result_tensor.shape().size(), 4); - EXPECT_EQ(result_tensor.shape()[0], 1); - EXPECT_EQ(result_tensor.shape()[1], 2); - EXPECT_EQ(result_tensor.shape()[2], 1); - EXPECT_EQ(result_tensor.shape()[3], 4); - - // Extract results - std::vector output_data(8); - auto extract_result = backend->extract(result_tensor, std::span(output_data)); - ASSERT_TRUE(extract_result); - - // All values should be valid (not NaN or Inf) - for (float val : output_data) { - EXPECT_FALSE(std::isnan(val)) << "Output contains NaN"; - EXPECT_FALSE(std::isinf(val)) << "Output contains Inf"; - } -} - -TEST(MLXBackendTest, MLXBackendTopkIndices) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) GTEST_SKIP() << "MLX backend not available"; - auto& backend = *backend_result; - ASSERT_TRUE(backend->initialize()); - - // scores: [3.0, 1.0, 4.0, 1.5, 2.0] — top-2 should be indices 2 (4.0) and 0 (3.0) - std::vector data = {3.0f, 1.0f, 4.0f, 1.5f, 2.0f}; - auto t = backend->create_tensor(std::span(data), {5}); - auto idx_result = backend->topk_indices(t, 2, 0); - ASSERT_TRUE(idx_result); - EXPECT_EQ(idx_result->shape(), (std::vector{2})); - - std::vector idx_f(2); - ASSERT_TRUE(backend->extract(*idx_result, std::span(idx_f))); - std::vector idx = {(int)idx_f[0], (int)idx_f[1]}; - std::sort(idx.begin(), idx.end()); - EXPECT_EQ(idx[0], 0); - EXPECT_EQ(idx[1], 2); -} - -TEST(MLXBackendTest, MLXBackendTakeTensor) { - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) GTEST_SKIP() << "MLX backend not available"; - auto& backend = *backend_result; - ASSERT_TRUE(backend->initialize()); - - // source: [[1,2,3],[4,5,6],[7,8,9]] - std::vector data = {1.f,2.f,3.f, 4.f,5.f,6.f, 7.f,8.f,9.f}; - auto src = backend->create_tensor(std::span(data), {3, 3}); - - // Use topk_indices to get index 1 (score 9.0) then take that row via Tensor overload - std::vector scores = {1.f, 9.f, 3.f}; - auto s = backend->create_tensor(std::span(scores), {3}); - auto top1 = backend->topk_indices(s, 1, 0); - ASSERT_TRUE(top1); - auto picked = backend->take(src, *top1, 0); // row 1 = [4,5,6] - ASSERT_TRUE(picked); - std::vector row(3); - ASSERT_TRUE(backend->extract(*picked, std::span(row))); - EXPECT_FLOAT_EQ(row[0], 4.f); - EXPECT_FLOAT_EQ(row[1], 5.f); - EXPECT_FLOAT_EQ(row[2], 6.f); -} diff --git a/compute/tests/compute/test_model_loader.cpp b/compute/tests/compute/test_model_loader.cpp index 91a0427..3923ded 100644 --- a/compute/tests/compute/test_model_loader.cpp +++ b/compute/tests/compute/test_model_loader.cpp @@ -1,167 +1,19 @@ #include #include "compute/model/model_loader.h" -#include "compute/core/compute_backend.h" #include "test_config.h" #include -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) -#include "compute/backends/mlx/mlx_backend.h" -#endif - namespace compute { -// Create a minimal backend mock for testing -class MockBackend : public ComputeBackend { -public: - BackendType type() const override { return BackendType::MLX; } - std::string name() const override { return "Mock"; } - bool is_available() const override { return false; } - Result initialize() override { return {}; } - void cleanup() override {} - - // Create dummy tensors with nullptr buffer - these won't actually work but satisfy the interface - Tensor create_tensor(std::span data, std::vector shape) override { - return Tensor(nullptr, shape); - } - Tensor create_tensor(std::vector shape) override { - return Tensor(nullptr, shape); - } - Tensor wrap_native_tensor(void*, std::vector shape) override { - return Tensor(nullptr, shape); - } - - Tensor dot_product(const Tensor& a, const Tensor& b) override { - return Tensor(nullptr, {1}); - } - Tensor matrix_scalar_add(const Tensor& input, float) override { - return Tensor(nullptr, input.shape()); - } - - Result matmul(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result dequantize(const Tensor&, const Tensor&, const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result quantized_matmul(const Tensor&, const Tensor&, const Tensor&, const Tensor*, bool, int, int, const std::string&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result add(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result multiply(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result softmax(const Tensor&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result silu(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result gelu(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result sigmoid(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result softplus(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result exp(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result subtract(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result take(const Tensor&, const std::vector&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result take(const Tensor&, const Tensor&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result topk_indices(const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result conv1d(const Tensor&, const Tensor&, int, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result transpose(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result swapaxes(const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result reshape(const Tensor&, const std::vector&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result concatenate(const std::vector&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result mean(const Tensor&, int, bool) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result rsqrt(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result slice(const Tensor&, int, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result repeat(const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result triu(const Tensor&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result rms_norm(const Tensor&, const Tensor&, float) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result rope(const Tensor&, int, float, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result scaled_dot_product_attention(const Tensor&, const Tensor&, const Tensor&, float, const std::string&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - - Result extract(const Tensor&, std::span) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result evaluate_all() override { return {}; } - - std::unordered_map load_model(const std::string&) override { return {}; } -}; - class ModelLoaderTest : public ::testing::Test { protected: void SetUp() override { - // Use configured paths from CMake tinyllama_model_dir = TINYLLAMA_MODEL_DIR; - test_resources_dir = TEST_RESOURCES_DIR; - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - // Create MLX backend for testing if available - auto backend_result = BackendFactory::create(BackendType::MLX); - if (backend_result.has_value()) { - backend = std::move(*backend_result); - auto init_result = backend->initialize(); - mlx_available = init_result.has_value(); - } else { - mlx_available = false; - } -#else - mlx_available = false; -#endif - } - - void TearDown() override { - if (backend) { - backend->cleanup(); - } + test_resources_dir = TEST_RESOURCES_DIR; } std::filesystem::path tinyllama_model_dir; std::filesystem::path test_resources_dir; - std::unique_ptr backend; - bool mlx_available = false; }; TEST_F(ModelLoaderTest, LoadConfigOnly) { @@ -186,239 +38,13 @@ TEST_F(ModelLoaderTest, LoadConfigNonExistentDirectory) { } TEST_F(ModelLoaderTest, LoadConfigMissingConfigFile) { - // Create temporary directory without config.json auto temp_dir = std::filesystem::temp_directory_path() / "test_model_no_config"; std::filesystem::create_directories(temp_dir); auto result = ModelLoader::load_config(temp_dir); EXPECT_FALSE(result.has_value()); - // Cleanup - std::filesystem::remove_all(temp_dir); -} - -TEST_F(ModelLoaderTest, FindSafetensorsFiles) { - if (!std::filesystem::exists(tinyllama_model_dir)) - GTEST_SKIP() << "Model not found: " << tinyllama_model_dir; - - // This tests file discovery indirectly by checking that load_model finds the files - // We'll use a mock backend that claims to be something other than MLX - class NonMLXMockBackend : public ComputeBackend { - public: - BackendType type() const override { return BackendType::Auto; } // Not MLX - std::string name() const override { return "NonMLXMock"; } - bool is_available() const override { return false; } - Result initialize() override { return {}; } - void cleanup() override {} - - Tensor create_tensor(std::span, std::vector shape) override { - return Tensor(nullptr, shape); - } - Tensor create_tensor(std::vector shape) override { - return Tensor(nullptr, shape); - } - Tensor wrap_native_tensor(void*, std::vector shape) override { - return Tensor(nullptr, shape); - } - Tensor dot_product(const Tensor&, const Tensor&) override { - return Tensor(nullptr, {1}); - } - Tensor matrix_scalar_add(const Tensor&, float) override { - return Tensor(nullptr, {1}); - } - - Result matmul(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result dequantize(const Tensor&, const Tensor&, const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result quantized_matmul(const Tensor&, const Tensor&, const Tensor&, const Tensor*, bool, int, int, const std::string&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result add(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result multiply(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result softmax(const Tensor&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result silu(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result gelu(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result sigmoid(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result softplus(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result exp(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result subtract(const Tensor&, const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result take(const Tensor&, const std::vector&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result take(const Tensor&, const Tensor&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result topk_indices(const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result conv1d(const Tensor&, const Tensor&, int, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result transpose(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result swapaxes(const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result reshape(const Tensor&, const std::vector&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result concatenate(const std::vector&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result mean(const Tensor&, int, bool) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result rsqrt(const Tensor&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result slice(const Tensor&, int, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result repeat(const Tensor&, int, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result triu(const Tensor&, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result rms_norm(const Tensor&, const Tensor&, float) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result rope(const Tensor&, int, float, int) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result scaled_dot_product_attention(const Tensor&, const Tensor&, const Tensor&, float, const std::string&) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result extract(const Tensor&, std::span) override { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, "Mock"}); - } - Result evaluate_all() override { return {}; } - std::unordered_map load_model(const std::string&) override { return {}; } - }; - - NonMLXMockBackend mock_backend; - - // This should fail because ModelLoader requires MLX backend for safetensors - auto result = ModelLoader::load_model(tinyllama_model_dir, &mock_backend); - - // Should fail because non-MLX backend can't load safetensors - EXPECT_FALSE(result.has_value()); - - // But error should NOT be about missing safetensors files - should be about backend type - EXPECT_EQ(result.error().message.find("No .safetensors files found"), std::string::npos); - EXPECT_NE(result.error().message.find("MLX backend"), std::string::npos); -} - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - -TEST_F(ModelLoaderTest, LoadCompleteModelWithMLX) { - if (!mlx_available) { - GTEST_SKIP() << "MLX backend not available"; - } - - if (!std::filesystem::exists(tinyllama_model_dir)) - GTEST_SKIP() << "Model not found: " << tinyllama_model_dir; - - ASSERT_NE(backend, nullptr) << "MLX backend not initialized"; - - auto result = ModelLoader::load_model(tinyllama_model_dir, backend.get()); - ASSERT_TRUE(result.has_value()) - << "Failed to load model: " << result.error().message; - - const auto& [config, tensors] = *result; - - // Verify config is correct - EXPECT_EQ(config.vocab_size, 32000); - EXPECT_EQ(config.hidden_size, 2048); - EXPECT_TRUE(config.is_valid()); - - // Verify tensors were loaded - EXPECT_FALSE(tensors.empty()) << "No tensors loaded"; - - // Check for expected tensor patterns (basic smoke test) - bool found_embedding = false; - bool found_attention = false; - bool found_mlp = false; - bool found_norm = false; - - for (const auto& [name, tensor] : tensors) { - if (name.find("embed_tokens") != std::string::npos) { - found_embedding = true; - } - if (name.find("self_attn") != std::string::npos) { - found_attention = true; - } - if (name.find("mlp") != std::string::npos) { - found_mlp = true; - } - if (name.find("norm") != std::string::npos) { - found_norm = true; - } - } - - EXPECT_TRUE(found_embedding) << "No embedding tensors found"; - EXPECT_TRUE(found_attention) << "No attention tensors found"; - EXPECT_TRUE(found_mlp) << "No MLP tensors found"; - EXPECT_TRUE(found_norm) << "No normalization tensors found"; - - std::cout << "Successfully loaded " << tensors.size() << " tensors" << std::endl; -} - -#endif - -TEST_F(ModelLoaderTest, LoadModelNullBackend) { - auto result = ModelLoader::load_model(tinyllama_model_dir, nullptr); - EXPECT_FALSE(result.has_value()); - EXPECT_NE(result.error().message.find("Backend cannot be null"), std::string::npos); -} - -TEST_F(ModelLoaderTest, LoadModelNonExistentDirectory) { - MockBackend mock_backend; - - auto result = ModelLoader::load_model("/nonexistent/directory", &mock_backend); - EXPECT_FALSE(result.has_value()); - EXPECT_NE(result.error().message.find("does not exist"), std::string::npos); -} - -TEST_F(ModelLoaderTest, LoadModelDirectoryWithoutSafetensors) { - MockBackend mock_backend; - - // Create temporary directory with only config.json - auto temp_dir = std::filesystem::temp_directory_path() / "test_model_no_safetensors"; - std::filesystem::create_directories(temp_dir); - - // Copy config file - std::filesystem::copy_file(test_resources_dir / "tinyllama_config.json", - temp_dir / "config.json"); - - auto result = ModelLoader::load_model(temp_dir, &mock_backend); - EXPECT_FALSE(result.has_value()); - EXPECT_NE(result.error().message.find("No .safetensors files found"), std::string::npos); - - // Cleanup std::filesystem::remove_all(temp_dir); } -} // namespace compute \ No newline at end of file +} // namespace compute diff --git a/compute/tests/compute/test_symbolic_api.cpp b/compute/tests/compute/test_symbolic_api.cpp deleted file mode 100644 index cb33b16..0000000 --- a/compute/tests/compute/test_symbolic_api.cpp +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include - -#include "compute/core/compute_types.h" -#include "compute/core/compute_backend.h" -#include "compute/core/graph.h" - -using namespace compute; - -TEST(SymbolicApiTest, TensorBasicFunctionality) { - std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; - - // Get a backend to create tensors - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "No backend available for testing"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - auto tensor = backend->create_tensor(std::span(data), {2, 2}); - - EXPECT_EQ(tensor.size(), 4); - EXPECT_EQ(tensor.shape().size(), 2); - EXPECT_EQ(tensor.shape()[0], 2); - EXPECT_EQ(tensor.shape()[1], 2); - - auto tensor_data = tensor.data(); - EXPECT_EQ(tensor_data[0], 1.0f); - EXPECT_EQ(tensor_data[3], 4.0f); - - backend->cleanup(); -} - -TEST(SymbolicApiTest, SymbolicScalarCreation) { - auto builder = ComputeGraphBuilder(); - - std::vector vec_a = {1.0f, 2.0f, 3.0f}; - std::vector vec_b = {4.0f, 5.0f, 6.0f}; - - // Get a backend to create tensors - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "No backend available for testing"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - auto tensor_a = backend->create_tensor(std::span(vec_a), {3}); - auto tensor_b = backend->create_tensor(std::span(vec_b), {3}); - - float result; - auto symbolic = builder.dot_product(tensor_a, tensor_b, std::span(&result, 1)); - - // Test symbolic scalar properties - EXPECT_EQ(symbolic.node_id(), 0); // First node should have ID 0 - - backend->cleanup(); -} - -TEST(SymbolicApiTest, ComputationGraphBuildingWithDependencies) { - auto builder = ComputeGraphBuilder(); - - // Get a backend to create tensors - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "No backend available for testing"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Create test data - std::vector vec_a = {1.0f, 2.0f}; - std::vector vec_b = {3.0f, 4.0f}; - std::vector matrix_data = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector output_data(4); - - auto tensor_a = backend->create_tensor(std::span(vec_a), {2}); - auto tensor_b = backend->create_tensor(std::span(vec_b), {2}); - auto matrix = backend->create_tensor(std::span(matrix_data), {2, 2}); - - // Build computation graph with dependencies - float scalar_result; - auto dot_result = builder.dot_product(tensor_a, tensor_b, std::span(&scalar_result, 1)); - - // This should create a dependency on the dot_product result - builder.matrix_scalar_add(matrix, dot_result, std::span(output_data), {2, 2}); - - // Test that we can build the graph without errors - auto graph_ptr = builder.build(); - EXPECT_NE(graph_ptr, nullptr); - - backend->cleanup(); -} - -TEST(SymbolicApiTest, MultipleOperationTypesInGraph) { - auto builder = ComputeGraphBuilder(); - - // Get a backend to create tensors - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "No backend available for testing"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - // Test data - std::vector vec_data = {2.0f, 3.0f, 4.0f}; - std::vector matrix_data = {1.0f, 1.0f, 1.0f, 1.0f}; - std::vector output1(4), output2(4); - - auto vector = backend->create_tensor(std::span(vec_data), {3}); - auto matrix = backend->create_tensor(std::span(matrix_data), {2, 2}); - - // Build complex graph - float scalar1, scalar2; - auto dot1 = builder.dot_product(vector, vector, std::span(&scalar1, 1)); - auto dot2 = builder.dot_product(vector, vector, std::span(&scalar2, 1)); - - builder.matrix_scalar_add(matrix, dot1, std::span(output1), {2, 2}); - builder.matrix_scalar_add(matrix, dot2, std::span(output2), {2, 2}); - - // Should build successfully - auto graph_ptr = builder.build(); - EXPECT_NE(graph_ptr, nullptr); - - backend->cleanup(); -} - -TEST(SymbolicApiTest, ConvenienceOverloadsWorkCorrectly) { - auto builder = ComputeGraphBuilder(); - - // Get a backend to create tensors - auto backend_result = BackendFactory::create(BackendType::MLX); - if (!backend_result) { - GTEST_SKIP() << "No backend available for testing"; - } - - auto& backend = *backend_result; - auto init_result = backend->initialize(); - ASSERT_TRUE(init_result); - - std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector output(4); - - auto tensor = backend->create_tensor(std::span(data), {2, 2}); - - // Test immediate value overload - builder.matrix_scalar_add(tensor, 5.0f, std::span(output), {2, 2}); - - // Test symbolic reference overload - float scalar_result; - auto symbolic = builder.dot_product(tensor, tensor, std::span(&scalar_result, 1)); - builder.matrix_scalar_add(tensor, symbolic, std::span(output), {2, 2}); - - // Should compile and build without errors - auto graph_ptr = builder.build(); - EXPECT_NE(graph_ptr, nullptr); - - backend->cleanup(); -} \ No newline at end of file From 6382dfa0fbf67e5c6f21845ea0a3e3c6a2f09216 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 16:05:50 -0400 Subject: [PATCH 4/5] chore: strip ComputeBackend to lifecycle handle; remove Tensor math methods ComputeBackend is now a pure lifecycle abstraction: type(), name(), is_available(), initialize(), cleanup(). All ~40 Tensor-based math methods (matmul, dequantize, rope, softmax, sdpa, etc.) are removed from the interface and MlxBackend. GemmaModelMLX and Qwen3MoeModelMLX no longer inherit from their Tensor-based base classes; config_ and tokenizer_ are owned directly. ModelLoader no longer exposes load_model() or load_all_safetensors(). --- .../src/compute/backends/mlx/mlx_backend.cpp | 754 +----------------- .../src/compute/backends/mlx/mlx_backend.h | 85 +- compute/src/compute/core/compute_backend.h | 307 +------ compute/src/compute/core/compute_types.h | 52 +- compute/src/compute/model/gemma_model_mlx.cpp | 3 +- compute/src/compute/model/gemma_model_mlx.h | 5 +- compute/src/compute/model/language_model.cpp | 22 +- compute/src/compute/model/llama_model.h | 67 +- compute/src/compute/model/model_loader.cpp | 118 +-- compute/src/compute/model/model_loader.h | 50 +- .../src/compute/model/qwen3_moe_model_mlx.cpp | 3 +- .../src/compute/model/qwen3_moe_model_mlx.h | 5 +- 12 files changed, 58 insertions(+), 1413 deletions(-) diff --git a/compute/src/compute/backends/mlx/mlx_backend.cpp b/compute/src/compute/backends/mlx/mlx_backend.cpp index 23baf5b..24bd939 100644 --- a/compute/src/compute/backends/mlx/mlx_backend.cpp +++ b/compute/src/compute/backends/mlx/mlx_backend.cpp @@ -1,13 +1,10 @@ #include "mlx_backend.h" -#include "../../model/model_loader.h" -#include #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) namespace compute { -MLXBackend::MLXBackend() : m_initialized(false) { -} +MLXBackend::MLXBackend() : m_initialized(false) {} MLXBackend::~MLXBackend() { cleanup(); @@ -22,8 +19,6 @@ std::string MLXBackend::name() const { } bool MLXBackend::is_available() const { - // Simple check - just try to create a basic MLX array - // If MLX is working, this should succeed try { auto test_array = mx::array({1.0f}); return true; @@ -33,761 +28,32 @@ bool MLXBackend::is_available() const { } Result MLXBackend::initialize() { - if (m_initialized) { - return {}; - } - - if (!is_available()) { - return std::unexpected(Error{ErrorCode::BackendNotAvailable, - "MLX backend not available"}); - } - + if (m_initialized) return {}; + + if (!is_available()) + return std::unexpected(Error{ErrorCode::BackendNotAvailable, "MLX backend not available"}); + try { - // Set MLX to use GPU by default mx::set_default_device(mx::Device::gpu); - - // Test basic functionality auto test = mx::array({1.0f, 2.0f}); mx::eval(test); - m_initialized = true; return {}; } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX initialization failed: ") + e.what()}); + return std::unexpected(Error{ErrorCode::ComputeError, + std::string("MLX initialization failed: ") + e.what()}); } } void MLXBackend::cleanup() { if (m_initialized) { - // Clear MLX memory cache mx::clear_cache(); m_initialized = false; } } -// Tensor creation -Tensor MLXBackend::create_tensor(std::span data, std::vector shape) { - auto buffer = std::make_shared(data, shape); - return Tensor(buffer, shape); -} - -Tensor MLXBackend::create_tensor(std::vector shape) { - auto buffer = std::make_shared(shape); - return Tensor(buffer, shape); -} - -Tensor MLXBackend::wrap_native_tensor(void* native_tensor, std::vector shape) { - // For MLX backend, native tensor is mx::array* - auto* array_ptr = static_cast(native_tensor); - auto buffer = std::make_shared(*array_ptr); - return Tensor(buffer, shape); -} - -// Core operations -Tensor MLXBackend::dot_product(const Tensor& a, const Tensor& b) { - // Validate inputs - if (a.size() != b.size()) { - throw std::runtime_error("MLX dot_product: Vector sizes must match"); - } - - if (a.shape().size() != 1 || b.shape().size() != 1) { - throw std::runtime_error("MLX dot_product: Both inputs must be vectors"); - } - - // Use MLX inner product for vectors - creates lazy computation graph - mx::array result = mx::inner(a.to_mlx(), b.to_mlx()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, {1}); -} - -Tensor MLXBackend::matrix_scalar_add(const Tensor& input, float scalar) { - // Use MLX broadcasting for scalar addition - creates lazy computation graph - mx::array result = input.to_mlx() + scalar; - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, input.shape()); -} - -// Core operations with proper error handling -Result MLXBackend::matmul(const Tensor& a, const Tensor& b) { - if (a.shape().size() < 2 || b.shape().size() < 2) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX matmul: inputs must be at least 2D"}); - } - - // Check inner dimension compatibility (last of a vs second-to-last of b) - const auto& sa = a.shape(); - const auto& sb = b.shape(); - if (sa[sa.size() - 1] != sb[sb.size() - 2]) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX matmul: incompatible inner dimensions"}); - } - - try { - mx::array result = mx::matmul(a.to_mlx(), b.to_mlx()); - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX matmul failed: ") + e.what()}); - } -} - -Result MLXBackend::dequantize( - const Tensor& w, - const Tensor& scales, - const Tensor& biases, - int group_size, - int bits -) { - if (w.backend_type() != BackendType::MLX || - scales.backend_type() != BackendType::MLX || - biases.backend_type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "MLX dequantize: all tensors must be from MLX backend"}); - } - - try { - mx::array result = mx::dequantize( - w.to_mlx(), scales.to_mlx(), biases.to_mlx(), group_size, bits); - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX dequantize failed: ") + e.what()}); - } -} - -Result MLXBackend::quantized_matmul( - const Tensor& x, - const Tensor& w, - const Tensor& scales, - const Tensor* biases, - bool transpose, - int group_size, - int bits, - const std::string& mode -) { - // Validate input tensors are from MLX backend - if (x.backend_type() != BackendType::MLX || - w.backend_type() != BackendType::MLX || - scales.backend_type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "MLX quantized_matmul: all tensors must be from MLX backend"}); - } - - try { - // Prepare optional biases for MLX call - std::optional mlx_biases = std::nullopt; - if (biases != nullptr && biases->backend_type() == BackendType::MLX) { - mlx_biases = biases->to_mlx(); - } - - // Call MLX quantized matrix multiplication - MLX validates all parameters - mx::array result = mx::quantized_matmul( - x.to_mlx(), - w.to_mlx(), - scales.to_mlx(), - mlx_biases, - transpose, - group_size, - bits, - mode - ); - - // Get result shape from MLX array - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX quantized_matmul failed: ") + e.what()}); - } -} - -Result MLXBackend::add(const Tensor& a, const Tensor& b) { - // Validate tensors have compatible shapes for broadcasting - if (a.backend_type() != BackendType::MLX || b.backend_type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX add: tensors must be from MLX backend"}); - } - - try { - mx::array result = a.to_mlx() + b.to_mlx(); - - // Get result shape from MLX array - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX add failed: ") + e.what()}); - } -} - -Result MLXBackend::multiply(const Tensor& a, const Tensor& b) { - VALIDATE_MLX_TENSOR(a, b); - - return mlx_utils::mlx_tensor_op(mlx_utils::broadcast_shape(a.shape(), b.shape()), [&]() { - auto mlx_a = mlx_utils::to_mlx_auto(a); - auto mlx_b = mlx_utils::to_mlx_auto(b); - return mlx_a * mlx_b; // Element-wise multiplication - }); -} - -Result MLXBackend::softmax(const Tensor& input, int dim) { - VALIDATE_MLX_TENSOR(input); - - return mlx_utils::mlx_tensor_op(input.shape(), [&]() { - auto mlx_input = mlx_utils::to_mlx_auto(input); - return (dim == -1) - ? mx::softmax(mlx_input, static_cast(input.shape().size()) - 1, true) - : mx::softmax(mlx_input, dim, true); - }); -} - -Result MLXBackend::silu(const Tensor& input) { - VALIDATE_MLX_TENSOR(input); - - return mlx_utils::mlx_tensor_op(input.shape(), [&]() { - auto mlx_input = mlx_utils::to_mlx_auto(input); - // SiLU = x * sigmoid(x) - return mlx_input * mx::sigmoid(mlx_input); - }); -} - -Result MLXBackend::gelu(const Tensor& input) { - VALIDATE_MLX_TENSOR(input); - - return mlx_utils::mlx_tensor_op(input.shape(), [&]() { - auto mlx_input = mlx_utils::to_mlx_auto(input); - // GELU tanh approximation (gelu_pytorch_tanh) used by Gemma GeGLU: - // 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) - const float c = 0.7978845608028654f; // sqrt(2 / pi) - const float k = 0.044715f; - auto x3 = mlx_input * mlx_input * mlx_input; - auto inner = mlx_input + k * x3; - auto tanh_part = mx::tanh(c * inner); - return 0.5f * mlx_input * (1.0f + tanh_part); - }); -} - -Result MLXBackend::sigmoid(const Tensor& input) { - VALIDATE_MLX_TENSOR(input); - - return mlx_utils::mlx_tensor_op(input.shape(), [&]() { - auto mlx_input = mlx_utils::to_mlx_auto(input); - return mx::sigmoid(mlx_input); - }); -} - -Result MLXBackend::softplus(const Tensor& input) { - VALIDATE_MLX_TENSOR(input); - return mlx_utils::mlx_tensor_op(input.shape(), [&]() { - auto x = mlx_utils::to_mlx_auto(input); - return mx::log(mx::ones_like(x) + mx::exp(x)); - }); -} - -Result MLXBackend::exp(const Tensor& input) { - VALIDATE_MLX_TENSOR(input); - return mlx_utils::mlx_tensor_op(input.shape(), [&]() { - return mx::exp(mlx_utils::to_mlx_auto(input)); - }); -} - -Result MLXBackend::subtract(const Tensor& a, const Tensor& b) { - VALIDATE_MLX_TENSOR(a, b); - return mlx_utils::mlx_tensor_op(mlx_utils::broadcast_shape(a.shape(), b.shape()), [&]() { - return mlx_utils::to_mlx_auto(a) - mlx_utils::to_mlx_auto(b); - }); -} - -Result MLXBackend::conv1d( - const Tensor& input, const Tensor& weight, - int stride, int padding, int groups) -{ - VALIDATE_MLX_TENSOR(input); - VALIDATE_MLX_TENSOR(weight); - - auto mlx_input = mlx_utils::to_mlx_auto(input); - auto mlx_weight = mlx_utils::to_mlx_auto(weight); - - // mx::conv1d expects [N, L, C_in]; add batch dim if missing. - bool added_batch = false; - if (mlx_input.ndim() == 2) { - mlx_input = mx::expand_dims(mlx_input, 0); - added_batch = true; - } - - auto out_shape = input.shape(); // placeholder — recomputed after call - return mlx_utils::mlx_tensor_op(out_shape, [&]() { - auto out = mx::conv1d(mlx_input, mlx_weight, stride, padding, /*dilation=*/1, groups); - if (added_batch) out = mx::squeeze(out, 0); - return out; - }); -} - -Result MLXBackend::transpose(const Tensor& input) { - if (input.backend_type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX transpose: tensor must be from MLX backend"}); - } - - try { - // General transpose (reverses dimension order) - mx::array result = mx::transpose(input.to_mlx()); - - // Get result shape from MLX array - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX transpose failed: ") + e.what()}); - } -} - -Result MLXBackend::swapaxes(const Tensor& input, int axis1, int axis2) { - if (input.backend_type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX swapaxes: tensor must be from MLX backend"}); - } - - const auto& shape = input.shape(); - int ndim = static_cast(shape.size()); - - // Normalize negative indices - if (axis1 < 0) axis1 += ndim; - if (axis2 < 0) axis2 += ndim; - - // Validate axes - if (axis1 < 0 || axis1 >= ndim || axis2 < 0 || axis2 >= ndim) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX swapaxes: axis out of bounds"}); - } - - try { - // Use MLX swapaxes for efficient axis swapping - mx::array result = mx::swapaxes(input.to_mlx(), axis1, axis2); - - // Get result shape from MLX array - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX swapaxes failed: ") + e.what()}); - } -} - -Result MLXBackend::reshape(const Tensor& input, const std::vector& new_shape) { - VALIDATE_MLX_TENSOR(input); - - // Validate that total number of elements remains the same - size_t input_size = input.size(); - size_t new_size = 1; - for (size_t dim : new_shape) { - if (dim == 0) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX reshape: shape dimensions cannot be zero"}); - } - new_size *= dim; - } - - if (input_size != new_size) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "MLX reshape: total elements must remain the same (input: " + - std::to_string(input_size) + ", new: " + std::to_string(new_size) + ")"}); - } - - return mlx_utils::mlx_tensor_op(new_shape, [&]() { - auto mlx_input = mlx_utils::to_mlx_auto(input); - - // Convert size_t vector to MLX Shape (SmallVector) - mx::Shape mlx_shape; - mlx_shape.reserve(new_shape.size()); - for (size_t dim : new_shape) { - mlx_shape.push_back(static_cast(dim)); - } - - return mx::reshape(mlx_input, mlx_shape); - }); -} - -Result MLXBackend::concatenate(const std::vector& tensors, int axis) { - // Validate input - if (tensors.empty()) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX concatenate: empty tensor list"}); - } - - // Validate all tensors are from MLX backend - for (const auto& tensor : tensors) { - if (auto error = mlx_utils::validate_single_mlx_tensor(tensor)) { - return std::unexpected(*error); - } - } - - // Get reference shape for validation - const auto& ref_shape = tensors[0].shape(); - int ndim = static_cast(ref_shape.size()); - - // Normalize negative axis - if (axis < 0) { - axis += ndim; - } - - // Validate axis bounds - if (axis < 0 || axis >= ndim) { - return std::unexpected(Error{ErrorCode::InvalidInput, "MLX concatenate: axis out of bounds"}); - } - - if (tensors.size() == 1) { - // Single tensor - just return a copy (after validation) - return tensors[0]; - } - - // Validate all tensors have compatible shapes - for (size_t i = 1; i < tensors.size(); ++i) { - const auto& current_shape = tensors[i].shape(); - - if (current_shape.size() != ref_shape.size()) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "MLX concatenate: all tensors must have same number of dimensions"}); - } - - // Check all dimensions except the concatenation axis - for (int dim = 0; dim < ndim; ++dim) { - if (dim != axis && current_shape[dim] != ref_shape[dim]) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "MLX concatenate: tensor shapes must match except along concatenation axis"}); - } - } - } - - // Compute output shape - std::vector output_shape = ref_shape; - for (size_t i = 1; i < tensors.size(); ++i) { - output_shape[axis] += tensors[i].shape()[axis]; - } - - return mlx_utils::mlx_tensor_op(output_shape, [&]() { - // Convert tensors to MLX arrays - std::vector mlx_arrays; - mlx_arrays.reserve(tensors.size()); - for (const auto& tensor : tensors) { - mlx_arrays.push_back(mlx_utils::to_mlx_auto(tensor)); - } - - return mx::concatenate(mlx_arrays, axis); - }); -} - -// Utility operations -Result MLXBackend::extract(const Tensor& tensor, std::span output) { - if (tensor.backend_type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, "Tensor not from MLX backend"}); - } - - if (output.size() != tensor.size()) { - return std::unexpected(Error{ErrorCode::InvalidInput, "Output buffer size mismatch"}); - } - - const float* data = tensor.data_f32(); - std::copy(data, data + tensor.size(), output.begin()); - - return {}; -} - -Result MLXBackend::evaluate_all() { - return {}; // MLX operations are lazy by default -} - -std::unordered_map MLXBackend::load_model(const std::string& path) { - try { - // Use ModelLoader to load the complete model - auto result = ModelLoader::load_model(std::filesystem::path(path), this); - - if (!result) { - throw std::runtime_error("Failed to load model: " + result.error().message); - } - - // Extract tensor map from the result pair (config, tensors) - return std::move(result->second); - } catch (const std::exception& e) { - throw std::runtime_error("MLX load_model failed: " + std::string(e.what())); - } -} - -size_t MLXBackend::preferred_batch_size() const { - return 2048; // MLX optimized batch size -} - -bool MLXBackend::supports_async() const { - return true; // MLX supports async operations -} - -Result MLXBackend::mean(const Tensor& input, int axis, bool keepdims) { - try { - mx::array input_array = input.to_mlx(); - mx::array result = mx::mean(input_array, axis, keepdims); - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX mean failed: ") + e.what()}); - } -} - -Result MLXBackend::rsqrt(const Tensor& input) { - try { - mx::array input_array = input.to_mlx(); - mx::array result = mx::rsqrt(input_array); - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX rsqrt failed: ") + e.what()}); - } -} - -Result MLXBackend::slice(const Tensor& input, int start, int stop, int axis) { - try { - mx::array input_array = input.to_mlx(); - - // Use split to extract slice along axis, then take first piece - std::vector splits = mx::split(input_array, {start, stop}, axis); - mx::array result = splits[1]; // Middle piece is what we want - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX slice failed: ") + e.what()}); - } -} - -Result MLXBackend::repeat(const Tensor& input, int repeats, int axis) { - try { - mx::array input_array = input.to_mlx(); - mx::array result = mx::repeat(input_array, repeats, axis); - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX repeat failed: ") + e.what()}); - } -} - -Result MLXBackend::triu(const Tensor& input, int k) { - try { - mx::array input_array = input.to_mlx(); - mx::array result = mx::triu(input_array, k); - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX triu failed: ") + e.what()}); - } -} - -Result MLXBackend::take(const Tensor& input, const std::vector& indices, int axis) { - VALIDATE_MLX_TENSOR(input); - try { - mx::array idx = mx::array(indices.data(), {static_cast(indices.size())}, mx::int32); - mx::array result = mx::take(mlx_utils::to_mlx_auto(input), idx, axis); - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - return Tensor(std::make_shared(result), result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX take failed: ") + e.what()}); - } -} - -Result MLXBackend::take(const Tensor& input, const Tensor& indices, int axis) { - VALIDATE_MLX_TENSOR(input); - VALIDATE_MLX_TENSOR(indices); - try { - mx::array result = mx::take(mlx_utils::to_mlx_auto(input), - mlx_utils::to_mlx_auto(indices), axis); - auto sh = result.shape(); - std::vector shape(sh.begin(), sh.end()); - return Tensor(std::make_shared(result), shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX take (tensor idx) failed: ") + e.what()}); - } -} - -Result MLXBackend::topk_indices(const Tensor& input, int k, int axis) { - VALIDATE_MLX_TENSOR(input); - try { - auto x = mlx_utils::to_mlx_auto(input); - // argsort of negated values → indices in descending order - auto sorted_idx = mx::argsort(mx::negative(x), axis); - // Select first k indices along the axis using a range index array - std::vector range(k); - std::iota(range.begin(), range.end(), 0); - mx::array range_arr(range.data(), {k}, mx::int32); - auto topk_idx = mx::take(sorted_idx, range_arr, axis); - auto sh = topk_idx.shape(); - std::vector shape(sh.begin(), sh.end()); - return Tensor(std::make_shared(topk_idx), shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX topk_indices failed: ") + e.what()}); - } -} - -Result MLXBackend::rms_norm(const Tensor& input, const Tensor& weight, float eps) { - try { - mx::array input_array = input.to_mlx(); - mx::array weight_array = weight.to_mlx(); - - // Use MLX optimized fused RMSNorm - mx::array result = mx::fast::rms_norm(input_array, weight_array, eps); - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX rms_norm failed: ") + e.what()}); - } -} - -Result MLXBackend::rope(const Tensor& input, int dims, float theta, int offset) { - try { - mx::array input_array = input.to_mlx(); - - // MLX fast::rope Metal kernel dispatches with: - // B = shape[0], T = shape[-2], N = product(shape[1..ndim-3]) - // For 3D input [n_heads, seq, head_dim], the kernel sees B=n_heads, T=seq, N=1. - // When seq=1 (decode), the "single" fast path activates with grid=(dims/2, N=1, 1), - // processing only N=1 head instead of n_heads — corrupting all heads except the first. - // - // Fix: always present rope with a 4D tensor [1, n_heads, seq, head_dim], matching - // the Python mlx-lm convention. This makes N=n_heads so all heads are processed. - bool added_batch_dim = false; - if (input_array.ndim() == 3) { - auto s = input_array.shape(); - input_array = mx::reshape(input_array, {1, s[0], s[1], s[2]}); - added_batch_dim = true; - } - - // MLX fast::rope parameters: - // - dims: number of dimensions to apply RoPE to - // - traditional: false for modern RoPE formulation (same as Python) - // - base: theta value (optional, default 10000.0) - // - scale: scaling factor (default 1.0) - // - offset: position offset - mx::array result = mx::fast::rope( - input_array, - dims, - false, // modern RoPE (not traditional) - std::optional(theta), // base frequency - 1.0f, // scale - offset // position offset - ); - - // Remove the batch dimension we added - if (added_batch_dim) { - result = mx::squeeze(result, 0); - } - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX rope failed: ") + e.what()}); - } -} - -Result MLXBackend::scaled_dot_product_attention( - const Tensor& queries, - const Tensor& keys, - const Tensor& values, - float scale, - const std::string& mask) { - try { - mx::array q_array = queries.to_mlx(); - mx::array k_array = keys.to_mlx(); - mx::array v_array = values.to_mlx(); - - // MLX fast::scaled_dot_product_attention expects rank 4: [B, n_heads, seq, head_dim] - // If inputs are rank 3 [n_heads, seq, head_dim], add batch dimension via reshape - // (not expand_dims, which creates a non-contiguous view with stride=1 at dim 0 — - // that fails the sdpa_vector q_copy_unless check and triggers an incorrect copy path) - bool added_batch_dim = false; - if (q_array.ndim() == 3) { - auto q_shape = q_array.shape(); - auto k_shape = k_array.shape(); - auto v_shape = v_array.shape(); - q_array = mx::reshape(q_array, {1, q_shape[0], q_shape[1], q_shape[2]}); - k_array = mx::reshape(k_array, {1, k_shape[0], k_shape[1], k_shape[2]}); - v_array = mx::reshape(v_array, {1, v_shape[0], v_shape[1], v_shape[2]}); - added_batch_dim = true; - } - - // MLX fast::scaled_dot_product_attention parameters: - // - queries, keys, values: input tensors - // - scale: scaling factor for attention scores - // - mask: attention mask ("causal" for autoregressive, "" for none) - mx::array result = mx::fast::scaled_dot_product_attention( - q_array, - k_array, - v_array, - scale, - mask - ); - - // Remove batch dimension if we added it - if (added_batch_dim) { - result = mx::squeeze(result, 0); - } - - auto mlx_shape = result.shape(); - std::vector result_shape(mlx_shape.begin(), mlx_shape.end()); - - auto result_buffer = std::make_shared(result); - return Tensor(result_buffer, result_shape); - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - std::string("MLX scaled_dot_product_attention failed: ") + e.what()}); - } -} +size_t MLXBackend::preferred_batch_size() const { return 1024; } +bool MLXBackend::supports_async() const { return false; } } // namespace compute diff --git a/compute/src/compute/backends/mlx/mlx_backend.h b/compute/src/compute/backends/mlx/mlx_backend.h index e5e6cbd..7432621 100644 --- a/compute/src/compute/backends/mlx/mlx_backend.h +++ b/compute/src/compute/backends/mlx/mlx_backend.h @@ -1,8 +1,6 @@ #pragma once #include "../../core/compute_backend.h" -#include "mlx_buffer.h" -#include "mlx_utils.h" #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) #include @@ -10,95 +8,18 @@ namespace mx = mlx::core; namespace compute { -// Forward declaration for MLX buffer wrapper -class MLXBuffer; - -// MLX-based backend for Apple Silicon using MLX framework class MLXBackend : public ComputeBackend { public: MLXBackend(); ~MLXBackend(); - + BackendType type() const override; std::string name() const override; bool is_available() const override; - + Result initialize() override; void cleanup() override; - - // Tensor creation - Tensor create_tensor(std::span data, std::vector shape) override; - Tensor create_tensor(std::vector shape) override; - Tensor wrap_native_tensor(void* native_tensor, std::vector shape) override; - - // Core operations using MLX - Tensor dot_product(const Tensor& a, const Tensor& b) override; - Tensor matrix_scalar_add(const Tensor& input, float scalar) override; - - // Core operations with proper error handling - Result matmul(const Tensor& a, const Tensor& b) override; - Result dequantize( - const Tensor& w, - const Tensor& scales, - const Tensor& biases, - int group_size = 64, - int bits = 4 - ) override; - Result quantized_matmul( - const Tensor& x, - const Tensor& w, - const Tensor& scales, - const Tensor* biases = nullptr, - bool transpose = true, - int group_size = 64, - int bits = 4, - const std::string& mode = "affine" - ) override; - Result add(const Tensor& a, const Tensor& b) override; - Result multiply(const Tensor& a, const Tensor& b) override; - Result softmax(const Tensor& input, int dim = -1) override; - Result silu(const Tensor& input) override; - Result gelu(const Tensor& input) override; - Result sigmoid(const Tensor& input) override; - Result softplus(const Tensor& input) override; - Result exp(const Tensor& input) override; - Result subtract(const Tensor& a, const Tensor& b) override; - Result conv1d(const Tensor& input, const Tensor& weight, - int stride = 1, int padding = 0, int groups = 1) override; - Result transpose(const Tensor& input) override; - Result swapaxes(const Tensor& input, int axis1, int axis2) override; - Result reshape(const Tensor& input, const std::vector& new_shape) override; - Result concatenate(const std::vector& tensors, int axis = 0) override; - - // Additional tensor operations for transformer inference - Result mean(const Tensor& input, int axis = -1, bool keepdims = false) override; - Result rsqrt(const Tensor& input) override; - Result slice(const Tensor& input, int start, int stop, int axis = 0) override; - Result repeat(const Tensor& input, int repeats, int axis) override; - Result triu(const Tensor& input, int k = 0) override; - Result take(const Tensor& input, const std::vector& indices, int axis = 0) override; - Result take(const Tensor& input, const Tensor& indices, int axis = 0) override; - Result topk_indices(const Tensor& input, int k, int axis = -1) override; - - // Optimized transformer operations using MLX fast implementations - Result rms_norm(const Tensor& input, const Tensor& weight, float eps) override; - Result rope(const Tensor& input, int dims, float theta, int offset) override; - Result scaled_dot_product_attention( - const Tensor& queries, - const Tensor& keys, - const Tensor& values, - float scale, - const std::string& mask = "" - ) override; - // Utility operations - Result extract(const Tensor& tensor, std::span output) override; - Result evaluate_all() override; - - // Model loading (stub for now) - std::unordered_map load_model(const std::string& path) override; - - // Performance hints optimized for MLX size_t preferred_batch_size() const override; bool supports_async() const override; @@ -108,4 +29,4 @@ class MLXBackend : public ComputeBackend { } // namespace compute -#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) \ No newline at end of file +#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) diff --git a/compute/src/compute/core/compute_backend.h b/compute/src/compute/core/compute_backend.h index dfe5d78..bf6d461 100644 --- a/compute/src/compute/core/compute_backend.h +++ b/compute/src/compute/core/compute_backend.h @@ -1,312 +1,27 @@ #pragma once #include "compute_types.h" -#include "tensor.h" #include -#include #include -#include +#include namespace compute { -// Abstract base for all compute backends using new Tensor-based API +// Thin lifecycle handle for compute backends. +// Math operations are performed directly via backend-native APIs (e.g. mx::array for MLX). class ComputeBackend { public: virtual ~ComputeBackend() = default; - + // Backend identification virtual BackendType type() const = 0; virtual std::string name() const = 0; virtual bool is_available() const = 0; - + // Lifecycle virtual Result initialize() = 0; virtual void cleanup() = 0; - - // Tensor creation (directly in backend-optimal format) - - /** - * Create tensor from raw data (single conversion to backend format) - */ - virtual Tensor create_tensor(std::span data, std::vector shape) = 0; - - /** - * Create uninitialized tensor (backend allocates optimal memory) - */ - virtual Tensor create_tensor(std::vector shape) = 0; - - /** - * Wrap existing backend-native tensor (e.g., from model loading) - * @param native_tensor Backend-specific tensor pointer (mx::array*, etc.) - * @param shape Tensor shape - */ - virtual Tensor wrap_native_tensor(void* native_tensor, std::vector shape) = 0; - - // Core operations (stay in backend-native format) - - /** - * Dot product of two vectors (returns scalar tensor) - */ - virtual Tensor dot_product(const Tensor& a, const Tensor& b) = 0; - - /** - * Add scalar to all elements of tensor - */ - virtual Tensor matrix_scalar_add(const Tensor& input, float scalar) = 0; - - /** - * Matrix multiplication - */ - virtual Result matmul(const Tensor& a, const Tensor& b) = 0; - - /** - * Dequantize a quantized tensor back to float - * @param w Quantized weight matrix - * @param scales Scale factors - * @param biases Zero-point biases - * @param group_size Elements per quantization group - * @param bits Quantization bits - */ - virtual Result dequantize( - const Tensor& w, - const Tensor& scales, - const Tensor& biases, - int group_size = 64, - int bits = 4 - ) = 0; - - /** - * Quantized matrix multiplication (universal quantization support) - * @param x Input activation tensor - * @param w Quantized weight matrix - * @param scales Scale factors for dequantization - * @param biases Optional bias terms (can be nullptr) - * @param transpose Whether to transpose w during computation - * @param group_size Elements per quantization group (default 64) - * @param bits Quantization bits (default 4, supports 1-8) - * @param mode Quantization mode ("affine" default, "symmetric") - */ - virtual Result quantized_matmul( - const Tensor& x, - const Tensor& w, - const Tensor& scales, - const Tensor* biases = nullptr, - bool transpose = true, - int group_size = 64, - int bits = 4, - const std::string& mode = "affine" - ) = 0; - - /** - * Element-wise addition - */ - virtual Result add(const Tensor& a, const Tensor& b) = 0; - - /** - * Element-wise multiplication - */ - virtual Result multiply(const Tensor& a, const Tensor& b) = 0; - - /** - * Softmax along specified dimension - */ - virtual Result softmax(const Tensor& input, int dim = -1) = 0; - - /** - * SiLU (Swish) activation function: x * sigmoid(x) - */ - virtual Result silu(const Tensor& input) = 0; - - /** - * GELU activation (tanh approximation): used by Gemma GeGLU FFN - */ - virtual Result gelu(const Tensor& input) = 0; - - /** - * Sigmoid activation: 1 / (1 + exp(-x)) - */ - virtual Result sigmoid(const Tensor& input) = 0; - - /** - * Softplus activation: log(1 + exp(x)) - */ - virtual Result softplus(const Tensor& input) = 0; - - /** - * Element-wise exponential: e^x - */ - virtual Result exp(const Tensor& input) = 0; - - /** - * Element-wise subtraction: a - b (supports broadcasting) - */ - virtual Result subtract(const Tensor& a, const Tensor& b) = 0; - - /** - * 1D convolution (channels-last layout: input [N, L, C_in], weight [C_out, kW, C_in/groups]) - * @param input [N, L, C_in] or [L, C_in] (batch dim added if missing) - * @param weight [C_out, kernel_size, C_in/groups] - * @param stride Convolution stride (default 1) - * @param padding Zero-padding applied symmetrically on both sides (default 0) - * @param groups Number of groups for grouped/depthwise conv (default 1) - */ - virtual Result conv1d( - const Tensor& input, - const Tensor& weight, - int stride = 1, - int padding = 0, - int groups = 1) = 0; - - /** - * Transpose tensor (general transpose - reverses dimension order) - */ - virtual Result transpose(const Tensor& input) = 0; - - /** - * Swap two axes of tensor (for attention mechanism) - */ - virtual Result swapaxes(const Tensor& input, int axis1, int axis2) = 0; - - /** - * Reshape tensor to new shape (total elements must remain the same) - */ - virtual Result reshape(const Tensor& input, const std::vector& new_shape) = 0; - - /** - * Concatenate tensors along specified axis - */ - virtual Result concatenate(const std::vector& tensors, int axis = 0) = 0; - // Additional tensor operations for transformer inference - - /** - * Compute mean along specified axis - * @param input Input tensor - * @param axis Axis to reduce (default -1 for last axis) - * @param keepdims Whether to keep reduced dimension (default false) - */ - virtual Result mean(const Tensor& input, int axis = -1, bool keepdims = false) = 0; - - /** - * Reciprocal square root: 1/sqrt(x) - * @param input Input tensor - */ - virtual Result rsqrt(const Tensor& input) = 0; - - /** - * Extract slice from tensor along specified axis - * @param input Input tensor - * @param start Start index - * @param stop Stop index (exclusive) - * @param axis Axis to slice along (default 0) - */ - virtual Result slice(const Tensor& input, int start, int stop, int axis = 0) = 0; - - /** - * Repeat tensor elements along specified axis - * @param input Input tensor - * @param repeats Number of repetitions - * @param axis Axis to repeat along - */ - virtual Result repeat(const Tensor& input, int repeats, int axis) = 0; - - /** - * Upper triangular matrix (for causal masking) - * @param input Input tensor - * @param k Diagonal offset (0 = main diagonal, >0 = above, <0 = below) - */ - virtual Result triu(const Tensor& input, int k = 0) = 0; - - /** - * Gather rows by index along an axis (numpy-style take, CPU indices). - * take(input, {2,0,5}, axis=0) returns rows input[2], input[0], input[5]. - */ - virtual Result take( - const Tensor& input, - const std::vector& indices, - int axis = 0) = 0; - - /** - * Gather rows by index along an axis (GPU tensor indices — no CPU roundtrip). - * Equivalent to the CPU-index overload but indices remain on the GPU, - * enabling fully lazy evaluation with no mx::eval() sync point. - */ - virtual Result take( - const Tensor& input, - const Tensor& indices, - int axis = 0) = 0; - - /** - * Return the indices of the top-k largest values along an axis. - * Result shape: same as input with the given axis replaced by k. - * Indices are returned as a GPU-resident tensor (no CPU extraction). - * @param input Source tensor - * @param k Number of top elements to select - * @param axis Axis to reduce (default -1 for last axis) - */ - virtual Result topk_indices( - const Tensor& input, - int k, - int axis = -1) = 0; - - // Optimized transformer operations - - /** - * RMSNorm layer normalization (fused implementation) - * @param input Input tensor - * @param weight Scale weights - * @param eps Epsilon for numerical stability - */ - virtual Result rms_norm(const Tensor& input, const Tensor& weight, float eps) = 0; - - /** - * Rotary Position Embedding (RoPE) - fused implementation - * @param input Input tensor - * @param dims Number of dimensions to apply RoPE to - * @param theta Base for frequency computation (typically 10000.0) - * @param offset Position offset for the sequence - */ - virtual Result rope(const Tensor& input, int dims, float theta, int offset) = 0; - - /** - * Scaled dot-product attention (fused implementation) - * @param queries Query tensor - * @param keys Key tensor - * @param values Value tensor - * @param scale Scaling factor (typically 1/sqrt(head_dim)) - * @param mask Attention mask type ("causal" for autoregressive, "" for none) - */ - virtual Result scaled_dot_product_attention( - const Tensor& queries, - const Tensor& keys, - const Tensor& values, - float scale, - const std::string& mask = "" - ) = 0; - - // Utility operations - - /** - * Extract tensor data to CPU buffer (triggers computation for lazy backends) - * @param tensor Source tensor - * @param output Destination CPU buffer - */ - virtual Result extract(const Tensor& tensor, std::span output) = 0; - - /** - * Force evaluation of all pending operations (for lazy backends) - */ - virtual Result evaluate_all() = 0; - - // Model loading integration (implementation in Phase 2+) - - /** - * Load model tensors from file (backend-specific format) - * @param path Path to model file (.safetensors, .gguf, etc.) - * @return Map of tensor names to tensors - */ - virtual std::unordered_map load_model(const std::string& path) = 0; - // Performance hints virtual size_t preferred_batch_size() const { return 1024; } virtual bool supports_async() const { return false; } @@ -320,24 +35,24 @@ class BackendFactory { static BackendType best_available_backend(); }; -// Backend manager - singleton that manages backend lifecycle +// Backend manager — singleton that manages backend lifecycle class BackendManager { public: static BackendManager& instance(); - + Result initialize(); void cleanup(); - + ComputeBackend* get_backend(BackendType type); ComputeBackend* get_default_backend(); - + private: BackendManager() = default; std::vector> backends_; ComputeBackend* default_backend_ = nullptr; bool initialized_ = false; - + void create_available_backends(); }; -} // namespace compute \ No newline at end of file +} // namespace compute diff --git a/compute/src/compute/core/compute_types.h b/compute/src/compute/core/compute_types.h index 6e01a0c..b90d44e 100644 --- a/compute/src/compute/core/compute_types.h +++ b/compute/src/compute/core/compute_types.h @@ -1,10 +1,7 @@ #pragma once -#include -#include #include #include -#include namespace compute { @@ -25,7 +22,7 @@ enum class ErrorCode { struct Error { ErrorCode code; std::string message; - + Error(ErrorCode c, std::string msg) : code(c), message(std::move(msg)) {} }; @@ -38,50 +35,7 @@ enum class BackendType { Auto // Let system choose best available }; -// Forward declarations +// Forward declarations class ComputeBackend; -class ComputeGraph; -class ComputeGraphBuilder; -class Tensor; -class BackendBuffer; - -// Node ID for tracking operations in the computation graph -using NodeId = size_t; - -// Symbolic reference to a scalar value in the computation graph -class SymbolicScalar { -public: - explicit SymbolicScalar(NodeId node_id) : node_id_(node_id) {} - - // Get the node ID for dependency tracking - NodeId node_id() const { return node_id_; } - -private: - NodeId node_id_; -}; - -// Symbolic reference to a tensor in the computation graph -class SymbolicTensor { -public: - explicit SymbolicTensor(NodeId node_id, std::vector shape) - : node_id_(node_id), shape_(std::move(shape)) {} - - // Get the node ID for dependency tracking - NodeId node_id() const { return node_id_; } - - // Get the expected output shape - const std::vector& shape() const { return shape_; } - - // Convenience accessors (same as TensorView) - bool is_vector() const { return shape_.size() == 1; } - bool is_matrix() const { return shape_.size() == 2; } - size_t length() const { return is_vector() ? shape_[0] : 0; } - size_t rows() const { return is_matrix() ? shape_[0] : 0; } - size_t cols() const { return is_matrix() ? shape_[1] : 0; } - -private: - NodeId node_id_; - std::vector shape_; -}; -} // namespace compute \ No newline at end of file +} // namespace compute diff --git a/compute/src/compute/model/gemma_model_mlx.cpp b/compute/src/compute/model/gemma_model_mlx.cpp index 2c1fa1d..326f8e8 100644 --- a/compute/src/compute/model/gemma_model_mlx.cpp +++ b/compute/src/compute/model/gemma_model_mlx.cpp @@ -100,7 +100,8 @@ GemmaModelMLX::GemmaModelMLX( SimpleBpeTokenizer tokenizer, std::unordered_map mlx_weights, mx::array embed_mat) - : GemmaModelBase(std::move(config), std::move(tokenizer), {}) + : config_(std::move(config)) + , tokenizer_(std::move(tokenizer)) , mlx_weights_(std::move(mlx_weights)) , embed_mat_(std::move(embed_mat)) {} diff --git a/compute/src/compute/model/gemma_model_mlx.h b/compute/src/compute/model/gemma_model_mlx.h index a5a0c6e..24a8c46 100644 --- a/compute/src/compute/model/gemma_model_mlx.h +++ b/compute/src/compute/model/gemma_model_mlx.h @@ -3,7 +3,6 @@ #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) #include "language_model.h" -#include "gemma_model_base.h" #include #include #include @@ -27,7 +26,7 @@ namespace compute { * - Per-layer rope_theta (local vs global in Gemma3) * - LM head may be tied to embedding table */ -class GemmaModelMLX final : public GemmaModelBase, public LanguageModel { +class GemmaModelMLX final : public LanguageModel { public: static Result from_model_dir( const std::filesystem::path& model_dir); @@ -74,6 +73,8 @@ class GemmaModelMLX final : public GemmaModelBase, public LanguageModel { SamplingParams params, std::function on_token); + ModelConfig config_; + SimpleBpeTokenizer tokenizer_; std::unordered_map mlx_weights_; mlx::core::array embed_mat_; std::optional mlx_state_; diff --git a/compute/src/compute/model/language_model.cpp b/compute/src/compute/model/language_model.cpp index 110fd06..38de4bf 100644 --- a/compute/src/compute/model/language_model.cpp +++ b/compute/src/compute/model/language_model.cpp @@ -1,7 +1,5 @@ #include "language_model.h" #include "llama_model.h" -#include "gemma_model.h" -#include "qwen3_moe_model.h" #include "model_loader.h" #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) @@ -11,17 +9,11 @@ namespace compute { -// ── Factory ─────────────────────────────────────────────────────────────────── -// -// Reads model_type from config.json and instantiates the right subclass. -// When a new architecture is added, add one branch here and one new subclass. - Result> LanguageModel::load( const std::filesystem::path& model_dir, ComputeBackend* backend, size_t context_size) { - // Peek at config to determine the model family without loading weights twice. auto config_result = ModelLoader::load_config(model_dir); if (!config_result) return std::unexpected(config_result.error()); @@ -34,29 +26,19 @@ Result> LanguageModel::load( return std::make_unique(std::move(*result)); } - if (model_type == "gemma" || model_type == "gemma2" || model_type == "gemma3_text") { #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) + if (model_type == "gemma" || model_type == "gemma2" || model_type == "gemma3_text") { auto result = GemmaModelMLX::from_model_dir(model_dir); if (!result) return std::unexpected(result.error()); return std::make_unique(std::move(*result)); -#else - auto result = GemmaModel::from_model_dir(model_dir, backend); - if (!result) return std::unexpected(result.error()); - return std::make_unique(std::move(*result)); -#endif } if (model_type == "qwen3_5_moe" || model_type == "qwen3_moe") { -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) auto result = Qwen3MoeModelMLX::from_model_dir(model_dir, context_size); if (!result) return std::unexpected(result.error()); return std::make_unique(std::move(*result)); -#else - auto result = Qwen3MoeModel::from_model_dir(model_dir, backend); - if (!result) return std::unexpected(result.error()); - return std::make_unique(std::move(*result)); -#endif } +#endif return std::unexpected(Error{ErrorCode::InvalidModel, "Unsupported model type: \"" + model_type + diff --git a/compute/src/compute/model/llama_model.h b/compute/src/compute/model/llama_model.h index ef6e98b..31f98ed 100644 --- a/compute/src/compute/model/llama_model.h +++ b/compute/src/compute/model/llama_model.h @@ -1,7 +1,6 @@ #pragma once #include "language_model.h" -#include "kv_cache.h" #include #include @@ -14,14 +13,12 @@ namespace compute { /** * Concrete LanguageModel for all Llama-family and Mistral-family models. * - * Handles model_type: "llama", "mistral", "qwen2" - * All three families use the identical forward pass: - * RMSNorm → RoPE → GQA → SwiGLU + * Handles model_type: "llama", "mistral", "qwen2", "qwen3" + * All families share the same forward pass: RMSNorm → RoPE → GQA → SwiGLU. * Differences (hidden_size, n_heads, rope_theta, …) are config-driven. */ class LlamaModel final : public LanguageModel { public: - // Factory — loads config, weights, and tokenizer from model_dir. static Result from_model_dir( const std::filesystem::path& model_dir, ComputeBackend* backend, @@ -53,68 +50,22 @@ class LlamaModel final : public LanguageModel { std::string format_tool_result(const std::string& tool_name, const std::string& result_json) const override; - // ── Testing / diagnostic interface ─────────────────────────────────────── - // (Not part of LanguageModel — used by unit tests only via LlamaModel& or TinyLlamaInference&) - - Result> forward(const std::vector& input_ids); - Result forward_logits(const std::vector& input_ids); - Result embedding(const std::vector& token_ids); - Result rms_norm(const Tensor& input, const Tensor& weight, float eps); - Result attention_layer(const Tensor& input, int layer_idx); - Result get_weight(const std::string& name) const; - size_t cache_position() const { return cache_position_; } - private: LlamaModel( - ModelConfig config, - SimpleBpeTokenizer tokenizer, - std::unordered_map weights, - ComputeBackend* backend); - - // ── Layer implementations ───────────────────────────────────────────────── - - // Linear projection: dispatches to quantized_matmul or matmul depending on - // whether {weight_key}.scales exists in the weight map. Works for both - // quantized (mlx-community int4) and unquantized (HF fp16/bf16) models. - Result linear(const Tensor& input, const std::string& weight_key); - - Result attention_layer( - const Tensor& input, - int layer_idx, - int position_offset, - LayerKVCache* cache); - - Result mlp_layer(const Tensor& input, int layer_idx); - - Result transformer_block( - const Tensor& input, - int layer_idx, - int position_offset, - LayerKVCache* cache); - - Result> forward_impl( - const std::vector& input_ids, - int position_offset, - std::vector* cache_vec); + ModelConfig config, + SimpleBpeTokenizer tokenizer, + ComputeBackend* backend); // ── State ───────────────────────────────────────────────────────────────── - // Detected at construction via tokenizer vocab probe. enum class ToolFamily { None, Qwen25, Llama31, MistralTool }; static ToolFamily detect_tool_family(const SimpleBpeTokenizer& tok, const ModelConfig& cfg); - ModelConfig config_; - SimpleBpeTokenizer tokenizer_; - std::unordered_map weights_; - ComputeBackend* backend_; - ToolFamily tool_family_ = ToolFamily::None; - - std::vector kv_cache_; - size_t cache_position_ = 0; - - // Cached dequantized embedding table (populated on first use for quantized models) - mutable std::optional dequantized_embed_tokens_; + ModelConfig config_; + SimpleBpeTokenizer tokenizer_; + ComputeBackend* backend_; + ToolFamily tool_family_ = ToolFamily::None; #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) struct MlxDecodeState { diff --git a/compute/src/compute/model/model_loader.cpp b/compute/src/compute/model/model_loader.cpp index c77bdb5..2daf822 100644 --- a/compute/src/compute/model/model_loader.cpp +++ b/compute/src/compute/model/model_loader.cpp @@ -1,5 +1,4 @@ #include "model_loader.h" -#include "../core/compute_backend.h" #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) #include @@ -7,48 +6,9 @@ namespace mx = mlx::core; #endif #include -#include namespace compute { -Result>> -ModelLoader::load_model(const std::filesystem::path& model_dir, ComputeBackend* backend) { - if (!backend) { - return std::unexpected(Error{ErrorCode::InvalidInput, "Backend cannot be null"}); - } - - if (!std::filesystem::exists(model_dir)) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Model directory does not exist: " + model_dir.string()}); - } - - if (!std::filesystem::is_directory(model_dir)) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Path is not a directory: " + model_dir.string()}); - } - - // 1. Load configuration - auto config_result = load_config(model_dir); - if (!config_result) { - return std::unexpected(config_result.error()); - } - - // 2. Find safetensors files - auto safetensors_files = find_safetensors_files(model_dir); - if (safetensors_files.empty()) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "No .safetensors files found in directory: " + model_dir.string()}); - } - - // 3. Load all tensors - auto tensors_result = load_all_safetensors(safetensors_files, backend); - if (!tensors_result) { - return std::unexpected(tensors_result.error()); - } - - return std::make_pair(*config_result, *tensors_result); -} - Result ModelLoader::load_config(const std::filesystem::path& model_dir) { auto config_path = model_dir / "config.json"; return ModelConfig::from_config_file(config_path); @@ -56,79 +16,15 @@ Result ModelLoader::load_config(const std::filesystem::path& model_ std::vector ModelLoader::find_safetensors_files(const std::filesystem::path& model_dir) { - std::vector safetensors_files; - + std::vector files; try { for (const auto& entry : std::filesystem::directory_iterator(model_dir)) { - if (entry.is_regular_file() && entry.path().extension() == ".safetensors") { - safetensors_files.push_back(entry.path()); - } - } - } catch (const std::filesystem::filesystem_error& e) { - // Return empty vector on filesystem errors - return {}; - } - - // Sort files for consistent ordering - std::sort(safetensors_files.begin(), safetensors_files.end()); - return safetensors_files; -} - -Result> -ModelLoader::load_all_safetensors(const std::vector& safetensors_files, - ComputeBackend* backend) { - -#if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - // Check that we have an MLX backend - if (backend->type() != BackendType::MLX) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "ModelLoader requires MLX backend for safetensors loading"}); - } - - std::unordered_map all_tensors; - - try { - for (const auto& file_path : safetensors_files) { - // Use MLX to load safetensors file - auto safetensors_result = mx::load_safetensors(file_path.string()); - - // mx::load_safetensors returns a pair - const auto& tensor_map = safetensors_result.first; - - // Convert each MLX array to compute Tensor using wrap_native_tensor - for (const auto& [name, mlx_array] : tensor_map) { - // Get shape from MLX array - auto mlx_shape = mlx_array.shape(); - std::vector shape(mlx_shape.begin(), mlx_shape.end()); - - // Create a copy of the MLX array for wrapping - // Note: This preserves lazy evaluation - no .eval() called - auto* array_ptr = new mx::array(mlx_array); - - // Wrap the MLX array as a Tensor - auto tensor = backend->wrap_native_tensor(array_ptr, shape); - - // Check for duplicate tensor names across files - if (all_tensors.find(name) != all_tensors.end()) { - return std::unexpected(Error{ErrorCode::InvalidInput, - "Duplicate tensor name found: " + name}); - } - - all_tensors.emplace(name, std::move(tensor)); - } + if (entry.is_regular_file() && entry.path().extension() == ".safetensors") + files.push_back(entry.path()); } - - return all_tensors; - - } catch (const std::exception& e) { - return std::unexpected(Error{ErrorCode::ComputeError, - "Failed to load safetensors: " + std::string(e.what())}); - } - -#else - return std::unexpected(Error{ErrorCode::BackendNotAvailable, - "MLX backend not available - cannot load safetensors"}); -#endif + } catch (const std::filesystem::filesystem_error&) {} + std::sort(files.begin(), files.end()); + return files; } #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) @@ -164,4 +60,4 @@ ModelLoader::load_model_mlx(const std::filesystem::path& model_dir) { #endif // MLX_BACKEND_ENABLED -} // namespace compute \ No newline at end of file +} // namespace compute diff --git a/compute/src/compute/model/model_loader.h b/compute/src/compute/model/model_loader.h index e929507..b098a95 100644 --- a/compute/src/compute/model/model_loader.h +++ b/compute/src/compute/model/model_loader.h @@ -1,11 +1,9 @@ #pragma once #include "../core/compute_types.h" -#include "../core/tensor.h" #include "model_config.h" -#include -#include #include +#include #include #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) @@ -14,62 +12,20 @@ namespace compute { -// Forward declaration -class ComputeBackend; - -/** - * ModelLoader handles loading transformer models from safetensors format - * Simple, focused implementation for MLX backend with lazy evaluation preservation - * Works generically with any MLX safetensors model - */ class ModelLoader { public: - /** - * Load complete model from directory containing config.json and .safetensors files - * @param model_dir Path to model directory - * @param backend ComputeBackend to use for tensor operations (must support MLX) - * @return Result containing model config and tensor map, or error - */ - static Result>> - load_model(const std::filesystem::path& model_dir, ComputeBackend* backend); - - /** - * Load model configuration only (no weights) - * @param model_dir Path to model directory containing config.json - * @return Result containing parsed model config or error - */ static Result load_config(const std::filesystem::path& model_dir); #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) - /** - * Load model weights as native MLX arrays — no ComputeBackend required. - * Eliminates the Tensor wrapper overhead for MLX-native model classes. - */ static Result>> load_model_mlx(const std::filesystem::path& model_dir); #endif private: - /** - * Find all .safetensors files in model directory - * @param model_dir Path to model directory - * @return Vector of paths to .safetensors files - */ static std::vector find_safetensors_files( const std::filesystem::path& model_dir); - /** - * Load tensors from all .safetensors files using MLX backend - * Preserves MLX lazy evaluation - tensors are not evaluated - * @param safetensors_files List of .safetensors file paths - * @param backend ComputeBackend to use (must support MLX) - * @return Result containing tensor name -> tensor map, or error - */ - static Result> - load_all_safetensors(const std::vector& safetensors_files, - ComputeBackend* backend); - - ModelLoader() = delete; // Static utility class + ModelLoader() = delete; }; -} // namespace compute \ No newline at end of file +} // namespace compute diff --git a/compute/src/compute/model/qwen3_moe_model_mlx.cpp b/compute/src/compute/model/qwen3_moe_model_mlx.cpp index f009900..aa4fb7c 100644 --- a/compute/src/compute/model/qwen3_moe_model_mlx.cpp +++ b/compute/src/compute/model/qwen3_moe_model_mlx.cpp @@ -429,7 +429,8 @@ Qwen3MoeModelMLX::Qwen3MoeModelMLX( std::unordered_map mlx_weights, mx::array embed_mat, size_t context_size) - : Qwen3MoeModelBase(std::move(config), std::move(tokenizer), {}) + : config_(std::move(config)) + , tokenizer_(std::move(tokenizer)) , mlx_weights_(std::move(mlx_weights)) , embed_mat_(std::move(embed_mat)) , context_size_(context_size) diff --git a/compute/src/compute/model/qwen3_moe_model_mlx.h b/compute/src/compute/model/qwen3_moe_model_mlx.h index a426976..bbe931f 100644 --- a/compute/src/compute/model/qwen3_moe_model_mlx.h +++ b/compute/src/compute/model/qwen3_moe_model_mlx.h @@ -3,7 +3,6 @@ #if defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) #include "language_model.h" -#include "qwen3_moe_model_base.h" #include #include #include @@ -21,7 +20,7 @@ namespace compute { * Prefill runs all T prompt tokens in one eager pass: one Metal dispatch per SSM layer * (T-loop inside kernel), one SDPA per attention layer, T per-token MoE calls. */ -class Qwen3MoeModelMLX final : public Qwen3MoeModelBase, public LanguageModel { +class Qwen3MoeModelMLX final : public LanguageModel { public: static Result from_model_dir( const std::filesystem::path& model_dir, @@ -72,6 +71,8 @@ class Qwen3MoeModelMLX final : public Qwen3MoeModelBase, public LanguageModel { SamplingParams params, std::function on_token); + ModelConfig config_; + SimpleBpeTokenizer tokenizer_; std::unordered_map mlx_weights_; mlx::core::array embed_mat_; std::optional mlx_state_; From 0596aa22c6b309fdcadc97f374d3c2f6a4ed8c90 Mon Sep 17 00:00:00 2001 From: Dex Date: Wed, 6 May 2026 16:37:44 -0400 Subject: [PATCH 5/5] chore: remove vestigial dead code identified in audit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ErrorCode: remove InvalidArgument, InsufficientMemory, TensorNotFound, NotImplemented — none were ever returned in production code - ComputeBackend: remove preferred_batch_size() and supports_async() — declared and overridden in MlxBackend but never called by any client - ModelConfig: remove name_or_path and transformers_version — parsed from JSON but never read after parsing - LlamaModel: remove context_size_ member — set in mlx_setup(), never read - Qwen3MoeModelMLX: remove context_size_ member — same pattern - Delete tinyllama_inference.h/.cpp — Phase D compatibility alias no longer needed; update 5 test files to use LlamaModel directly - Delete test_attention_qkv_trace.cpp — became an empty placeholder after attention_layer() was removed in Phase O --- compute/CMakeLists.txt | 1 - compute/src/compute/backends/mlx/mlx_backend.cpp | 3 --- compute/src/compute/backends/mlx/mlx_backend.h | 3 --- compute/src/compute/core/compute_backend.h | 4 ---- compute/src/compute/core/compute_types.h | 4 ---- compute/src/compute/model/llama_model.cpp | 1 - compute/src/compute/model/llama_model.h | 1 - compute/src/compute/model/model_config.cpp | 9 --------- compute/src/compute/model/model_config.h | 4 ---- compute/src/compute/model/qwen3_moe_model_mlx.cpp | 5 +++-- compute/src/compute/model/qwen3_moe_model_mlx.h | 1 - compute/src/compute/model/tinyllama_inference.cpp | 2 -- compute/src/compute/model/tinyllama_inference.h | 12 ------------ compute/tests/compute/test_attention_qkv_trace.cpp | 3 --- compute/tests/compute/test_forward_pass.cpp | 10 +++++----- compute/tests/compute/test_fp16bf16_integration.cpp | 10 +++++----- compute/tests/compute/test_llama3_integration.cpp | 10 +++++----- compute/tests/compute/test_mistral_integration.cpp | 10 +++++----- compute/tests/compute/test_qwen2_integration.cpp | 10 +++++----- 19 files changed, 28 insertions(+), 75 deletions(-) delete mode 100644 compute/src/compute/model/tinyllama_inference.cpp delete mode 100644 compute/src/compute/model/tinyllama_inference.h delete mode 100644 compute/tests/compute/test_attention_qkv_trace.cpp diff --git a/compute/CMakeLists.txt b/compute/CMakeLists.txt index 83bfdf6..719a486 100644 --- a/compute/CMakeLists.txt +++ b/compute/CMakeLists.txt @@ -110,7 +110,6 @@ if(BUILD_TESTING) tests/compute/test_model_loader.cpp tests/compute/test_tokenizer_config.cpp tests/compute/test_simple_bpe_tokenizer.cpp - tests/compute/test_attention_qkv_trace.cpp tests/compute/test_forward_pass.cpp tests/compute/test_mistral_integration.cpp tests/compute/test_llama3_integration.cpp diff --git a/compute/src/compute/backends/mlx/mlx_backend.cpp b/compute/src/compute/backends/mlx/mlx_backend.cpp index 24bd939..d382395 100644 --- a/compute/src/compute/backends/mlx/mlx_backend.cpp +++ b/compute/src/compute/backends/mlx/mlx_backend.cpp @@ -52,9 +52,6 @@ void MLXBackend::cleanup() { } } -size_t MLXBackend::preferred_batch_size() const { return 1024; } -bool MLXBackend::supports_async() const { return false; } - } // namespace compute #endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED) diff --git a/compute/src/compute/backends/mlx/mlx_backend.h b/compute/src/compute/backends/mlx/mlx_backend.h index 7432621..1fce31c 100644 --- a/compute/src/compute/backends/mlx/mlx_backend.h +++ b/compute/src/compute/backends/mlx/mlx_backend.h @@ -20,9 +20,6 @@ class MLXBackend : public ComputeBackend { Result initialize() override; void cleanup() override; - size_t preferred_batch_size() const override; - bool supports_async() const override; - private: bool m_initialized; }; diff --git a/compute/src/compute/core/compute_backend.h b/compute/src/compute/core/compute_backend.h index bf6d461..3da435f 100644 --- a/compute/src/compute/core/compute_backend.h +++ b/compute/src/compute/core/compute_backend.h @@ -21,10 +21,6 @@ class ComputeBackend { // Lifecycle virtual Result initialize() = 0; virtual void cleanup() = 0; - - // Performance hints - virtual size_t preferred_batch_size() const { return 1024; } - virtual bool supports_async() const { return false; } }; // Factory for creating backends diff --git a/compute/src/compute/core/compute_types.h b/compute/src/compute/core/compute_types.h index b90d44e..333b57a 100644 --- a/compute/src/compute/core/compute_types.h +++ b/compute/src/compute/core/compute_types.h @@ -9,13 +9,9 @@ namespace compute { enum class ErrorCode { Success, InvalidInput, - InvalidArgument, InvalidModel, BackendNotAvailable, - InsufficientMemory, ComputeError, - TensorNotFound, - NotImplemented, UnknownError }; diff --git a/compute/src/compute/model/llama_model.cpp b/compute/src/compute/model/llama_model.cpp index 662d162..f7e9308 100644 --- a/compute/src/compute/model/llama_model.cpp +++ b/compute/src/compute/model/llama_model.cpp @@ -333,7 +333,6 @@ void LlamaModel::mlx_setup( { mlx_weights_ = std::move(mlx_weights); mlx_embed_mat_ = std::move(mlx_embed_mat); - context_size_ = context_size; } void LlamaModel::mlx_init_state() { diff --git a/compute/src/compute/model/llama_model.h b/compute/src/compute/model/llama_model.h index 31f98ed..095476a 100644 --- a/compute/src/compute/model/llama_model.h +++ b/compute/src/compute/model/llama_model.h @@ -78,7 +78,6 @@ class LlamaModel final : public LanguageModel { std::unordered_map mlx_weights_; mlx::core::array mlx_embed_mat_; std::optional mlx_state_; - size_t context_size_ = 0; size_t mlx_pos_ = 0; void mlx_setup(std::unordered_map mlx_weights, diff --git a/compute/src/compute/model/model_config.cpp b/compute/src/compute/model/model_config.cpp index 42ec0d5..07e2ace 100644 --- a/compute/src/compute/model/model_config.cpp +++ b/compute/src/compute/model/model_config.cpp @@ -252,15 +252,6 @@ Result ModelConfig::from_json_string(const std::string& json_str) { config.mrope_section = json["mrope_section"].get>(); } - // Parse OPTIONAL metadata - if (json.contains("_name_or_path") && !json["_name_or_path"].is_null()) { - config.name_or_path = json["_name_or_path"].get(); - } - - if (json.contains("transformers_version") && !json["transformers_version"].is_null()) { - config.transformers_version = json["transformers_version"].get(); - } - return config; } catch (const nlohmann::json::parse_error& e) { diff --git a/compute/src/compute/model/model_config.h b/compute/src/compute/model/model_config.h index bb4e7bf..3125e5f 100644 --- a/compute/src/compute/model/model_config.h +++ b/compute/src/compute/model/model_config.h @@ -81,10 +81,6 @@ struct ModelConfig { std::optional mrope_interleaved; // interleaved mRoPE mode std::optional> mrope_section; // mRoPE section sizes [11,11,10] - // Model path info - may be optional - std::optional name_or_path; - std::optional transformers_version; - /** * Parse ModelConfig from config.json file * @param config_path Path to config.json file diff --git a/compute/src/compute/model/qwen3_moe_model_mlx.cpp b/compute/src/compute/model/qwen3_moe_model_mlx.cpp index aa4fb7c..d348926 100644 --- a/compute/src/compute/model/qwen3_moe_model_mlx.cpp +++ b/compute/src/compute/model/qwen3_moe_model_mlx.cpp @@ -433,8 +433,9 @@ Qwen3MoeModelMLX::Qwen3MoeModelMLX( , tokenizer_(std::move(tokenizer)) , mlx_weights_(std::move(mlx_weights)) , embed_mat_(std::move(embed_mat)) - , context_size_(context_size) -{} +{ + (void)context_size; +} size_t Qwen3MoeModelMLX::num_parameters() const { size_t total = 0; diff --git a/compute/src/compute/model/qwen3_moe_model_mlx.h b/compute/src/compute/model/qwen3_moe_model_mlx.h index bbe931f..f3c9734 100644 --- a/compute/src/compute/model/qwen3_moe_model_mlx.h +++ b/compute/src/compute/model/qwen3_moe_model_mlx.h @@ -77,7 +77,6 @@ class Qwen3MoeModelMLX final : public LanguageModel { mlx::core::array embed_mat_; std::optional mlx_state_; size_t cache_position_ = 0; - size_t context_size_ = 0; }; } // namespace compute diff --git a/compute/src/compute/model/tinyllama_inference.cpp b/compute/src/compute/model/tinyllama_inference.cpp deleted file mode 100644 index 344276b..0000000 --- a/compute/src/compute/model/tinyllama_inference.cpp +++ /dev/null @@ -1,2 +0,0 @@ -// Implementation moved to llama_model.cpp (Phase D refactor). -// This file is intentionally empty and no longer compiled. diff --git a/compute/src/compute/model/tinyllama_inference.h b/compute/src/compute/model/tinyllama_inference.h deleted file mode 100644 index 3e32da6..0000000 --- a/compute/src/compute/model/tinyllama_inference.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once -// TinyLlamaInference has been renamed to LlamaModel (Phase D refactor). -// This header is kept for backwards compatibility so existing test files -// continue to compile without edits. -// -// New code should #include "llama_model.h" and use compute::LlamaModel directly. -// This alias will be removed in Phase D.6. -#include "llama_model.h" - -namespace compute { - using TinyLlamaInference = LlamaModel; -} diff --git a/compute/tests/compute/test_attention_qkv_trace.cpp b/compute/tests/compute/test_attention_qkv_trace.cpp deleted file mode 100644 index 41392d4..0000000 --- a/compute/tests/compute/test_attention_qkv_trace.cpp +++ /dev/null @@ -1,3 +0,0 @@ -// attention_layer() was removed in the Phase O Tensor-abstraction cleanup (O.6.2). -// Tests that relied on it have been deleted. This file is preserved as a placeholder. -#include diff --git a/compute/tests/compute/test_forward_pass.cpp b/compute/tests/compute/test_forward_pass.cpp index 4d84a3a..859fd02 100644 --- a/compute/tests/compute/test_forward_pass.cpp +++ b/compute/tests/compute/test_forward_pass.cpp @@ -1,5 +1,5 @@ #include -#include "compute/model/tinyllama_inference.h" +#include "compute/model/llama_model.h" #include "compute/core/compute_backend.h" #include "test_config.h" #include @@ -28,9 +28,9 @@ class ForwardPassTest : public ::testing::Test { backend_ = std::move(*backend_result); if (!backend_->initialize()) { skip_reason_ = "Backend init failed"; return; } - auto inf_result = TinyLlamaInference::from_model_dir(model_dir_, backend_.get()); + auto inf_result = LlamaModel::from_model_dir(model_dir_, backend_.get()); if (!inf_result) { skip_reason_ = inf_result.error().message; return; } - inference_ = std::make_unique(std::move(*inf_result)); + inference_ = std::make_unique(std::move(*inf_result)); } static void TearDownTestSuite() { @@ -49,14 +49,14 @@ class ForwardPassTest : public ::testing::Test { static std::filesystem::path baseline_dir_; static std::string skip_reason_; static std::unique_ptr backend_; - static std::unique_ptr inference_; + static std::unique_ptr inference_; }; std::filesystem::path ForwardPassTest::model_dir_; std::filesystem::path ForwardPassTest::baseline_dir_; std::string ForwardPassTest::skip_reason_; std::unique_ptr ForwardPassTest::backend_; -std::unique_ptr ForwardPassTest::inference_; +std::unique_ptr ForwardPassTest::inference_; TEST_F(ForwardPassTest, GenerateCoherentOutput) { const std::string prompt = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"; diff --git a/compute/tests/compute/test_fp16bf16_integration.cpp b/compute/tests/compute/test_fp16bf16_integration.cpp index dd27196..20f4bf1 100644 --- a/compute/tests/compute/test_fp16bf16_integration.cpp +++ b/compute/tests/compute/test_fp16bf16_integration.cpp @@ -1,5 +1,5 @@ #include -#include "compute/model/tinyllama_inference.h" +#include "compute/model/llama_model.h" #include "compute/core/compute_backend.h" #include #include @@ -26,12 +26,12 @@ class Fp16Bf16IntegrationTest : public ::testing::Test { backend_ = std::move(*backend_result); if (!backend_->initialize()) { skip_reason_ = "Backend init failed"; return; } - auto inf_result = TinyLlamaInference::from_model_dir(model_dir_, backend_.get()); + auto inf_result = LlamaModel::from_model_dir(model_dir_, backend_.get()); if (!inf_result) { skip_reason_ = "Failed to load bf16 model: " + inf_result.error().message; return; } - inference_ = std::make_unique(std::move(*inf_result)); + inference_ = std::make_unique(std::move(*inf_result)); std::cout << "Loaded bf16 model: " << inference_->config().model_type << " hidden=" << inference_->config().hidden_size @@ -55,13 +55,13 @@ class Fp16Bf16IntegrationTest : public ::testing::Test { static std::filesystem::path model_dir_; static std::string skip_reason_; static std::unique_ptr backend_; - static std::unique_ptr inference_; + static std::unique_ptr inference_; }; std::filesystem::path Fp16Bf16IntegrationTest::model_dir_; std::string Fp16Bf16IntegrationTest::skip_reason_; std::unique_ptr Fp16Bf16IntegrationTest::backend_; -std::unique_ptr Fp16Bf16IntegrationTest::inference_; +std::unique_ptr Fp16Bf16IntegrationTest::inference_; TEST_F(Fp16Bf16IntegrationTest, ConfigIsUnquantized) { EXPECT_FALSE(inference_->config().quantization.has_value()) diff --git a/compute/tests/compute/test_llama3_integration.cpp b/compute/tests/compute/test_llama3_integration.cpp index 5627447..9b4ad26 100644 --- a/compute/tests/compute/test_llama3_integration.cpp +++ b/compute/tests/compute/test_llama3_integration.cpp @@ -1,5 +1,5 @@ #include -#include "compute/model/tinyllama_inference.h" +#include "compute/model/llama_model.h" #include "compute/core/compute_backend.h" #include #include @@ -27,12 +27,12 @@ class Llama3IntegrationTest : public ::testing::Test { backend_ = std::move(*backend_result); if (!backend_->initialize()) { skip_reason_ = "Backend init failed"; return; } - auto inf_result = TinyLlamaInference::from_model_dir(kModelDir, backend_.get()); + auto inf_result = LlamaModel::from_model_dir(kModelDir, backend_.get()); if (!inf_result) { skip_reason_ = "Failed to load Llama-3: " + inf_result.error().message; return; } - inference_ = std::make_unique(std::move(*inf_result)); + inference_ = std::make_unique(std::move(*inf_result)); std::cout << "Loaded Llama-3 model: " << inference_->config().model_type << " hidden=" << inference_->config().hidden_size @@ -53,12 +53,12 @@ class Llama3IntegrationTest : public ::testing::Test { static std::string skip_reason_; static std::unique_ptr backend_; - static std::unique_ptr inference_; + static std::unique_ptr inference_; }; std::string Llama3IntegrationTest::skip_reason_; std::unique_ptr Llama3IntegrationTest::backend_; -std::unique_ptr Llama3IntegrationTest::inference_; +std::unique_ptr Llama3IntegrationTest::inference_; // ── Config ──────────────────────────────────────────────────────────────────── diff --git a/compute/tests/compute/test_mistral_integration.cpp b/compute/tests/compute/test_mistral_integration.cpp index c019993..7cba1a3 100644 --- a/compute/tests/compute/test_mistral_integration.cpp +++ b/compute/tests/compute/test_mistral_integration.cpp @@ -1,5 +1,5 @@ #include -#include "compute/model/tinyllama_inference.h" +#include "compute/model/llama_model.h" #include "compute/core/compute_backend.h" #include "test_config.h" #include @@ -29,12 +29,12 @@ class MistralIntegrationTest : public ::testing::Test { backend_ = std::move(*backend_result); if (!backend_->initialize()) { skip_reason_ = "Backend init failed"; return; } - auto inf_result = TinyLlamaInference::from_model_dir(model_dir_, backend_.get()); + auto inf_result = LlamaModel::from_model_dir(model_dir_, backend_.get()); if (!inf_result) { skip_reason_ = "Failed to load Mistral: " + inf_result.error().message; return; } - inference_ = std::make_unique(std::move(*inf_result)); + inference_ = std::make_unique(std::move(*inf_result)); std::cout << "Loaded Mistral model: " << inference_->config().model_type << " hidden=" << inference_->config().hidden_size @@ -56,13 +56,13 @@ class MistralIntegrationTest : public ::testing::Test { static std::filesystem::path model_dir_; static std::string skip_reason_; static std::unique_ptr backend_; - static std::unique_ptr inference_; + static std::unique_ptr inference_; }; std::filesystem::path MistralIntegrationTest::model_dir_; std::string MistralIntegrationTest::skip_reason_; std::unique_ptr MistralIntegrationTest::backend_; -std::unique_ptr MistralIntegrationTest::inference_; +std::unique_ptr MistralIntegrationTest::inference_; // Verify the model loads with correct Mistral architecture config TEST_F(MistralIntegrationTest, ConfigLoadsCorrectly) { diff --git a/compute/tests/compute/test_qwen2_integration.cpp b/compute/tests/compute/test_qwen2_integration.cpp index 68521e9..a4041d4 100644 --- a/compute/tests/compute/test_qwen2_integration.cpp +++ b/compute/tests/compute/test_qwen2_integration.cpp @@ -1,5 +1,5 @@ #include -#include "compute/model/tinyllama_inference.h" +#include "compute/model/llama_model.h" #include "compute/core/compute_backend.h" #include "test_config.h" #include @@ -28,12 +28,12 @@ class Qwen2IntegrationTest : public ::testing::Test { backend_ = std::move(*backend_result); if (!backend_->initialize()) { skip_reason_ = "Backend init failed"; return; } - auto inf_result = TinyLlamaInference::from_model_dir(model_dir_, backend_.get()); + auto inf_result = LlamaModel::from_model_dir(model_dir_, backend_.get()); if (!inf_result) { skip_reason_ = "Failed to load Qwen2.5: " + inf_result.error().message; return; } - inference_ = std::make_unique(std::move(*inf_result)); + inference_ = std::make_unique(std::move(*inf_result)); std::cout << "Loaded Qwen2.5 model: " << inference_->config().model_type << " hidden=" << inference_->config().hidden_size @@ -56,13 +56,13 @@ class Qwen2IntegrationTest : public ::testing::Test { static std::filesystem::path model_dir_; static std::string skip_reason_; static std::unique_ptr backend_; - static std::unique_ptr inference_; + static std::unique_ptr inference_; }; std::filesystem::path Qwen2IntegrationTest::model_dir_; std::string Qwen2IntegrationTest::skip_reason_; std::unique_ptr Qwen2IntegrationTest::backend_; -std::unique_ptr Qwen2IntegrationTest::inference_; +std::unique_ptr Qwen2IntegrationTest::inference_; TEST_F(Qwen2IntegrationTest, ConfigLoadsCorrectly) { EXPECT_EQ(inference_->config().model_type, "qwen2");