Skip to content

Commit e9a538f

Browse files
author
Ralf Waldukat
committed
fix: critical fixes for recurrent/hybrid model support
After external code review (GPT-5.2), fixed 4 critical issues: 1. CRITICAL: Fixed tokens[:-1] bug in prefix matching - Was silently breaking prefix matching for ALL models - Caused false rewind detection and cache inefficiency - Impact: Transformers AND recurrent models 2. CRITICAL: Implement proper reset() for recurrent models - Now actually clears llama_memory backend state - Root cause fix for 'sequence positions not consecutive' crash - Without this, reset was a no-op for recurrent models 3. CRITICAL: Enforce strict append policy for recurrent models - Prevents KV cache rewinding that's impossible without state snapshots - Forces full reset on history edits instead of crashing 4. Performance: Cache _is_recurrent to avoid repeated FFI calls 5. Documentation: Simplified comments and updated docstring 6. Testing: All existing tests pass + Mistral-Small-3.2-24B validated Resolves multi-turn crashes for Nemotron-A3B, Mamba, RWKV, Jamba models. Reviewed-by: GPT-5.2 (OpenAI) Tested-by: pytest + Mistral-Small-3.2-24B Fixes: abetlen#2108 (recurrent model crashes) Compatible-with: abetlen#2109 (Granite-Docling/SmolVLM special tokens)
1 parent 02d6bee commit e9a538f

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

llama_cpp/llama.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def __init__(
192192
type_v: KV cache data type for V (default: f16)
193193
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
194194
195+
Note:
196+
Recurrent and hybrid models (Mamba, RWKV, Nemotron-A3B, Jamba) cannot
197+
rewind their state and require full reset on history edits. This is handled
198+
automatically to maintain compatibility. Standard transformers are unaffected.
199+
195200
Raises:
196201
ValueError: If the model path does not exist.
197202
@@ -553,6 +558,11 @@ def free_lora_adapter():
553558

554559
self._sampler = None
555560

561+
# Cache recurrent/hybrid model detection to avoid repeated FFI calls
562+
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
563+
self._model.model
564+
) or llama_cpp.llama_model_is_hybrid(self._model.model)
565+
556566
@property
557567
def ctx(self) -> llama_cpp.llama_context_p:
558568
return self._ctx.ctx
@@ -580,6 +590,19 @@ def eval_logits(self) -> Deque[List[float]]:
580590
maxlen=self._n_ctx if self._logits_all else 1,
581591
)
582592

593+
@property
594+
def _is_recurrent(self) -> bool:
595+
"""Check if model is recurrent (SSM) or hybrid (SSM+Attention).
596+
597+
These models (Mamba, RWKV, Nemotron, Jamba, etc.) cannot rewind their
598+
recurrent state without snapshots. Only strict forward progression or
599+
full reset is allowed.
600+
601+
Returns:
602+
True if model has recurrent state that cannot be rewound.
603+
"""
604+
return self._is_recurrent_model
605+
583606
def tokenize(
584607
self, text: bytes, add_bos: bool = True, special: bool = False
585608
) -> List[int]:
@@ -638,6 +661,11 @@ def reset(self):
638661
"""Reset the model state."""
639662
self.n_tokens = 0
640663

664+
if self._is_recurrent:
665+
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
666+
if mem is not None:
667+
llama_cpp.llama_memory_clear(mem, True)
668+
641669
def eval(self, tokens: Sequence[int]):
642670
"""Evaluate a list of tokens.
643671
@@ -888,11 +916,22 @@ def generate(
888916
# Check for kv cache prefix match
889917
if reset and self.n_tokens > 0:
890918
longest_prefix = 0
891-
for a, b in zip(self._input_ids, tokens[:-1]):
919+
for a, b in zip(self._input_ids, tokens):
892920
if a == b:
893921
longest_prefix += 1
894922
else:
895923
break
924+
925+
# Recurrent models cannot rewind state; reset if needed
926+
if self._is_recurrent and longest_prefix < self.n_tokens:
927+
longest_prefix = 0
928+
reset = True
929+
if self.verbose:
930+
print(
931+
"Llama.generate: recurrent model requires full state reset",
932+
file=sys.stderr,
933+
)
934+
896935
if longest_prefix > 0:
897936
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
898937
reset = False

0 commit comments

Comments
 (0)