Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![GUI Test](https://github.com/dexwritescode/neurons/actions/workflows/gui-test.yml/badge.svg)](https://github.com/dexwritescode/neurons/actions/workflows/gui-test.yml)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)

A from-scratch LLM inference engine and chat application. Built to understand how large language models actually work at the hardware level — using Metal/MLX, cuBLAS, and flash-attention directly rather than wrapping llama.cpp or Ollama.
A from-scratch LLM inference engine and chat application. Built to understand how large language models actually work at the hardware level — using Metal/MLX directly rather than wrapping llama.cpp or Ollama.

---

Expand Down Expand Up @@ -59,38 +59,39 @@ The GUI never links C++ directly. Locally it calls `libneurons_core.dylib` over
|---|---|---|
| Llama 2/3, TinyLlama | `mlx-community/Llama-3.2-3B-Instruct-4bit` | MLX |
| Mistral | `mlx-community/Mistral-7B-Instruct-v0.3-4bit` | MLX |
| Qwen2 / Qwen2.5 | `mlx-community/Qwen2.5-7B-Instruct-4bit` | MLX |
| Qwen2 / Qwen2.5 / Qwen3 | `mlx-community/Qwen2.5-7B-Instruct-4bit` | MLX |
| Qwen3 MoE | `mlx-community/Qwen3-30B-A3B-4bit`, `mlx-community/Qwen3.6-35B-A3B-4bit` | MLX |
| Gemma / Gemma2 / Gemma3 | `mlx-community/gemma-3-1b-it-qat-4bit` | MLX |
| fp16 / bf16 unquantized | any base HuggingFace safetensors repo | MLX |

All models are downloaded directly from HuggingFace in their `mlx-community` MLX-quantized variants for Apple Silicon. CUDA (cuBLAS + flash-attention) and ROCm backends are on the roadmap.
All models are downloaded directly from HuggingFace in their `mlx-community` MLX-quantized variants for Apple Silicon. CUDA and ROCm backends are on the roadmap.

---

## Architecture

```
┌──────────────────────────────────────────────────────┐
Flutter GUI (macOS / Windows / Linux / mobile) │
│ dart:ffi (local) · gRPC (remote nodes) │
└──────────────────────┬───────────────────────────────┘
│ dart:ffi / gRPC
┌──────────────────────▼───────────────────────────────┐
│ libneurons_core (C FFI surface) │
│ NeuronsServiceImpl → LanguageModel::load() │
└──────────────────────┬───────────────────────────────┘
│ LanguageModel (interface)
┌────────────┼────────────┐
▼ ▼ ▼
LlamaModel GemmaModel (future)
Llama/Mistral Gemma 1-3
Qwen2/2.5 GeGLU/QKV-norm
└────────────┬────────────┘
│ ComputeBackend (interface)
┌────────────┼────────────┬──────────────┐
▼ ▼ ▼ ▼
MLXBackend CUDABackend ROCmBackend CPUBackend
(done) (roadmap) (roadmap) (roadmap)
```mermaid
graph TD
GUI["Flutter GUI (macOS) — dart:ffi · gRPC"]
Core["libneurons_core — C FFI · NeuronsServiceImpl"]
LM["LanguageModel::load()"]

Llama["LlamaModel — Llama 2/3 · Mistral · Qwen2/2.5/3"]
Gemma["GemmaModelMLX — Gemma / Gemma2 / Gemma3"]
Qwen3Moe["Qwen3MoeModelMLX — Qwen3 MoE"]

Backend["ComputeBackend (interface)"]
MLX["MLXBackend — Apple Silicon · Metal · mx::compile"]
Roadmap["CUDA / ROCm (roadmap)"]

GUI --> Core
Core --> LM
LM --> Llama
LM --> Gemma
LM --> Qwen3Moe
Llama & Gemma & Qwen3Moe --> Backend
Backend --> MLX
Backend -.-> Roadmap
```

---
Expand Down Expand Up @@ -156,6 +157,22 @@ make gui # build dylib + Flutter macOS release app

---

## Performance

Measured on Apple Silicon (M2 Max 64 GB), greedy decoding (temperature=0), release build:

| Model | Params | Active params | tok/s |
|---|---|---|---|
| TinyLlama 1.1B 4-bit | 1.1B | 1.1B | ~265 |
| Gemma 3 1B 4-bit | 1B | 1B | ~190 |
| Llama-3.1 8B 4-bit | 8B | 8B | ~61 |
| Mistral 7B 4-bit | 7B | 7B | ~57 |
| Qwen3.6 35B-A3B 4-bit | 35B | 3.6B | ~77 |

MoE models run near the speed of a dense 3-4B model because only a small fraction of parameters are active per token. Decode uses GPU-pipelined generation with `mx::compile` — the first generation per session incurs a one-time compilation cost.

---

## Quick start

### Download and run a model in the terminal
Expand Down Expand Up @@ -249,15 +266,16 @@ Neurons/
| Phase | Status | Description |
|---|---|---|
| A–E | ✅ | MLX backend, KV cache, sampling, Llama/Gemma/Qwen/Mistral |
| F | ✅ | Model family support (fp16/bf16, Gemma3, Qwen2.5) |
| F | ✅ | Model family support (fp16/bf16, Gemma3, Qwen2.5, Qwen3, Qwen3 MoE) |
| G–I | ✅ | gRPC service, Flutter GUI, CLI, OpenAI HTTP, logging |
| O | ✅ | MLX performance — GPU-pipelined decode, mx::compile, batched prefill |
| J | 🚧 | File attach + RAG (embeddings, sqlite-vec) |
| K | 🚧 | Multi-node: routing, speculative decoding, failover |
| L.1–2 | ✅ | MCP client runtime — stdio/SSE transport, JSON-RPC 2.0, McpManager |
| L.3 | ✅ | MCP gRPC extensions — server/permission RPCs, tool approval flow |
| L.4–6 | 🚧 | MCP GUI — settings, permissions table, live approval prompt |
| L.8 | 🚧 | Built-in MCP servers (filesystem, shell) |
| B/C | 🚧 | CUDA (cuBLAS + flash-attention) and ROCm backends |
| B/C | 🚧 | CUDA and ROCm backends |

---

Expand Down
36 changes: 36 additions & 0 deletions compute/tests/compute/test_gemma_integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#include <gmock/gmock.h>
#include "../../src/compute/model/language_model.h"
#include "../../src/compute/core/compute_backend.h"
#include <chrono>
#include <filesystem>
#include <iostream>
#include <string>

namespace {
Expand Down Expand Up @@ -122,4 +124,38 @@ TEST_F(GemmaIntegrationTest, GenerateCapitalOfFrance) {
EXPECT_THAT(output, ::testing::HasSubstr("Paris"));
}

TEST_F(GemmaIntegrationTest, GenerateThroughput) {
const std::string prompt =
"<start_of_turn>user\n"
"Write a detailed paragraph about the history of France, "
"including the French Revolution, Napoleon, and the World Wars.<end_of_turn>\n"
"<start_of_turn>model\n";

auto ids = model_->tokenizer().encode(prompt, /*add_special_tokens=*/true);
ASSERT_FALSE(ids.empty());

compute::SamplingParams greedy;
greedy.temperature = 0.0f;

// Warmup: trigger mx::compile before measuring steady-state decode.
model_->generate(ids, /*max_new_tokens=*/8, greedy, [](int) { return true; });

int token_count = 0;
auto start = std::chrono::steady_clock::now();
auto result = model_->generate(ids, /*max_new_tokens=*/128, greedy,
[&](int /*tok*/) { ++token_count; return true; });
double elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start).count();

ASSERT_TRUE(result.has_value()) << result.error().message;
ASSERT_GT(token_count, 0) << "Model produced no tokens";

double tok_s = token_count * 1000.0 / elapsed_ms;
std::cout << "Gemma 3 1B throughput: " << tok_s << " tok/s ("
<< token_count << " tokens in " << elapsed_ms << " ms)" << std::endl;

// Baseline (debug build, warmed): ~60 tok/s. Floor = baseline / 2.
EXPECT_GE(tok_s, 30.0) << "throughput regression: " << tok_s << " tok/s";
}

} // namespace
39 changes: 39 additions & 0 deletions compute/tests/compute/test_llama3_integration.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <gtest/gtest.h>
#include "compute/model/llama_model.h"
#include "compute/core/compute_backend.h"
#include <chrono>
#include <filesystem>
#include <iostream>
#include <string>
Expand Down Expand Up @@ -156,6 +157,44 @@ TEST_F(Llama3IntegrationTest, GenerateCapitalOfFrance) {
<< "Expected 'Paris' in output, got: \"" << decoded_so_far << "\"";
}

TEST_F(Llama3IntegrationTest, GenerateThroughput) {
const std::string prompt =
"<|begin_of_text|>"
"<|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful assistant.<|eot_id|>\n"
"<|start_header_id|>user<|end_header_id|>\n\n"
"Write a detailed paragraph about the history of France, "
"including the French Revolution, Napoleon, and the World Wars."
"<|eot_id|>\n"
"<|start_header_id|>assistant<|end_header_id|>\n\n";

auto token_ids = inference_->tokenizer().encode(prompt, /*add_special_tokens=*/false);
ASSERT_FALSE(token_ids.empty());

SamplingParams greedy;
greedy.temperature = 0.0f;

// Warmup: trigger mx::compile before measuring steady-state decode.
inference_->generate(token_ids, /*max_new_tokens=*/8, greedy, [](int) { return true; });

int token_count = 0;
auto start = std::chrono::steady_clock::now();
auto result = inference_->generate(token_ids, /*max_new_tokens=*/128, greedy,
[&](int /*tok*/) { ++token_count; return true; });
double elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start).count();

ASSERT_TRUE(result.has_value()) << result.error().message;
ASSERT_GT(token_count, 0) << "Model produced no tokens";

double tok_s = token_count * 1000.0 / elapsed_ms;
std::cout << "Llama-3.1 8B throughput: " << tok_s << " tok/s ("
<< token_count << " tokens in " << elapsed_ms << " ms)" << std::endl;

// Baseline (debug build, warmed): ~68 tok/s. Floor = baseline / 2.
EXPECT_GE(tok_s, 33.0) << "throughput regression: " << tok_s << " tok/s";
}

} // namespace compute

#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED)
33 changes: 33 additions & 0 deletions compute/tests/compute/test_mistral_integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "compute/model/llama_model.h"
#include "compute/core/compute_backend.h"
#include "test_config.h"
#include <chrono>
#include <filesystem>
#include <iostream>
#include <string>
Expand Down Expand Up @@ -274,6 +275,38 @@ TEST_F(MistralIntegrationTest, GenerateCapitalOfFranceWithSampling) {
EXPECT_GT(ratio, 0.85f) << "Output contains too many non-printable chars";
}

TEST_F(MistralIntegrationTest, GenerateThroughput) {
const std::string prompt =
"[INST] Write a detailed paragraph about the history of France, "
"including the French Revolution, Napoleon, and the World Wars. [/INST]";
auto token_ids = inference_->tokenizer().encode(prompt, /*add_special_tokens=*/true);
ASSERT_FALSE(token_ids.empty());

SamplingParams greedy;
greedy.temperature = 0.0f;
greedy.top_k = 0;

// Warmup: trigger mx::compile before measuring steady-state decode.
inference_->generate(token_ids, /*max_new_tokens=*/8, greedy, [](int) { return true; });

int token_count = 0;
auto start = std::chrono::steady_clock::now();
auto result = inference_->generate(token_ids, /*max_new_tokens=*/128, greedy,
[&](int /*tok*/) { ++token_count; return true; });
double elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start).count();

ASSERT_TRUE(result.has_value()) << result.error().message;
ASSERT_GT(token_count, 0) << "Model produced no tokens";

double tok_s = token_count * 1000.0 / elapsed_ms;
std::cout << "Mistral 7B throughput: " << tok_s << " tok/s ("
<< token_count << " tokens in " << elapsed_ms << " ms)" << std::endl;

// Baseline (debug build, warmed): ~44 tok/s. Floor = baseline / 2.
EXPECT_GE(tok_s, 22.0) << "throughput regression: " << tok_s << " tok/s";
}

} // namespace compute

#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED)
35 changes: 35 additions & 0 deletions compute/tests/compute/test_qwen3_moe_integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "compute/model/language_model.h"
#include "compute/core/compute_backend.h"
#include "test_config.h"
#include <chrono>
#include <filesystem>
#include <iostream>
#include <string>
Expand Down Expand Up @@ -101,6 +102,40 @@ TEST_F(Qwen3MoeIntegrationTest, GenerateCapitalOfFrance) {
EXPECT_TRUE(mentions_paris) << "Expected Paris, got: \"" << decoded << "\"";
}

TEST_F(Qwen3MoeIntegrationTest, GenerateThroughput) {
const std::string prompt =
"<|im_start|>user\n"
"What is the capital of France?<|im_end|>\n"
"<|im_start|>assistant\n";

auto token_ids = model_->tokenizer().encode(prompt, /*add_special_tokens=*/false);
ASSERT_FALSE(token_ids.empty());

SamplingParams greedy;
greedy.temperature = 0.0f;
greedy.top_k = 0;

// Warmup: trigger mx::compile before measuring steady-state decode.
model_->generate(token_ids, /*max_new_tokens=*/8, greedy, [](int) { return true; });

int token_count = 0;
auto start = std::chrono::steady_clock::now();
auto result = model_->generate(token_ids, /*max_new_tokens=*/128, greedy,
[&](int /*tok*/) { ++token_count; return true; });
double elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - start).count();

ASSERT_TRUE(result.has_value()) << result.error().message;
ASSERT_GT(token_count, 0) << "Model produced no tokens";

double tok_s = token_count * 1000.0 / elapsed_ms;
std::cout << "Qwen3 MoE 30B-A3B throughput: " << tok_s << " tok/s ("
<< token_count << " tokens in " << elapsed_ms << " ms)" << std::endl;

// Baseline (debug build, warmed): ~23 tok/s. Floor = baseline / 2.
EXPECT_GE(tok_s, 11.0) << "throughput regression: " << tok_s << " tok/s";
}

} // namespace compute

#endif // defined(__APPLE__) && defined(__aarch64__) && defined(MLX_BACKEND_ENABLED)
Loading