diff --git a/pyproject.toml b/pyproject.toml index a45e940..2848209 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "tomli-w>=1.0", "rich>=13.0", "huggingface-hub>=0.20", + "gguf>=0.10", ] [project.optional-dependencies] diff --git a/src/arc_llama/cli.py b/src/arc_llama/cli.py index a0ec3dd..994d98e 100644 --- a/src/arc_llama/cli.py +++ b/src/arc_llama/cli.py @@ -309,16 +309,21 @@ def list_models(ctx: click.Context) -> None: table.add_column("Port") table.add_column("ctx") table.add_column("KV") + table.add_column("Spec") table.add_column("Path") for m in cfg.models: r = m.recipe or {} kv = f"{r.get('cache_type_k','f16')}/{r.get('cache_type_v','f16')}" + spec = r.get("spec_type", "—") + if r.get("ubatch_size"): + spec += f" (ub={r['ubatch_size']})" table.add_row( m.name, m.gpu_pci_slot, str(m.port), str(r.get("ctx", "?")), kv, + spec, m.path, ) console.print(table) @@ -353,6 +358,14 @@ def list_models(ctx: click.Context) -> None: help="KV-class hint, used for VRAM estimation.", ) @click.option("--alias", "aliases", multiple=True, help="Extra match strings (repeatable).") +@click.option( + "--spec-type", "spec_type", default=None, + help="Speculative decoding type (e.g. draft-mtp). Auto-detected for MTP models.", +) +@click.option( + "--ubatch-size", "ubatch_size", type=int, default=None, + help="Ubatch size (-ub). Auto-set to 8 for MTP models.", +) @click.option( "--from-hf", is_flag=True, help="Treat SOURCE as a Hugging Face spec (`org/repo` or `org/repo:Q4_K_M`).", @@ -370,6 +383,8 @@ def add( display_name: str, kv_class: str, aliases: tuple[str, ...], + spec_type: str | None, + ubatch_size: int | None, from_hf: bool, hf_token: str | None, ) -> None: @@ -419,6 +434,10 @@ def add( if kv_type is not None: overrides["cache_type_k"] = kv_type overrides["cache_type_v"] = kv_type + if spec_type is not None: + overrides["spec_type"] = spec_type + if ubatch_size is not None: + overrides["ubatch_size"] = ubatch_size try: mc = add_local_model( @@ -596,6 +615,24 @@ def _on_signal(signum: int, _frame) -> None: # noqa: ANN001 uvicorn.run(app, host=cfg.server.host, port=cfg.server.port, log_level="info") +# =========================================================================== +# mtp-info +# =========================================================================== + +@cli.command("mtp-info") +@click.argument("path", type=click.Path(exists=True, dir_okay=False, path_type=Path)) +def mtp_info_cmd(path: Path) -> None: + """Inspect a GGUF file for MTP-relevant metadata.""" + from arc_llama.gguf_meta import mtp_info + info = mtp_info(path) + console.print(f"[bold]GGUF:[/bold] {info['path']}") + console.print(f" architecture: {info['architecture']}") + console.print(f" block_count: {info['block_count']}") + console.print(f" nextn_predict_layers: {info['nextn_predict_layers']}") + console.print(f" has_mtp_heads: {info['has_mtp_heads']}") + console.print(f" is_hybrid_ssm: {info['is_hybrid_ssm']}") + + # =========================================================================== # systemd # =========================================================================== diff --git a/src/arc_llama/config.py b/src/arc_llama/config.py index 3b1b7be..b9e40f9 100644 --- a/src/arc_llama/config.py +++ b/src/arc_llama/config.py @@ -140,6 +140,8 @@ def launch_recipe(self) -> LaunchRecipe: temp=r.get("temp"), top_p=r.get("top_p"), top_k=r.get("top_k"), + spec_type=r.get("spec_type"), + ubatch_size=r.get("ubatch_size"), extra_flags=list(r.get("extra_flags", [])), ) diff --git a/src/arc_llama/gguf_meta.py b/src/arc_llama/gguf_meta.py new file mode 100644 index 0000000..f69d8d6 --- /dev/null +++ b/src/arc_llama/gguf_meta.py @@ -0,0 +1,112 @@ +"""Lightweight GGUF metadata peeking for arc-llama. + +Uses llama.cpp's `gguf-py` to read key metadata (architecture, +nextn_predict_layers, block_count) without loading tensor data. +""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +import gguf # type: ignore[import-untyped] + +log = logging.getLogger("arc_llama.gguf_meta") + + +# --------------------------------------------------------------------------- +# Metadata reading +# --------------------------------------------------------------------------- + +def read_gguf_meta(path: Path | str) -> dict[str, Any]: + """Read a GGUF file and return a small dict of metadata we care about. + + Returns empty dict if the file can't be read. + """ + p = Path(path) + if not p.exists(): + return {} + try: + reader = gguf.GGUFReader(p) + except Exception as exc: + log.debug("gguf read failed for %s: %s", p, exc) + return {} + + meta: dict[str, Any] = {} + arch_field = reader.get_field(gguf.Keys.General.ARCHITECTURE) + if arch_field is not None: + meta["architecture"] = str(arch_field.contents()) + + arch = meta.get("architecture", "") + if arch: + # nextn_predict_layers is the definitive MTP signal in the pr-22673 branch + nextn_field = reader.get_field(f"{arch}.nextn_predict_layers") + if nextn_field is not None: + try: + meta["nextn_predict_layers"] = int(nextn_field.contents()) + except (TypeError, ValueError): + pass + # layer count, useful for diagnostics + n_layer_field = reader.get_field(f"{arch}.block_count") + if n_layer_field is not None: + try: + meta["block_count"] = int(n_layer_field.contents()) + except (TypeError, ValueError): + pass + + return meta + + +# --------------------------------------------------------------------------- +# MTP detection +# --------------------------------------------------------------------------- + +def has_mtp_heads(path: Path | str) -> bool: + """Return True if the GGUF at *path* contains real MTP heads. + + Checks metadata — not the filename. The canonical signal is + ``nextn_predict_layers > 0`` in the GGUF kv store. Stand-alone MTP-only + GGUFs (architecture ``qwen35_mtp`` / ``qwen35moe_mtp``) also count. + """ + meta = read_gguf_meta(path) + if not meta: + return False + arch = meta.get("architecture", "") + if arch in ("qwen35_mtp", "qwen35moe_mtp"): + return True + nextn = meta.get("nextn_predict_layers", 0) + if isinstance(nextn, int) and nextn > 0: + return True + return False + + +def is_hybrid_ssm(path: Path | str) -> bool: + """Return True if the GGUF is a hybrid SSM+attention architecture. + + Today this means the Qwen3.5/3.6 family (dense or MoE) which use GDN + (gated delta net) layers — a recurrent state-space-like attention + hybrid. These architectures are known to perform poorly with SYCL MTP + speculative decoding on Xe2 (Battlemage, Lunar Lake). + """ + meta = read_gguf_meta(path) + arch = meta.get("architecture", "") + # qwen35 and qwen35moe (with or without the _mtp suffix) are the + # known hybrid SSM+attention families today. + return arch.startswith("qwen35") + + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- + +def mtp_info(path: Path | str) -> dict[str, Any]: + """Return a human-readable summary of MTP-relevant metadata.""" + meta = read_gguf_meta(path) + return { + "path": str(path), + "architecture": meta.get("architecture", "unknown"), + "block_count": meta.get("block_count", "unknown"), + "nextn_predict_layers": meta.get("nextn_predict_layers", 0), + "has_mtp_heads": has_mtp_heads(path), + "is_hybrid_ssm": is_hybrid_ssm(path), + } diff --git a/src/arc_llama/launcher.py b/src/arc_llama/launcher.py index fd2026a..3c38b69 100644 --- a/src/arc_llama/launcher.py +++ b/src/arc_llama/launcher.py @@ -24,6 +24,7 @@ from arc_llama.arch import Arch, ArchProfile, profile_for from arc_llama.config import Config, GPUConfig, ModelConfig +from arc_llama.gguf_meta import has_mtp_heads, is_hybrid_ssm log = logging.getLogger("arc_llama.launcher") @@ -110,6 +111,41 @@ def build_plan( profile = profile_for(arch) env = build_env(profile, gpu.sycl_index) recipe = model.launch_recipe() + + # --- MTP head detection & safety wiring --- + mtp_present = has_mtp_heads(model.path) + hybrid_ssm = is_hybrid_ssm(model.path) + + # 1. Auto-inject -ub 8 for MTP models (prevents SSM compute-buffer OOM). + if mtp_present and recipe.ubatch_size is None: + recipe.ubatch_size = 8 + log.info( + "[%s] MTP heads detected; auto-setting ubatch_size=8", + model.name, + ) + + # 2. Warn if the user explicitly asked for draft-mtp on a model that + # does not actually contain MTP heads. + if recipe.spec_type == "draft-mtp" and not mtp_present: + log.warning( + "[%s] recipe.spec_type='draft-mtp' but GGUF has no MTP heads " + "(nextn_predict_layers == 0). Speculative decoding will likely " + "degenerate or crash.", + model.name, + ) + + # 3. Backend recommendation for hybrid SSM + MTP on Xe2 (Battlemage, + # Lunar Lake). GDN sequential state passes make SYCL MTP net-negative. + if mtp_present and hybrid_ssm and arch in (Arch.BATTLEMAGE, Arch.LUNAR_LAKE): + log.info( + "[%s] Hybrid SSM+attention model with MTP heads on Xe2 (%s): " + "SYCL MTP speculative decoding is net-negative here because GDN " + "layers force serial state passes. Consider a Vulkan backend " + "build for ~+9%% throughput with --spec-type draft-mtp.", + model.name, + arch.value, + ) + argv: list[str] = [ cfg.paths.llama_server, "-m", model.path, diff --git a/src/arc_llama/models.py b/src/arc_llama/models.py index 7898910..b274b99 100644 --- a/src/arc_llama/models.py +++ b/src/arc_llama/models.py @@ -20,6 +20,7 @@ Config, ModelConfig, ) +from arc_llama.gguf_meta import has_mtp_heads from arc_llama.recipes import default_recipe log = logging.getLogger("arc_llama.models") @@ -127,6 +128,14 @@ def add_local_model( "cache_type_k": recipe.cache_type_k.value, "cache_type_v": recipe.cache_type_v.value, } + # Auto-enable draft-mtp for models that actually carry MTP heads. + if has_mtp_heads(p): + recipe_dict["spec_type"] = "draft-mtp" + recipe_dict["ubatch_size"] = 8 + log.info( + "model %s has MTP heads; auto-enabling spec_type=draft-mtp, ubatch_size=8", + name, + ) if recipe_overrides: recipe_dict.update(recipe_overrides) mc = ModelConfig( @@ -304,6 +313,21 @@ def register_discovered( model_file_mb=rp.stat().st_size // (1024 * 1024), kv_class=kv_class, ) + recipe_dict: dict[str, Any] = { + "n_gpu_layers": recipe.n_gpu_layers, + "ctx": recipe.ctx, + "parallel": recipe.parallel, + "cache_type_k": recipe.cache_type_k.value, + "cache_type_v": recipe.cache_type_v.value, + } + # Auto-enable draft-mtp for discovered models that carry MTP heads. + if has_mtp_heads(rp): + recipe_dict["spec_type"] = "draft-mtp" + recipe_dict["ubatch_size"] = 8 + log.info( + "discovered %s has MTP heads; auto-enabling spec_type=draft-mtp, ubatch_size=8", + rp.name, + ) name = short_name_from_path(rp, used_names) used_names.add(name) port = port_start @@ -317,13 +341,7 @@ def register_discovered( gpu_pci_slot=gpu_pci_slot, display_name=infer_display_name(rp), kv_class=kv_class, - recipe={ - "n_gpu_layers": recipe.n_gpu_layers, - "ctx": recipe.ctx, - "parallel": recipe.parallel, - "cache_type_k": recipe.cache_type_k.value, - "cache_type_v": recipe.cache_type_v.value, - }, + recipe=recipe_dict, aliases=[rp.name], ) cfg.models.append(mc) diff --git a/src/arc_llama/recipes.py b/src/arc_llama/recipes.py index 74e909b..306d4f6 100644 --- a/src/arc_llama/recipes.py +++ b/src/arc_llama/recipes.py @@ -47,6 +47,10 @@ class LaunchRecipe: temp: float | None = None top_p: float | None = None top_k: int | None = None + spec_type: str | None = None + """Speculative decoding type, e.g. 'draft-mtp'.""" + ubatch_size: int | None = None + """Ubatch size (-ub). Auto-set to 8 for MTP models to avoid SSM compute-buffer OOM.""" extra_flags: list[str] = field(default_factory=list) """Anything else the user wants appended to the command line verbatim.""" @@ -66,6 +70,10 @@ def to_argv(self) -> list[str]: argv += ["--top-p", str(self.top_p)] if self.top_k is not None: argv += ["--top-k", str(self.top_k)] + if self.spec_type: + argv += ["--spec-type", self.spec_type] + if self.ubatch_size is not None: + argv += ["-ub", str(self.ubatch_size)] argv += list(self.extra_flags) return argv diff --git a/src/arc_llama/server.py b/src/arc_llama/server.py index 21f9da1..3b7effb 100644 --- a/src/arc_llama/server.py +++ b/src/arc_llama/server.py @@ -175,7 +175,8 @@ async def admin_edit_model(name: str, request: Request) -> dict: """Update a model's recipe in-place. Body is a partial recipe dict — only provided fields change. Recognised - fields: `ctx`, `cache_type_k`, `cache_type_v`, `parallel`, `kv_class`. + fields: `ctx`, `cache_type_k`, `cache_type_v`, `parallel`, `kv_class`, + `spec_type`, `ubatch_size`. If the model is currently loaded, the server is stopped first; callers decide whether to reload it afterwards via /admin/load. """ @@ -233,6 +234,19 @@ async def admin_edit_model(name: str, request: Request) -> dict: ) model.kv_class = v changed.append("kv_class") + if "spec_type" in body: + v = str(body["spec_type"]) + recipe["spec_type"] = v + changed.append("spec_type") + if "ubatch_size" in body: + try: + ub = int(body["ubatch_size"]) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="ubatch_size must be an integer") + if not (1 <= ub <= 4096): + raise HTTPException(status_code=400, detail="ubatch_size must be 1..4096") + recipe["ubatch_size"] = ub + changed.append("ubatch_size") if not changed: raise HTTPException(status_code=400, detail="no recognised fields to edit") model.recipe = recipe diff --git a/tests/test_gguf_meta.py b/tests/test_gguf_meta.py new file mode 100644 index 0000000..52e65bd --- /dev/null +++ b/tests/test_gguf_meta.py @@ -0,0 +1,65 @@ +"""Tests for arc_llama.gguf_meta — MTP head detection from GGUF metadata.""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from arc_llama.gguf_meta import has_mtp_heads, is_hybrid_ssm, mtp_info, read_gguf_meta + +# Real GGUFs on the host's storage, discovered during exploration. +_BASE_QWEN = Path("/mnt/storage/models/qwen3.6-27b/Qwen_Qwen3.6-27B-Q4_K_M.gguf") +_MTP_QWEN = Path("/mnt/storage/models/qwen3.6-27b/Qwen3.6-27B-MTP-UD-Q4_K_XL.gguf") + + +def _have_fixtures() -> bool: + return _BASE_QWEN.exists() and _MTP_QWEN.exists() + + +class TestReadGgufMeta: + @pytest.mark.skipif(not _have_fixtures(), reason="fixture GGUFs not on disk") + def test_reads_architecture(self): + meta = read_gguf_meta(_BASE_QWEN) + assert meta["architecture"] == "qwen35" + + @pytest.mark.skipif(not _have_fixtures(), reason="fixture GGUFs not on disk") + def test_reads_nextn_and_layers(self): + meta = read_gguf_meta(_MTP_QWEN) + assert meta["nextn_predict_layers"] == 1 + assert meta["block_count"] == 65 + + def test_missing_file_returns_empty(self): + meta = read_gguf_meta("/nonexistent/file.gguf") + assert meta == {} + + +class TestMtpDetection: + @pytest.mark.skipif(not _have_fixtures(), reason="fixture GGUFs not on disk") + def test_base_qwen_has_no_mtp(self): + assert has_mtp_heads(_BASE_QWEN) is False + + @pytest.mark.skipif(not _have_fixtures(), reason="fixture GGUFs not on disk") + def test_mtp_qwen_has_mtp(self): + assert has_mtp_heads(_MTP_QWEN) is True + + def test_missing_file_is_false(self): + assert has_mtp_heads("/nonexistent.gguf") is False + + +class TestHybridSsmDetection: + @pytest.mark.skipif(not _have_fixtures(), reason="fixture GGUFs not on disk") + def test_qwen_is_hybrid_ssm(self): + assert is_hybrid_ssm(_BASE_QWEN) is True + assert is_hybrid_ssm(_MTP_QWEN) is True + + def test_missing_file_is_false(self): + assert is_hybrid_ssm("/nonexistent.gguf") is False + + +class TestMtpInfo: + @pytest.mark.skipif(not _have_fixtures(), reason="fixture GGUFs not on disk") + def test_summary_keys(self): + info = mtp_info(_MTP_QWEN) + assert info["has_mtp_heads"] is True + assert info["is_hybrid_ssm"] is True + assert info["nextn_predict_layers"] == 1 diff --git a/tests/test_launcher.py b/tests/test_launcher.py new file mode 100644 index 0000000..24a8544 --- /dev/null +++ b/tests/test_launcher.py @@ -0,0 +1,227 @@ +"""Tests for arc_llama.launcher — env construction, command-line building.""" +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from arc_llama.arch import Arch +from arc_llama.config import Config, GPUConfig, ModelConfig +from arc_llama.launcher import LlamaServer, build_env, build_plan + + +class TestBuildEnv: + def test_sets_device_selector(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(os, "environ", {"PATH": "/usr/bin"}) + from arc_llama.arch import profile_for + profile = profile_for(Arch.BATTLEMAGE) + env = build_env(profile, sycl_index=2) + assert env["ONEAPI_DEVICE_SELECTOR"] == "level_zero:2" + + def test_strips_bad_vars(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(os, "environ", { + "PATH": "/usr/bin", + "GGML_SYCL_DISABLE_OPT": "1", + "SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS": "1", + }) + from arc_llama.arch import profile_for + profile = profile_for(Arch.BATTLEMAGE) + env = build_env(profile, sycl_index=0) + assert "GGML_SYCL_DISABLE_OPT" not in env + assert "SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS" not in env + + def test_applies_arch_env(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(os, "environ", {"PATH": "/usr/bin"}) + from arc_llama.arch import profile_for + profile = profile_for(Arch.BATTLEMAGE) + env = build_env(profile, sycl_index=0) + assert env["SYCL_CACHE_PERSISTENT"] == "0" + assert env["ZES_ENABLE_SYSMAN"] == "1" + + +class TestBuildPlan: + def test_includes_model_and_port(self): + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0") + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage") + plan = build_plan(cfg, model, gpu) + assert plan.argv[0] == "/bin/llama-server" + assert "-m" in plan.argv + assert "/m.gguf" in plan.argv + assert "--port" in plan.argv + assert "18080" in plan.argv + assert plan.backend_url == "http://127.0.0.1:18080" + + def test_uses_custom_host(self): + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0") + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage") + plan = build_plan(cfg, model, gpu, host="0.0.0.0") + assert plan.backend_url == "http://0.0.0.0:18080" + assert "--host" in plan.argv + assert "0.0.0.0" in plan.argv + + def test_env_has_device_selector(self): + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0") + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=2, arch="battlemage") + plan = build_plan(cfg, model, gpu) + assert plan.env["ONEAPI_DEVICE_SELECTOR"] == "level_zero:2" + + def test_fake_path_no_mtp_no_ub_injected(self): + """Non-existent GGUF → no MTP heads → -ub should NOT appear.""" + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0") + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage") + plan = build_plan(cfg, model, gpu) + assert "-ub" not in plan.argv + + +# Real GGUF fixtures for MTP integration tests. +_MTP_QWEN = Path("/mnt/storage/models/qwen3.6-27b/Qwen3.6-27B-MTP-UD-Q4_K_XL.gguf") + + +def _have_mtp_fixture() -> bool: + return _MTP_QWEN.exists() + + +class TestBuildPlanMtp: + @pytest.mark.skipif(not _have_mtp_fixture(), reason="MTP fixture GGUF not on disk") + def test_auto_injects_ub_8_for_mtp(self): + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig( + name="mtp-qwen", + path=str(_MTP_QWEN), + port=18080, + gpu_pci_slot="00:00.0", + ) + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage") + plan = build_plan(cfg, model, gpu) + assert "-ub" in plan.argv + idx = plan.argv.index("-ub") + assert plan.argv[idx + 1] == "8" + + @pytest.mark.skipif(not _have_mtp_fixture(), reason="MTP fixture GGUF not on disk") + def test_user_ubatch_size_not_overridden(self): + """If the recipe already has ubatch_size, don't stomp it.""" + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig( + name="mtp-qwen", + path=str(_MTP_QWEN), + port=18080, + gpu_pci_slot="00:00.0", + recipe={"ubatch_size": 16}, + ) + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage") + plan = build_plan(cfg, model, gpu) + idx = plan.argv.index("-ub") + assert plan.argv[idx + 1] == "16" + + @pytest.mark.skipif(not _have_mtp_fixture(), reason="MTP fixture GGUF not on disk") + def test_mtp_on_lunar_lake_also_gets_ub_8(self): + """Xe2 iGPU (Lunar Lake) is the same generation as Battlemage — same fix.""" + cfg = Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()) + model = ModelConfig( + name="mtp-qwen", + path=str(_MTP_QWEN), + port=18080, + gpu_pci_slot="00:00.0", + ) + gpu = GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="lunar_lake") + plan = build_plan(cfg, model, gpu) + assert "-ub" in plan.argv + idx = plan.argv.index("-ub") + assert plan.argv[idx + 1] == "8" + + +class TestLlamaServerLifecycle: + def test_not_running_before_start(self): + plan = build_plan( + Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()), + ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0"), + GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage"), + ) + srv = LlamaServer(plan) + assert srv.is_running is False + + def test_start_log_dir_creates_parents(self, tmp_path: Path): + plan = build_plan( + Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()), + ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0"), + GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage"), + ) + srv = LlamaServer(plan) + log_dir = tmp_path / "deep" / "logs" + # We can't actually start a fake binary without mocking Popen, + # but we can at least assert the log_dir path would be used. + assert not log_dir.exists() + # Mock Popen to avoid actually spawning + import subprocess + original_popen = subprocess.Popen + called = {} + + def _fake_popen(*args, **kwargs): + called["args"] = args + called["kwargs"] = kwargs + class FakeProc: + pid = 12345 + def poll(self): + return None + return FakeProc() + + subprocess.Popen = _fake_popen + try: + srv.start(log_dir=log_dir) + assert log_dir.exists() + finally: + subprocess.Popen = original_popen + assert srv.is_running is True + + @pytest.mark.asyncio + async def test_wait_ready_true_when_healthy(self, monkeypatch: pytest.MonkeyPatch): + plan = build_plan( + Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()), + ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0"), + GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage"), + ) + srv = LlamaServer(plan) + # Pretend it's running + srv.process = type("P", (), {"poll": lambda self: None, "pid": 1})() + srv.started_at = 0.0 + + import httpx + original_get = httpx.AsyncClient.get + + async def _fake_get(self, url): + if "/health" in url: + return type("R", (), {"status_code": 200, "json": lambda self: {"status": "ok"}})() + return type("R", (), {"status_code": 404})() + + monkeypatch.setattr(httpx.AsyncClient, "get", _fake_get) + ready = await srv.wait_ready(timeout=2.0) + assert ready is True + + @pytest.mark.asyncio + async def test_wait_ready_false_on_crash(self): + plan = build_plan( + Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()), + ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0"), + GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage"), + ) + srv = LlamaServer(plan) + # Simulate crashed process + srv.process = type("P", (), {"poll": lambda self: 1})() + ready = await srv.wait_ready(timeout=1.0) + assert ready is False + + def test_stop_idempotent(self): + plan = build_plan( + Config(paths=type("P", (), {"llama_server": "/bin/llama-server"})()), + ModelConfig(name="m", path="/m.gguf", port=18080, gpu_pci_slot="00:00.0"), + GPUConfig(pci_slot="00:00.0", sycl_index=0, arch="battlemage"), + ) + srv = LlamaServer(plan) + # Should not raise when not running + srv.stop() + assert srv.is_running is False diff --git a/tests/test_recipes.py b/tests/test_recipes.py new file mode 100644 index 0000000..c05a0af --- /dev/null +++ b/tests/test_recipes.py @@ -0,0 +1,194 @@ +"""Tests for arc_llama.recipes — VRAM math, recipe generation, KV sizing.""" +from __future__ import annotations + +import pytest + +from arc_llama.arch import Arch +from arc_llama.recipes import ( + DEFAULT_CTX_CAP, + KVCacheType, + LaunchRecipe, + default_recipe, + estimate_kv_bytes, + suggest_ctx, +) + + +class TestEstimateKvBytes: + def test_default_f16(self): + # 70 KiB/token at f16 for default class + assert estimate_kv_bytes(1024, KVCacheType.F16, "default") == 1024 * 70 * 1024 + + def test_moe_a3b_q8(self): + # 20 KiB/token f16, q8_0 halves it + assert estimate_kv_bytes(4096, KVCacheType.Q8_0, "moe_a3b") == 4096 * 20 * 1024 * 0.5 + + def test_gemma_swa_f16(self): + assert estimate_kv_bytes(8192, KVCacheType.F16, "gemma_swa") == 8192 * 16 * 1024 + + def test_unknown_class_fallback(self): + # Unknown kv_class falls back to default f16 per-token + assert estimate_kv_bytes(1000, KVCacheType.F16, "no_such_class") == 1000 * 70 * 1024 + + +class TestSuggestCtx: + def test_basic_fit(self): + # 24 GB VRAM, 4 GB model, q8_0 KV + ctx = suggest_ctx( + vram_mb=24 * 1024, + model_file_mb=4 * 1024, + kv_type=KVCacheType.Q8_0, + ) + assert ctx >= 4096 + assert ctx <= DEFAULT_CTX_CAP + assert ctx % 4096 == 0 + + def test_oom_returns_minimum(self): + # Model bigger than VRAM + ctx = suggest_ctx( + vram_mb=1024, + model_file_mb=2048, + kv_type=KVCacheType.F16, + ) + assert ctx == 4096 + + def test_ctx_cap_clamps(self): + # Ridiculous VRAM should still be capped + ctx = suggest_ctx( + vram_mb=1024 * 1024, # 1 TB + model_file_mb=1, + kv_type=KVCacheType.Q8_0, + ctx_cap=131072, + ) + assert ctx == 131072 + + def test_f32_doubles_bytes(self): + # Tight VRAM so the difference is visible below the cap + f16_ctx = suggest_ctx( + vram_mb=8 * 1024, + model_file_mb=6 * 1024, + kv_type=KVCacheType.F16, + ctx_cap=1_000_000, + ) + f32_ctx = suggest_ctx( + vram_mb=8 * 1024, + model_file_mb=6 * 1024, + kv_type=KVCacheType.F32, + ctx_cap=1_000_000, + ) + assert f32_ctx < f16_ctx + + def test_q4_saves_more_than_q8(self): + # Tight VRAM so the difference is visible + q8_ctx = suggest_ctx( + vram_mb=8 * 1024, + model_file_mb=6 * 1024, + kv_type=KVCacheType.Q8_0, + ctx_cap=1_000_000, + ) + q4_ctx = suggest_ctx( + vram_mb=8 * 1024, + model_file_mb=6 * 1024, + kv_type=KVCacheType.Q4_0, + ctx_cap=1_000_000, + ) + assert q4_ctx > q8_ctx + + +class TestDefaultRecipe: + def test_battlemage_prefers_q8(self): + r = default_recipe( + Arch.BATTLEMAGE, + vram_mb=24 * 1024, + model_file_mb=4 * 1024, + ) + assert r.cache_type_k == KVCacheType.Q8_0 + assert r.cache_type_v == KVCacheType.Q8_0 + assert r.n_gpu_layers == 999 + + def test_unknown_arch_is_conservative(self): + r = default_recipe( + Arch.UNKNOWN, + vram_mb=8 * 1024, + model_file_mb=4 * 1024, + ) + assert r.ctx >= 4096 + assert r.cache_type_k == KVCacheType.Q8_0 + + def test_no_prefer_q8_gives_f16(self): + r = default_recipe( + Arch.BATTLEMAGE, + vram_mb=24 * 1024, + model_file_mb=4 * 1024, + prefer_q8_kv=False, + ) + assert r.cache_type_k == KVCacheType.F16 + assert r.cache_type_v == KVCacheType.F16 + + def test_moe_class_increases_ctx(self): + # MoE has smaller per-token KV, so same VRAM → larger ctx + dense = default_recipe( + Arch.BATTLEMAGE, + vram_mb=8 * 1024, + model_file_mb=6 * 1024, + kv_class="default", + prefer_q8_kv=False, + ) + moe = default_recipe( + Arch.BATTLEMAGE, + vram_mb=8 * 1024, + model_file_mb=6 * 1024, + kv_class="moe_a3b", + prefer_q8_kv=False, + ) + assert moe.ctx > dense.ctx + + +class TestLaunchRecipeArgv: + def test_all_fields_present(self): + r = LaunchRecipe( + n_gpu_layers=999, + ctx=32768, + parallel=2, + cache_type_k=KVCacheType.Q8_0, + cache_type_v=KVCacheType.Q5_1, + threads=8, + temp=0.7, + top_p=0.9, + top_k=40, + extra_flags=["--reasoning", "off"], + ) + argv = r.to_argv() + assert argv == [ + "-ngl", "999", + "-c", "32768", + "--parallel", "2", + "--cache-type-k", "q8_0", + "--cache-type-v", "q5_1", + "-t", "8", + "--temp", "0.7", + "--top-p", "0.9", + "--top-k", "40", + "--reasoning", "off", + ] + + def test_optional_fields_omitted(self): + r = LaunchRecipe() + argv = r.to_argv() + assert "-t" not in argv + assert "--temp" not in argv + assert "--top-p" not in argv + assert "--top-k" not in argv + assert "--spec-type" not in argv + assert "-ub" not in argv + + def test_spec_type_and_ubatch_size(self): + r = LaunchRecipe( + spec_type="draft-mtp", + ubatch_size=8, + ) + argv = r.to_argv() + assert "--spec-type" in argv + assert argv[argv.index("--spec-type") + 1] == "draft-mtp" + assert "-ub" in argv + assert argv[argv.index("-ub") + 1] == "8"