diff --git a/dflash/benchmark.py b/dflash/benchmark.py index a273e50..6cb3f59 100644 --- a/dflash/benchmark.py +++ b/dflash/benchmark.py @@ -337,7 +337,12 @@ def _run_mlx(args: argparse.Namespace) -> None: logger.info(f"Loading target: {args.model}") model, tokenizer = load(args.model) logger.info(f"Loading draft: {args.draft_model}") - draft = load_draft(args.draft_model, sliding_window_size=args.draft_sliding_window_size) + draft = load_draft( + args.draft_model, + sliding_window_size=args.draft_sliding_window_size, + quantize_kv_bits=args.draft_quantize_kv_bits, + quantize_kv_group_size=args.draft_quantize_kv_group_size, + ) block_size = args.block_size if args.block_size is not None else int(draft.config.block_size) dataset = load_and_process_dataset(args.dataset) @@ -488,6 +493,9 @@ def main() -> None: parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--block-size", type=int, default=None) parser.add_argument("--draft-sliding-window-size", type=int, default=None) + parser.add_argument("--draft-quantize-kv-bits", type=int, default=None, choices=[4, 8], + help="Quantize draft KV cache to int4/int8 (reduces memory ~4x/2x)") + parser.add_argument("--draft-quantize-kv-group-size", type=int, default=64) parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000") diff --git a/dflash/model_mlx.py b/dflash/model_mlx.py index 1a6293f..1da9e91 100644 --- a/dflash/model_mlx.py +++ b/dflash/model_mlx.py @@ -9,7 +9,7 @@ import mlx.nn as nn from huggingface_hub import snapshot_download from mlx_lm.generate import generation_stream -from mlx_lm.models.cache import KVCache, RotatingKVCache, can_trim_prompt_cache, make_prompt_cache, trim_prompt_cache +from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache, can_trim_prompt_cache, make_prompt_cache, trim_prompt_cache from mlx_lm.models.qwen3 import MLP from mlx_lm.models.rope_utils import initialize_rope from mlx_lm.sample_utils import make_sampler @@ -43,6 +43,8 @@ class DFlashConfig: mask_token_id: int = 0 rope_scaling: Optional[Dict[str, Any]] = None sliding_window_size: Optional[int] = None + quantize_kv_bits: Optional[int] = None + quantize_kv_group_size: int = 64 def _build_rope( @@ -90,10 +92,20 @@ def __call__(self, x, x_ctx, rope, cache): queries = rope(queries, offset=cache.offset + S) ctx_keys = rope(ctx_keys, offset=cache.offset) prop_keys = rope(prop_keys, offset=cache.offset + S) - keys, values = cache.update_and_fetch(ctx_keys, ctx_values) - keys = mx.concatenate([keys, prop_keys], axis=2) - values = mx.concatenate([values, prop_values], axis=2) - output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale) + + if isinstance(cache, QuantizedKVCache): + # Dequant + concat path (saves memory, no bandwidth gain) + q_k, q_v = cache.update_and_fetch(ctx_keys, ctx_values) + k_deq = mx.dequantize(*q_k, group_size=cache.group_size, bits=cache.bits) + v_deq = mx.dequantize(*q_v, group_size=cache.group_size, bits=cache.bits) + keys = mx.concatenate([k_deq, prop_keys], axis=2) + values = mx.concatenate([v_deq, prop_values], axis=2) + output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale) + else: + keys, values = cache.update_and_fetch(ctx_keys, ctx_values) + keys = mx.concatenate([keys, prop_keys], axis=2) + values = mx.concatenate([values, prop_values], axis=2) + output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale) return self.o_proj(output.transpose(0, 2, 1, 3).reshape(B, L, -1)) @@ -147,6 +159,14 @@ def bind(self, target_model): def make_cache(self): if self.config.sliding_window_size is not None: return [RotatingKVCache(max_size=self.config.sliding_window_size, keep=0) for _ in self.layers] + if self.config.quantize_kv_bits is not None: + return [ + QuantizedKVCache( + group_size=self.config.quantize_kv_group_size, + bits=self.config.quantize_kv_bits, + ) + for _ in self.layers + ] return [KVCache() for _ in self.layers] def __call__(self, inputs, target_hidden, cache): @@ -162,11 +182,22 @@ def load(model_id: str): return mlx_lm_load(model_id) -def load_draft(draft_id: str, sliding_window_size: Optional[int] = None) -> DFlashDraftModel: +def load_draft( + draft_id: str, + sliding_window_size: Optional[int] = None, + quantize_kv_bits: Optional[int] = None, + quantize_kv_group_size: int = 64, +) -> DFlashDraftModel: if sliding_window_size is not None and sliding_window_size <= 0: raise ValueError( f"sliding_window_size must be positive or None, got {sliding_window_size}" ) + if quantize_kv_bits is not None and quantize_kv_bits not in (4, 8): + raise ValueError( + f"quantize_kv_bits must be 4, 8, or None, got {quantize_kv_bits}" + ) + if sliding_window_size is not None and quantize_kv_bits is not None: + raise ValueError("sliding_window_size and quantize_kv_bits cannot be used together") path = Path(snapshot_download(draft_id, allow_patterns=["*.safetensors", "*.json"])) cfg = json.loads((path / "config.json").read_text()) config = DFlashConfig( @@ -186,6 +217,8 @@ def load_draft(draft_id: str, sliding_window_size: Optional[int] = None) -> DFla mask_token_id=cfg["dflash_config"]["mask_token_id"], rope_scaling=cfg.get("rope_scaling"), sliding_window_size=sliding_window_size, + quantize_kv_bits=quantize_kv_bits, + quantize_kv_group_size=quantize_kv_group_size, ) weights = {k: v for f in path.glob("*.safetensors") for k, v in mx.load(str(f)).items()} model = DFlashDraftModel(config)