Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions diffulex/extensions/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@
VllmCutlassW4A8,
VllmFp8LinearOp,
# Triton kernels
Fp8KVAttentionKernel,
fp8_kv_attention_forward,
chunked_prefill_attn_unified_fp8,
_HAS_TRITON_KERNELS,
)

Expand Down Expand Up @@ -170,6 +169,8 @@
# Offline quantization
from .quantize_model import quantize_model



__all__ = [
# Bootstrap
"enable",
Expand Down Expand Up @@ -198,8 +199,7 @@
"VllmAllSparkW8A16",
"VllmCutlassW4A8",
"VllmFp8LinearOp",
"Fp8KVAttentionKernel",
"fp8_kv_attention_forward",
"chunked_prefill_attn_unified_fp8",
"_HAS_TRITON_KERNELS",

# Configuration
Expand Down
217 changes: 127 additions & 90 deletions diffulex/extensions/quantization/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,20 @@ def enable(config: Optional[Dict[str, Any]] = None,

_quant_config = config

# Setup import hooks first (before any diffulex imports)
_setup_import_hooks()

# Import and initialize all components
try:
# 0. Import and patch loader first (before auto_model imports it)
# This ensures load_model is patched before auto_model binds to it
try:
import diffulex.utils.loader
from .loader_patch import patch_loader
patch_loader()
except Exception:
pass

# 1. Import strategies to register them
from . import strategies # noqa: F401

Expand Down Expand Up @@ -121,7 +133,59 @@ def enable(config: Optional[Dict[str, Any]] = None,
# 5. Setup import hooks for post-import patching
_setup_import_hooks()

# 6. Explicitly import and patch Attention class for FP8 kernel
# Attention is lazily imported via __getattr__, so we need to trigger the import
try:
from diffulex.attention import Attention
from .kv_cache_patch import patch_attention_class
patch_attention_class()
except Exception:
pass

# 7. Patch already-imported model runners (e.g., d2f imported before hook was set)
# Also patch ModelRunnerBase class directly to ensure allocate_kv_cache uses quantization
try:
from .kv_cache_patch import patch_model_runner, patch_allocate_kv_cache_method
import sys

# Explicitly import and patch ModelRunnerBase class
# This must be done before any runner instances are created
try:
from diffulex.engine.model_runner import ModelRunnerBase
patch_allocate_kv_cache_method(ModelRunnerBase)
print(f"[Quantization] Patched ModelRunnerBase.allocate_kv_cache")
except ImportError as e:
print(f"[Quantization] Warning: Could not patch ModelRunnerBase: {e}")

# Also patch instance __init__ for any runtime setup
for mod_name in list(sys.modules.keys()):
if 'model_runner' in mod_name:
mod = sys.modules[mod_name]
for attr_name in dir(mod):
attr = getattr(mod, attr_name)
if isinstance(attr, type) and 'runner' in attr_name.lower():
# Patch the class
original_init = attr.__init__
def make_patched_init(orig_init):
def patched_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
patch_model_runner(self)
return patched_init
attr.__init__ = make_patched_init(original_init)
except Exception:
pass

_is_enabled = True

# Eager patch loader if it was already imported before hook was set
try:
import sys
if 'diffulex.utils.loader' in sys.modules:
from .loader_patch import patch_loader
patch_loader()
except Exception:
pass

return True

except Exception as e:
Expand Down Expand Up @@ -198,20 +262,24 @@ def _post_import_patch(module_name: str, module):

This handles patching of modules that are imported after enable() is called.
"""
# Patch model runner
if 'model_runner' in module_name or hasattr(module, 'ModelRunner'):
# Patch model runner classes
if 'model_runner' in module_name or module_name.endswith('.d2f'):
try:
from .kv_cache_patch import patch_model_runner

# Patch class
if hasattr(module, 'ModelRunner'):
original_init = module.ModelRunner.__init__

def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
patch_model_runner(self)

module.ModelRunner.__init__ = patched_init
# Patch all runner classes in the module
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and 'runner' in attr_name.lower():
original_init = attr.__init__

def make_patched_init(orig_init):
def patched_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
patch_model_runner(self)
return patched_init

attr.__init__ = make_patched_init(original_init)
except Exception:
pass

Expand All @@ -231,11 +299,24 @@ def patched_init(self, *args, **kwargs):
except Exception:
pass

# Patch loader
if 'loader' in module_name:
# Patch Attention class for FP8 custom kernel
print(f"[_post_import_patch] Checking module: {module_name}, 'attn_impl' in name: {'attn_impl' in module_name}")
if 'attn_impl' in module_name or module_name == 'diffulex.attention.attn_impl':
print(f"[_post_import_patch] Patching Attention for {module_name}")
try:
from .loader_patch import patch_loader
patch_loader()
from .kv_cache_patch import patch_attention_class
patch_attention_class()
print(f"[_post_import_patch] Patching Attention succeeded for {module_name}")
except Exception as e:
print(f"[_post_import_patch] Patching Attention failed: {e}")

# Patch loader when diffulex.utils is imported (loader is a submodule)
if module_name == 'diffulex.utils':
try:
# Check if loader submodule exists
if hasattr(module, 'loader'):
from .loader_patch import patch_loader
patch_loader()
except Exception:
pass

Expand Down Expand Up @@ -379,96 +460,52 @@ def patched_init(self, *args, **kwargs):

def _quantize_model_weights(model_wrapper):
"""
Quantize all linear layer weights in the model.
Verify offline quantized weights are properly loaded.

This is called once after model loading to pre-quantize weights.
Raises error if user specified GPTQ/AWQ but weights are not loaded.
"""
from .context import get_linear_strategy
from .layer_mixin import LinearQuantizationMixin
from .context import get_linear_strategy

# Check if already quantized (avoid duplicate quantization in multi-worker setup)
if getattr(model_wrapper, '_weights_quantized', False):
return
model_wrapper._weights_quantized = True

# Get model runner
model_runner = getattr(model_wrapper, 'model_runner', None)
if model_runner is None:
return

model = getattr(model_runner, 'model', None)
if model is None:
# Get current quantization config
global _quant_config
if _quant_config is None:
return

# Get current quantization config
weight_method = _quant_config.get('weights', {}).get('method', 'bf16')

# Skip if not online quantization
if weight_method in ['bf16', 'none']:
# Skip if not offline quantization
if weight_method not in ['gptq_w4a16', 'gptq_w8a16', 'awq_w4a16', 'gptq_marlin_w4a16', 'awq_marlin_w4a16']:
return

# Skip if offline quantization (GPTQ/AWQ) - those are already quantized
if any(fmt in weight_method.lower() for fmt in ['gptq', 'awq', 'marlin']):
# Get model
model_runner = getattr(model_wrapper, 'model_runner', None)
if model_runner is None:
return

# Mark as quantized to avoid duplicate work
model_wrapper._weights_quantized = True

print(f"[Quantization] Pre-quantizing model weights to {weight_method}...")

# Get strategy
strategy = get_linear_strategy('attn') # Use attn strategy for all
if strategy is None:
model = getattr(model_runner, 'model', None)
if model is None:
return

quantized_count = 0
total_saved_bytes = 0

# Iterate through all modules
# Check offline quantized layers
offline_count = 0
total_count = 0
for name, module in model.named_modules():
# Check if this is a quantized linear layer
if isinstance(module, LinearQuantizationMixin):
# Skip if already quantized
if module.has_quantized_weight() or module.has_offline_quantized_weight():
continue

# Quantize weight
try:
weight = module.weight
if weight is None or weight.dtype != torch.bfloat16:
continue

original_size = weight.numel() * weight.element_size()

# Use strategy to quantize weight
q_weight, w_meta = strategy.quantize_weight_for_kernel(weight)
w_scale = w_meta.get('scale')
w_zero = w_meta.get('zero_point')

# Store quantized weight
module.set_quantized_weight(q_weight, w_scale, w_zero)

# Delete original weight to save memory
if hasattr(module, 'weight'):
delattr(module, 'weight')
if 'weight' in module._parameters:
del module._parameters['weight']

quantized_size = q_weight.numel() * q_weight.element_size()
total_saved_bytes += (original_size - quantized_size)
quantized_count += 1

except Exception as e:
# Log but continue
print(f"[Quantization] Warning: Failed to quantize {name}: {e}")
continue

if quantized_count > 0:
saved_mb = total_saved_bytes / (1024 ** 2)
print(f"[Quantization] Pre-quantized {quantized_count} layers to {weight_method}")
print(f"[Quantization] Estimated memory saved: {saved_mb:.1f} MB")

# Force CUDA synchronization to get accurate memory stats
if torch.cuda.is_available():
torch.cuda.synchronize()
mem_allocated = torch.cuda.memory_allocated() / 1024**3
print(f"[Quantization] Current GPU memory: {mem_allocated:.2f} GB")
total_count += 1
if module.has_offline_quantized_weight():
offline_count += 1

if offline_count == 0 and total_count > 0:
raise RuntimeError(
f"Quantization mismatch: weight_quant_method='{weight_method}' specified, "
f"but no offline quantized weights found in model. "
f"Please ensure you're loading a {weight_method.upper()} quantized model, "
f"or set weight_quant_method='bf16' for non-quantized models."
)

if offline_count > 0:
print(f"[Quantization] {offline_count}/{total_count} layers using {weight_method}")
10 changes: 4 additions & 6 deletions diffulex/extensions/quantization/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@
# Import custom Triton kernels
try:
from .triton_kernels import (
Fp8KVAttentionKernel,
fp8_kv_attention_forward,
chunked_prefill_attn_unified_fp8,
_HAS_FP8_UNIFIED_KERNEL,
)
_HAS_TRITON_KERNELS = True
except ImportError:
_HAS_TRITON_KERNELS = False
Fp8KVAttentionKernel = None
fp8_kv_attention_forward = None
chunked_prefill_attn_unified_fp8 = None

__all__ = [
# Registry
Expand All @@ -70,7 +69,6 @@
"VllmCutlassW4A8",
"VllmFp8LinearOp",
# Triton kernels
"Fp8KVAttentionKernel",
"fp8_kv_attention_forward",
"chunked_prefill_attn_unified_fp8",
"_HAS_TRITON_KERNELS",
]
11 changes: 10 additions & 1 deletion diffulex/extensions/quantization/kernels/kernel_availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import os
from typing import Set, Optional
import torch

# Track which warnings have been issued to avoid spamming
_issued_warnings: Set[str] = set()
Expand All @@ -27,14 +28,22 @@ def is_strict_mode() -> bool:


def check_vllm_op_available(op_name: str) -> bool:
"""Check if a vLLM custom op is available."""
"""Check if a vLLM custom op is available via vllm._custom_ops."""
try:
import vllm._custom_ops as ops
return hasattr(ops, op_name)
except (ImportError, AttributeError):
return False


def check_torch_c_op_available(op_name: str) -> bool:
"""Check if a vLLM custom op is available via torch.ops._C."""
try:
return hasattr(torch.ops._C, op_name)
except (ImportError, AttributeError):
return False


def check_kernel_available(kernel_name: str, op_checker: Optional[callable] = None) -> bool:
"""
Check if a kernel is available.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
Pure Triton implementations for operations not covered by vLLM kernels.
"""

# Unified FP8 kernel (Stage 1 + Stage 2)
try:
from .fp8_kv_attention import (
Fp8KVAttentionKernel,
fp8_kv_attention_forward,
from .chunked_prefill_attn_unified_fp8 import (
chunked_prefill_attn_unified_fp8,
)
_HAS_FP8_KERNEL = True
_HAS_FP8_UNIFIED_KERNEL = True
except ImportError:
_HAS_FP8_KERNEL = False
_HAS_FP8_UNIFIED_KERNEL = False

__all__ = [
"Fp8KVAttentionKernel",
"fp8_kv_attention_forward",
"_HAS_FP8_KERNEL",
"chunked_prefill_attn_unified_fp8",
"_HAS_FP8_UNIFIED_KERNEL",
]
Loading