Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"tomli-w>=1.0",
"rich>=13.0",
"huggingface-hub>=0.20",
"gguf>=0.10",
]

[project.optional-dependencies]
Expand Down
37 changes: 37 additions & 0 deletions src/arc_llama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,21 @@ def list_models(ctx: click.Context) -> None:
table.add_column("Port")
table.add_column("ctx")
table.add_column("KV")
table.add_column("Spec")
table.add_column("Path")
for m in cfg.models:
r = m.recipe or {}
kv = f"{r.get('cache_type_k','f16')}/{r.get('cache_type_v','f16')}"
spec = r.get("spec_type", "—")
if r.get("ubatch_size"):
spec += f" (ub={r['ubatch_size']})"
table.add_row(
m.name,
m.gpu_pci_slot,
str(m.port),
str(r.get("ctx", "?")),
kv,
spec,
m.path,
)
console.print(table)
Expand Down Expand Up @@ -353,6 +358,14 @@ def list_models(ctx: click.Context) -> None:
help="KV-class hint, used for VRAM estimation.",
)
@click.option("--alias", "aliases", multiple=True, help="Extra match strings (repeatable).")
@click.option(
"--spec-type", "spec_type", default=None,
help="Speculative decoding type (e.g. draft-mtp). Auto-detected for MTP models.",
)
@click.option(
"--ubatch-size", "ubatch_size", type=int, default=None,
help="Ubatch size (-ub). Auto-set to 8 for MTP models.",
)
@click.option(
"--from-hf", is_flag=True,
help="Treat SOURCE as a Hugging Face spec (`org/repo` or `org/repo:Q4_K_M`).",
Expand All @@ -370,6 +383,8 @@ def add(
display_name: str,
kv_class: str,
aliases: tuple[str, ...],
spec_type: str | None,
ubatch_size: int | None,
from_hf: bool,
hf_token: str | None,
) -> None:
Expand Down Expand Up @@ -419,6 +434,10 @@ def add(
if kv_type is not None:
overrides["cache_type_k"] = kv_type
overrides["cache_type_v"] = kv_type
if spec_type is not None:
overrides["spec_type"] = spec_type
if ubatch_size is not None:
overrides["ubatch_size"] = ubatch_size

try:
mc = add_local_model(
Expand Down Expand Up @@ -596,6 +615,24 @@ def _on_signal(signum: int, _frame) -> None: # noqa: ANN001
uvicorn.run(app, host=cfg.server.host, port=cfg.server.port, log_level="info")


# ===========================================================================
# mtp-info
# ===========================================================================

@cli.command("mtp-info")
@click.argument("path", type=click.Path(exists=True, dir_okay=False, path_type=Path))
def mtp_info_cmd(path: Path) -> None:
"""Inspect a GGUF file for MTP-relevant metadata."""
from arc_llama.gguf_meta import mtp_info
info = mtp_info(path)
console.print(f"[bold]GGUF:[/bold] {info['path']}")
console.print(f" architecture: {info['architecture']}")
console.print(f" block_count: {info['block_count']}")
console.print(f" nextn_predict_layers: {info['nextn_predict_layers']}")
console.print(f" has_mtp_heads: {info['has_mtp_heads']}")
console.print(f" is_hybrid_ssm: {info['is_hybrid_ssm']}")


# ===========================================================================
# systemd
# ===========================================================================
Expand Down
2 changes: 2 additions & 0 deletions src/arc_llama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def launch_recipe(self) -> LaunchRecipe:
temp=r.get("temp"),
top_p=r.get("top_p"),
top_k=r.get("top_k"),
spec_type=r.get("spec_type"),
ubatch_size=r.get("ubatch_size"),
extra_flags=list(r.get("extra_flags", [])),
)

Expand Down
112 changes: 112 additions & 0 deletions src/arc_llama/gguf_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Lightweight GGUF metadata peeking for arc-llama.

Uses llama.cpp's `gguf-py` to read key metadata (architecture,
nextn_predict_layers, block_count) without loading tensor data.
"""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

import gguf # type: ignore[import-untyped]

log = logging.getLogger("arc_llama.gguf_meta")


# ---------------------------------------------------------------------------
# Metadata reading
# ---------------------------------------------------------------------------

def read_gguf_meta(path: Path | str) -> dict[str, Any]:
"""Read a GGUF file and return a small dict of metadata we care about.

Returns empty dict if the file can't be read.
"""
p = Path(path)
if not p.exists():
return {}
try:
reader = gguf.GGUFReader(p)
except Exception as exc:
log.debug("gguf read failed for %s: %s", p, exc)
return {}

meta: dict[str, Any] = {}
arch_field = reader.get_field(gguf.Keys.General.ARCHITECTURE)
if arch_field is not None:
meta["architecture"] = str(arch_field.contents())

arch = meta.get("architecture", "")
if arch:
# nextn_predict_layers is the definitive MTP signal in the pr-22673 branch
nextn_field = reader.get_field(f"{arch}.nextn_predict_layers")
if nextn_field is not None:
try:
meta["nextn_predict_layers"] = int(nextn_field.contents())
except (TypeError, ValueError):
pass
# layer count, useful for diagnostics
n_layer_field = reader.get_field(f"{arch}.block_count")
if n_layer_field is not None:
try:
meta["block_count"] = int(n_layer_field.contents())
except (TypeError, ValueError):
pass

return meta


# ---------------------------------------------------------------------------
# MTP detection
# ---------------------------------------------------------------------------

def has_mtp_heads(path: Path | str) -> bool:
"""Return True if the GGUF at *path* contains real MTP heads.

Checks metadata — not the filename. The canonical signal is
``nextn_predict_layers > 0`` in the GGUF kv store. Stand-alone MTP-only
GGUFs (architecture ``qwen35_mtp`` / ``qwen35moe_mtp``) also count.
"""
meta = read_gguf_meta(path)
if not meta:
return False
arch = meta.get("architecture", "")
if arch in ("qwen35_mtp", "qwen35moe_mtp"):
return True
nextn = meta.get("nextn_predict_layers", 0)
if isinstance(nextn, int) and nextn > 0:
return True
return False


def is_hybrid_ssm(path: Path | str) -> bool:
"""Return True if the GGUF is a hybrid SSM+attention architecture.

Today this means the Qwen3.5/3.6 family (dense or MoE) which use GDN
(gated delta net) layers — a recurrent state-space-like attention
hybrid. These architectures are known to perform poorly with SYCL MTP
speculative decoding on Xe2 (Battlemage, Lunar Lake).
"""
meta = read_gguf_meta(path)
arch = meta.get("architecture", "")
# qwen35 and qwen35moe (with or without the _mtp suffix) are the
# known hybrid SSM+attention families today.
return arch.startswith("qwen35")


# ---------------------------------------------------------------------------
# Diagnostics
# ---------------------------------------------------------------------------

def mtp_info(path: Path | str) -> dict[str, Any]:
"""Return a human-readable summary of MTP-relevant metadata."""
meta = read_gguf_meta(path)
return {
"path": str(path),
"architecture": meta.get("architecture", "unknown"),
"block_count": meta.get("block_count", "unknown"),
"nextn_predict_layers": meta.get("nextn_predict_layers", 0),
"has_mtp_heads": has_mtp_heads(path),
"is_hybrid_ssm": is_hybrid_ssm(path),
}
36 changes: 36 additions & 0 deletions src/arc_llama/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from arc_llama.arch import Arch, ArchProfile, profile_for
from arc_llama.config import Config, GPUConfig, ModelConfig
from arc_llama.gguf_meta import has_mtp_heads, is_hybrid_ssm

log = logging.getLogger("arc_llama.launcher")

Expand Down Expand Up @@ -110,6 +111,41 @@ def build_plan(
profile = profile_for(arch)
env = build_env(profile, gpu.sycl_index)
recipe = model.launch_recipe()

# --- MTP head detection & safety wiring ---
mtp_present = has_mtp_heads(model.path)
hybrid_ssm = is_hybrid_ssm(model.path)

# 1. Auto-inject -ub 8 for MTP models (prevents SSM compute-buffer OOM).
if mtp_present and recipe.ubatch_size is None:
recipe.ubatch_size = 8
log.info(
"[%s] MTP heads detected; auto-setting ubatch_size=8",
model.name,
)

# 2. Warn if the user explicitly asked for draft-mtp on a model that
# does not actually contain MTP heads.
if recipe.spec_type == "draft-mtp" and not mtp_present:
log.warning(
"[%s] recipe.spec_type='draft-mtp' but GGUF has no MTP heads "
"(nextn_predict_layers == 0). Speculative decoding will likely "
"degenerate or crash.",
model.name,
)

# 3. Backend recommendation for hybrid SSM + MTP on Xe2 (Battlemage,
# Lunar Lake). GDN sequential state passes make SYCL MTP net-negative.
if mtp_present and hybrid_ssm and arch in (Arch.BATTLEMAGE, Arch.LUNAR_LAKE):
log.info(
"[%s] Hybrid SSM+attention model with MTP heads on Xe2 (%s): "
"SYCL MTP speculative decoding is net-negative here because GDN "
"layers force serial state passes. Consider a Vulkan backend "
"build for ~+9%% throughput with --spec-type draft-mtp.",
model.name,
arch.value,
)

argv: list[str] = [
cfg.paths.llama_server,
"-m", model.path,
Expand Down
32 changes: 25 additions & 7 deletions src/arc_llama/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Config,
ModelConfig,
)
from arc_llama.gguf_meta import has_mtp_heads
from arc_llama.recipes import default_recipe

log = logging.getLogger("arc_llama.models")
Expand Down Expand Up @@ -127,6 +128,14 @@ def add_local_model(
"cache_type_k": recipe.cache_type_k.value,
"cache_type_v": recipe.cache_type_v.value,
}
# Auto-enable draft-mtp for models that actually carry MTP heads.
if has_mtp_heads(p):
recipe_dict["spec_type"] = "draft-mtp"
recipe_dict["ubatch_size"] = 8
log.info(
"model %s has MTP heads; auto-enabling spec_type=draft-mtp, ubatch_size=8",
name,
)
if recipe_overrides:
recipe_dict.update(recipe_overrides)
mc = ModelConfig(
Expand Down Expand Up @@ -304,6 +313,21 @@ def register_discovered(
model_file_mb=rp.stat().st_size // (1024 * 1024),
kv_class=kv_class,
)
recipe_dict: dict[str, Any] = {
"n_gpu_layers": recipe.n_gpu_layers,
"ctx": recipe.ctx,
"parallel": recipe.parallel,
"cache_type_k": recipe.cache_type_k.value,
"cache_type_v": recipe.cache_type_v.value,
}
# Auto-enable draft-mtp for discovered models that carry MTP heads.
if has_mtp_heads(rp):
recipe_dict["spec_type"] = "draft-mtp"
recipe_dict["ubatch_size"] = 8
log.info(
"discovered %s has MTP heads; auto-enabling spec_type=draft-mtp, ubatch_size=8",
rp.name,
)
name = short_name_from_path(rp, used_names)
used_names.add(name)
port = port_start
Expand All @@ -317,13 +341,7 @@ def register_discovered(
gpu_pci_slot=gpu_pci_slot,
display_name=infer_display_name(rp),
kv_class=kv_class,
recipe={
"n_gpu_layers": recipe.n_gpu_layers,
"ctx": recipe.ctx,
"parallel": recipe.parallel,
"cache_type_k": recipe.cache_type_k.value,
"cache_type_v": recipe.cache_type_v.value,
},
recipe=recipe_dict,
aliases=[rp.name],
)
cfg.models.append(mc)
Expand Down
8 changes: 8 additions & 0 deletions src/arc_llama/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class LaunchRecipe:
temp: float | None = None
top_p: float | None = None
top_k: int | None = None
spec_type: str | None = None
"""Speculative decoding type, e.g. 'draft-mtp'."""
ubatch_size: int | None = None
"""Ubatch size (-ub). Auto-set to 8 for MTP models to avoid SSM compute-buffer OOM."""
extra_flags: list[str] = field(default_factory=list)
"""Anything else the user wants appended to the command line verbatim."""

Expand All @@ -66,6 +70,10 @@ def to_argv(self) -> list[str]:
argv += ["--top-p", str(self.top_p)]
if self.top_k is not None:
argv += ["--top-k", str(self.top_k)]
if self.spec_type:
argv += ["--spec-type", self.spec_type]
if self.ubatch_size is not None:
argv += ["-ub", str(self.ubatch_size)]
argv += list(self.extra_flags)
return argv

Expand Down
16 changes: 15 additions & 1 deletion src/arc_llama/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ async def admin_edit_model(name: str, request: Request) -> dict:
"""Update a model's recipe in-place.

Body is a partial recipe dict — only provided fields change. Recognised
fields: `ctx`, `cache_type_k`, `cache_type_v`, `parallel`, `kv_class`.
fields: `ctx`, `cache_type_k`, `cache_type_v`, `parallel`, `kv_class`,
`spec_type`, `ubatch_size`.
If the model is currently loaded, the server is stopped first; callers
decide whether to reload it afterwards via /admin/load.
"""
Expand Down Expand Up @@ -233,6 +234,19 @@ async def admin_edit_model(name: str, request: Request) -> dict:
)
model.kv_class = v
changed.append("kv_class")
if "spec_type" in body:
v = str(body["spec_type"])
recipe["spec_type"] = v
changed.append("spec_type")
if "ubatch_size" in body:
try:
ub = int(body["ubatch_size"])
except (TypeError, ValueError):
raise HTTPException(status_code=400, detail="ubatch_size must be an integer")
if not (1 <= ub <= 4096):
raise HTTPException(status_code=400, detail="ubatch_size must be 1..4096")
recipe["ubatch_size"] = ub
changed.append("ubatch_size")
if not changed:
raise HTTPException(status_code=400, detail="no recognised fields to edit")
model.recipe = recipe
Expand Down
Loading
Loading