|
15 | 15 | import sys |
16 | 16 | from typing import Optional, Dict, Any |
17 | 17 |
|
| 18 | +import torch |
| 19 | + |
18 | 20 | # Global state |
19 | 21 | _is_enabled = False |
20 | 22 | _quant_config: Optional[Dict[str, Any]] = None |
@@ -236,6 +238,14 @@ def patched_init(self, *args, **kwargs): |
236 | 238 | patch_loader() |
237 | 239 | except Exception: |
238 | 240 | pass |
| 241 | + |
| 242 | + # Patch model to quantize weights after loading |
| 243 | + # Only patch when the actual engine module is imported (not during recursion) |
| 244 | + if module_name == 'diffulex.diffulex' or module_name == 'diffulex.engine.tp_worker': |
| 245 | + try: |
| 246 | + _patch_model_for_weight_quantization(module) |
| 247 | + except Exception: |
| 248 | + pass |
239 | 249 |
|
240 | 250 |
|
241 | 251 | # Convenience function for configuring quantization from CLI args |
@@ -333,3 +343,132 @@ def auto_enable_from_config(config): |
333 | 343 | } |
334 | 344 |
|
335 | 345 | return enable(config=quant_config) |
| 346 | + |
| 347 | + |
| 348 | +def _patch_model_for_weight_quantization(module): |
| 349 | + """ |
| 350 | + Patch model initialization to quantize weights after loading. |
| 351 | + |
| 352 | + This ensures online quantization (INT8/FP8) is applied to weights |
| 353 | + immediately after model creation, not during each forward pass. |
| 354 | + """ |
| 355 | + from .context import get_linear_strategy |
| 356 | + from .layer_mixin import LinearQuantizationMixin |
| 357 | + |
| 358 | + # Find the Diffulex class |
| 359 | + DiffulexClass = None |
| 360 | + for attr_name in ['DiffulexTPWorker', 'Diffulex', 'DiffulexDPWorker']: |
| 361 | + if hasattr(module, attr_name): |
| 362 | + DiffulexClass = getattr(module, attr_name) |
| 363 | + break |
| 364 | + |
| 365 | + if DiffulexClass is None: |
| 366 | + return |
| 367 | + |
| 368 | + original_init = DiffulexClass.__init__ |
| 369 | + |
| 370 | + def patched_init(self, *args, **kwargs): |
| 371 | + # Call original init |
| 372 | + original_init(self, *args, **kwargs) |
| 373 | + |
| 374 | + # After initialization, quantize weights if needed |
| 375 | + _quantize_model_weights(self) |
| 376 | + |
| 377 | + DiffulexClass.__init__ = patched_init |
| 378 | + |
| 379 | + |
| 380 | +def _quantize_model_weights(model_wrapper): |
| 381 | + """ |
| 382 | + Quantize all linear layer weights in the model. |
| 383 | + |
| 384 | + This is called once after model loading to pre-quantize weights. |
| 385 | + """ |
| 386 | + from .context import get_linear_strategy |
| 387 | + from .layer_mixin import LinearQuantizationMixin |
| 388 | + |
| 389 | + # Check if already quantized (avoid duplicate quantization in multi-worker setup) |
| 390 | + if getattr(model_wrapper, '_weights_quantized', False): |
| 391 | + return |
| 392 | + |
| 393 | + # Get model runner |
| 394 | + model_runner = getattr(model_wrapper, 'model_runner', None) |
| 395 | + if model_runner is None: |
| 396 | + return |
| 397 | + |
| 398 | + model = getattr(model_runner, 'model', None) |
| 399 | + if model is None: |
| 400 | + return |
| 401 | + |
| 402 | + # Get current quantization config |
| 403 | + weight_method = _quant_config.get('weights', {}).get('method', 'bf16') |
| 404 | + |
| 405 | + # Skip if not online quantization |
| 406 | + if weight_method in ['bf16', 'none']: |
| 407 | + return |
| 408 | + |
| 409 | + # Skip if offline quantization (GPTQ/AWQ) - those are already quantized |
| 410 | + if any(fmt in weight_method.lower() for fmt in ['gptq', 'awq', 'marlin']): |
| 411 | + return |
| 412 | + |
| 413 | + # Mark as quantized to avoid duplicate work |
| 414 | + model_wrapper._weights_quantized = True |
| 415 | + |
| 416 | + print(f"[Quantization] Pre-quantizing model weights to {weight_method}...") |
| 417 | + |
| 418 | + # Get strategy |
| 419 | + strategy = get_linear_strategy('attn') # Use attn strategy for all |
| 420 | + if strategy is None: |
| 421 | + return |
| 422 | + |
| 423 | + quantized_count = 0 |
| 424 | + total_saved_bytes = 0 |
| 425 | + |
| 426 | + # Iterate through all modules |
| 427 | + for name, module in model.named_modules(): |
| 428 | + # Check if this is a quantized linear layer |
| 429 | + if isinstance(module, LinearQuantizationMixin): |
| 430 | + # Skip if already quantized |
| 431 | + if module.has_quantized_weight() or module.has_offline_quantized_weight(): |
| 432 | + continue |
| 433 | + |
| 434 | + # Quantize weight |
| 435 | + try: |
| 436 | + weight = module.weight |
| 437 | + if weight is None or weight.dtype != torch.bfloat16: |
| 438 | + continue |
| 439 | + |
| 440 | + original_size = weight.numel() * weight.element_size() |
| 441 | + |
| 442 | + # Use strategy to quantize weight |
| 443 | + q_weight, w_meta = strategy.quantize_weight_for_kernel(weight) |
| 444 | + w_scale = w_meta.get('scale') |
| 445 | + w_zero = w_meta.get('zero_point') |
| 446 | + |
| 447 | + # Store quantized weight |
| 448 | + module.set_quantized_weight(q_weight, w_scale, w_zero) |
| 449 | + |
| 450 | + # Delete original weight to save memory |
| 451 | + if hasattr(module, 'weight'): |
| 452 | + delattr(module, 'weight') |
| 453 | + if 'weight' in module._parameters: |
| 454 | + del module._parameters['weight'] |
| 455 | + |
| 456 | + quantized_size = q_weight.numel() * q_weight.element_size() |
| 457 | + total_saved_bytes += (original_size - quantized_size) |
| 458 | + quantized_count += 1 |
| 459 | + |
| 460 | + except Exception as e: |
| 461 | + # Log but continue |
| 462 | + print(f"[Quantization] Warning: Failed to quantize {name}: {e}") |
| 463 | + continue |
| 464 | + |
| 465 | + if quantized_count > 0: |
| 466 | + saved_mb = total_saved_bytes / (1024 ** 2) |
| 467 | + print(f"[Quantization] Pre-quantized {quantized_count} layers to {weight_method}") |
| 468 | + print(f"[Quantization] Estimated memory saved: {saved_mb:.1f} MB") |
| 469 | + |
| 470 | + # Force CUDA synchronization to get accurate memory stats |
| 471 | + if torch.cuda.is_available(): |
| 472 | + torch.cuda.synchronize() |
| 473 | + mem_allocated = torch.cuda.memory_allocated() / 1024**3 |
| 474 | + print(f"[Quantization] Current GPU memory: {mem_allocated:.2f} GB") |
0 commit comments