-
-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathexample.py
More file actions
57 lines (48 loc) · 1.75 KB
/
example.py
File metadata and controls
57 lines (48 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python3
"""
70B+ inference without quantization, with KV cache on disk (avoids OOM on 8 GB VRAM).
"""
import tempfile
import warnings
import torch
from rabbitllm import AutoModel
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*CUDA.*unknown error.*", category=UserWarning)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Directory for KV cache (on disk, not on GPU)
kv_cache_dir = tempfile.mkdtemp(prefix="rabbitllm_kv_")
# For persistent use: kv_cache_dir = "./kv_cache"
model = AutoModel.from_pretrained(
"Qwen/Qwen2.5-72B-Instruct",
device=device,
compression=None, # no quantization, full precision
kv_cache_dir=kv_cache_dir, # KV cache to disk → avoids OOM on 8 GB
max_seq_len=512, # increase if you need longer context
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France? Answer in one sentence."},
]
input_text = model.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
tokens = model.tokenizer(
[input_text], return_tensors="pt", truncation=True, max_length=512
)
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
else:
attention_mask = attention_mask.to(device)
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=64,
use_cache=True,
do_sample=True,
temperature=0.6,
return_dict_in_generate=True,
)
input_len = tokens["input_ids"].shape[1]
print(model.tokenizer.decode(output.sequences[0][input_len:], skip_special_tokens=True))