Skip to content

Commit 2c2ddb5

Browse files
committed
feat: update example script for 70B model inference with disk KV cache support
- Changed model from Qwen2.5-0.5B to Qwen2.5-72B for enhanced performance. - Implemented disk-based KV cache to prevent OOM issues on 8 GB VRAM. - Updated user prompt in the example to reflect a new question. - Removed outdated comments and added new ones for clarity.
1 parent 155642f commit 2c2ddb5

1 file changed

Lines changed: 13 additions & 12 deletions

File tree

example.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
11
#!/usr/bin/env python3
22
"""
3-
RabbitLLM example — minimal inference script.
4-
5-
Run: python example.py
6-
Or: uv run python example.py
7-
8-
Uses a small model (Qwen2.5-0.5B) for fast testing. For larger models or long
9-
context, see scripts/quickstart.py and the Configuration section in README.
3+
Inferencia 70B+ sin cuantización, con KV cache en disco (evita OOM en 8 GB VRAM).
104
"""
115

6+
import tempfile
127
import warnings
138

149
import torch
1510
from rabbitllm import AutoModel
1611

1712
with warnings.catch_warnings():
1813
warnings.filterwarnings("ignore", message=".*CUDA.*unknown error.*", category=UserWarning)
19-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
14+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
15+
16+
# Directorio para el KV cache (en disco, no en GPU)
17+
kv_cache_dir = tempfile.mkdtemp(prefix="rabbitllm_kv_")
18+
# Para uso persistente: kv_cache_dir = "./kv_cache"
2019

2120
model = AutoModel.from_pretrained(
22-
"Qwen/Qwen2.5-0.5B-Instruct",
21+
"Qwen/Qwen2.5-72B-Instruct",
2322
device=device,
24-
compression="4bit",
23+
compression=None, # sin cuantización, full precision
24+
kv_cache_dir=kv_cache_dir, # KV cache a disco → evita OOM en 8 GB
25+
max_seq_len=512, # ajusta si necesitas contexto más largo
2526
)
2627

2728
messages = [
2829
{"role": "system", "content": "You are a helpful assistant."},
29-
{"role": "user", "content": "What is 2 + 2? Answer briefly."},
30+
{"role": "user", "content": "What is the capital of France? Answer in one sentence."},
3031
]
3132

3233
input_text = model.tokenizer.apply_chat_template(
@@ -53,4 +54,4 @@
5354
)
5455

5556
input_len = tokens["input_ids"].shape[1]
56-
print(model.tokenizer.decode(output.sequences[0][input_len:], skip_special_tokens=True))
57+
print(model.tokenizer.decode(output.sequences[0][input_len:], skip_special_tokens=True))

0 commit comments

Comments
 (0)