Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dflash/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,12 @@ def _run_mlx(args: argparse.Namespace) -> None:
logger.info(f"Loading target: {args.model}")
model, tokenizer = load(args.model)
logger.info(f"Loading draft: {args.draft_model}")
draft = load_draft(args.draft_model, sliding_window_size=args.draft_sliding_window_size)
draft = load_draft(
args.draft_model,
sliding_window_size=args.draft_sliding_window_size,
quantize_kv_bits=args.draft_quantize_kv_bits,
quantize_kv_group_size=args.draft_quantize_kv_group_size,
)
block_size = args.block_size if args.block_size is not None else int(draft.config.block_size)

dataset = load_and_process_dataset(args.dataset)
Expand Down Expand Up @@ -488,6 +493,9 @@ def main() -> None:
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--block-size", type=int, default=None)
parser.add_argument("--draft-sliding-window-size", type=int, default=None)
parser.add_argument("--draft-quantize-kv-bits", type=int, default=None, choices=[4, 8],
help="Quantize draft KV cache to int4/int8 (reduces memory ~4x/2x)")
parser.add_argument("--draft-quantize-kv-group-size", type=int, default=64)
parser.add_argument("--max-samples", type=int, default=None)

parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000")
Expand Down
45 changes: 39 additions & 6 deletions dflash/model_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_lm.generate import generation_stream
from mlx_lm.models.cache import KVCache, RotatingKVCache, can_trim_prompt_cache, make_prompt_cache, trim_prompt_cache
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache, can_trim_prompt_cache, make_prompt_cache, trim_prompt_cache
from mlx_lm.models.qwen3 import MLP
from mlx_lm.models.rope_utils import initialize_rope
from mlx_lm.sample_utils import make_sampler
Expand Down Expand Up @@ -43,6 +43,8 @@ class DFlashConfig:
mask_token_id: int = 0
rope_scaling: Optional[Dict[str, Any]] = None
sliding_window_size: Optional[int] = None
quantize_kv_bits: Optional[int] = None
quantize_kv_group_size: int = 64


def _build_rope(
Expand Down Expand Up @@ -90,10 +92,20 @@ def __call__(self, x, x_ctx, rope, cache):
queries = rope(queries, offset=cache.offset + S)
ctx_keys = rope(ctx_keys, offset=cache.offset)
prop_keys = rope(prop_keys, offset=cache.offset + S)
keys, values = cache.update_and_fetch(ctx_keys, ctx_values)
keys = mx.concatenate([keys, prop_keys], axis=2)
values = mx.concatenate([values, prop_values], axis=2)
output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale)

if isinstance(cache, QuantizedKVCache):
# Dequant + concat path (saves memory, no bandwidth gain)
q_k, q_v = cache.update_and_fetch(ctx_keys, ctx_values)
k_deq = mx.dequantize(*q_k, group_size=cache.group_size, bits=cache.bits)
v_deq = mx.dequantize(*q_v, group_size=cache.group_size, bits=cache.bits)
keys = mx.concatenate([k_deq, prop_keys], axis=2)
values = mx.concatenate([v_deq, prop_values], axis=2)
output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale)
else:
keys, values = cache.update_and_fetch(ctx_keys, ctx_values)
keys = mx.concatenate([keys, prop_keys], axis=2)
values = mx.concatenate([values, prop_values], axis=2)
output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale)
return self.o_proj(output.transpose(0, 2, 1, 3).reshape(B, L, -1))


Expand Down Expand Up @@ -147,6 +159,14 @@ def bind(self, target_model):
def make_cache(self):
if self.config.sliding_window_size is not None:
return [RotatingKVCache(max_size=self.config.sliding_window_size, keep=0) for _ in self.layers]
if self.config.quantize_kv_bits is not None:
return [
QuantizedKVCache(
group_size=self.config.quantize_kv_group_size,
bits=self.config.quantize_kv_bits,
)
for _ in self.layers
]
return [KVCache() for _ in self.layers]

def __call__(self, inputs, target_hidden, cache):
Expand All @@ -162,11 +182,22 @@ def load(model_id: str):
return mlx_lm_load(model_id)


def load_draft(draft_id: str, sliding_window_size: Optional[int] = None) -> DFlashDraftModel:
def load_draft(
draft_id: str,
sliding_window_size: Optional[int] = None,
quantize_kv_bits: Optional[int] = None,
quantize_kv_group_size: int = 64,
) -> DFlashDraftModel:
if sliding_window_size is not None and sliding_window_size <= 0:
raise ValueError(
f"sliding_window_size must be positive or None, got {sliding_window_size}"
)
if quantize_kv_bits is not None and quantize_kv_bits not in (4, 8):
raise ValueError(
f"quantize_kv_bits must be 4, 8, or None, got {quantize_kv_bits}"
)
if sliding_window_size is not None and quantize_kv_bits is not None:
raise ValueError("sliding_window_size and quantize_kv_bits cannot be used together")
path = Path(snapshot_download(draft_id, allow_patterns=["*.safetensors", "*.json"]))
cfg = json.loads((path / "config.json").read_text())
config = DFlashConfig(
Expand All @@ -186,6 +217,8 @@ def load_draft(draft_id: str, sliding_window_size: Optional[int] = None) -> DFla
mask_token_id=cfg["dflash_config"]["mask_token_id"],
rope_scaling=cfg.get("rope_scaling"),
sliding_window_size=sliding_window_size,
quantize_kv_bits=quantize_kv_bits,
quantize_kv_group_size=quantize_kv_group_size,
)
weights = {k: v for f in path.glob("*.safetensors") for k, v in mx.load(str(f)).items()}
model = DFlashDraftModel(config)
Expand Down