Skip to content

Commit cce6213

Browse files
Add video demo
1 parent 7a7bed3 commit cce6213

6 files changed

Lines changed: 230 additions & 1 deletion

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ benchmarking/**/reports/**
7272

7373
out.log
7474

75+
assets/*
76+
!assets/*.png
77+
!assets/*.webp
78+
!assets/rsr_baseline_compare.mp4
79+
7580
integrations/**/*.json
7681
integrations/**/*.safetensors
7782

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44

55
Reference: [UIC-InDeXLab/RSR](https://github.com/UIC-InDeXLab/RSR)
66

7+
## Installation
8+
9+
**Prerequisites:** Python >= 3.10, a C compiler (for CPU kernels), and optionally CUDA for GPU support.
10+
11+
```bash
12+
git clone https://github.com/UIC-InDeXLab/RSR-Core.git
13+
cd RSR-Core
14+
pip install -e .
15+
```
16+
717
## Structure
818

919
```
@@ -24,6 +34,17 @@ RSR-core/
2434
└── tests/ # Unit and integration tests
2535
```
2636

37+
38+
## Demo
39+
40+
<!-- <p align="center">
41+
<a href="assets/rsr_baseline_compare.mp4">
42+
<img src="assets/rsr_baseline_compare.webp" alt="Comparison of the Hugging Face baseline and RSR inference on 1.58-bit LLM inference. Click to open the MP4 version." width="900" />
43+
</a>
44+
</p> -->
45+
46+
[![RSR vs Baseline](assets/rsr_baseline_compare.webp)](https://raw.githubusercontent.com/UIC-InDeXLab/RSR-core/main/assets/rsr_baseline_compare.mp4)
47+
2748
## Benchmark Results
2849

2950
### Matrix-Vector Multiplication

assets/rsr_baseline_compare.mp4

1.06 MB
Binary file not shown.

assets/rsr_baseline_compare.webp

28.9 MB
Loading

integrations/hf/model_infer.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import copy
1919
import json
2020
import sys
21+
import time
2122
from contextlib import contextmanager
2223
from pathlib import Path
2324
from typing import Any
@@ -56,6 +57,20 @@ def no_init_weights():
5657
)
5758

5859

60+
class GreenTextStreamer(TextStreamer):
61+
"""TextStreamer that prints generated tokens in ANSI green."""
62+
63+
_GREEN = "\033[32m"
64+
_RESET = "\033[0m"
65+
66+
def on_finalized_text(self, text: str, stream_end: bool = False) -> None:
67+
print(
68+
f"{self._GREEN}{text}{self._RESET}",
69+
flush=True,
70+
end="" if not stream_end else "\n",
71+
)
72+
73+
5974
def _bitnet_act_quant(activation: torch.Tensor) -> torch.Tensor:
6075
"""Compatibility wrapper around the shared BitNet activation quantizer."""
6176
return bitnet_act_quant(activation)
@@ -575,6 +590,35 @@ def load_hf_model(
575590
return model, tokenizer
576591

577592

593+
_BOLD_CYAN = "\033[1;36m"
594+
_RESET = "\033[0m"
595+
596+
597+
def _print_inference_stats(n_tokens: int, elapsed: float) -> None:
598+
"""Print a bold-cyan summary table with token count, wall time, and throughput."""
599+
tok_per_sec = n_tokens / elapsed if elapsed > 0 else float("inf")
600+
rows = [
601+
("tokens", str(n_tokens)),
602+
("time", f"{elapsed:.3f} s"),
603+
("tok/s", f"{tok_per_sec:.1f}"),
604+
]
605+
w_label = max(len(r[0]) for r in rows) + 2
606+
w_value = max(len(r[1]) for r in rows) + 2
607+
top = f"┌{'─' * w_label}{'─' * w_value}┐"
608+
mid = f"├{'─' * w_label}{'─' * w_value}┤"
609+
bot = f"└{'─' * w_label}{'─' * w_value}┘"
610+
611+
def _line(s: str) -> None:
612+
print(f"{_BOLD_CYAN}{s}{_RESET}")
613+
614+
_line(top)
615+
for i, (label, value) in enumerate(rows):
616+
_line(f"│ {label:<{w_label - 2}}{value:>{w_value - 2}} │")
617+
if i < len(rows) - 1:
618+
_line(mid)
619+
_line(bot)
620+
621+
578622
@torch.inference_mode()
579623
def generate_text(
580624
model: nn.Module,
@@ -602,24 +646,30 @@ def generate_text(
602646

603647
streamer = None
604648
if stream:
605-
streamer = TextStreamer(
649+
streamer = GreenTextStreamer(
606650
tokenizer,
607651
skip_prompt=True,
608652
skip_special_tokens=True,
609653
)
610654

655+
if stream:
656+
print(f"{_BOLD_CYAN}▶ response{_RESET}")
657+
658+
t0 = time.perf_counter()
611659
output_ids = model.generate(
612660
**inputs,
613661
max_new_tokens=max_new_tokens,
614662
pad_token_id=tokenizer.pad_token_id,
615663
streamer=streamer,
616664
**generate_kwargs,
617665
) # type: ignore
666+
elapsed = time.perf_counter() - t0
618667

619668
prompt_length = inputs["input_ids"].shape[-1]
620669
generated_ids = output_ids[0, prompt_length:]
621670
if stream:
622671
print()
672+
_print_inference_stats(len(generated_ids), elapsed)
623673
return tokenizer.decode(generated_ids, skip_special_tokens=True)
624674

625675

tests/test_model_infer.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import torch.nn as nn
1010

1111
from integrations.hf.model_infer import (
12+
GreenTextStreamer,
1213
RSRLinear,
1314
_bitnet_act_quant,
1415
_detect_device_from_dir,
16+
_print_inference_stats,
1517
_resolve_module,
1618
_set_module,
1719
parse_args,
@@ -405,3 +407,154 @@ def test_default_mode_is_multiply(self):
405407
ws = torch.tensor([3.0])
406408
layer = RSRLinear("test", meta, arrays, weight_scale=ws)
407409
assert layer._weight_scale_mode == "multiply"
410+
411+
412+
# ---------------------------------------------------------------------------
413+
# GreenTextStreamer
414+
# ---------------------------------------------------------------------------
415+
416+
class TestGreenTextStreamer:
417+
def test_output_wrapped_in_green(self, capsys):
418+
"""on_finalized_text prints text wrapped in ANSI green codes."""
419+
from transformers import AutoTokenizer
420+
421+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
422+
streamer = GreenTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
423+
streamer.on_finalized_text("hello", stream_end=False)
424+
captured = capsys.readouterr()
425+
assert "\033[32m" in captured.out
426+
assert "hello" in captured.out
427+
assert "\033[0m" in captured.out
428+
429+
def test_stream_end_adds_newline(self, capsys):
430+
"""stream_end=True terminates with a newline."""
431+
from transformers import AutoTokenizer
432+
433+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
434+
streamer = GreenTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
435+
streamer.on_finalized_text("done", stream_end=True)
436+
captured = capsys.readouterr()
437+
assert captured.out.endswith("\n")
438+
439+
def test_green_codes_present_without_tokenizer(self, capsys):
440+
"""GreenTextStreamer wraps any text in green regardless of content."""
441+
from transformers import AutoTokenizer
442+
443+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
444+
streamer = GreenTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
445+
for token in ["The", " quick", " brown", " fox"]:
446+
streamer.on_finalized_text(token, stream_end=False)
447+
captured = capsys.readouterr()
448+
assert captured.out.count("\033[32m") == 4
449+
assert captured.out.count("\033[0m") == 4
450+
451+
452+
# ---------------------------------------------------------------------------
453+
# _print_inference_stats
454+
# ---------------------------------------------------------------------------
455+
456+
class TestPrintInferenceStats:
457+
def test_contains_required_fields(self, capsys):
458+
_print_inference_stats(n_tokens=42, elapsed=1.234)
459+
out = capsys.readouterr().out
460+
assert "tokens" in out
461+
assert "time" in out
462+
assert "tok/s" in out
463+
464+
def test_token_count_displayed(self, capsys):
465+
_print_inference_stats(n_tokens=100, elapsed=2.0)
466+
out = capsys.readouterr().out
467+
assert "100" in out
468+
469+
def test_elapsed_time_displayed(self, capsys):
470+
_print_inference_stats(n_tokens=10, elapsed=3.5)
471+
out = capsys.readouterr().out
472+
assert "3.500 s" in out
473+
474+
def test_throughput_displayed(self, capsys):
475+
_print_inference_stats(n_tokens=50, elapsed=2.0)
476+
out = capsys.readouterr().out
477+
assert "25.0" in out # 50 / 2.0
478+
479+
def test_zero_elapsed_no_crash(self, capsys):
480+
_print_inference_stats(n_tokens=10, elapsed=0.0)
481+
out = capsys.readouterr().out
482+
assert "tok/s" in out
483+
assert "inf" in out
484+
485+
def test_table_borders(self, capsys):
486+
_print_inference_stats(n_tokens=5, elapsed=0.5)
487+
out = capsys.readouterr().out
488+
assert "┌" in out and "┐" in out
489+
assert "└" in out and "┘" in out
490+
assert "│" in out
491+
492+
def test_output_is_bold_cyan(self, capsys):
493+
_print_inference_stats(n_tokens=10, elapsed=1.0)
494+
out = capsys.readouterr().out
495+
assert "\033[1;36m" in out # bold cyan
496+
assert "\033[0m" in out # reset after each line
497+
498+
499+
# ---------------------------------------------------------------------------
500+
# Stream header ("▶ response")
501+
# ---------------------------------------------------------------------------
502+
503+
class TestStreamHeader:
504+
def test_header_printed_before_tokens(self, capsys, monkeypatch):
505+
"""generate_text prints a bold-cyan '▶ response' line before streaming."""
506+
import integrations.hf.model_infer as mi
507+
508+
# Minimal stubs so generate_text can run without a real model/tokenizer.
509+
fake_ids = torch.tensor([[1, 2, 3, 4]]) # 4 tokens, no prompt
510+
511+
class _FakeTokenizer:
512+
pad_token_id = 0
513+
def __call__(self, prompt, return_tensors):
514+
return {"input_ids": fake_ids[:, :1]} # 1-token "prompt"
515+
def decode(self, ids, skip_special_tokens):
516+
return "ok"
517+
518+
class _FakeModel(torch.nn.Module):
519+
def parameters(self):
520+
return iter([torch.empty(1)])
521+
def generate(self, **kwargs):
522+
return fake_ids
523+
524+
monkeypatch.setattr(mi, "_print_inference_stats", lambda *a, **k: None)
525+
526+
mi.generate_text(
527+
_FakeModel(), _FakeTokenizer(), "hi",
528+
use_chat_template=False, stream=True,
529+
)
530+
out = capsys.readouterr().out
531+
assert "▶ response" in out
532+
assert "\033[1;36m" in out
533+
534+
def test_header_absent_without_stream(self, capsys, monkeypatch):
535+
"""No header is printed when stream=False."""
536+
import integrations.hf.model_infer as mi
537+
538+
fake_ids = torch.tensor([[1, 2]])
539+
540+
class _FakeTokenizer:
541+
pad_token_id = 0
542+
def __call__(self, prompt, return_tensors):
543+
return {"input_ids": fake_ids[:, :1]}
544+
def decode(self, ids, skip_special_tokens):
545+
return "ok"
546+
547+
class _FakeModel(torch.nn.Module):
548+
def parameters(self):
549+
return iter([torch.empty(1)])
550+
def generate(self, **kwargs):
551+
return fake_ids
552+
553+
monkeypatch.setattr(mi, "_print_inference_stats", lambda *a, **k: None)
554+
555+
mi.generate_text(
556+
_FakeModel(), _FakeTokenizer(), "hi",
557+
use_chat_template=False, stream=False,
558+
)
559+
out = capsys.readouterr().out
560+
assert "▶ response" not in out

0 commit comments

Comments
 (0)