Skip to content

Commit eb6de92

Browse files
committed
feat: optimize ssm with jit and add inference sanity check
1 parent f750d0b commit eb6de92

2 files changed

Lines changed: 13 additions & 1 deletion

File tree

aetheris/cli/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ def generate_command(args):
107107
repetition_penalty = args.repetition_penalty
108108

109109
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
110+
111+
# --- INFERENCE SANITY CHECK ---
112+
print(f"\n[SANITY CHECK] Inference Tokenizer: {tokenizer.name_or_path}")
113+
print(f"[SANITY CHECK] Vocab Size: {tokenizer.vocab_size}")
114+
print(f"[SANITY CHECK] Input IDs: {input_ids.tolist()}")
115+
decoded_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=False)
116+
print(f"[SANITY CHECK] Decoded Prompt: '{decoded_prompt}'\n")
117+
# ------------------------------
118+
110119
generated_ids = input_ids.clone()
111120
history_ids = set(input_ids[0].tolist())
112121

aetheris/modules/ssm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import torch.nn.functional as F
44
from ..config import AetherisConfig
55

6+
from typing import List
7+
8+
@torch.jit.script
69
def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
710
B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
811
"""Memory-efficient scan with reduced intermediate tensors."""
@@ -15,7 +18,7 @@ def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
1518
# Use in-place operations where possible
1619
# FORCE FLOAT32 for state to prevent underflow/overflow in long sequences
1720
h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=torch.float32)
18-
ys = []
21+
ys: List[torch.Tensor] = []
1922

2023
# Cast inputs to float32 for the scan
2124
# Note: This increases memory usage slightly but is critical for stability

0 commit comments

Comments
 (0)