diff --git a/configs/lingbot_fast/kv_calib/lingbot_fast_i2v_kv_sagequant_calib.json b/configs/lingbot_fast/kv_calib/lingbot_fast_i2v_kv_sagequant_calib.json deleted file mode 100755 index c8ef1e8ad..000000000 --- a/configs/lingbot_fast/kv_calib/lingbot_fast_i2v_kv_sagequant_calib.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 161, - "text_len": 512, - "target_height": 720, - "target_width": 1280, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1.0, - "sample_shift": 10.0, - "enable_cfg": false, - "cpu_offload": true, - "offload_granularity": "block", - "t5_cpu_offload": true, - "vae_cpu_offload": true, - "use_image_encoder": false, - "dit_original_ckpt": "path/to/lingbot_world_fast/", - "ar_config": { - "local_attn_size": 21, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "sink_size": 3, - "kv_quant": { - "calibrate": true, - "calib_path": "calib_kv.pt", - "quant_scheme": "sage", - "k_cache_type": "int8", - "v_cache_type": "fp8" - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/lingbot_fast/kv_calib/lingbot_fast_i2v_kv_turboquant_calib.json b/configs/lingbot_fast/kv_calib/lingbot_fast_i2v_kv_turboquant_calib.json deleted file mode 100755 index e35557cc5..000000000 --- a/configs/lingbot_fast/kv_calib/lingbot_fast_i2v_kv_turboquant_calib.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 161, - "text_len": 512, - "target_height": 720, - "target_width": 1280, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1.0, - "sample_shift": 10.0, - "enable_cfg": false, - "cpu_offload": false, - "offload_granularity": "block", - "t5_cpu_offload": false, - "vae_cpu_offload": false, - "use_image_encoder": false, - "dit_original_ckpt": "path/to/lingbot_world_fast/", - "ar_config": { - "local_attn_size": 21, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "sink_size": 3, - "kv_quant": { - "calibrate": true, - "quant_scheme": "turboquant", - "codebook_dir": "path/to/turboquant_codebooks", - "key_bits": 4, - "value_bits": 2 - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/lingbot_fast/lingbot_fast_i2v.json b/configs/lingbot_fast/lingbot_fast_i2v.json index a650617f8..5a2e64e02 100644 --- a/configs/lingbot_fast/lingbot_fast_i2v.json +++ b/configs/lingbot_fast/lingbot_fast_i2v.json @@ -15,7 +15,7 @@ "t5_cpu_offload": false, "vae_cpu_offload": false, "use_image_encoder": false, - "dit_original_ckpt": "path/to/lingbot_world_fast/", + "dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast", "ar_config": { "local_attn_size": 21, "num_frame_per_chunk": 3, diff --git a/configs/lingbot_fast/lingbot_fast_i2v_kv_sagequant.json b/configs/lingbot_fast/lingbot_fast_i2v_kv_sagequant.json deleted file mode 100755 index 1e56f0682..000000000 --- a/configs/lingbot_fast/lingbot_fast_i2v_kv_sagequant.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 161, - "text_len": 512, - "target_height": 720, - "target_width": 1280, - "self_attn_1_type": "sage_attn2_k_int8_v_fp8", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1.0, - "sample_shift": 10.0, - "enable_cfg": false, - "cpu_offload": false, - "offload_granularity": "block", - "t5_cpu_offload": false, - "vae_cpu_offload": false, - "use_image_encoder": false, - "dit_original_ckpt": "path/to/lingbot_world_fast/", - "ar_config": { - "local_attn_size": 21, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "sink_size": 3, - "kv_quant": { - "calibrate": false, - "calib_path": "calib_kv.pt", - "quant_scheme": "sage", - "k_cache_type": "int8", - "v_cache_type": "fp8" - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/lingbot_fast/lingbot_fast_i2v_kv_turboquant.json b/configs/lingbot_fast/lingbot_fast_i2v_kv_turboquant.json deleted file mode 100755 index 204804e3b..000000000 --- a/configs/lingbot_fast/lingbot_fast_i2v_kv_turboquant.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 161, - "text_len": 512, - "target_height": 720, - "target_width": 1280, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1.0, - "sample_shift": 10.0, - "enable_cfg": false, - "cpu_offload": false, - "offload_granularity": "block", - "t5_cpu_offload": false, - "vae_cpu_offload": false, - "use_image_encoder": false, - "dit_original_ckpt": "path/to/lingbot_world_fast/", - "ar_config": { - "local_attn_size": 21, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "sink_size": 3, - "kv_quant": { - "quant_scheme": "turboquant", - "calibrate": false, - "codebook_dir": "path/to/turboquant_codebooks", - "key_bits": 4, - "value_bits": 2 - }, - "kv_offload": true - }, - "causal_rope_type": "triton" -} diff --git a/configs/matrix_game2/matrix_game2_universal.json b/configs/matrix_game2/matrix_game2_universal.json index 9e84f5a20..b66ebda78 100644 --- a/configs/matrix_game2/matrix_game2_universal.json +++ b/configs/matrix_game2/matrix_game2_universal.json @@ -17,7 +17,7 @@ "local_attn_size": 6, "num_frame_per_chunk": 3, "denoising_step_list": [1000.0, 908.8427, 713.9794], - "kv_offload": true + "kv_offload": false }, "sub_model_folder": "base_distilled_model", "sub_model_name": "base_distill.safetensors", diff --git a/configs/seko_talk/ar/seko_talk_ar.json b/configs/seko_talk/ar/seko_talk_ar.json index 4cc0cf073..77c1fb663 100644 --- a/configs/seko_talk/ar/seko_talk_ar.json +++ b/configs/seko_talk/ar/seko_talk_ar.json @@ -35,7 +35,7 @@ "step_kv_cache": true, "sink_size": 2, "local_attn_size": 21, - "kv_offload": false, - "async_vae_decode": true + "kv_offload": true, + "async_vae_decode": false } } diff --git a/configs/seko_talk/ar/seko_talk_ar_kv_dist.json b/configs/seko_talk/ar/seko_talk_ar_kv_dist.json index d409e17bc..bf3e776f6 100644 --- a/configs/seko_talk/ar/seko_talk_ar_kv_dist.json +++ b/configs/seko_talk/ar/seko_talk_ar_kv_dist.json @@ -17,11 +17,11 @@ "offload_granularity": "block", "use_31_block": true, "dit_quantized": true, - "dit_quantized_ckpt": "/models/seko_ar/converted_fp8.safetensors", + "dit_quantized_ckpt": "/SekoTalk-Distill-AR/converted_fp8.safetensors", "dit_quant_scheme": "fp8-sgl", "adapter_quantized": true, "adapter_quant_scheme": "fp8", - "adapter_model_path": "/models/seko_ar/-audio_adapter_fp8.pt", + "adapter_model_path": "/SekoTalk-Distill-AR/-audio_adapter_fp8.pt", "audio_feature_dim": 1024, "audio_projection_dim": 1024, "audio_num_tokens": 32, diff --git a/configs/seko_talk/ar/seko_talk_ar_kv_longlivequant.json b/configs/seko_talk/ar/seko_talk_ar_kv_dist_5090.json similarity index 64% rename from configs/seko_talk/ar/seko_talk_ar_kv_longlivequant.json rename to configs/seko_talk/ar/seko_talk_ar_kv_dist_5090.json index 4e1d36411..04827a476 100644 --- a/configs/seko_talk/ar/seko_talk_ar_kv_longlivequant.json +++ b/configs/seko_talk/ar/seko_talk_ar_kv_dist_5090.json @@ -17,11 +17,13 @@ "offload_granularity": "block", "use_31_block": true, "dit_quantized": true, - "dit_quantized_ckpt": "/models/seko-distill-ar/converted_fp8.safetensors", + "dit_quantized_ckpt": "/SekoTalk-Distill-AR/converted_fp8.safetensors", "dit_quant_scheme": "fp8-sgl", + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, "adapter_quantized": true, "adapter_quant_scheme": "fp8", - "adapter_model_path": "/models/seko-distill-ar/-audio_adapter_fp8.pt", + "adapter_model_path": "/SekoTalk-Distill-AR/-audio_adapter_fp8.pt", "audio_feature_dim": 1024, "audio_projection_dim": 1024, "audio_num_tokens": 32, @@ -30,17 +32,22 @@ "look_ahead": 0.0, "audio_feat_window_neighbor_frame": 0, "audio_feat_fps": 50, + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "t5_cpu_offload": true, + "clip_cpu_offload": true, + "clip_quantized": false, + "parallel": { + "seq_p_size": 8, + "seq_p_fp8_comm": true, + "seq_p_attn_type": "ulysses" + }, "ar_config": { - "num_frame_per_chunk": 1, + "num_frame_per_chunk": 20, "step_kv_cache": true, "sink_size": 2, "local_attn_size": 21, "kv_offload": false, - "async_vae_decode": true, - "kv_quant": { - "quant_scheme": "longlive_fp4", - "scale_rule": "mse", - "backend": "triton" - } + "async_vae_decode": true } } diff --git a/configs/self_forcing/kv_calib/wan_t2v_sf_sagequant_calib.json b/configs/self_forcing/kv_calib/wan_t2v_sf_sagequant_calib.json deleted file mode 100755 index c12fc5a2a..000000000 --- a/configs/self_forcing/kv_calib/wan_t2v_sf_sagequant_calib.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "path/to/checkpoints/self_forcing_dmd.pt", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": true, - "calib_path": "calib_kv.pt", - "quant_scheme": "sage", - "k_cache_type": "int8", - "v_cache_type": "fp8" - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/self_forcing/kv_calib/wan_t2v_sf_sagequant_calib_14b.json b/configs/self_forcing/kv_calib/wan_t2v_sf_sagequant_calib_14b.json deleted file mode 100755 index 143f09687..000000000 --- a/configs/self_forcing/kv_calib/wan_t2v_sf_sagequant_calib_14b.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "path/to/krea-realtime-video-14b.safetensors", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": true, - "calib_path": "calib_kv.pt", - "quant_scheme": "sage", - "k_cache_type": "int8", - "v_cache_type": "fp8" - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/self_forcing/kv_calib/wan_t2v_sf_turboquant_calib.json b/configs/self_forcing/kv_calib/wan_t2v_sf_turboquant_calib.json deleted file mode 100755 index c5d0ceab6..000000000 --- a/configs/self_forcing/kv_calib/wan_t2v_sf_turboquant_calib.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "path/to/checkpoints/self_forcing_dmd.pt", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": true, - "quant_scheme": "turboquant", - "codebook_dir": "path/to/turboquant_codebooks", - "key_bits": 8, - "value_bits": 8 - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/self_forcing/kv_calib/wan_t2v_sf_turboquant_calib_14b.json b/configs/self_forcing/kv_calib/wan_t2v_sf_turboquant_calib_14b.json deleted file mode 100755 index ea84a045e..000000000 --- a/configs/self_forcing/kv_calib/wan_t2v_sf_turboquant_calib_14b.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "path/to/krea-realtime-video-14b.safetensors", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": true, - "quant_scheme": "turboquant", - "codebook_dir": "path/to/turboquant_codebooks", - "key_bits": 4, - "value_bits": 2 - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/self_forcing/wan_t2v_sf.json b/configs/self_forcing/wan_t2v_sf.json index eddfe4591..1647b7412 100755 --- a/configs/self_forcing/wan_t2v_sf.json +++ b/configs/self_forcing/wan_t2v_sf.json @@ -4,9 +4,9 @@ "text_len": 512, "target_height": 480, "target_width": 832, - "self_attn_1_type": "flash_attn2", - "cross_attn_1_type": "flash_attn2", - "cross_attn_2_type": "flash_attn2", + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", "sample_guide_scale": 1, "sample_shift": 5.0, "enable_cfg": false, @@ -17,6 +17,7 @@ "local_attn_size": -1, "num_frame_per_chunk": 3, "timesteps_index": [0, 179, 358, 679], - "kv_offload": false + "kv_offload": false, + "async_vae_decode": true } } diff --git a/configs/self_forcing/wan_t2v_sf_kv_sagequant.json b/configs/self_forcing/wan_t2v_sf_kv_sagequant.json deleted file mode 100755 index dc1ec33bb..000000000 --- a/configs/self_forcing/wan_t2v_sf_kv_sagequant.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2_k_int8_v_fp8", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "/data/nvme4/gushiqiao/Self-Forcing/checkpoints/self_forcing_dmd.pt", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": false, - "calib_path": "calib_kv.pt", - "quant_scheme": "sage", - "k_cache_type": "int8", - "v_cache_type": "fp8" - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/self_forcing/wan_t2v_sf_kv_turboquant.json b/configs/self_forcing/wan_t2v_sf_kv_turboquant.json deleted file mode 100755 index 7e992d9a9..000000000 --- a/configs/self_forcing/wan_t2v_sf_kv_turboquant.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "path/to/self_forcing_dmd.pt", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": false, - "quant_scheme": "turboquant", - "codebook_dir": "path/to/turboquant_codebooks", - "key_bits": 8, - "value_bits": 8 - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/configs/self_forcing/wan_t2v_sf_kv_turboquant_14b.json b/configs/self_forcing/wan_t2v_sf_kv_turboquant_14b.json deleted file mode 100755 index 6259e6986..000000000 --- a/configs/self_forcing/wan_t2v_sf_kv_turboquant_14b.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "infer_steps": 4, - "target_video_length": 81, - "text_len": 512, - "target_height": 480, - "target_width": 832, - "self_attn_1_type": "sage_attn2", - "cross_attn_1_type": "sage_attn2", - "cross_attn_2_type": "sage_attn2", - "sample_guide_scale": 1, - "sample_shift": 5.0, - "enable_cfg": false, - "cpu_offload": false, - "dit_original_ckpt": "path/to/krea-realtime-video-14b.safetensors", - "ar_config": { - "local_attn_size": -1, - "num_frame_per_chunk": 3, - "timesteps_index": [0, 179, 358, 679], - "kv_quant": { - "calibrate": false, - "quant_scheme": "turboquant", - "codebook_dir": "path/to/turboquant_codebooks", - "key_bits": 4, - "value_bits": 2 - }, - "kv_offload": false - }, - "causal_rope_type": "triton" -} diff --git a/lightx2v/common/kvcache/__init__.py b/lightx2v/common/kvcache/__init__.py index c2cffe2e1..1459eca49 100755 --- a/lightx2v/common/kvcache/__init__.py +++ b/lightx2v/common/kvcache/__init__.py @@ -1,14 +1,11 @@ -from .calib import CalibRollingKVCachePool from .manager import KVCacheManager -from .quant import LongLiveQuantRollingKVCachePool, SageQuantRollingKVCachePool, StepLongLiveQuantRollingKVCachePool +from .quant import KIVIQuantRollingKVCachePool, StepKiviQuantRollingKVCachePool from .rolling import RollingKVCachePool, SpatialRollingKVCachePool __all__ = [ "KVCacheManager", "RollingKVCachePool", "SpatialRollingKVCachePool", - "CalibRollingKVCachePool", - "SageQuantRollingKVCachePool", - "LongLiveQuantRollingKVCachePool", - "StepLongLiveQuantRollingKVCachePool", + "KIVIQuantRollingKVCachePool", + "StepKiviQuantRollingKVCachePool", ] diff --git a/lightx2v/common/kvcache/calib.py b/lightx2v/common/kvcache/calib.py deleted file mode 100755 index 151aca27d..000000000 --- a/lightx2v/common/kvcache/calib.py +++ /dev/null @@ -1,319 +0,0 @@ -import torch - -try: - from sageattention.triton.quant_per_thread import quant_key_per_thread_int8_kernel -except ImportError: - quant_key_per_thread_int8_kernel = None - -from .rolling import RollingKVCachePool -from .utils import tq_fw_generate_rotation_matrix - - -class CalibRollingKVCachePool(RollingKVCachePool): - _BLKK = 128 - _SCALES_PER_BLK = 4 # WARPK=128 ⇒ 4 thread groups per block per head - - def __init__( - self, - num_layers: int, - cache_size: int, - num_heads: int, - head_dim: int, - dtype: torch.dtype, - device: torch.device, - num_steps: int = 1, - *, - turboquant_calibrate: bool = False, - key_bits: int = 3, - turboquant_seed: int = 42, - per_layer_compressors: bool = True, - ) -> None: - self._num_steps = num_steps - self.current_step: int = 0 - self._turboquant_calibrate = bool(turboquant_calibrate) - self._tq_key_bits = int(key_bits) - self._turboquant_seed = int(turboquant_seed) - self._tq_per_layer = bool(per_layer_compressors) - if self._turboquant_calibrate and self._tq_key_bits < 2: - raise ValueError("TurboQuantProd calibration requires key_bits >= 2") - super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device) - - def _init_kv_buffer(self) -> None: - super()._init_kv_buffer() - S = self._num_steps - L, H, D = self._num_layers, self._num_heads, self._head_dim - self._captured_window_size = torch.zeros(S, L, dtype=torch.long, device="cpu") - - if self._turboquant_calibrate: - self._tq_hist_k = torch.zeros(4096, dtype=torch.int64, device=self._device) - return - - BLK = self._BLKK - max_blks = (self._cache_size + BLK - 1) // BLK - self._km = torch.zeros(S, L, 1, H, D, dtype=torch.float32, device=self._device) - self._v_channel_max = torch.zeros(S, L, H, D, dtype=torch.float32, device=self._device) - self._k_block_scale_calib = torch.zeros( - S, - L, - max_blks, - H, - self._SCALES_PER_BLK, - dtype=torch.float32, - device=self._device, - ) - self._capture_flag = torch.zeros(S, L, dtype=torch.bool, device=self._device) - - def _calib_k_buffer(self, layer_id: int) -> torch.Tensor: - return self._k_buffer[layer_id] - - def _calib_v_buffer(self, layer_id: int) -> torch.Tensor: - return self._v_buffer[layer_id] - - def _quant_key(self, k: torch.Tensor, km: torch.Tensor | None = None, BLKK: int = 128, WARPK: int = 128): - """Run sage's per_thread int8 K-quantisation kernel on ``k``. - - Returns ``(k_int8, k_scale)`` where ``k`` is ``[B, kv_len, H, D]`` (NHD). - The km subtraction (if any) is done in ``k.dtype`` to match sage's - behaviour exactly — sage does ``k - km`` in bf16, NOT fp32. - - This is the source-of-truth quantisation used both at calibration time - (to capture the per-block scale we'll later replay) and as a reference - for the preset-scale quantisation path. - """ - if km is not None: - km_lowp = km.to(k.dtype) if km.dtype != k.dtype else km - k = k - km_lowp - - k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) - b, kv_len, h_kv, head_dim = k.shape - - stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) - stride_bz_ko, stride_h_ko, stride_seq_ko = ( - k_int8.stride(0), - k_int8.stride(2), - k_int8.stride(1), - ) - - num_blk = (kv_len + BLKK - 1) // BLKK - scales_per_blk = (BLKK // WARPK) * 4 - k_scale = torch.empty( - (b, h_kv, num_blk * scales_per_blk), - device=k.device, - dtype=torch.float32, - ) - - grid = (num_blk * scales_per_blk, h_kv, b) - quant_key_per_thread_int8_kernel[grid]( - k, - k_int8, - k_scale, - kv_len, - stride_bz_k, - stride_h_k, - stride_seq_k, - stride_bz_ko, - stride_h_ko, - stride_seq_ko, - k_scale.stride(0), - k_scale.stride(1), - C=head_dim, - BLK=WARPK, - ) - return k_int8, k_scale - - def capture_attn( - self, - layer_id: int, - attn_start: int, - local_end: int, - ) -> None: - """Capture calibration data from the current buffer state. - - For TurboQuant calibration, only the rotated-unit-K histogram needed - for empirical codebook export is collected. Otherwise this captures - (km, v_channel_max, k_block_scale) from exactly what sage_attn would see. - - Parameters - ---------- - attn_start : start position of the attention window in the buffer - (may not be 128-aligned). - local_end : end position (exclusive) — the buffer's current valid - length for this layer. - - The captured K slice is aligned down to the nearest 128 boundary - so per-block scales map cleanly to buffer block indices. - """ - BLK = self._BLKK - aligned_start = (attn_start // BLK) * BLK - step, layer = self.current_step, layer_id - - k_full = self._calib_k_buffer(layer_id)[aligned_start:local_end] # [kv_len_a, H, D] bf16 - kv_len_a = k_full.size(0) - if kv_len_a == 0: - return - - prev_window = int(self._captured_window_size[step, layer].item()) - if 0 < prev_window >= kv_len_a: - return - self._captured_window_size[step, layer] = kv_len_a - - if self._turboquant_calibrate: - self._capture_turboquant_marginals(layer_id, k_full) - return - - v_full = self._calib_v_buffer(layer_id)[aligned_start:local_end] # [kv_len_a, H, D] bf16 - - # ---- km (bf16 mean to match sage) ---- - km_lowp = k_full.mean(dim=0, keepdim=True) # bf16 [1, H, D] - self._km[step, layer] = km_lowp.to(torch.float32) - - # ---- k_block_scale via sage's quant kernel on (k - km) ---- - k_batch = k_full.unsqueeze(0).contiguous() # [1, kv_len_a, H, D] - _, k_scale_raw = self._quant_key(k_batch, km_lowp) # [1, H, num_blk*4] - num_blk_local = (kv_len_a + BLK - 1) // BLK - k_scale_local = k_scale_raw[0].reshape(self._num_heads, num_blk_local, self._SCALES_PER_BLK).permute(1, 0, 2) # [num_blk_local, H, 4] - blk_offset = aligned_start // BLK - self._k_block_scale_calib[step, layer, blk_offset : blk_offset + num_blk_local] = k_scale_local - self._v_channel_max[step, layer] = v_full.float().abs().amax(dim=0) # [H, D] - self._capture_flag[step, layer] = True - - def _capture_turboquant_marginals( - self, - layer_id: int, - k_full: torch.Tensor, - ) -> None: - """Histogram rotated coordinate marginals (same convention as TurboQuant inference).""" - D = self._head_dim - dev = self._device - nb = self._tq_hist_k.numel() - - seed_k = self._turboquant_seed + layer_id * 7 if self._tq_per_layer else self._turboquant_seed - Pi_k = tq_fw_generate_rotation_matrix(D, dev, torch.float32, seed=seed_k) - x = k_full.float() - norms = x.norm(dim=-1, keepdim=True).clamp(min=1e-10) - x_unit = x / norms - y = torch.matmul(x_unit, Pi_k.T).clamp(-1.0 + 1e-7, 1.0 - 1e-7) - idx = ((y + 1.0) * (0.5 * (nb - 1))).long().clamp(0, nb - 1) - ones = torch.ones(idx.numel(), dtype=torch.int64, device=dev) - self._tq_hist_k.scatter_add_(0, idx.reshape(-1), ones) - - def export_calibration(self) -> dict[str, torch.Tensor]: - if self._turboquant_calibrate: - return {"_turboquant_hist_k": self._tq_hist_k.clone()} - - v_scale = self._v_channel_max.clamp(min=1e-5) / 448.0 - out: dict[str, torch.Tensor] = { - "km": self._km.clone(), - "v_scale": v_scale, - "k_block_scale": self._k_block_scale_calib.clone(), - } - return out - - def reset(self) -> None: - super().reset() - self._captured_window_size.zero_() - if self._turboquant_calibrate: - self._tq_hist_k.zero_() - return - - self._km.zero_() - self._v_channel_max.zero_() - self._k_block_scale_calib.zero_() - self._capture_flag.zero_() - - -class StepCalibRollingKVCachePool(CalibRollingKVCachePool): - """Step-isolated calibration pool for step-dependent reference K/V.""" - - def _step(self) -> int: - return int(self.current_step) - - def _init_kv_buffer(self) -> None: - S = self._num_steps - L, N, H, D = self._num_layers, self._cache_size, self._num_heads, self._head_dim - self._k_buffer = torch.zeros(S, L, N, H, D, dtype=self._dtype, device=self._device) - self._v_buffer = torch.zeros(S, L, N, H, D, dtype=self._dtype, device=self._device) - self._global_end = torch.zeros(S, L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(S, L, dtype=torch.long, device="cpu") - self._captured_window_size = torch.zeros(S, L, dtype=torch.long, device="cpu") - - if self._turboquant_calibrate: - self._tq_hist_k = torch.zeros(4096, dtype=torch.int64, device=self._device) - return - - BLK = self._BLKK - max_blks = (self._cache_size + BLK - 1) // BLK - self._km = torch.zeros(S, L, 1, H, D, dtype=torch.float32, device=self._device) - self._v_channel_max = torch.zeros(S, L, H, D, dtype=torch.float32, device=self._device) - self._k_block_scale_calib = torch.zeros( - S, - L, - max_blks, - H, - self._SCALES_PER_BLK, - dtype=torch.float32, - device=self._device, - ) - self._capture_flag = torch.zeros(S, L, dtype=torch.bool, device=self._device) - - def _calib_k_buffer(self, layer_id: int) -> torch.Tensor: - return self._k_buffer[self._step(), layer_id] - - def _calib_v_buffer(self, layer_id: int) -> torch.Tensor: - return self._v_buffer[self._step(), layer_id] - - def store_kv( - self, - k: torch.Tensor, - v: torch.Tensor, - start_idx: int, - end_idx: int, - layer_id: int, - ) -> None: - self._calib_k_buffer(layer_id)[start_idx:end_idx] = k - self._calib_v_buffer(layer_id)[start_idx:end_idx] = v - - def k_cache( - self, - layer_id: int, - attn_start: int | None = None, - local_end: int | None = None, - ) -> torch.Tensor: - kb = self._calib_k_buffer(layer_id) - if attn_start is None and local_end is None: - return kb - return kb[attn_start:local_end] - - def v_cache( - self, - layer_id: int, - attn_start: int | None = None, - local_end: int | None = None, - ) -> torch.Tensor: - vb = self._calib_v_buffer(layer_id) - if attn_start is None and local_end is None: - return vb - return vb[attn_start:local_end] - - def get_global_end(self, layer_id: int) -> int: - return int(self._global_end[self._step(), layer_id].item()) - - def get_local_end(self, layer_id: int) -> int: - return int(self._local_end[self._step(), layer_id].item()) - - def set_ends(self, layer_id: int, global_end: int, local_end: int) -> None: - step = self._step() - self._global_end[step, layer_id] = global_end - self._local_end[step, layer_id] = local_end - - def roll_window(self, layer_id: int, sink_tokens: int, num_evicted: int) -> None: - num_kept = self.get_local_end(layer_id) - num_evicted - sink_tokens - if num_kept <= 0: - return - src_start = sink_tokens + num_evicted - src_end = src_start + num_kept - dst_start = sink_tokens - dst_end = dst_start + num_kept - kb, vb = self._calib_k_buffer(layer_id), self._calib_v_buffer(layer_id) - kb[dst_start:dst_end].copy_(kb[src_start:src_end].clone()) - vb[dst_start:dst_end].copy_(vb[src_start:src_end].clone()) diff --git a/lightx2v/common/kvcache/manager.py b/lightx2v/common/kvcache/manager.py index ede5dbc14..b1a1e6288 100755 --- a/lightx2v/common/kvcache/manager.py +++ b/lightx2v/common/kvcache/manager.py @@ -1,6 +1,3 @@ -import json -import os - import torch import torch.distributed as dist from loguru import logger @@ -8,16 +5,7 @@ from lightx2v.utils.envs import GET_DTYPE from .base import BaseKVCachePool -from .calib import CalibRollingKVCachePool, StepCalibRollingKVCachePool -from .quant import ( - KIVIQuantRollingKVCachePool, - LongLiveQuantRollingKVCachePool, - SageQuantRollingKVCachePool, - StepKiviQuantRollingKVCachePool, - StepLongLiveQuantRollingKVCachePool, - StepTurboQuantRollingKVCachePool, - TurboQuantRollingKVCachePool, -) +from .quant import KIVIQuantRollingKVCachePool, StepKiviQuantRollingKVCachePool from .rolling import RollingKVCachePool, SpatialRollingKVCachePool, StepRollingKVCachePool from .utils import * @@ -51,29 +39,6 @@ def _step_fp_kwargs(config, ar_config, _kv_quant): } -def _sage_kwargs(_config, ar_config, kv_quant): - return { - "k_cache_type": kv_quant.get("k_cache_type", "int8"), - "v_cache_type": kv_quant.get("v_cache_type", "fp8"), - "calib_path": kv_quant.get("calib_path", None), - "kv_offload": ar_config.get("kv_offload", False), - } - - -def _turboquant_kwargs(_config, ar_config, kv_quant): - return { - "key_bits": kv_quant.get("key_bits", 3), - "value_bits": kv_quant.get("value_bits", 2), - "seed": kv_quant.get("turboquant_seed", kv_quant.get("seed", 42)), - "per_layer_compressors": kv_quant.get("per_layer_compressors", True), - "kv_offload": ar_config.get("kv_offload", False), - "codebook_dir": kv_quant.get("codebook_dir"), - "codebook_cache_dir": kv_quant.get("codebook_cache_dir"), - "export_missing_codebooks": kv_quant.get("export_missing_codebooks", False), - "value_group_size": kv_quant.get("value_group_size", 32), - } - - def _kivi_kwargs(_config, ar_config, kv_quant): return { "k_cache_type": kv_quant.get("k_cache_type", "int4"), @@ -89,42 +54,6 @@ def _step_kivi_kwargs(config, ar_config, kv_quant): return kwargs -def _step_turboquant_kwargs(config, ar_config, kv_quant): - kwargs = _turboquant_kwargs(config, ar_config, kv_quant) - kwargs["num_steps"] = config.get("infer_steps", ar_config.get("cache_step", 1)) - return kwargs - - -def _longlive_fp4_kwargs(_config, ar_config, kv_quant, *, frame_seq_length: int | None = None): - block_token_size = kv_quant.get("block_token_size") - if block_token_size is None and frame_seq_length is not None: - block_token_size = frame_seq_length * ar_config.get("num_frame_per_chunk", 1) - return { - "block_token_size": block_token_size, - "scale_rule": kv_quant.get("scale_rule", "mse"), - "backend": kv_quant.get("backend", "pytorch"), - "kv_offload": ar_config.get("kv_offload", False), - } - - -def _step_longlive_fp4_kwargs(config, ar_config, kv_quant, *, frame_seq_length: int | None = None): - kwargs = _longlive_fp4_kwargs(config, ar_config, kv_quant, frame_seq_length=frame_seq_length) - kwargs["num_steps"] = config.get("infer_steps", ar_config.get("cache_step", 1)) - return kwargs - - -def _calib_kwargs(config, _ar_config, kv_quant): - kwargs = {"num_steps": config.get("infer_steps", 1)} - if kv_quant.get("quant_scheme") == "turboquant": - kwargs.update( - turboquant_calibrate=True, - key_bits=kv_quant.get("key_bits", 3), - turboquant_seed=kv_quant.get("turboquant_seed", kv_quant.get("seed", 42)), - per_layer_compressors=kv_quant.get("per_layer_compressors", True), - ) - return kwargs - - def _get_self_attn_kv_cache_entry(scheme: str, step: bool): entry = SELF_ATTN_KV_CACHE_REGISTRY.get((scheme, bool(step))) if entry is None: @@ -134,20 +63,8 @@ def _get_self_attn_kv_cache_entry(scheme: str, step: bool): register_self_attn_kv_cache("fp", RollingKVCachePool, kwargs_builder=_fp_kwargs) register_self_attn_kv_cache("fp", StepRollingKVCachePool, step=True, kwargs_builder=_step_fp_kwargs) -register_self_attn_kv_cache("calib", CalibRollingKVCachePool, kwargs_builder=_calib_kwargs) -register_self_attn_kv_cache("calib", StepCalibRollingKVCachePool, step=True, kwargs_builder=_calib_kwargs) -register_self_attn_kv_cache("sage", SageQuantRollingKVCachePool, kwargs_builder=_sage_kwargs) -register_self_attn_kv_cache("turboquant", TurboQuantRollingKVCachePool, kwargs_builder=_turboquant_kwargs) -register_self_attn_kv_cache("turboquant", StepTurboQuantRollingKVCachePool, step=True, kwargs_builder=_step_turboquant_kwargs) register_self_attn_kv_cache("kivi", KIVIQuantRollingKVCachePool, kwargs_builder=_kivi_kwargs) register_self_attn_kv_cache("kivi", StepKiviQuantRollingKVCachePool, step=True, kwargs_builder=_step_kivi_kwargs) -register_self_attn_kv_cache("longlive_fp4", LongLiveQuantRollingKVCachePool, kwargs_builder=_longlive_fp4_kwargs) -register_self_attn_kv_cache( - "longlive_fp4", - StepLongLiveQuantRollingKVCachePool, - step=True, - kwargs_builder=_step_longlive_fp4_kwargs, -) def build_self_attn_kv_cache(config, ar_config, kv_size, dtype, device, *, frame_seq_length: int | None = None, num_heads: int | None = None): @@ -158,25 +75,17 @@ def build_self_attn_kv_cache(config, ar_config, kv_size, dtype, device, *, frame scheme = "fp" step = ar_config.get("step_kv_cache", False) else: - quant_scheme = kv_quant.get("quant_scheme", "sage") - registered_schemes = {registered_scheme for registered_scheme, _step in SELF_ATTN_KV_CACHE_REGISTRY if registered_scheme not in {"fp", "calib"}} - if config.get("parallel"): - assert quant_scheme in {"kivi", "longlive_fp4"}, f"Invalid quant_scheme: {quant_scheme} for parallel inference" - assert quant_scheme in registered_schemes, f"Invalid quant_scheme: {quant_scheme}" + quant_scheme = kv_quant.get("quant_scheme", "kivi") + if quant_scheme != "kivi": + raise NotImplementedError(f"Only quant_scheme='kivi' is supported, got {quant_scheme!r}.") if kv_quant.get("calibrate", False): - scheme = "calib" - step = ar_config.get("step_kv_cache", False) + raise NotImplementedError("KV calibration caches were removed; only KIVI inference cache is supported.") else: scheme = quant_scheme step = ar_config.get("step_kv_cache", False) - if step and scheme == "sage": - raise NotImplementedError("step_kv_cache does not support quant_scheme='sage'. Use step_kv_cache with quant_scheme='kivi', or disable step_kv_cache for sage.") cache_cls, kwargs_builder = _get_self_attn_kv_cache_entry(scheme, step) - extra = {} - if scheme == "longlive_fp4": - extra["frame_seq_length"] = frame_seq_length - return cache_cls(**common, **kwargs_builder(config, ar_config, kv_quant or {}, **extra)) + return cache_cls(**common, **kwargs_builder(config, ar_config, kv_quant or {})) class KVCacheManager: @@ -342,7 +251,6 @@ def _create_matrix_action_kv_caches(self) -> None: self.action_keyboard_kv_cache._init_kv_buffer() if ac.get("enable_mouse", False): - kv_offload = bool(self.ar_config.get("kv_offload", False)) self.action_mouse_kv_cache = SpatialRollingKVCachePool( spatial_len=self.frame_seq_length, num_layers=num_layers, @@ -351,70 +259,6 @@ def _create_matrix_action_kv_caches(self) -> None: head_dim=head_dim, dtype=self.dtype, device=self.device, - kv_offload=kv_offload, + kv_offload=False, ) self.action_mouse_kv_cache._init_kv_buffer() - - def save_calibration(self) -> None: - """Auto-save calibration if running in calibrate mode with calib_path.""" - kv_quant = self.ar_config.get("kv_quant") - if not kv_quant or not isinstance(kv_quant, dict): - return - if not kv_quant.get("calibrate", False): - return - output_path = kv_quant.get("calib_path", "calib_kv.pt") - pool = self.self_attn_kv_cache - if not isinstance(pool, CalibRollingKVCachePool): - return - calib = pool.export_calibration() - hk = calib.pop("_turboquant_hist_k", None) - - rank = 0 - world_size = 1 - pg = None - if dist.is_available() and dist.is_initialized(): - if self.sp_group is not None: - rank = dist.get_rank(self.sp_group) - world_size = dist.get_world_size(self.sp_group) - pg = self.sp_group - else: - rank = dist.get_rank() - world_size = dist.get_world_size() - - if hk is not None: - hk_acc = hk.to(device=self.device, dtype=torch.int64) - if world_size > 1: - dist.all_reduce(hk_acc, op=dist.ReduceOp.SUM, group=pg) - if rank == 0: - out_dir = kv_quant.get("codebook_dir") - if not out_dir: - out_dir = os.path.dirname(os.path.abspath(output_path)) or "." - os.makedirs(out_dir, exist_ok=True) - head_dim = self.config["dim"] // self.config["num_heads"] - books = build_turboquant_codebooks_from_calib_histograms( - hk_acc.cpu(), - head_dim=head_dim, - key_bits=kv_quant.get("key_bits", 3), - ) - for fname, cb_dict in books.items(): - fpath = os.path.join(out_dir, fname) - with open(fpath, "w", encoding="utf-8") as f: - json.dump(cb_dict, f, indent=2) - logger.info("[KVCacheManager] TurboQuant empirical codebook written {!r}", fpath) - - if not calib: - return - - save_path = output_path - if world_size > 1: - save_path = ranked_calib_path(output_path, rank) - torch.save(calib, save_path) - logger.info( - "[KVCacheManager] calibration saved to {} (rank {}/{}) — km {}, v_scale {}, k_block_scale {}", - save_path, - rank, - world_size, - list(calib["km"].shape), - list(calib["v_scale"].shape), - list(calib["k_block_scale"].shape), - ) diff --git a/lightx2v/common/kvcache/quant.py b/lightx2v/common/kvcache/quant.py index 8917705f4..bfef86ebd 100755 --- a/lightx2v/common/kvcache/quant.py +++ b/lightx2v/common/kvcache/quant.py @@ -1,1182 +1,10 @@ -import os - import torch -import torch.distributed as dist from loguru import logger from .kernel import * from .rolling import RollingKVCachePool from .utils import * -try: - from fouroversix import quantize_to_fp4 - from fouroversix.quantize.quantized_tensor import QuantizedTensor -except ImportError: - QuantizedTensor = None - quantize_to_fp4 = None - - -# ============================================================================= -# Generic token-ring helper for quantized rolling caches. -# -# Logical layout exposed to callers: -# [sink logical tokens][recent logical tokens] -# -# Physical layout in cache tensors: -# [sink fixed region][recent ring region] -# -# `roll_window()` is O(1): it updates head/len metadata and never moves the -# kept window. `k_cache()` / `v_cache()` materialize a contiguous logical range -# by reading one or more physical chunks and concatenating them. -# ============================================================================= - - -class _QuantTokenRingMixin: - def _ring_index(self, layer_id: int): - return int(layer_id) - - def _init_ring_metadata(self, *shape: int) -> None: - self._ring_active = torch.zeros(*shape, dtype=torch.bool, device="cpu") - self._ring_sink_len = torch.zeros(*shape, dtype=torch.long, device="cpu") - self._ring_recent_head = torch.zeros(*shape, dtype=torch.long, device="cpu") - self._ring_recent_len = torch.zeros(*shape, dtype=torch.long, device="cpu") - - def _ring_is_active(self, layer_id: int) -> bool: - return bool(self._ring_active[self._ring_index(layer_id)].item()) - - def _set_ring_active(self, layer_id: int, value: bool) -> None: - self._ring_active[self._ring_index(layer_id)] = bool(value) - - def _get_sink_len(self, layer_id: int) -> int: - return int(self._ring_sink_len[self._ring_index(layer_id)].item()) - - def _set_sink_len(self, layer_id: int, value: int) -> None: - self._ring_sink_len[self._ring_index(layer_id)] = int(value) - - def _get_recent_head(self, layer_id: int) -> int: - return int(self._ring_recent_head[self._ring_index(layer_id)].item()) - - def _set_recent_head(self, layer_id: int, value: int) -> None: - self._ring_recent_head[self._ring_index(layer_id)] = int(value) - - def _get_recent_len(self, layer_id: int) -> int: - return int(self._ring_recent_len[self._ring_index(layer_id)].item()) - - def _set_recent_len(self, layer_id: int, value: int) -> None: - self._ring_recent_len[self._ring_index(layer_id)] = int(value) - - def _zero_tensors(self, names: list[str]) -> None: - for name in names: - getattr(self, name).zero_() - - def _reset_ring(self) -> None: - self._zero_tensors(["_ring_active", "_ring_sink_len", "_ring_recent_head", "_ring_recent_len"]) - - def _reset_ends(self) -> None: - self._zero_tensors(["_global_end", "_local_end"]) - - def _recent_capacity(self, layer_id: int) -> int: - return int(self._cache_size) - self._get_sink_len(layer_id) - - def _recent_offset_to_physical_chunks( - self, - layer_id: int, - recent_offset: int, - length: int, - ) -> list[tuple[int, int]]: - if length <= 0: - return [] - sink_len = self._get_sink_len(layer_id) - cap = self._recent_capacity(layer_id) - if cap <= 0: - raise RuntimeError("ring cache has no recent capacity") - head = self._get_recent_head(layer_id) - pos = (head + int(recent_offset)) % cap - first = min(int(length), cap - pos) - out = [(sink_len + pos, sink_len + pos + first)] - remain = int(length) - first - if remain > 0: - out.append((sink_len, sink_len + remain)) - return out - - def _logical_to_physical_chunks( - self, - layer_id: int, - start: int, - end: int, - ) -> list[tuple[int, int]]: - start = int(start) - end = int(end) - if end <= start: - return [] - if not self._ring_is_active(layer_id): - return [(start, end)] - - sink_len = self._get_sink_len(layer_id) - recent_len = self._get_recent_len(layer_id) - logical_end = sink_len + recent_len - if end > logical_end: - raise RuntimeError(f"ring read exceeds logical end: read=[{start},{end}), logical_end={logical_end}, sink={sink_len}, recent_len={recent_len}") - - chunks: list[tuple[int, int]] = [] - # Fixed sink: logical == physical. - s0, s1 = start, min(end, sink_len) - if s1 > s0: - chunks.append((s0, s1)) - - # Recent ring. - r0, r1 = max(start, sink_len), end - if r1 > r0: - chunks.extend( - self._recent_offset_to_physical_chunks( - layer_id, - recent_offset=r0 - sink_len, - length=r1 - r0, - ) - ) - return chunks - - def _logical_store_chunks( - self, - layer_id: int, - start: int, - end: int, - ) -> list[tuple[int, int]]: - # Store range uses the same mapping as read range. If it extends the - # recent length, update recent metadata after physical stores. - return self._logical_to_physical_chunks_for_store(layer_id, start, end) - - def _logical_to_physical_chunks_for_store( - self, - layer_id: int, - start: int, - end: int, - ) -> list[tuple[int, int]]: - start = int(start) - end = int(end) - if end <= start: - return [] - if not self._ring_is_active(layer_id): - return [(start, end)] - - sink_len = self._get_sink_len(layer_id) - chunks: list[tuple[int, int]] = [] - - if start < sink_len: - s1 = min(end, sink_len) - chunks.append((start, s1)) - if s1 == end: - return chunks - start = sink_len - - recent_offset = start - sink_len - length = end - start - chunks.extend(self._recent_offset_to_physical_chunks(layer_id, recent_offset, length)) - self._set_recent_len(layer_id, max(self._get_recent_len(layer_id), recent_offset + length)) - return chunks - - def roll_window(self, layer_id: int, sink_tokens: int, num_evicted: int) -> None: - old_local_end = self.get_local_end(layer_id) - sink_tokens = int(sink_tokens) - num_evicted = int(num_evicted) - num_kept = old_local_end - num_evicted - sink_tokens - if num_kept <= 0: - self._set_ring_active(layer_id, True) - self._set_sink_len(layer_id, sink_tokens) - self._set_recent_head(layer_id, 0) - self._set_recent_len(layer_id, 0) - return - - if not self._ring_is_active(layer_id): - self._set_ring_active(layer_id, True) - self._set_sink_len(layer_id, sink_tokens) - cap = self._recent_capacity(layer_id) - if num_kept > cap: - raise RuntimeError(f"ring kept tokens {num_kept} exceed recent capacity {cap}") - # Before first roll physical layout is contiguous: - # [sink][evicted][kept]. Recent ring starts after sink, so the - # kept range physical offset inside recent ring is num_evicted. - self._set_recent_head(layer_id, num_evicted % cap) - self._set_recent_len(layer_id, num_kept) - return - - if sink_tokens != self._get_sink_len(layer_id): - raise RuntimeError(f"ring sink size changed: old={self._get_sink_len(layer_id)}, new={sink_tokens}") - cap = self._recent_capacity(layer_id) - recent_len = self._get_recent_len(layer_id) - if num_evicted > recent_len: - raise RuntimeError(f"ring evict exceeds recent length: evict={num_evicted}, recent_len={recent_len}") - self._set_recent_head(layer_id, (self._get_recent_head(layer_id) + num_evicted) % cap) - self._set_recent_len(layer_id, recent_len - num_evicted) - - -# ============================================================================= -# SageQuant -# ============================================================================= - - -class SageQuantRollingKVCachePool(_QuantTokenRingMixin, RollingKVCachePool): - _BLKK = 128 - _SCALES_PER_BLK = 4 - _PERM_16_VAL = [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] - _INV_PERM_16_VAL = [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15] - - def __init__( - self, - num_layers: int, - cache_size: int, - num_heads: int, - head_dim: int, - dtype: torch.dtype, - device: torch.device, - k_cache_type: str = "int8", - v_cache_type: str = "fp8", - calib_path: str = None, - kv_offload: bool = False, - ) -> None: - assert k_cache_type in ["int8"] - assert v_cache_type in ["fp8", "fp16"] - self._k_cache_type = k_cache_type - self._v_cache_type = v_cache_type - self._calib_path = calib_path - self.current_step: int = 0 - self._PERM_16 = torch.tensor(self._PERM_16_VAL, dtype=torch.long, device=device) - self._INV_PERM_16 = torch.tensor(self._INV_PERM_16_VAL, dtype=torch.long, device=device) - self._load_calib(device=device) - super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device, kv_offload=kv_offload) - - def _init_kv_buffer(self) -> None: - if self._kv_offload: - self._init_kv_buffer_offload() - return - L, N, H, D = self._num_layers, self._cache_size, self._num_heads, self._head_dim - self._k_buffer = torch.zeros(L, N, H, D, dtype=torch.int8, device=self._device) - v_dtype = torch.float8_e4m3fn if self._v_cache_type == "fp8" else torch.float16 - self._v_buffer = torch.zeros(L, N, H, D, dtype=v_dtype, device=self._device) - self._global_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._init_ring_metadata(L) - - def _init_kv_buffer_offload(self) -> None: - L, N, H, D = self._num_layers, self._cache_size, self._num_heads, self._head_dim - v_dtype = torch.float8_e4m3fn if self._v_cache_type == "fp8" else torch.float16 - self._k_cpu = torch.zeros(L, N, H, D, dtype=torch.int8, device="cpu").pin_memory() - self._v_cpu = torch.zeros(L, N, H, D, dtype=v_dtype, device="cpu").pin_memory() - self._k_gpu_buf = torch.zeros(N, H, D, dtype=torch.int8, device=self._device) - self._v_gpu_buf = torch.zeros(N, H, D, dtype=v_dtype, device=self._device) - self._global_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._init_ring_metadata(L) - self._init_offload_state((L,)) - - def _load_calib(self, device=torch.device("cuda")) -> None: - load_path = self._calib_path - if dist.is_available() and dist.is_initialized() and self._calib_path is not None: - rank = dist.get_rank() - rank_path = ranked_calib_path(self._calib_path, rank) - if os.path.exists(rank_path): - load_path = rank_path - calib = torch.load(load_path, map_location=device, weights_only=True) - self._calib_km = calib["km"].to(device=device, dtype=torch.float32) - self._calib_v_scale = calib["v_scale"].to(device=device, dtype=torch.float32) - self._calib_k_block_scale = calib["k_block_scale"].to(device=device, dtype=torch.float32) - - def _sage_k_storage(self, layer_id: int) -> torch.Tensor: - if self._kv_offload: - return self._k_cpu[layer_id] - return self._k_buffer[layer_id] - - def _sage_v_storage(self, layer_id: int) -> torch.Tensor: - if self._kv_offload: - return self._v_cpu[layer_id] - return self._v_buffer[layer_id] - - def _lookup_km(self, layer_id: int) -> torch.Tensor | None: - km_cal = self._calib_km - if km_cal.dim() == 5: - return km_cal[self.current_step, layer_id].unsqueeze(0) - return km_cal[layer_id].unsqueeze(0) - - def _lookup_v_scale(self, layer_id: int) -> torch.Tensor: - vs_cal = self._calib_v_scale - if vs_cal.dim() == 4: - return vs_cal[self.current_step, layer_id] - return vs_cal[layer_id] - - def _lookup_k_block_scale(self, layer_id: int, blk_start: int, num_blk: int) -> torch.Tensor: - return self._calib_k_block_scale[self.current_step, layer_id, blk_start : blk_start + num_blk] - - def _quant_key(self, k_smoothed: torch.Tensor, preset_scale: torch.Tensor, start_idx: int, BLKK: int = 128) -> torch.Tensor: - chunk_len, H, D = k_smoothed.shape - num_blk = preset_scale.size(0) - k_int8 = torch.empty_like(k_smoothed, dtype=torch.int8) - preset_scale_c = preset_scale.contiguous() - grid = (num_blk * 4, H, 1) - quant_key_per_thread_int8_static_scale_kernel[grid]( - k_smoothed, - k_int8, - preset_scale_c, - chunk_len, - start_idx, - 0, - k_smoothed.stride(1), - k_smoothed.stride(0), - 0, - k_int8.stride(1), - k_int8.stride(0), - preset_scale_c.stride(0), - preset_scale_c.stride(1), - C=D, - BLK=BLKK, - ) - return k_int8 - - def _physical_store_kv(self, k: torch.Tensor, v: torch.Tensor, p0: int, p1: int, layer_id: int) -> None: - km = self._lookup_km(layer_id) - if km is not None: - k_smoothed = k - km.to(k.dtype).squeeze(0) - else: - k_smoothed = k - blk_start = p0 // self._BLKK - last_blk = (p1 - 1) // self._BLKK - num_blk = last_blk - blk_start + 1 - preset_scale = self._lookup_k_block_scale(layer_id, blk_start, num_blk) - k_int8 = self._quant_key(k_smoothed, preset_scale, p0, self._BLKK) - v_scale = self._lookup_v_scale(layer_id) - v_fp8 = quant_value_per_channel_fp8_static_scale_kernel(v, v_scale, fp8_max=448.0) - - if self._kv_offload: - self._check_layer_loaded(layer_id) - self._k_gpu_buf[p0:p1].copy_(k_int8) - self._v_gpu_buf[p0:p1].copy_(v_fp8) - self._k_cpu[layer_id, p0:p1].copy_(k_int8, non_blocking=True) - self._v_cpu[layer_id, p0:p1].copy_(v_fp8, non_blocking=True) - else: - self._k_buffer[layer_id, p0:p1] = k_int8 - self._v_buffer[layer_id, p0:p1] = v_fp8 - - def store_kv(self, k: torch.Tensor, v: torch.Tensor, start_idx: int, end_idx: int, layer_id: int) -> None: - length = int(end_idx) - int(start_idx) - if length <= 0: - return - chunks = self._logical_to_physical_chunks_for_store(layer_id, start_idx, end_idx) - off = 0 - for p0, p1 in chunks: - n = p1 - p0 - self._physical_store_kv(k[off : off + n].contiguous(), v[off : off + n].contiguous(), p0, p1, layer_id) - off += n - if self._kv_offload: - self._record_cpu_update(layer_id) - - def _copy_layer_to_gpu(self, layer_id: int) -> None: - self._k_gpu_buf.copy_(self._k_cpu[layer_id], non_blocking=True) - self._v_gpu_buf.copy_(self._v_cpu[layer_id], non_blocking=True) - - def _k_scale_for_physical_chunks(self, layer_id: int, chunks: list[tuple[int, int]]) -> torch.Tensor: - scales = [] - for p0, p1 in chunks: - a0 = (p0 // self._BLKK) * self._BLKK - blk_s = a0 // self._BLKK - blk_e = (p1 + self._BLKK - 1) // self._BLKK - scales.append(self._calib_k_block_scale[self.current_step, layer_id, blk_s:blk_e]) - if len(scales) == 1: - sc = scales[0] - else: - sc = torch.cat(scales, dim=0) - return sc.permute(1, 0, 2).reshape(1, self._num_heads, -1).contiguous() - - def k_cache(self, layer_id: int, attn_start: int, local_end: int): - chunks = self._logical_to_physical_chunks(layer_id, attn_start, local_end) - if self._kv_offload: - self._check_layer_loaded(layer_id) - parts = [self._k_gpu_buf[p0:p1] for p0, p1 in chunks] - k_int8 = (parts[0] if len(parts) == 1 else torch.cat(parts, dim=0)).unsqueeze(0).contiguous() - k_scale = self._k_scale_for_physical_chunks(layer_id, chunks) - return k_int8, k_scale - - parts = [self._sage_k_storage(layer_id)[p0:p1] for p0, p1 in chunks] - k_int8 = (parts[0] if len(parts) == 1 else torch.cat(parts, dim=0)).unsqueeze(0).contiguous() - k_scale = self._k_scale_for_physical_chunks(layer_id, chunks) - return k_int8, k_scale - - def _transpose_permute_v(self, v: torch.Tensor) -> torch.Tensor: - kv_len, H, D = v.shape - padded_len = (kv_len + 127) // 128 * 128 - if padded_len > kv_len: - v_t = v.new_zeros(D, H, padded_len) - v_t[:, :, :kv_len].copy_(v.permute(2, 1, 0)) - else: - v_t = v.permute(2, 1, 0).contiguous() - v_t = v_t.view(D, H, -1, 16)[:, :, :, self._PERM_16].contiguous() - return v_t.view(1, D, H, padded_len) - - def v_cache(self, layer_id: int, attn_start: int, local_end: int): - chunks = self._logical_to_physical_chunks(layer_id, attn_start, local_end) - if self._kv_offload: - self._check_layer_loaded(layer_id) - parts = [self._v_gpu_buf[p0:p1] for p0, p1 in chunks] - v = parts[0] if len(parts) == 1 else torch.cat(parts, dim=0) - else: - parts = [self._sage_v_storage(layer_id)[p0:p1] for p0, p1 in chunks] - v = parts[0] if len(parts) == 1 else torch.cat(parts, dim=0) - return self._transpose_permute_v(v), self._lookup_v_scale(layer_id).unsqueeze(0).contiguous() - - def reset(self) -> None: - if self._kv_offload: - self.sync_all() - self._zero_tensors(["_k_cpu", "_v_cpu", "_k_gpu_buf", "_v_gpu_buf"]) - self._reset_offload_state() - else: - self._zero_tensors(["_k_buffer", "_v_buffer"]) - self._reset_ends() - self._reset_ring() - - -# ============================================================================= -# TurboQuant -# ============================================================================= - - -class TurboQuantRollingKVCachePool(_QuantTokenRingMixin, RollingKVCachePool): - def __init__( - self, - num_layers: int, - cache_size: int, - num_heads: int, - head_dim: int, - dtype: torch.dtype, - device: torch.device, - key_bits: int = 3, - value_bits: int = 2, - seed: int = 42, - per_layer_compressors: bool = True, - kv_offload: bool = False, - *, - codebook_dir: str | None = None, - codebook_cache_dir: str | None = None, - export_missing_codebooks: bool = False, - value_group_size: int = 32, - ) -> None: - self._key_bits = int(key_bits) - self._value_bits = int(value_bits) - self._seed_base = int(seed) - self._per_layer_compressors = bool(per_layer_compressors) - self._n_layers = int(num_layers) - self._value_group_size = int(value_group_size) - if self._key_bits < 2: - raise ValueError("TurboQuantProd requires key_bits >= 2") - if head_dim % self._value_group_size != 0: - raise ValueError(f"head_dim {head_dim} must divide value_group_size {self._value_group_size}") - - device_t = torch.device(str(device)) - inf_dtype = torch.float32 - nk_bits = self._key_bits - 1 - cb_key = tq_fw_load_codebook_record(head_dim, nk_bits, codebook_dir, codebook_cache_dir, export_missing_codebooks) - self._inf_nk = tq_fw_packed_width(head_dim, nk_bits) - self._inf_nqjl = (head_dim + 7) // 8 - - def _make_k_mod(seed_k: int) -> torch.nn.Module: - return TurboQuantProdInference(head_dim, self._key_bits, device_t, seed_k, cb_key, dtype=inf_dtype) - - if self._per_layer_compressors: - self._k_inference_modules = [_make_k_mod(self._seed_base + lid * 7) for lid in range(self._n_layers)] - else: - km = _make_k_mod(self._seed_base) - self._k_inference_modules = [km for _ in range(self._n_layers)] - - self._inf_v_width = tq_value_group_packed_width(head_dim, self._value_bits) - self._inf_v_n_groups = head_dim // self._value_group_size - super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device, kv_offload=kv_offload) - - def _k_mod_inf(self, layer_id: int) -> torch.nn.Module: - return self._k_inference_modules[layer_id] - - def _init_kv_buffer(self) -> None: - if self._kv_offload: - self._init_kv_buffer_offload() - return - L, N, H = self._num_layers, self._cache_size, self._num_heads - ng = self._inf_v_n_groups - d = self._device - self._k_packed = torch.zeros(L, N, H, self._inf_nk, dtype=torch.uint8, device=d) - self._k_norms = torch.zeros(L, N, H, dtype=torch.float16, device=d) - self._k_qjl_packed = torch.zeros(L, N, H, self._inf_nqjl, dtype=torch.uint8, device=d) - self._k_res_norms = torch.zeros(L, N, H, dtype=torch.float16, device=d) - self._v_group_data = torch.zeros(L, N, H, self._inf_v_width, dtype=torch.uint8, device=d) - self._v_group_scales = torch.zeros(L, N, H, ng, dtype=torch.float16, device=d) - self._v_group_zeros = torch.zeros(L, N, H, ng, dtype=torch.float16, device=d) - self._global_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._init_ring_metadata(L) - - def _init_kv_buffer_offload(self) -> None: - L, N, H = self._num_layers, self._cache_size, self._num_heads - ng = self._inf_v_n_groups - self._k_packed_cpu = torch.zeros(L, N, H, self._inf_nk, dtype=torch.uint8, device="cpu").pin_memory() - self._k_norms_cpu = torch.zeros(L, N, H, dtype=torch.float16, device="cpu").pin_memory() - self._k_qjl_packed_cpu = torch.zeros(L, N, H, self._inf_nqjl, dtype=torch.uint8, device="cpu").pin_memory() - self._k_res_norms_cpu = torch.zeros(L, N, H, dtype=torch.float16, device="cpu").pin_memory() - self._v_group_data_cpu = torch.zeros(L, N, H, self._inf_v_width, dtype=torch.uint8, device="cpu").pin_memory() - self._v_group_scales_cpu = torch.zeros(L, N, H, ng, dtype=torch.float16, device="cpu").pin_memory() - self._v_group_zeros_cpu = torch.zeros(L, N, H, ng, dtype=torch.float16, device="cpu").pin_memory() - d = self._device - self._k_packed_gpu = torch.zeros(N, H, self._inf_nk, dtype=torch.uint8, device=d) - self._k_norms_gpu = torch.zeros(N, H, dtype=torch.float16, device=d) - self._k_qjl_packed_gpu = torch.zeros(N, H, self._inf_nqjl, dtype=torch.uint8, device=d) - self._k_res_norms_gpu = torch.zeros(N, H, dtype=torch.float16, device=d) - self._v_group_data_gpu = torch.zeros(N, H, self._inf_v_width, dtype=torch.uint8, device=d) - self._v_group_scales_gpu = torch.zeros(N, H, ng, dtype=torch.float16, device=d) - self._v_group_zeros_gpu = torch.zeros(N, H, ng, dtype=torch.float16, device=d) - self._global_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._init_ring_metadata(L) - self._init_offload_state((L,)) - - def _store_arrays(self, layer_id: int, p0: int, p1: int, tensors: dict[str, torch.Tensor]) -> None: - if self._kv_offload: - self._check_layer_loaded(layer_id) - self._k_packed_gpu[p0:p1].copy_(tensors["mse"]) - self._k_norms_gpu[p0:p1].copy_(tensors["norms"]) - self._k_qjl_packed_gpu[p0:p1].copy_(tensors["qjl"]) - self._k_res_norms_gpu[p0:p1].copy_(tensors["res"]) - self._v_group_data_gpu[p0:p1].copy_(tensors["v_data"]) - self._v_group_scales_gpu[p0:p1].copy_(tensors["v_scales"]) - self._v_group_zeros_gpu[p0:p1].copy_(tensors["v_zeros"]) - self._k_packed_cpu[layer_id, p0:p1].copy_(tensors["mse"], non_blocking=True) - self._k_norms_cpu[layer_id, p0:p1].copy_(tensors["norms"], non_blocking=True) - self._k_qjl_packed_cpu[layer_id, p0:p1].copy_(tensors["qjl"], non_blocking=True) - self._k_res_norms_cpu[layer_id, p0:p1].copy_(tensors["res"], non_blocking=True) - self._v_group_data_cpu[layer_id, p0:p1].copy_(tensors["v_data"], non_blocking=True) - self._v_group_scales_cpu[layer_id, p0:p1].copy_(tensors["v_scales"], non_blocking=True) - self._v_group_zeros_cpu[layer_id, p0:p1].copy_(tensors["v_zeros"], non_blocking=True) - else: - self._k_packed[layer_id, p0:p1].copy_(tensors["mse"]) - self._k_norms[layer_id, p0:p1].copy_(tensors["norms"]) - self._k_qjl_packed[layer_id, p0:p1].copy_(tensors["qjl"]) - self._k_res_norms[layer_id, p0:p1].copy_(tensors["res"]) - self._v_group_data[layer_id, p0:p1].copy_(tensors["v_data"]) - self._v_group_scales[layer_id, p0:p1].copy_(tensors["v_scales"]) - self._v_group_zeros[layer_id, p0:p1].copy_(tensors["v_zeros"]) - - def _physical_store_kv(self, k: torch.Tensor, v: torch.Tensor, p0: int, p1: int, layer_id: int) -> None: - chunk_len = p1 - p0 - k_bhsd = k.unsqueeze(0).transpose(1, 2).contiguous() - v_bhsd = v.unsqueeze(0).transpose(1, 2).contiguous() - with torch.no_grad(): - ck = self._k_mod_inf(layer_id).compress_bhsd(k_bhsd) - cv = tq_group_quantize_values(v_bhsd, self._value_bits, self._value_group_size) - tensors = { - "mse": ck["mse_idx_bytes"][0].transpose(0, 1).contiguous(), - "norms": ck["vec_norms"][0].transpose(0, 1).contiguous(), - "qjl": ck["qjl_bytes"][0].transpose(0, 1).contiguous(), - "res": ck["residual_norms"][0].transpose(0, 1).contiguous(), - "v_data": cv["data"][0].transpose(0, 1).contiguous(), - "v_scales": cv["scales"][0].transpose(0, 1).contiguous(), - "v_zeros": cv["zeros"][0].transpose(0, 1).contiguous(), - } - self._store_arrays(layer_id, p0, p1, tensors) - - def store_kv(self, k: torch.Tensor, v: torch.Tensor, start_idx: int, end_idx: int, layer_id: int) -> None: - chunks = self._logical_to_physical_chunks_for_store(layer_id, start_idx, end_idx) - off = 0 - for p0, p1 in chunks: - n = p1 - p0 - self._physical_store_kv(k[off : off + n], v[off : off + n], p0, p1, layer_id) - off += n - if self._kv_offload: - self._record_cpu_update(layer_id) - - def _copy_layer_to_gpu(self, layer_id: int) -> None: - self._k_packed_gpu.copy_(self._k_packed_cpu[layer_id], non_blocking=True) - self._k_norms_gpu.copy_(self._k_norms_cpu[layer_id], non_blocking=True) - self._k_qjl_packed_gpu.copy_(self._k_qjl_packed_cpu[layer_id], non_blocking=True) - self._k_res_norms_gpu.copy_(self._k_res_norms_cpu[layer_id], non_blocking=True) - self._v_group_data_gpu.copy_(self._v_group_data_cpu[layer_id], non_blocking=True) - self._v_group_scales_gpu.copy_(self._v_group_scales_cpu[layer_id], non_blocking=True) - self._v_group_zeros_gpu.copy_(self._v_group_zeros_cpu[layer_id], non_blocking=True) - - @staticmethod - def _sh_extra_to_bhs(extra_sh: torch.Tensor) -> torch.Tensor: - return extra_sh.unsqueeze(0).permute(0, 2, 1, 3).contiguous() - - def _decompress_k_from_arrays(self, layer_id: int, arrays: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: - packed, norms, qjl, res = arrays - kv_len = packed.size(0) - idx_bytes = packed.unsqueeze(0).permute(0, 2, 1, 3).contiguous() - norms_bhs = norms.unsqueeze(0).transpose(1, 2).contiguous() - qjl_bhs = self._sh_extra_to_bhs(qjl) - res_bhs = res.unsqueeze(0).transpose(1, 2).contiguous() - comp = { - "mse_idx_bytes": idx_bytes, - "qjl_bytes": qjl_bhs, - "residual_norms": res_bhs, - "vec_norms": norms_bhs, - "shape": (1, self._num_heads, kv_len, self._head_dim), - "mse_bits": self._key_bits - 1, - } - with torch.no_grad(): - out = self._k_mod_inf(layer_id).decompress_bhsd(comp) - return out[0].transpose(0, 1).to(dtype=self._dtype) - - def k_cache(self, layer_id: int, attn_start: int, local_end: int) -> torch.Tensor: - chunks = self._logical_to_physical_chunks(layer_id, attn_start, local_end) - if self._kv_offload: - self._check_layer_loaded(layer_id) - parts = [] - for p0, p1 in chunks: - parts.append( - ( - self._k_packed_gpu[p0:p1], - self._k_norms_gpu[p0:p1], - self._k_qjl_packed_gpu[p0:p1], - self._k_res_norms_gpu[p0:p1], - ) - ) - arrays = tuple(torch.cat([part[i] for part in parts], dim=0) for i in range(4)) if len(parts) > 1 else parts[0] - return self._decompress_k_from_arrays(layer_id, arrays) - - parts = [] - for p0, p1 in chunks: - arrays = ( - self._k_packed[layer_id, p0:p1], - self._k_norms[layer_id, p0:p1], - self._k_qjl_packed[layer_id, p0:p1], - self._k_res_norms[layer_id, p0:p1], - ) - parts.append(self._decompress_k_from_arrays(layer_id, arrays)) - return parts[0] if len(parts) == 1 else torch.cat(parts, dim=0).contiguous() - - def _decompress_v_from_arrays(self, arrays: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: - data, scales, zeros = arrays - kv_len = data.size(0) - comp = { - "data": data.unsqueeze(0).permute(0, 2, 1, 3).contiguous(), - "scales": scales.unsqueeze(0).transpose(1, 2).contiguous(), - "zeros": zeros.unsqueeze(0).transpose(1, 2).contiguous(), - "bits": self._value_bits, - "group_size": self._value_group_size, - "shape": (1, self._num_heads, kv_len, self._head_dim), - } - with torch.no_grad(): - out = tq_group_dequantize_values(comp) - return out[0].transpose(0, 1).to(dtype=self._dtype) - - def v_cache(self, layer_id: int, attn_start: int, local_end: int) -> torch.Tensor: - chunks = self._logical_to_physical_chunks(layer_id, attn_start, local_end) - if self._kv_offload: - self._check_layer_loaded(layer_id) - parts = [] - for p0, p1 in chunks: - parts.append( - ( - self._v_group_data_gpu[p0:p1], - self._v_group_scales_gpu[p0:p1], - self._v_group_zeros_gpu[p0:p1], - ) - ) - arrays = tuple(torch.cat([part[i] for part in parts], dim=0) for i in range(3)) if len(parts) > 1 else parts[0] - return self._decompress_v_from_arrays(arrays) - - parts = [] - for p0, p1 in chunks: - arrays = ( - self._v_group_data[layer_id, p0:p1], - self._v_group_scales[layer_id, p0:p1], - self._v_group_zeros[layer_id, p0:p1], - ) - parts.append(self._decompress_v_from_arrays(arrays)) - return parts[0] if len(parts) == 1 else torch.cat(parts, dim=0).contiguous() - - def reset(self) -> None: - if self._kv_offload: - self.sync_all() - self._zero_tensors( - [ - "_k_packed_cpu", - "_k_norms_cpu", - "_k_qjl_packed_cpu", - "_k_res_norms_cpu", - "_v_group_data_cpu", - "_v_group_scales_cpu", - "_v_group_zeros_cpu", - "_k_packed_gpu", - "_k_norms_gpu", - "_k_qjl_packed_gpu", - "_k_res_norms_gpu", - "_v_group_data_gpu", - "_v_group_scales_gpu", - "_v_group_zeros_gpu", - ] - ) - self._reset_offload_state() - else: - self._zero_tensors( - [ - "_k_packed", - "_k_norms", - "_k_qjl_packed", - "_k_res_norms", - "_v_group_data", - "_v_group_scales", - "_v_group_zeros", - ] - ) - self._reset_ends() - self._reset_ring() - - -class StepTurboQuantRollingKVCachePool(TurboQuantRollingKVCachePool): - def __init__(self, num_steps: int, *args, **kwargs) -> None: - self.num_steps = int(num_steps) - self._current_step = 0 - super().__init__(*args, **kwargs) - - @property - def current_step(self) -> int: - return self._current_step - - @current_step.setter - def current_step(self, value: int) -> None: - value = int(value) - if value == self._current_step: - return - if getattr(self, "_kv_offload", False) and hasattr(self, "_prefetch_stream"): - self.sync_all() - self._reset_offload_state() - self._current_step = value - - def _step(self) -> int: - return int(self._current_step) - - def _ring_index(self, layer_id: int): - return (self._step(), int(layer_id)) - - def _init_kv_buffer(self) -> None: - if self._kv_offload: - # Use parent non-step allocation, then add step dimension manually. - T, L, N, H = self.num_steps, self._num_layers, self._cache_size, self._num_heads - ng = self._inf_v_n_groups - self._k_packed_cpu = torch.zeros(T, L, N, H, self._inf_nk, dtype=torch.uint8, device="cpu").pin_memory() - self._k_norms_cpu = torch.zeros(T, L, N, H, dtype=torch.float16, device="cpu").pin_memory() - self._k_qjl_packed_cpu = torch.zeros(T, L, N, H, self._inf_nqjl, dtype=torch.uint8, device="cpu").pin_memory() - self._k_res_norms_cpu = torch.zeros(T, L, N, H, dtype=torch.float16, device="cpu").pin_memory() - self._v_group_data_cpu = torch.zeros(T, L, N, H, self._inf_v_width, dtype=torch.uint8, device="cpu").pin_memory() - self._v_group_scales_cpu = torch.zeros(T, L, N, H, ng, dtype=torch.float16, device="cpu").pin_memory() - self._v_group_zeros_cpu = torch.zeros(T, L, N, H, ng, dtype=torch.float16, device="cpu").pin_memory() - d = self._device - self._k_packed_gpu = torch.zeros(N, H, self._inf_nk, dtype=torch.uint8, device=d) - self._k_norms_gpu = torch.zeros(N, H, dtype=torch.float16, device=d) - self._k_qjl_packed_gpu = torch.zeros(N, H, self._inf_nqjl, dtype=torch.uint8, device=d) - self._k_res_norms_gpu = torch.zeros(N, H, dtype=torch.float16, device=d) - self._v_group_data_gpu = torch.zeros(N, H, self._inf_v_width, dtype=torch.uint8, device=d) - self._v_group_scales_gpu = torch.zeros(N, H, ng, dtype=torch.float16, device=d) - self._v_group_zeros_gpu = torch.zeros(N, H, ng, dtype=torch.float16, device=d) - self._global_end = torch.zeros(T, L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(T, L, dtype=torch.long, device="cpu") - self._init_ring_metadata(T, L) - self._init_offload_state((T, L)) - return - - T, L, N, H = self.num_steps, self._num_layers, self._cache_size, self._num_heads - ng = self._inf_v_n_groups - d = self._device - self._k_packed = torch.zeros(T, L, N, H, self._inf_nk, dtype=torch.uint8, device=d) - self._k_norms = torch.zeros(T, L, N, H, dtype=torch.float16, device=d) - self._k_qjl_packed = torch.zeros(T, L, N, H, self._inf_nqjl, dtype=torch.uint8, device=d) - self._k_res_norms = torch.zeros(T, L, N, H, dtype=torch.float16, device=d) - self._v_group_data = torch.zeros(T, L, N, H, self._inf_v_width, dtype=torch.uint8, device=d) - self._v_group_scales = torch.zeros(T, L, N, H, ng, dtype=torch.float16, device=d) - self._v_group_zeros = torch.zeros(T, L, N, H, ng, dtype=torch.float16, device=d) - self._global_end = torch.zeros(T, L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(T, L, dtype=torch.long, device="cpu") - self._init_ring_metadata(T, L) - - # Step-aware storage helpers. - def _store_arrays(self, layer_id: int, p0: int, p1: int, tensors: dict[str, torch.Tensor]) -> None: - s = self._step() - if self._kv_offload: - self._check_layer_loaded(layer_id) - self._k_packed_gpu[p0:p1].copy_(tensors["mse"]) - self._k_norms_gpu[p0:p1].copy_(tensors["norms"]) - self._k_qjl_packed_gpu[p0:p1].copy_(tensors["qjl"]) - self._k_res_norms_gpu[p0:p1].copy_(tensors["res"]) - self._v_group_data_gpu[p0:p1].copy_(tensors["v_data"]) - self._v_group_scales_gpu[p0:p1].copy_(tensors["v_scales"]) - self._v_group_zeros_gpu[p0:p1].copy_(tensors["v_zeros"]) - self._k_packed_cpu[s, layer_id, p0:p1].copy_(tensors["mse"], non_blocking=True) - self._k_norms_cpu[s, layer_id, p0:p1].copy_(tensors["norms"], non_blocking=True) - self._k_qjl_packed_cpu[s, layer_id, p0:p1].copy_(tensors["qjl"], non_blocking=True) - self._k_res_norms_cpu[s, layer_id, p0:p1].copy_(tensors["res"], non_blocking=True) - self._v_group_data_cpu[s, layer_id, p0:p1].copy_(tensors["v_data"], non_blocking=True) - self._v_group_scales_cpu[s, layer_id, p0:p1].copy_(tensors["v_scales"], non_blocking=True) - self._v_group_zeros_cpu[s, layer_id, p0:p1].copy_(tensors["v_zeros"], non_blocking=True) - else: - self._k_packed[s, layer_id, p0:p1].copy_(tensors["mse"]) - self._k_norms[s, layer_id, p0:p1].copy_(tensors["norms"]) - self._k_qjl_packed[s, layer_id, p0:p1].copy_(tensors["qjl"]) - self._k_res_norms[s, layer_id, p0:p1].copy_(tensors["res"]) - self._v_group_data[s, layer_id, p0:p1].copy_(tensors["v_data"]) - self._v_group_scales[s, layer_id, p0:p1].copy_(tensors["v_scales"]) - self._v_group_zeros[s, layer_id, p0:p1].copy_(tensors["v_zeros"]) - - def _offload_index(self, layer_id: int) -> tuple[int, ...]: - return (self._step(), int(layer_id)) - - def _copy_layer_to_gpu(self, layer_id: int) -> None: - s = self._step() - self._k_packed_gpu.copy_(self._k_packed_cpu[s, layer_id], non_blocking=True) - self._k_norms_gpu.copy_(self._k_norms_cpu[s, layer_id], non_blocking=True) - self._k_qjl_packed_gpu.copy_(self._k_qjl_packed_cpu[s, layer_id], non_blocking=True) - self._k_res_norms_gpu.copy_(self._k_res_norms_cpu[s, layer_id], non_blocking=True) - self._v_group_data_gpu.copy_(self._v_group_data_cpu[s, layer_id], non_blocking=True) - self._v_group_scales_gpu.copy_(self._v_group_scales_cpu[s, layer_id], non_blocking=True) - self._v_group_zeros_gpu.copy_(self._v_group_zeros_cpu[s, layer_id], non_blocking=True) - - def get_global_end(self, layer_id: int) -> int: - return int(self._global_end[self._step(), layer_id].item()) - - def get_local_end(self, layer_id: int) -> int: - return int(self._local_end[self._step(), layer_id].item()) - - def set_ends(self, layer_id: int, global_end: int, local_end: int) -> None: - s = self._step() - self._global_end[s, layer_id] = int(global_end) - self._local_end[s, layer_id] = int(local_end) - - -# ============================================================================= -# LongLive FP4 -# ============================================================================= - - -class LongLiveQuantRollingKVCachePool(_QuantTokenRingMixin, RollingKVCachePool): - def __init__( - self, - num_layers: int, - cache_size: int, - num_heads: int, - head_dim: int, - dtype: torch.dtype, - device: torch.device, - *, - block_token_size: int | None = None, - scale_rule: str = "mse", - backend: str = "pytorch", - kv_offload: bool = False, - ) -> None: - self._block_token_size = int(block_token_size or cache_size) - if self._block_token_size <= 0: - raise ValueError(f"block_token_size must be positive, got {block_token_size}") - self._quant_config = build_fp4_quant_config(scale_rule=scale_rule, backend=backend) - self._dequant_backend = normalize_dequant_backend(backend) - n_alloc = cdiv(int(cache_size), self._block_token_size) * self._block_token_size - self._max_blocks = n_alloc // self._block_token_size - super().__init__(num_layers, n_alloc, num_heads, head_dim, dtype, device, kv_offload=kv_offload) - - def _make_zero_block_qt(self) -> QuantizedTensor: - h, d, blk = self._num_heads, self._head_dim, self._block_token_size - zero = torch.zeros(blk * h, d, dtype=self._dtype, device=self._device) - return quantize_to_fp4(zero, self._quant_config) - - @staticmethod - def _clone_qt_to(qt: QuantizedTensor, device: torch.device | str, *, pin_memory: bool = False) -> QuantizedTensor: - def clone_tensor(t: torch.Tensor | None) -> torch.Tensor | None: - if t is None: - return None - out = t.detach().to(device=device).clone() - return out.pin_memory() if pin_memory else out - - return QuantizedTensor( - values=clone_tensor(qt.values), - scale_factors=clone_tensor(qt.scale_factors), - amax=clone_tensor(qt.amax), - dtype=qt.dtype, - original_shape=qt.original_shape, - scale_rule=qt.scale_rule, - padded_shape=qt.padded_shape, - ) - - @staticmethod - def _copy_qt(dst: QuantizedTensor, src: QuantizedTensor, *, non_blocking: bool = True) -> None: - dst.values.copy_(src.values, non_blocking=non_blocking) - dst.scale_factors.copy_(src.scale_factors, non_blocking=non_blocking) - if dst.amax is not None and src.amax is not None: - dst.amax.copy_(src.amax, non_blocking=non_blocking) - - def _make_qt_blocks( - self, - zero_qt: QuantizedTensor, - shape: tuple[int, ...], - device: torch.device | str, - *, - pin_memory: bool = False, - ): - if len(shape) == 1: - return [self._clone_qt_to(zero_qt, device, pin_memory=pin_memory) for _ in range(shape[0])] - return [self._make_qt_blocks(zero_qt, shape[1:], device, pin_memory=pin_memory) for _ in range(shape[0])] - - def _reset_qt_blocks(self, blocks, zero_qt: QuantizedTensor) -> None: - if not blocks: - return - if hasattr(blocks[0], "values"): - for block in blocks: - self._copy_qt(block, zero_qt, non_blocking=False) - return - for sub_blocks in blocks: - self._reset_qt_blocks(sub_blocks, zero_qt) - - def _init_kv_buffer(self) -> None: - if self._kv_offload: - self._init_kv_buffer_offload() - return - zero_qt = self._make_zero_block_qt() - self._k_blocks = [[clone_quantized_tensor(zero_qt) for _ in range(self._max_blocks)] for _ in range(self._num_layers)] - self._v_blocks = [[clone_quantized_tensor(zero_qt) for _ in range(self._max_blocks)] for _ in range(self._num_layers)] - self._global_end = torch.zeros(self._num_layers, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(self._num_layers, dtype=torch.long, device="cpu") - self._init_ring_metadata(self._num_layers) - - def _init_kv_buffer_offload(self) -> None: - zero_qt = self._make_zero_block_qt() - self._k_blocks_cpu = self._make_qt_blocks(zero_qt, (self._num_layers, self._max_blocks), "cpu", pin_memory=True) - self._v_blocks_cpu = self._make_qt_blocks(zero_qt, (self._num_layers, self._max_blocks), "cpu", pin_memory=True) - self._k_blocks_gpu = self._make_qt_blocks(zero_qt, (self._max_blocks,), self._device) - self._v_blocks_gpu = self._make_qt_blocks(zero_qt, (self._max_blocks,), self._device) - self._global_end = torch.zeros(self._num_layers, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(self._num_layers, dtype=torch.long, device="cpu") - self._init_ring_metadata(self._num_layers) - self._init_offload_state((self._num_layers,)) - - def _layer_k_blocks(self, layer_id: int): - if self._kv_offload: - return self._k_blocks_gpu - return self._k_blocks[layer_id] - - def _layer_v_blocks(self, layer_id: int): - if self._kv_offload: - return self._v_blocks_gpu - return self._v_blocks[layer_id] - - def _cpu_layer_k_blocks(self, layer_id: int): - return self._k_blocks_cpu[layer_id] - - def _cpu_layer_v_blocks(self, layer_id: int): - return self._v_blocks_cpu[layer_id] - - def _copy_layer_to_gpu(self, layer_id: int) -> None: - for dst, src in zip(self._k_blocks_gpu, self._cpu_layer_k_blocks(layer_id), strict=True): - self._copy_qt(dst, src) - for dst, src in zip(self._v_blocks_gpu, self._cpu_layer_v_blocks(layer_id), strict=True): - self._copy_qt(dst, src) - - def _quantize_block(self, k_nhd: torch.Tensor, v_nhd: torch.Tensor): - blk, h, d = self._block_token_size, self._num_heads, self._head_dim - if k_nhd.shape[0] != blk: - raise ValueError(f"K block token count {k_nhd.shape[0]} != block_token_size {blk}") - k2d = k_smooth(k_nhd).reshape(blk * h, d).contiguous() - v2d = v_nhd.reshape(blk * h, d).contiguous() - return quantize_to_fp4(k2d, self._quant_config), quantize_to_fp4(v2d, self._quant_config) - - def _dequant_token_range(self, blocks: list[QuantizedTensor], start: int, end: int) -> torch.Tensor: - return dequantize_token_range( - blocks, - start, - end, - cache_size=self._cache_size, - num_heads=self._num_heads, - head_dim=self._head_dim, - block_token_size=self._block_token_size, - dtype=self._dtype, - device=self._device, - backend=self._dequant_backend, - ) - - def _pad_nhd_to_blocks(self, k_nhd: torch.Tensor, v_nhd: torch.Tensor): - blk = self._block_token_size - t_len = k_nhd.size(0) - t_pad = cdiv(t_len, blk) * blk - if t_len == t_pad: - return k_nhd, v_nhd - pad = t_pad - t_len - return ( - torch.cat((k_nhd, k_nhd.new_zeros(pad, *k_nhd.shape[1:])), dim=0), - torch.cat((v_nhd, v_nhd.new_zeros(pad, *v_nhd.shape[1:])), dim=0), - ) - - def _write_blocks_from_nhd(self, k_nhd: torch.Tensor, v_nhd: torch.Tensor, layer_id: int, physical_start: int) -> None: - blk = self._block_token_size - t_len = k_nhd.size(0) - if t_len % blk != 0: - raise RuntimeError(f"longlive_fp4 store length {t_len} is not a multiple of block_token_size {blk}") - b0 = physical_start // blk - if self._kv_offload: - self._check_layer_loaded(layer_id) - for i in range(t_len // blk): - bi = b0 + i - ts, te = i * blk, (i + 1) * blk - k_qt, v_qt = self._quantize_block(k_nhd[ts:te], v_nhd[ts:te]) - self._copy_qt(self._layer_k_blocks(layer_id)[bi], k_qt) - self._copy_qt(self._layer_v_blocks(layer_id)[bi], v_qt) - if self._kv_offload: - self._copy_qt(self._cpu_layer_k_blocks(layer_id)[bi], k_qt) - self._copy_qt(self._cpu_layer_v_blocks(layer_id)[bi], v_qt) - - def _physical_store_kv(self, k: torch.Tensor, v: torch.Tensor, p0: int, p1: int, layer_id: int) -> None: - blk = self._block_token_size - s0 = (p0 // blk) * blk - e1 = min(cdiv(p1, blk) * blk, self._cache_size) - parts_k, parts_v = [], [] - if s0 < p0: - parts_k.append(self._dequant_token_range(self._layer_k_blocks(layer_id), s0, p0)) - parts_v.append(self._dequant_token_range(self._layer_v_blocks(layer_id), s0, p0)) - parts_k.append(k) - parts_v.append(v) - if p1 < e1: - parts_k.append(self._dequant_token_range(self._layer_k_blocks(layer_id), p1, e1)) - parts_v.append(self._dequant_token_range(self._layer_v_blocks(layer_id), p1, e1)) - k_cat, v_cat = self._pad_nhd_to_blocks(torch.cat(parts_k, dim=0), torch.cat(parts_v, dim=0)) - self._write_blocks_from_nhd(k_cat, v_cat, layer_id, s0) - - def store_kv(self, k: torch.Tensor, v: torch.Tensor, start_idx: int, end_idx: int, layer_id: int) -> None: - chunks = self._logical_to_physical_chunks_for_store(layer_id, start_idx, end_idx) - off = 0 - for p0, p1 in chunks: - n = p1 - p0 - self._physical_store_kv(k[off : off + n], v[off : off + n], p0, p1, layer_id) - off += n - if self._kv_offload: - self._record_cpu_update(layer_id) - - def _read_chunks(self, layer_id: int, start: int, end: int, kind: str) -> torch.Tensor: - self._check_layer_loaded(layer_id) - chunks = self._logical_to_physical_chunks(layer_id, start, end) - blocks = self._layer_k_blocks(layer_id) if kind == "k" else self._layer_v_blocks(layer_id) - outs = [self._dequant_token_range(blocks, p0, p1) for p0, p1 in chunks] - if not outs: - return torch.empty(0, self._num_heads, self._head_dim, dtype=self._dtype, device=self._device) - return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0).contiguous() - - def k_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None): - if attn_start is None or local_end is None: - raise ValueError("longlive_fp4 k_cache requires attn_start and local_end") - return self._read_chunks(layer_id, attn_start, local_end, "k") - - def v_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None): - if attn_start is None or local_end is None: - raise ValueError("longlive_fp4 v_cache requires attn_start and local_end") - return self._read_chunks(layer_id, attn_start, local_end, "v") - - def reset(self) -> None: - zero_qt = self._make_zero_block_qt() - if self._kv_offload: - self.sync_all() - self._reset_qt_blocks(self._k_blocks_cpu, zero_qt) - self._reset_qt_blocks(self._v_blocks_cpu, zero_qt) - self._reset_qt_blocks(self._k_blocks_gpu, zero_qt) - self._reset_qt_blocks(self._v_blocks_gpu, zero_qt) - self._reset_offload_state() - else: - self._reset_qt_blocks(self._k_blocks, zero_qt) - self._reset_qt_blocks(self._v_blocks, zero_qt) - self._reset_ends() - self._reset_ring() - - -class StepLongLiveQuantRollingKVCachePool(LongLiveQuantRollingKVCachePool): - def __init__(self, num_steps: int, *args, **kwargs) -> None: - self.num_steps = int(num_steps) - self._current_step = 0 - super().__init__(*args, **kwargs) - - @property - def current_step(self) -> int: - return self._current_step - - @current_step.setter - def current_step(self, value: int) -> None: - value = int(value) - if value == self._current_step: - return - if self._kv_offload and hasattr(self, "_prefetch_stream"): - self.sync_all() - self._current_step = value - self._reset_offload_state() - return - self._current_step = value - - def _step(self) -> int: - return int(self._current_step) - - def _ring_index(self, layer_id: int): - return (self._step(), int(layer_id)) - - def _init_kv_buffer(self) -> None: - if self._kv_offload: - self._init_kv_buffer_offload() - return - zero_qt = self._make_zero_block_qt() - self._k_blocks = [[[clone_quantized_tensor(zero_qt) for _ in range(self._max_blocks)] for _ in range(self._num_layers)] for _ in range(self.num_steps)] - self._v_blocks = [[[clone_quantized_tensor(zero_qt) for _ in range(self._max_blocks)] for _ in range(self._num_layers)] for _ in range(self.num_steps)] - self._global_end = torch.zeros(self.num_steps, self._num_layers, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(self.num_steps, self._num_layers, dtype=torch.long, device="cpu") - self._init_ring_metadata(self.num_steps, self._num_layers) - - def _init_kv_buffer_offload(self) -> None: - zero_qt = self._make_zero_block_qt() - self._k_blocks_cpu = self._make_qt_blocks( - zero_qt, - (self.num_steps, self._num_layers, self._max_blocks), - "cpu", - pin_memory=True, - ) - self._v_blocks_cpu = self._make_qt_blocks( - zero_qt, - (self.num_steps, self._num_layers, self._max_blocks), - "cpu", - pin_memory=True, - ) - self._k_blocks_gpu = self._make_qt_blocks(zero_qt, (self._max_blocks,), self._device) - self._v_blocks_gpu = self._make_qt_blocks(zero_qt, (self._max_blocks,), self._device) - self._global_end = torch.zeros(self.num_steps, self._num_layers, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(self.num_steps, self._num_layers, dtype=torch.long, device="cpu") - self._init_ring_metadata(self.num_steps, self._num_layers) - self._init_offload_state((self.num_steps, self._num_layers)) - - def _layer_k_blocks(self, layer_id: int): - if self._kv_offload: - return self._k_blocks_gpu - return self._k_blocks[self._step()][layer_id] - - def _layer_v_blocks(self, layer_id: int): - if self._kv_offload: - return self._v_blocks_gpu - return self._v_blocks[self._step()][layer_id] - - def _cpu_layer_k_blocks(self, layer_id: int): - return self._k_blocks_cpu[self._step()][layer_id] - - def _cpu_layer_v_blocks(self, layer_id: int): - return self._v_blocks_cpu[self._step()][layer_id] - - def _offload_index(self, layer_id: int) -> tuple[int, ...]: - return (self._step(), int(layer_id)) - - def get_global_end(self, layer_id: int) -> int: - return int(self._global_end[self._step(), layer_id].item()) - - def get_local_end(self, layer_id: int) -> int: - return int(self._local_end[self._step(), layer_id].item()) - - def set_ends(self, layer_id: int, global_end: int, local_end: int) -> None: - self._global_end[self._step(), layer_id] = int(global_end) - self._local_end[self._step(), layer_id] = int(local_end) - - -# ============================================================================= -# KIVI classes are intentionally kept in their dedicated ring implementation. -# For KIVI, use the previously generated kivi_quant_cache_ring_no_padding.py and -# add offload there, because KIVI's packed [H,D,T/pack] layout needs its own -# physical chunk materialization. # ============================================================================= diff --git a/lightx2v/common/kvcache/rolling.py b/lightx2v/common/kvcache/rolling.py index 4251af3a0..5630d8272 100755 --- a/lightx2v/common/kvcache/rolling.py +++ b/lightx2v/common/kvcache/rolling.py @@ -1,5 +1,4 @@ import torch -from loguru import logger from .base import BaseKVCachePool from .utils import _kvcache_dma_stream_priority @@ -18,12 +17,6 @@ class RollingKVCachePool(BaseKVCachePool): recent ring head and shrinks recent_len. ``k_cache`` / ``v_cache`` return a contiguous logical tensor, concatenating at most two recent ring fragments. - Offload mode follows HuggingFace ``OffloadedStaticCache``: pinned CPU - tensors are the authoritative cache, while one GPU staging buffer holds the - layer currently being computed. ``store_kv`` updates both the GPU buffer and - the CPU cache; ``end_layer`` asynchronously prefetches the next layer into - the same staging buffer. - """ def __init__( @@ -137,9 +130,13 @@ def _logical_chunks(self, layer_id: int, start: int, end: int) -> list[tuple[int # Buffer accessors. Step/spatial variants override these. # --------------------------------------------------------------------- def _k_layer(self, layer_id: int) -> torch.Tensor: + if self._kv_offload: + return self._k_gpu_buf return self._k_buffer[layer_id] def _v_layer(self, layer_id: int) -> torch.Tensor: + if self._kv_offload: + return self._v_gpu_buf return self._v_buffer[layer_id] def _k_cpu_layer(self, layer_id: int) -> torch.Tensor: @@ -161,23 +158,28 @@ def _init_kv_buffer(self): self._init_ring_metadata() def _init_kv_buffer_offload(self) -> None: - """OffloadedStaticCache-style: CPU authority + one GPU staging layer.""" + from loguru import logger + L, N, H, D = self._num_layers, self._cache_size, self._num_heads, self._head_dim + d = self._device + + # CPU pinned buffers hold the authoritative per-layer physical ring; a + # single GPU staging buffer holds the layer currently being computed. self._k_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() self._v_cpu = torch.empty(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() - self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) - self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) + self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=d) + self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=d) + self._global_end = torch.zeros(L, dtype=torch.long, device="cpu") self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") self._init_ring_metadata() - self._init_offload_state((L,)) gpu_mb = (self._k_gpu_buf.nbytes + self._v_gpu_buf.nbytes) / (1024 * 1024) cpu_mb = (self._k_cpu.nbytes + self._v_cpu.nbytes) / (1024 * 1024) logger.info( - "[RollingKVCachePool+ring+offload] OffloadedStaticCache-style: CPU authority, one GPU staging layer N={} tokens: {:.1f} MB, CPU pinned: {:.1f} MB", - N, + "[{}+offload] GPU staging layer: {:.1f} MB, CPU pinned: {:.1f} MB", + self.__class__.__name__, gpu_mb, cpu_mb, ) @@ -206,12 +208,16 @@ def _cpu_update_event(self, layer_id: int) -> torch.cuda.Event: return event def _offload_events(self) -> list[torch.cuda.Event]: - return [self._load_done, *self._flatten_events(self._cpu_update_done)] + return [self._load_done, self._staging_free, *self._flatten_events(self._cpu_update_done)] def _init_offload_state(self, event_shape: tuple[int, ...]) -> None: pr = _kvcache_dma_stream_priority() self._prefetch_stream = torch.cuda.Stream(device=self._device, priority=pr) self._load_done = torch.cuda.Event() + # Single GPU staging buffer is shared across layers; this event marks + # that the compute stream has finished reading the currently-loaded + # layer, so the prefetch stream may safely overwrite the buffer. + self._staging_free = torch.cuda.Event() self._cpu_update_done = self._make_event_tree(event_shape) self._loaded_layer = -1 cur = torch.cuda.current_stream() @@ -235,7 +241,11 @@ def _prefetch_layer(self, layer_id: int) -> None: if layer_id >= self._num_layers: return with torch.cuda.stream(self._prefetch_stream): + # CPU authoritative buffer for this layer must be up to date, and + # the compute stream must be done reading the staging buffer's + # previous contents before we overwrite the single shared buffer. self._prefetch_stream.wait_event(self._cpu_update_event(layer_id)) + self._prefetch_stream.wait_event(self._staging_free) self._copy_layer_to_gpu(layer_id) self._load_done.record(self._prefetch_stream) self._loaded_layer = int(layer_id) @@ -263,6 +273,9 @@ def end_layer(self, layer_id: int, next_prefetch: int | None = None) -> None: """CPU cache is updated directly by ``store_kv``; prefetch the next layer.""" if not self._kv_offload: return + # All compute-stream work that reads the staging buffer for this layer + # (store + attention) is now enqueued; let the prefetch overwrite it. + self._staging_free.record(torch.cuda.current_stream()) next_layer = int(layer_id) + 1 if next_prefetch is None else int(next_prefetch) self._prefetch_layer(next_layer) @@ -296,34 +309,25 @@ def store_kv( ) -> None: if end_idx <= start_idx: return - if not self._kv_offload: - kb, vb = self._k_layer(layer_id), self._v_layer(layer_id) - for logical_s, phys_s, n in self._logical_chunks(layer_id, start_idx, end_idx): - ks = logical_s - start_idx - ke = ks + n - kb[phys_s : phys_s + n].copy_(k[ks:ke]) - vb[phys_s : phys_s + n].copy_(v[ks:ke]) - self._update_ring_len_after_store(layer_id, end_idx) - return - self._check_layer_loaded(layer_id) - kb, vb = self._k_gpu_buf, self._v_gpu_buf - k_cpu, v_cpu = self._k_cpu_layer(layer_id), self._v_cpu_layer(layer_id) + kb, vb = self._k_layer(layer_id), self._v_layer(layer_id) + kcpu = self._k_cpu_layer(layer_id) if self._kv_offload else None + vcpu = self._v_cpu_layer(layer_id) if self._kv_offload else None for logical_s, phys_s, n in self._logical_chunks(layer_id, start_idx, end_idx): ks = logical_s - start_idx ke = ks + n kb[phys_s : phys_s + n].copy_(k[ks:ke]) vb[phys_s : phys_s + n].copy_(v[ks:ke]) - k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True) - v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True) + if self._kv_offload: + kcpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True) + vcpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True) self._update_ring_len_after_store(layer_id, end_idx) - self._record_cpu_update(layer_id) + if self._kv_offload: + self._record_cpu_update(layer_id) def _read_logical(self, layer_id: int, attn_start: int, local_end: int, which: str) -> torch.Tensor: - if self._kv_offload: - base = self._k_gpu_buf if which == "k" else self._v_gpu_buf - else: - base = self._k_layer(layer_id) if which == "k" else self._v_layer(layer_id) + self._check_layer_loaded(layer_id) + base = self._k_layer(layer_id) if which == "k" else self._v_layer(layer_id) chunks = self._logical_chunks(layer_id, attn_start, local_end) if not chunks: return torch.empty(0, self._num_heads, self._head_dim, device=self._device, dtype=self._dtype) @@ -336,12 +340,8 @@ def k_cache( attn_start: int | None = None, local_end: int | None = None, ) -> torch.Tensor: - if not self._kv_offload: - if attn_start is None and local_end is None: - attn_start, local_end = 0, self.get_local_end(layer_id) - return self._read_logical(layer_id, int(attn_start), int(local_end), "k") if attn_start is None and local_end is None: - return self._k_gpu_buf + attn_start, local_end = 0, self.get_local_end(layer_id) return self._read_logical(layer_id, int(attn_start), int(local_end), "k") def v_cache( @@ -350,12 +350,8 @@ def v_cache( attn_start: int | None = None, local_end: int | None = None, ) -> torch.Tensor: - if not self._kv_offload: - if attn_start is None and local_end is None: - attn_start, local_end = 0, self.get_local_end(layer_id) - return self._read_logical(layer_id, int(attn_start), int(local_end), "v") if attn_start is None and local_end is None: - return self._v_gpu_buf + attn_start, local_end = 0, self.get_local_end(layer_id) return self._read_logical(layer_id, int(attn_start), int(local_end), "v") def get_global_end(self, layer_id: int) -> int: @@ -381,16 +377,20 @@ def roll_window(self, layer_id: int, sink_tokens: int, num_evicted: int) -> None self._ring_set(layer_id, active=True, sink=sink, head=head, recent_len=recent_len) def reset(self) -> None: - if not self._kv_offload: + if self._kv_offload: + self.sync_all() + self._k_cpu.zero_() + self._v_cpu.zero_() + self._k_gpu_buf.zero_() + self._v_gpu_buf.zero_() self._global_end.zero_() self._local_end.zero_() self._init_ring_metadata() + self._reset_offload_state() return - self.sync_all() self._global_end.zero_() self._local_end.zero_() self._init_ring_metadata() - self._reset_offload_state() class StepRollingKVCachePool(RollingKVCachePool): @@ -438,9 +438,13 @@ def _meta_idx(self, layer_id: int): return (self._step(), int(layer_id)) def _k_layer(self, layer_id: int) -> torch.Tensor: + if self._kv_offload: + return self._k_gpu_buf return self._k_buffer[self._step(), layer_id] def _v_layer(self, layer_id: int) -> torch.Tensor: + if self._kv_offload: + return self._v_gpu_buf return self._v_buffer[self._step(), layer_id] def _k_cpu_layer(self, layer_id: int) -> torch.Tensor: @@ -454,7 +458,7 @@ def _offload_index(self, layer_id: int) -> tuple[int, ...]: def _init_kv_buffer(self) -> None: if self._kv_offload: - self._init_kv_buffer_offload_step() + self._init_kv_buffer_offload() return S, L, N, H, D = self.num_steps, self._num_layers, self._cache_size, self._num_heads, self._head_dim self._k_buffer = torch.empty(S, L, N, H, D, dtype=self._dtype, device=self._device) @@ -463,24 +467,28 @@ def _init_kv_buffer(self) -> None: self._local_end = torch.zeros(S, L, dtype=torch.long, device="cpu") self._init_ring_metadata() - def _init_kv_buffer_offload_step(self) -> None: + def _init_kv_buffer_offload(self) -> None: + from loguru import logger + S, L, N, H, D = self.num_steps, self._num_layers, self._cache_size, self._num_heads, self._head_dim + d = self._device + self._k_cpu = torch.empty(S, L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() self._v_cpu = torch.empty(S, L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() - self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) - self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=self._device) + self._k_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=d) + self._v_gpu_buf = torch.empty(N, H, D, dtype=self._dtype, device=d) + self._global_end = torch.zeros(S, L, dtype=torch.long, device="cpu") self._local_end = torch.zeros(S, L, dtype=torch.long, device="cpu") self._init_ring_metadata() - self._init_offload_state((S, L)) gpu_mb = (self._k_gpu_buf.nbytes + self._v_gpu_buf.nbytes) / (1024 * 1024) cpu_mb = (self._k_cpu.nbytes + self._v_cpu.nbytes) / (1024 * 1024) logger.info( - "[StepRollingKVCachePool+ring+offload] steps={}, one GPU staging layer N={} tokens: {:.1f} MB, CPU pinned: {:.1f} MB", + "[{}+offload] steps={}, GPU staging layer: {:.1f} MB, CPU pinned: {:.1f} MB", + self.__class__.__name__, self.num_steps, - N, gpu_mb, cpu_mb, ) @@ -515,6 +523,8 @@ def __init__( *, kv_offload: bool = False, ) -> None: + if kv_offload: + raise ValueError("SpatialRollingKVCachePool does not support kv_offload.") self._spatial_len = int(spatial_len) super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device, kv_offload=kv_offload) @@ -523,9 +533,6 @@ def spatial_len(self) -> int: return self._spatial_len def _init_kv_buffer(self) -> None: - if self._kv_offload: - self._init_kv_buffer_offload_spatial() - return L, S, N, H, D = self._num_layers, self._spatial_len, self._cache_size, self._num_heads, self._head_dim self._k_buffer = torch.empty(L, S, N, H, D, dtype=self._dtype, device=self._device) self._v_buffer = torch.empty(L, S, N, H, D, dtype=self._dtype, device=self._device) @@ -533,26 +540,6 @@ def _init_kv_buffer(self) -> None: self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") self._init_ring_metadata() - def _init_kv_buffer_offload_spatial(self) -> None: - L, S, N, H, D = self._num_layers, self._spatial_len, self._cache_size, self._num_heads, self._head_dim - self._k_cpu = torch.empty(L, S, N, H, D, dtype=self._dtype, device="cpu").pin_memory() - self._v_cpu = torch.empty(L, S, N, H, D, dtype=self._dtype, device="cpu").pin_memory() - self._k_gpu_buf = torch.empty(S, N, H, D, dtype=self._dtype, device=self._device) - self._v_gpu_buf = torch.empty(S, N, H, D, dtype=self._dtype, device=self._device) - self._global_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._local_end = torch.zeros(L, dtype=torch.long, device="cpu") - self._init_ring_metadata() - - self._init_offload_state((L,)) - - gpu_mb = (self._k_gpu_buf.nbytes + self._v_gpu_buf.nbytes) / (1024 * 1024) - cpu_mb = (self._k_cpu.nbytes + self._v_cpu.nbytes) / (1024 * 1024) - logger.info( - "[SpatialRollingKVCachePool+ring+offload] one GPU staging layer: {:.1f} MB, CPU pinned: {:.1f} MB", - gpu_mb, - cpu_mb, - ) - def _k_layer(self, layer_id: int) -> torch.Tensor: return self._k_buffer[layer_id] @@ -560,40 +547,24 @@ def _v_layer(self, layer_id: int) -> torch.Tensor: return self._v_buffer[layer_id] def _k_cpu_layer(self, layer_id: int) -> torch.Tensor: - return self._k_cpu[layer_id] + raise NotImplementedError("SpatialRollingKVCachePool does not support kv_offload.") def _v_cpu_layer(self, layer_id: int) -> torch.Tensor: - return self._v_cpu[layer_id] + raise NotImplementedError("SpatialRollingKVCachePool does not support kv_offload.") def store_kv(self, k: torch.Tensor, v: torch.Tensor, start_idx: int, end_idx: int, layer_id: int) -> None: if end_idx <= start_idx: return - if not self._kv_offload: - kb, vb = self._k_layer(layer_id), self._v_layer(layer_id) - for logical_s, phys_s, n in self._logical_chunks(layer_id, start_idx, end_idx): - ks = logical_s - start_idx - ke = ks + n - kb[:, phys_s : phys_s + n].copy_(k[:, ks:ke]) - vb[:, phys_s : phys_s + n].copy_(v[:, ks:ke]) - self._update_ring_len_after_store(layer_id, end_idx) - return - self._check_layer_loaded(layer_id) - k_cpu, v_cpu = self._k_cpu_layer(layer_id), self._v_cpu_layer(layer_id) + kb, vb = self._k_layer(layer_id), self._v_layer(layer_id) for logical_s, phys_s, n in self._logical_chunks(layer_id, start_idx, end_idx): ks = logical_s - start_idx ke = ks + n - self._k_gpu_buf[:, phys_s : phys_s + n].copy_(k[:, ks:ke]) - self._v_gpu_buf[:, phys_s : phys_s + n].copy_(v[:, ks:ke]) - k_cpu[:, phys_s : phys_s + n].copy_(k[:, ks:ke], non_blocking=True) - v_cpu[:, phys_s : phys_s + n].copy_(v[:, ks:ke], non_blocking=True) + kb[:, phys_s : phys_s + n].copy_(k[:, ks:ke]) + vb[:, phys_s : phys_s + n].copy_(v[:, ks:ke]) self._update_ring_len_after_store(layer_id, end_idx) - self._record_cpu_update(layer_id) def _read_logical(self, layer_id: int, attn_start: int, local_end: int, which: str) -> torch.Tensor: - if self._kv_offload: - base = self._k_gpu_buf if which == "k" else self._v_gpu_buf - else: - base = self._k_layer(layer_id) if which == "k" else self._v_layer(layer_id) + base = self._k_layer(layer_id) if which == "k" else self._v_layer(layer_id) chunks = self._logical_chunks(layer_id, attn_start, local_end) if not chunks: return torch.empty(self._spatial_len, 0, self._num_heads, self._head_dim, device=self._device, dtype=self._dtype) @@ -601,19 +572,11 @@ def _read_logical(self, layer_id: int, attn_start: int, local_end: int, which: s return parts[0] if len(parts) == 1 else torch.cat(parts, dim=1) def k_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor: - if not self._kv_offload: - if attn_start is None and local_end is None: - attn_start, local_end = 0, self.get_local_end(layer_id) - return self._read_logical(layer_id, int(attn_start), int(local_end), "k") if attn_start is None and local_end is None: - return self._k_gpu_buf + attn_start, local_end = 0, self.get_local_end(layer_id) return self._read_logical(layer_id, int(attn_start), int(local_end), "k") def v_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor: - if not self._kv_offload: - if attn_start is None and local_end is None: - attn_start, local_end = 0, self.get_local_end(layer_id) - return self._read_logical(layer_id, int(attn_start), int(local_end), "v") if attn_start is None and local_end is None: - return self._v_gpu_buf + attn_start, local_end = 0, self.get_local_end(layer_id) return self._read_logical(layer_id, int(attn_start), int(local_end), "v") diff --git a/lightx2v/common/kvcache/utils.py b/lightx2v/common/kvcache/utils.py index 6fd414918..bbc679351 100755 --- a/lightx2v/common/kvcache/utils.py +++ b/lightx2v/common/kvcache/utils.py @@ -1,32 +1,9 @@ -import json import math -import os -import numpy as np import torch -import torch.nn.functional as Fn -from loguru import logger from packaging.version import parse -from scipy import integrate, special - -from .kernel import fp4_dequantize - -try: - from lightx2v_kernel.kv_cache import dequantize_kv_cache_fp4 -except ImportError: - dequantize_kv_cache_fp4 = None - -try: - from fouroversix import QuantizationConfig, QuantizeBackend - from fouroversix.quantize.quantized_tensor import QuantizedTensor, from_blocked -except ImportError: - QuantizedTensor = None - from_blocked = None - QuantizationConfig = None - QuantizeBackend = None _KV_TORCH_VER = None -_VALID_DEQUANT_BACKENDS = frozenset({"cuda", "triton", "pytorch"}) def _kvcache_dma_stream_priority() -> int: @@ -39,15 +16,6 @@ def _kvcache_dma_stream_priority() -> int: return 1 if _KV_TORCH_VER >= parse("2.7") else 0 -def ranked_calib_path(path: str, rank: int) -> str: - if not path: - return path - dot = path.rfind(".") - if dot <= 0: - return f"{path}.rank{rank}" - return f"{path[:dot]}.rank{rank}{path[dot:]}" - - def cdiv(n: int, m: int) -> int: return (n + m - 1) // m @@ -56,759 +24,3 @@ def lcm(a: int, b: int) -> int: if a == 0 or b == 0: return max(a, b) or 1 return a * b // math.gcd(a, b) - - -def compute_analytical_turboquant_codebook(head_dim: int, bits: int) -> dict: - """Lloyd-Max codebook on the sphere marginal (Beta on [-1,1]); returns JSON-serializable dict.""" - - def beta_pdf(x: np.ndarray, d: int) -> np.ndarray: - if d <= 2: - raise ValueError(f"head_dim d={d} too small for TurboQuant codebook (need d>=3)") - log_const = special.gammaln(d / 2.0) - 0.5 * np.log(np.pi) - special.gammaln((d - 1) / 2.0) - exponent = (d - 3) / 2.0 - x = np.clip(x, -1 + 1e-15, 1 - 1e-15) - log_val = log_const + exponent * np.log(1 - x**2) - return np.exp(log_val) - - def conditional_mean(lo: float, hi: float, d: int) -> float: - num, _ = integrate.quad(lambda x: x * beta_pdf(np.array([x]), d)[0], lo, hi) - den, _ = integrate.quad(lambda x: beta_pdf(np.array([x]), d)[0], lo, hi) - if den < 1e-30: - return (lo + hi) / 2.0 - return num / den - - def mse_cost(centroids: np.ndarray, d: int) -> float: - n = len(centroids) - boundaries = np.zeros(n + 1) - boundaries[0] = -1.0 - boundaries[-1] = 1.0 - for i in range(n - 1): - boundaries[i + 1] = (centroids[i] + centroids[i + 1]) / 2.0 - cost = 0.0 - for i in range(n): - lo, hi = boundaries[i], boundaries[i + 1] - c = centroids[i] - val, _ = integrate.quad(lambda x: (x - c) ** 2 * beta_pdf(np.array([x]), d)[0], lo, hi) - cost += val - return cost - - d, n_clusters = head_dim, 2**bits - x_grid = np.linspace(-1 + 1e-10, 1 - 1e-10, 10000) - pdf_vals = beta_pdf(x_grid, d) - cdf_vals = np.cumsum(pdf_vals) * (x_grid[1] - x_grid[0]) - cdf_vals /= cdf_vals[-1] - quantile_edges = np.linspace(0, 1, n_clusters + 1) - centroids = np.zeros(n_clusters) - for i in range(n_clusters): - q_lo, q_hi = quantile_edges[i], quantile_edges[i + 1] - q_mid = (q_lo + q_hi) / 2.0 - idx = min(int(np.searchsorted(cdf_vals, q_mid)), len(x_grid) - 1) - centroids[i] = x_grid[idx] - - prev_cost = float("inf") - cost = 0.0 - for _ in range(200): - boundaries = np.zeros(n_clusters + 1) - boundaries[0] = -1.0 - boundaries[-1] = 1.0 - for i in range(n_clusters - 1): - boundaries[i + 1] = (centroids[i] + centroids[i + 1]) / 2.0 - new_centroids = np.zeros(n_clusters) - for i in range(n_clusters): - new_centroids[i] = conditional_mean(boundaries[i], boundaries[i + 1], d) - cost = mse_cost(new_centroids, d) - centroids = new_centroids - if abs(prev_cost - cost) < 1e-12: - break - prev_cost = cost - - boundaries = np.zeros(n_clusters + 1) - boundaries[0] = -1.0 - boundaries[-1] = 1.0 - for i in range(n_clusters - 1): - boundaries[i + 1] = (centroids[i] + centroids[i + 1]) / 2.0 - - return { - "centroids": centroids.tolist(), - "boundaries": boundaries.tolist(), - "mse_per_coord": float(cost), - "mse_total": float(cost * d), - "d": d, - "bits": bits, - "source": "analytical", - } - - -def export_turboquant_codebook_json( - head_dim: int, - bits: int, - out_dir: str, -) -> str: - """Pre-compute Lloyd-Max codebook (sphere marginal Beta on [-1,1]) and save JSON. - - Output format matches ``/turboquant`` filename ``codebook_d{d}_b{b}.json`` (loadable via inference engine). - Requires numpy + scipy. - """ - os.makedirs(out_dir, exist_ok=True) - path = os.path.join(out_dir, f"codebook_d{head_dim}_b{bits}.json") - if os.path.isfile(path): - return path - - cb = compute_analytical_turboquant_codebook(head_dim, bits) - cb.pop("source", None) - with open(path, "w", encoding="utf-8") as f: - json.dump(cb, f, indent=2) - logger.info("[TurboQuant] wrote codebook {!r} (d={}, bits={})", path, head_dim, bits) - return path - - -def tq_fw_load_codebook_record( - head_dim: int, - bits: int, - codebook_dir: str | None, - codebook_cache_dir: str | None, - export_missing: bool, -) -> dict: - """Load codebook JSON dict; optional compute+write to cache dir.""" - subdirs = [p for p in (codebook_dir, codebook_cache_dir) if p] - name = f"codebook_d{head_dim}_b{bits}.json" - for ddir in subdirs: - p = os.path.join(ddir, name) - if os.path.isfile(p): - with open(p, "r", encoding="utf-8") as f: - return json.load(f) - if export_missing and codebook_cache_dir: - export_turboquant_codebook_json(head_dim, bits, codebook_cache_dir) - p = os.path.join(codebook_cache_dir, name) - with open(p, "r", encoding="utf-8") as f: - return json.load(f) - raise FileNotFoundError(f"TurboQuant codebook not found: {name} under {subdirs or '(no dirs)'}; run export_turboquant_codebook_json(...) or set codebook_cache_dir + export_missing_codebooks.") - - -def tq_fw_pack_indices(indices: torch.Tensor, bits: int) -> torch.Tensor: - """Bit-pack integer indices (aligned with /turboquant ``quantizer._pack_indices``).""" - - d = indices.shape[-1] - batch_shape = indices.shape[:-1] - if bits == 1: - vals_per_byte = 8 - elif bits == 2: - vals_per_byte = 4 - elif bits <= 4: - vals_per_byte = 2 - bits = 4 - else: - return indices.to(torch.uint8) - - padded_d = ((d + vals_per_byte - 1) // vals_per_byte) * vals_per_byte - if padded_d > d: - indices = Fn.pad(indices.to(torch.uint8), (0, padded_d - d), value=0) - reshaped = indices.to(torch.uint8).reshape(*batch_shape, -1, vals_per_byte) - shifts = torch.arange(vals_per_byte, device=indices.device, dtype=torch.uint8) * bits - packed = (reshaped << shifts).sum(dim=-1, dtype=torch.uint8) - return packed - - -def tq_fw_unpack_indices(packed: torch.Tensor, bits: int, d: int) -> torch.Tensor: - batch_shape = packed.shape[:-1] - if bits == 1: - vals_per_byte = 8 - elif bits == 2: - vals_per_byte = 4 - elif bits <= 4: - vals_per_byte = 2 - bits = 4 - else: - return packed.long() - - mask = (1 << bits) - 1 - shifts = torch.arange(vals_per_byte, device=packed.device, dtype=torch.uint8) * bits - unpacked = (packed.unsqueeze(-1) >> shifts) & mask - unpacked = unpacked.reshape(*batch_shape, -1) - return unpacked[..., :d].long() - - -def tq_fw_packed_width(head_dim: int, bits: int) -> int: - if bits > 4: - return head_dim - if bits == 1: - vpb = 8 - elif bits == 2: - vpb = 4 - else: - vpb = 2 - padded_d = ((head_dim + vpb - 1) // vpb) * vpb - return padded_d // vpb - - -def tq_fw_generate_rotation_matrix( - d: int, - device: torch.device, - dtype: torch.dtype = torch.float32, - seed: int = 42, -) -> torch.Tensor: - rng = torch.Generator(device="cpu") - rng.manual_seed(seed) - G = torch.randn(d, d, generator=rng, dtype=torch.float32) - Q, R = torch.linalg.qr(G) - diag_sign = torch.sign(torch.diag(R)) - Q = Q * diag_sign.unsqueeze(0) - return Q.to(device=device, dtype=dtype) - - -def tq_fw_generate_qjl_matrix( - d: int, - device: torch.device, - dtype: torch.dtype = torch.float32, - seed: int = 12345, -) -> torch.Tensor: - rng = torch.Generator(device="cpu") - rng.manual_seed(seed) - S = torch.randn(d, d, generator=rng, dtype=torch.float32) - return S.to(device=device, dtype=dtype) - - -def tq_fw_rotate_forward(x: torch.Tensor, Pi: torch.Tensor) -> torch.Tensor: - return torch.matmul(x, Pi.T) - - -def tq_fw_rotate_backward(y: torch.Tensor, Pi: torch.Tensor) -> torch.Tensor: - return torch.matmul(y, Pi) - - -def tq_fw_pack_qjl_signs(projected: torch.Tensor) -> torch.Tensor: - signs = (projected > 0).to(torch.uint8) - d = signs.shape[-1] - if d % 8 != 0: - signs = torch.nn.functional.pad(signs, (0, 8 - d % 8), value=0) - signs_reshaped = signs.reshape(*signs.shape[:-1], -1, 8) - powers = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=signs.device, dtype=torch.uint8) - return (signs_reshaped * powers).sum(dim=-1, dtype=torch.uint8) - - -def tq_fw_unpack_qjl_signs(packed: torch.Tensor, dim: int) -> torch.Tensor: - powers = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=packed.device, dtype=torch.uint8) - unpacked = ((packed.unsqueeze(-1) & powers) > 0).float() - signs = unpacked.reshape(*packed.shape[:-1], -1)[..., :dim] - return 2.0 * signs - 1.0 - - -def tq_group_quantize_values(v: torch.Tensor, bits: int, group_size: int) -> dict: - """Group min-max quantize V; ``v`` shape (B,H,S,D). Returns packed data + scales + zeros.""" - orig_shape = v.shape - d = orig_shape[-1] - n_groups = d // group_size - if d % group_size != 0: - raise ValueError(f"head_dim {d} must divide value_group_size {group_size}") - v_grouped = v.reshape(*orig_shape[:-1], n_groups, group_size) - v_min = v_grouped.min(dim=-1, keepdim=True).values - v_max = v_grouped.max(dim=-1, keepdim=True).values - n_levels = 2**bits - 1 - scale = (v_max - v_min) / n_levels - scale = scale.clamp(min=1e-10) - zero = v_min - v_q = ((v_grouped - zero) / scale).round().clamp(0, n_levels).to(torch.uint8) - v_q_flat = v_q.reshape(*orig_shape[:-1], d) - if bits == 2: - v_4 = v_q_flat.reshape(*orig_shape[:-1], d // 4, 4) - packed = v_4[..., 0] | (v_4[..., 1] << 2) | (v_4[..., 2] << 4) | (v_4[..., 3] << 6) - elif bits == 4: - v_2 = v_q_flat.reshape(*orig_shape[:-1], d // 2, 2) - packed = v_2[..., 0] | (v_2[..., 1] << 4) - else: - packed = v_q_flat - return { - "data": packed, - "scales": scale.squeeze(-1).to(torch.float16), - "zeros": zero.squeeze(-1).to(torch.float16), - "bits": bits, - "group_size": group_size, - "shape": tuple(orig_shape), - } - - -def tq_group_dequantize_values(comp: dict) -> torch.Tensor: - bits = int(comp["bits"]) - group_size = int(comp["group_size"]) - packed = comp["data"] - d = comp["shape"][-1] - batch_shape = comp["shape"][:-1] - if bits == 2: - v0 = packed & 0x03 - v1 = (packed >> 2) & 0x03 - v2 = (packed >> 4) & 0x03 - v3 = (packed >> 6) & 0x03 - data = torch.stack([v0, v1, v2, v3], dim=-1).reshape(*batch_shape, packed.shape[-1] * 4) - elif bits == 4: - v0 = packed & 0x0F - v1 = (packed >> 4) & 0x0F - data = torch.stack([v0, v1], dim=-1).reshape(*batch_shape, packed.shape[-1] * 2) - else: - data = packed - data = data.float() - n_groups = d // group_size - data = data.reshape(*batch_shape, n_groups, group_size) - scales = comp["scales"].unsqueeze(-1).float() - zeros = comp["zeros"].unsqueeze(-1).float() - return (data * scales + zeros).reshape(*batch_shape, d) - - -def tq_value_group_packed_width(head_dim: int, bits: int) -> int: - if bits == 2: - return head_dim // 4 - if bits == 4: - return head_dim // 2 - return head_dim - - -def tq_lloyd_max_from_histogram_counts( - hist_counts, - n_centroids: int, - max_iter: int = 150, -): - """1D Lloyd-Max on a uniform histogram over [-1, 1]. ``hist_counts`` shape (n_bins,).""" - - n_bins = int(hist_counts.shape[0]) - edges = np.linspace(-1.0, 1.0, n_bins + 1) - centers = (edges[:-1] + edges[1:]) / 2.0 - w = hist_counts.astype(np.float64) - total = w.sum() - if total < 1e-30: - raise ValueError("TurboQuant calib histogram is empty") - w /= total - - cdf = np.cumsum(w) - targets = (np.arange(n_centroids, dtype=np.float64) + 0.5) / n_centroids - centroids = np.sort(np.interp(targets, cdf, centers)) - - for _ in range(max_iter): - boundaries = np.zeros(n_centroids + 1) - boundaries[0] = -1.0 - boundaries[-1] = 1.0 - for i in range(n_centroids - 1): - boundaries[i + 1] = (centroids[i] + centroids[i + 1]) / 2.0 - - assign = np.searchsorted(boundaries, centers, side="right") - 1 - assign = np.clip(assign, 0, n_centroids - 1) - - new_c = np.zeros(n_centroids) - for j in range(n_centroids): - mask = assign == j - ww = w[mask].sum() - if ww > 1e-30: - new_c[j] = (w[mask] * centers[mask]).sum() / ww - else: - new_c[j] = centroids[j] - - if np.max(np.abs(new_c - centroids)) < 1e-9: - centroids = new_c - break - centroids = new_c - - boundaries = np.zeros(n_centroids + 1) - boundaries[0] = -1.0 - boundaries[-1] = 1.0 - for i in range(n_centroids - 1): - boundaries[i + 1] = (centroids[i] + centroids[i + 1]) / 2.0 - - assign = np.searchsorted(boundaries, centers, side="right") - 1 - assign = np.clip(assign, 0, n_centroids - 1) - mse = 0.0 - for j in range(n_centroids): - mask = assign == j - if w[mask].sum() > 0: - mse += float((w[mask] * (centroids[j] - centers[mask]) ** 2).sum()) - - return centroids, boundaries, mse - - -def turboquant_codebook_dict_from_histogram( - hist: torch.Tensor, - head_dim: int, - bits: int, - *, - n_bins: int = 4096, -) -> dict: - """Build TurboQuant JSON codebook dict from accumulated marginal histogram (rotated unit keys/values).""" - - hc = hist.detach().cpu().numpy().astype(np.float64) - if hc.shape[0] != n_bins: - raise ValueError(f"hist length {hc.shape[0]} != n_bins {n_bins}") - - n_centroids = 2**bits - if hc.sum() < 1: - logger.warning( - "[TurboQuant calib] empty histogram for d={}, bits={}; using analytical codebook.", - head_dim, - bits, - ) - cb = compute_analytical_turboquant_codebook(head_dim, bits) - cb.pop("source", None) - cb["source"] = "analytical_fallback" - return cb - - centroids, boundaries, mse_coord = tq_lloyd_max_from_histogram_counts(hc, n_centroids) - return { - "centroids": centroids.tolist(), - "boundaries": boundaries.tolist(), - "mse_per_coord": float(mse_coord), - "mse_total": float(mse_coord * head_dim), - "d": head_dim, - "bits": bits, - "source": "empirical_histogram", - } - - -def build_turboquant_codebooks_from_calib_histograms( - hist_k: torch.Tensor, - *, - head_dim: int, - key_bits: int, - n_bins: int = 4096, -) -> dict[str, dict]: - """Produce filename -> codebook dict for JSON export (inference loader compatible).""" - out: dict[str, dict] = {} - if key_bits < 2: - raise ValueError("TurboQuantProd requires key_bits >= 2") - b_k = key_bits - 1 - ck = turboquant_codebook_dict_from_histogram(hist_k, head_dim, b_k, n_bins=n_bins) - out[f"codebook_d{head_dim}_b{b_k}.json"] = ck - return out - - -class TurboQuantMSEInference(torch.nn.Module): - """TurboQuant MSE stage: rotation + Lloyd-Max via ``searchsorted`` + bit-pack.""" - - def __init__( - self, - dim: int, - bits: int, - device: torch.device, - seed: int, - codebook: dict, - dtype: torch.dtype = torch.float32, - ): - super().__init__() - self.dim = dim - self.bits = bits - self.register_buffer("Pi", tq_fw_generate_rotation_matrix(dim, device, dtype, seed=seed)) - c = torch.tensor(codebook["centroids"], device=device, dtype=dtype) - b = torch.tensor(codebook["boundaries"], device=device, dtype=dtype) - self.register_buffer("centroids", c) - self.register_buffer("boundaries", b) - self.register_buffer("decision_boundaries", b[1:-1].contiguous()) - - @torch.no_grad() - def compress_bhsd(self, x: torch.Tensor) -> dict: - norms = x.norm(dim=-1, keepdim=False) - x_unit = x / (norms.unsqueeze(-1) + 1e-10) - y = tq_fw_rotate_forward(x_unit.float(), self.Pi) - indices = torch.searchsorted(self.decision_boundaries, y.contiguous()) - packed = tq_fw_pack_indices(indices, self.bits) - B, H, S, D = x.shape - return { - "idx_bytes": packed, - "vec_norms": norms.to(torch.float16), - "shape": (B, H, S, D), - "bits": self.bits, - } - - @torch.no_grad() - def decompress_bhsd(self, comp: dict) -> torch.Tensor: - B, H, S, D = comp["shape"] - bits = int(comp["bits"]) - idx = tq_fw_unpack_indices(comp["idx_bytes"], bits, D) - y_hat = self.centroids[idx] - x_hat = tq_fw_rotate_backward(y_hat, self.Pi) - return x_hat * comp["vec_norms"].unsqueeze(-1).float() - - -class TurboQuantProdInference(torch.nn.Module): - """TurboQuant inner-product path: (key_bits-1) MSE + QJL on residual.""" - - def __init__( - self, - dim: int, - bits: int, - device: torch.device, - seed: int, - codebook_mse: dict, - dtype: torch.dtype = torch.float32, - ): - super().__init__() - assert bits >= 2, "TurboQuantProd needs key_bits >= 2" - self.dim = dim - self.bits = bits - self.mse_bits = bits - 1 - self.qjl_scale = math.sqrt(math.pi / 2.0) / dim - self.mse = TurboQuantMSEInference(dim, self.mse_bits, device, seed, codebook_mse, dtype=dtype) - self.register_buffer("S", tq_fw_generate_qjl_matrix(dim, device, dtype, seed=seed + 1000)) - - @torch.no_grad() - def compress_bhsd(self, x: torch.Tensor) -> dict: - mse_c = self.mse.compress_bhsd(x) - x_mse = self.mse.decompress_bhsd(mse_c) - residual = x - x_mse - residual_norms = residual.norm(dim=-1) - projected = torch.matmul(residual.float(), self.S.T) - qjl_packed = tq_fw_pack_qjl_signs(projected) - B, H, S, D = x.shape - return { - "mse_idx_bytes": mse_c["idx_bytes"], - "qjl_bytes": qjl_packed, - "residual_norms": residual_norms.to(torch.float16), - "vec_norms": mse_c["vec_norms"], - "shape": (B, H, S, D), - "mse_bits": self.mse_bits, - } - - @torch.no_grad() - def decompress_bhsd(self, comp: dict) -> torch.Tensor: - B, H, S, D = comp["shape"] - mse_c = { - "idx_bytes": comp["mse_idx_bytes"], - "vec_norms": comp["vec_norms"], - "shape": (B, H, S, D), - "bits": int(comp["mse_bits"]), - } - x_mse = self.mse.decompress_bhsd(mse_c) - signs = tq_fw_unpack_qjl_signs(comp["qjl_bytes"], D) - x_qjl = torch.matmul(signs, self.S) - x_qjl = x_qjl * (self.qjl_scale * comp["residual_norms"].unsqueeze(-1).float()) - return x_mse + x_qjl - - -def normalize_dequant_backend(backend: str | None) -> str: - """Map ``kv_quant.backend`` to a KV dequant implementation name.""" - if backend is None or not str(backend).strip(): - raise ValueError( - f"kv_quant.backend is required for longlive_fp4 dequant. Choose one of: {', '.join(sorted(_VALID_DEQUANT_BACKENDS))}", - ) - name = str(backend).strip().lower() - if name == "torch": - name = "pytorch" - if name == "transformer_engine": - name = "cuda" - if name not in _VALID_DEQUANT_BACKENDS: - allowed = ", ".join(sorted(_VALID_DEQUANT_BACKENDS)) - raise ValueError(f"Unsupported KV dequant backend {backend!r}. Expected one of: {allowed}") - return name - - -def scale_rule_to_fp4_limits(scale_rule) -> tuple[float, float]: - if hasattr(scale_rule, "max_allowed_e2m1_value") and hasattr( - scale_rule, - "max_allowed_e4m3_value", - ): - return ( - float(scale_rule.max_allowed_e2m1_value()), - float(scale_rule.max_allowed_e4m3_value()), - ) - - normalized = str(scale_rule).lower() - if "." in normalized: - normalized = normalized.rsplit(".", 1)[-1] - normalized = normalized.strip().strip("\"'") - - if normalized == "static_4": - return 4.0, 448.0 - if normalized == "static_6": - return 6.0, 448.0 - if normalized in {"mse", "mae", "l1_norm", "abs_max"}: - return 6.0, 256.0 - - raise ValueError(f"Unsupported FP4 scale_rule: {scale_rule}") - - -def _dequant_blocks_cuda( - values: list[torch.Tensor], - scale_factors: list[torch.Tensor], - amax_list: list[torch.Tensor], - *, - num_heads: int, - block_token_size: int, - dtype: torch.dtype, - scale_rule, -) -> torch.Tensor: - """Fused parallel dequant via ``lightx2v_kernel`` (optional LongLive op fallback).""" - if not values or values[0].device.type != "cuda": - raise RuntimeError("KV dequant backend=cuda requires CUDA tensors.") - - e2m1_max, e4m3_max = scale_rule_to_fp4_limits(scale_rule) - out = dequantize_kv_cache_fp4( - values, - scale_factors, - amax_list, - num_heads=num_heads, - block_token_size=block_token_size, - dtype=dtype, - e2m1_max=e2m1_max, - e4m3_max=e4m3_max, - ) - return out[0] - - -def _global_scale_for_qt(qt: QuantizedTensor) -> torch.Tensor: - e2m1_max, e4m3_max = scale_rule_to_fp4_limits(qt.scale_rule) - return qt.amax / (e2m1_max * e4m3_max) - - -def _dequant_qt_triton(qt: QuantizedTensor, dtype: torch.dtype) -> torch.Tensor: - """Per-block NVFP4 dequant via LightX2V Triton (fouroversix tensor layout).""" - block_size = qt.dtype.block_size() - padded_shape = qt.padded_shape - scales_2d = from_blocked( - qt.scale_factors, - (padded_shape[0], padded_shape[1] // block_size), - ) - return fp4_dequantize( - qt.values, - scales_2d, - _global_scale_for_qt(qt), - block_size=block_size, - dtype=dtype, - ) - - -def _dequant_blocks_pytorch( - blocks: list[QuantizedTensor], - num_heads: int, - head_dim: int, - block_token_size: int, - dtype: torch.dtype, -) -> torch.Tensor: - parts = [qt.dequantize(dtype).view(block_token_size, num_heads, head_dim) for qt in blocks] - return torch.cat(parts, dim=0) - - -def _dequant_blocks_triton( - blocks: list[QuantizedTensor], - num_heads: int, - head_dim: int, - block_token_size: int, - dtype: torch.dtype, - device: torch.device, -) -> torch.Tensor: - n_blks = len(blocks) - h, d = num_heads, head_dim - out = torch.zeros( - [1, n_blks * block_token_size, h, d], - dtype=dtype, - device=device, - ) - for block_idx, qt in enumerate(blocks): - deq = _dequant_qt_triton(qt, dtype) - deq = deq[: block_token_size * h] - t_start = block_idx * block_token_size - t_end = t_start + block_token_size - out[0, t_start:t_end, :, :] = deq.view(block_token_size, h, d) - return out[0] - - -def dequantize_kv_blocks( - blocks: list[QuantizedTensor], - num_heads: int, - head_dim: int, - block_token_size: int, - dtype: torch.dtype, - device: torch.device, - *, - backend: str, -) -> torch.Tensor: - """ - Dequantize block list to ``[T, num_heads, head_dim]`` (T = len(blocks) * block_token_size). - - ``backend`` must be one of ``cuda``, ``triton``, ``pytorch`` (``torch`` aliases ``pytorch``). - """ - if not blocks: - return torch.empty(0, num_heads, head_dim, device=device, dtype=dtype) - - mode = normalize_dequant_backend(backend) - scale_rule = blocks[0].scale_rule - values = [qt.values for qt in blocks] - scale_factors = [qt.scale_factors for qt in blocks] - amax_list = [qt.amax for qt in blocks] - - if mode == "cuda": - return _dequant_blocks_cuda( - values, - scale_factors, - amax_list, - num_heads=num_heads, - block_token_size=block_token_size, - dtype=dtype, - scale_rule=scale_rule, - ) - - if mode == "triton": - return _dequant_blocks_triton( - blocks, - num_heads, - head_dim, - block_token_size, - dtype, - device, - ) - - return _dequant_blocks_pytorch(blocks, num_heads, head_dim, block_token_size, dtype) - - -def dequantize_token_range( - blocks: list[QuantizedTensor], - attn_start: int, - local_end: int, - *, - cache_size: int, - num_heads: int, - head_dim: int, - block_token_size: int, - dtype: torch.dtype, - device: torch.device, - backend: str, -) -> torch.Tensor: - if local_end <= attn_start: - return torch.empty(0, num_heads, head_dim, device=device, dtype=dtype) - - t0 = attn_start - t1 = min(local_end, cache_size) - b0 = t0 // block_token_size - b1 = (t1 - 1) // block_token_size - sub = blocks[b0 : b1 + 1] - nhd = dequantize_kv_blocks( - sub, - num_heads, - head_dim, - block_token_size, - dtype, - device, - backend=backend, - ) - off0 = t0 - b0 * block_token_size - off1 = t1 - b0 * block_token_size - return nhd[off0:off1].contiguous() - - -def k_smooth(k: torch.Tensor) -> torch.Tensor: - """Per-head mean removal before K quantization (LongLive).""" - return k - k.mean(dim=-1, keepdim=True) - - -def clone_quantized_tensor(qt: QuantizedTensor) -> QuantizedTensor: - return QuantizedTensor( - values=qt.values.clone(), - scale_factors=qt.scale_factors.clone(), - amax=qt.amax.clone() if qt.amax is not None else None, - dtype=qt.dtype, - original_shape=qt.original_shape, - scale_rule=qt.scale_rule, - padded_shape=qt.padded_shape, - ) - - -def build_fp4_quant_config( - *, - scale_rule: str = "mse", - backend: str | None = None, -) -> QuantizationConfig: - backend_enum = QuantizeBackend(backend) if backend is not None else None - return QuantizationConfig(scale_rule=scale_rule, backend=backend_enum) diff --git a/lightx2v/common/ops/attn/utils/all2all.py b/lightx2v/common/ops/attn/utils/all2all.py index 037eb00b1..9bc717048 100644 --- a/lightx2v/common/ops/attn/utils/all2all.py +++ b/lightx2v/common/ops/attn/utils/all2all.py @@ -1,8 +1,58 @@ import torch import torch.distributed as dist +try: + from sageattn3_sparse import dequant_fp4 as dequant_fp4_sage3 + from sageattn3_sparse import quant_fp4 as quant_fp4_sage3 +except ImportError: + quant_fp4_sage3 = None + dequant_fp4_sage3 = None -def all2all_seq2head(input, group=None): + +def _fp8_all_to_all(input_t, group=None): + """All-to-all with per-token fp8 compression along the last dim. + + ``input_t`` must be contiguous with dim 0 == world_size (the all-to-all + split dim). Only the quantized payload + per-token scale cross the wire, + roughly halving bf16/fp16 communication volume. Returns the dequantized + tensor in ``input_t``'s original dtype. + """ + from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm + + orig_dtype = input_t.dtype + shape = input_t.shape + hidden = shape[-1] + q, scale = quant_fp8_vllm(input_t.reshape(-1, hidden).contiguous()) + q = q.reshape(shape) + scale = scale.reshape(*shape[:-1], 1).contiguous() + out_q = torch.empty_like(q) + out_scale = torch.empty_like(scale) + dist.all_to_all_single(out_q, q, group=group) + dist.all_to_all_single(out_scale, scale, group=group) + return dequant_fp8_vllm(out_q, out_scale, orig_dtype) + + +def _fp4_all_to_all(input_t, group=None): + """All-to-all with SageAttention3 FP4 compression along the last dim.""" + if quant_fp4_sage3 is None or dequant_fp4_sage3 is None: + raise ImportError("sageattn3_sparse quant_fp4/dequant_fp4 is required for seq_p_fp4_comm.") + + shape = input_t.shape + hidden = shape[-1] + q, scale = quant_fp4_sage3(input_t.reshape(1, 1, -1, hidden).contiguous()) + q = q.reshape(*shape[:-1], hidden // 2).contiguous() + scale = scale.reshape(*shape[:-1], hidden // 16).contiguous() + out_q = torch.empty_like(q) + out_scale = torch.empty_like(scale) + dist.all_to_all_single(out_q, q, group=group) + dist.all_to_all_single(out_scale, scale, group=group) + return dequant_fp4_sage3( + out_q.reshape(1, 1, -1, hidden // 2), + out_scale.reshape(1, 1, -1, hidden // 16), + ).reshape(shape) + + +def all2all_seq2head(input, group=None, use_fp8_comm=False, use_fp4_comm=False): """ 将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。 @@ -14,6 +64,7 @@ def all2all_seq2head(input, group=None): """ # 确保输入是一个3D张量 assert input.dim() == 3, f"input must be 3D tensor" + assert not (use_fp8_comm and use_fp4_comm), "use_fp8_comm and use_fp4_comm can't be enabled at the same time." # 获取当前进程的世界大小 world_size = dist.get_world_size(group=group) @@ -30,11 +81,14 @@ def all2all_seq2head(input, group=None): .contiguous() # 确保内存连续 ) - # 创建一个与输入张量相同形状的输出张量 - output = torch.empty_like(input_t) - # 执行 all-to-all 操作,将输入张量的内容分发到所有进程 - dist.all_to_all_single(output, input_t, group=group) + if use_fp8_comm: + output = _fp8_all_to_all(input_t, group=group) + elif use_fp4_comm: + output = _fp4_all_to_all(input_t, group=group) + else: + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) # 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状 output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous() diff --git a/lightx2v/models/networks/wan/infer/audio/transformer_infer.py b/lightx2v/models/networks/wan/infer/audio/transformer_infer.py index 777ae1c07..0db1a169d 100755 --- a/lightx2v/models/networks/wan/infer/audio/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/audio/transformer_infer.py @@ -298,8 +298,20 @@ def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, sh else: start_frame = segment_idx * frames q_rope, k_rope = self._apply_rope_sp(q, k, grid_sizes, freqs, start_frame) - k_to_store = all2all_seq2head(k_rope, group=self.seq_p_group) - v_to_store = all2all_seq2head(v, group=self.seq_p_group) + use_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False) + use_fp4_comm = self.config["parallel"].get("seq_p_fp4_comm", False) + k_to_store = all2all_seq2head( + k_rope, + group=self.seq_p_group, + use_fp8_comm=use_fp8_comm, + use_fp4_comm=use_fp4_comm, + ) + v_to_store = all2all_seq2head( + v, + group=self.seq_p_group, + use_fp8_comm=use_fp8_comm, + use_fp4_comm=use_fp4_comm, + ) kv_cache.store_kv(k_to_store, v_to_store, local_start_idx, local_end_idx, self.block_idx) else: kv_cache.store_kv(k, v, local_start_idx, local_end_idx, self.block_idx) diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py index 54ec9842d..27dd42cc1 100755 --- a/lightx2v/models/runners/wan/wan_audio_runner.py +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -1193,12 +1193,6 @@ def init_kv_cache_manager(self): if sp_group is not None or torch.distributed.get_world_size() > 1: torch.distributed.barrier(group=sp_group) - def end_run(self): - kv_mgr = getattr(getattr(self, "model", None), "kv_cache_manager", None) - if kv_mgr is not None: - kv_mgr.save_calibration() - super().end_run() - @ProfilingContext4DebugL2("Run DiT") def run_main(self): try: diff --git a/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py b/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py index 22c394df2..9fdf10b1c 100755 --- a/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py +++ b/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py @@ -12,7 +12,7 @@ from lightx2v.utils.envs import * from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER -from lightx2v.utils.utils import get_rank_and_world_size, wan_vae_to_comfy +from lightx2v.utils.utils import get_rank_and_world_size from lightx2v.utils.video_recorder import VideoRecorder try: @@ -33,11 +33,6 @@ class LingbotFastRunner(LingbotRunner): def __init__(self, config): WanRunner.__init__(self, config) self.control_type = config.get("control_type", "cam") - self.is_live = config.get("is_live", False) - if self.is_live: - self.width = self.config["target_width"] - self.height = self.config["target_height"] - self.run_main = self.run_main_live def load_transformer(self): wan_model_kwargs = { @@ -99,19 +94,15 @@ def run_segment(self, segment_idx=0): return self.model.scheduler.stream_output - def decode_segment_latents(self, segment_idx: int, latents: torch.Tensor) -> torch.Tensor: - return self.run_vae_decoder(latents.detach().clone()) + def decode_segment_latents(self, segment_idx: int, segment_latents: torch.Tensor) -> torch.Tensor: + is_first = segment_idx == 0 + is_last = segment_idx == self.video_segment_num - 1 + return self.vae_decoder.cached_decode_withflag(segment_latents.to(GET_DTYPE()), is_first, is_last) def init_run(self): self.init_kv_cache_manager() super().init_run() - def end_run(self): - kv_mgr = getattr(getattr(self, "model", None), "kv_cache_manager", None) - if kv_mgr is not None: - kv_mgr.save_calibration() - super().end_run() - @ProfilingContext4DebugL2("Run DiT") def run_main(self, total_steps=None): self.init_run() @@ -123,49 +114,47 @@ def run_main(self, total_steps=None): self.vae_decoder = self.load_vae_decoder() vae_decoder = AsyncVAEChunkDecoder.from_config(self.config, device=torch.device("cuda"), vae_decoder=self.vae_decoder) - with ( - no_sync_profiling(enabled=vae_decoder.is_async), - ProfilingContext4DebugL1( + with no_sync_profiling(enabled=vae_decoder.is_async): + with ProfilingContext4DebugL1( f"AR chunk total {self.video_segment_num} chunks", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, metrics_labels=["DefaultRunner"], - ), - ): - try: - for segment_idx in range(self.video_segment_num): - logger.info(f"start chunk {segment_idx + 1}/{self.video_segment_num}") - with ProfilingContext4DebugL1( - f"chunk end2end {segment_idx + 1}/{self.video_segment_num}", - recorder_mode=GET_RECORDER_MODE(), - metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, - metrics_labels=["DefaultRunner"], - ): - self.check_stop() - self.init_run_segment(segment_idx) - latents = self.run_segment(segment_idx) - - with ProfilingContext4DebugL1("step_pre_in_rerun"): - self.model.scheduler.step_pre( - seg_index=segment_idx, - step_index=self.model.scheduler.infer_steps - 1, - is_rerun=True, - ) - with ProfilingContext4DebugL1("infer_main_in_rerun"): - self.model.infer(self.inputs) - - vae_decoder.submit(self.decode_segment_latents, segment_idx, latents) - torch.cuda.empty_cache() - decoded_chunks = vae_decoder.finish() - finally: - if "vae_decoder" in locals(): - vae_decoder.finish() - if lazy_vae: - del self.vae_decoder - torch.cuda.empty_cache() - gc.collect() - - self.gen_video = torch.cat(decoded_chunks, dim=0) + ): + try: + for segment_idx in range(self.video_segment_num): + logger.info(f"start chunk {segment_idx + 1}/{self.video_segment_num}") + with ProfilingContext4DebugL1( + f"chunk end2end {segment_idx + 1}/{self.video_segment_num}", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, + metrics_labels=["DefaultRunner"], + ): + self.check_stop() + self.init_run_segment(segment_idx) + latents = self.run_segment(segment_idx) + + with ProfilingContext4DebugL1("step_pre_in_rerun"): + self.model.scheduler.step_pre( + seg_index=segment_idx, + step_index=self.model.scheduler.infer_steps - 1, + is_rerun=True, + ) + with ProfilingContext4DebugL1("infer_main_in_rerun"): + self.model.infer(self.inputs) + + vae_decoder.submit(self.decode_segment_latents, segment_idx, latents) + torch.cuda.empty_cache() + decoded_chunks = vae_decoder.finish() + finally: + if "vae_decoder" in locals(): + vae_decoder.finish() + if lazy_vae: + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + + self.gen_video = torch.cat(decoded_chunks, dim=2) self.gen_video_final = self.gen_video gen_video_final = self.process_images_after_vae_decoder() self.end_run() @@ -186,53 +175,3 @@ def init_video_recorder(self): livestream_url=output_video_path, fps=record_fps, ) - - @ProfilingContext4DebugL1("End run segment") - def end_run_segment(self, segment_idx=None): - with ProfilingContext4DebugL1("step_pre_in_rerun"): - self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True) - with ProfilingContext4DebugL1("infer_main_in_rerun"): - self.model.infer(self.inputs) - - self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video - if self.is_live: - if self.video_recorder: - stream_video = wan_vae_to_comfy(self.gen_video) - self.video_recorder.pub_video(stream_video) - - torch.cuda.empty_cache() - - @ProfilingContext4DebugL2("Run DiT") - def run_main_live(self, total_steps=None): - try: - self.init_video_recorder() - logger.info(f"init video_recorder: {self.video_recorder}") - rank, world_size = get_rank_and_world_size() - if rank == world_size - 1: - assert self.video_recorder is not None - self.video_recorder.start(self.width, self.height) - if world_size > 1 and dist is not None: - dist.barrier() - self.init_run() - if self.config.get("compile", False): - self.model.select_graph_for_compile(self.input_info) - - for segment_idx in range(self.video_segment_num): - logger.info(f"start segment {segment_idx + 1}/{self.video_segment_num}") - with ProfilingContext4DebugL1( - f"segment end2end {segment_idx + 1}/{self.video_segment_num}", - recorder_mode=GET_RECORDER_MODE(), - metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, - metrics_labels=["DefaultRunner"], - ): - self.check_stop() - self.init_run_segment(segment_idx) - latents = self.run_segment(segment_idx) - self.gen_video = self.run_vae_decoder(latents) - self.end_run_segment(segment_idx) - finally: - if hasattr(self.model, "inputs"): - self.end_run() - if self.video_recorder: - self.video_recorder.stop() - self.video_recorder = None diff --git a/lightx2v/models/runners/wan/wan_sf_runner.py b/lightx2v/models/runners/wan/wan_sf_runner.py index e101bbb28..dead7573b 100755 --- a/lightx2v/models/runners/wan/wan_sf_runner.py +++ b/lightx2v/models/runners/wan/wan_sf_runner.py @@ -7,13 +7,12 @@ from lightx2v.models.networks.wan.sf_model import WanSFModel from lightx2v.models.runners.wan.wan_runner import WanRunner, build_wan_model_with_lora from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler -from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE from lightx2v.server.metrics import monitor_cli from lightx2v.utils.async_vae import AsyncVAEChunkDecoder from lightx2v.utils.envs import * from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER -from lightx2v.utils.utils import get_rank_and_world_size, wan_vae_to_comfy +from lightx2v.utils.utils import get_rank_and_world_size from lightx2v.utils.video_recorder import VideoRecorder @@ -21,12 +20,6 @@ class WanSFRunner(WanRunner): def __init__(self, config): super().__init__(config) - self.is_live = config.get("is_live", False) - if self.is_live: - self.vae_cls = WanSFVAE - self.width = self.config["target_width"] - self.height = self.config["target_height"] - self.run_main = self.run_main_live def load_transformer(self): wan_model_kwargs = {"model_path": self.config["model_path"], "config": self.config, "device": self.init_device} @@ -64,10 +57,7 @@ def get_video_segment_num(self): def run_vae_decoder(self, latents): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): self.vae_decoder = self.load_vae_decoder() - if self.is_live: - images = self.vae_decoder.decode(latents.to(GET_DTYPE()), use_cache=True) - else: - images = self.vae_decoder.decode(latents.to(GET_DTYPE())) + images = self.vae_decoder.decode(latents.to(GET_DTYPE()), use_cache=True) if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): del self.vae_decoder torch.cuda.empty_cache() @@ -78,12 +68,6 @@ def init_run(self): self.init_kv_cache_manager() super().init_run() - def end_run(self): - kv_mgr = getattr(getattr(self, "model", None), "kv_cache_manager", None) - if kv_mgr is not None: - kv_mgr.save_calibration() - super().end_run() - def run_segment(self, segment_idx=0): infer_steps = self.model.scheduler.infer_steps for step_index in range(infer_steps): @@ -110,8 +94,10 @@ def run_segment(self, segment_idx=0): return self.model.scheduler.stream_output - def decode_segment_latents(self, segment_idx: int, latents: torch.Tensor) -> torch.Tensor: - return self.run_vae_decoder(latents.detach().clone()) + def decode_segment_latents(self, segment_idx: int, segment_latents: torch.Tensor) -> torch.Tensor: + is_first = segment_idx == 0 + is_last = segment_idx == self.video_segment_num - 1 + return self.vae_decoder.cached_decode_withflag(segment_latents.to(GET_DTYPE()), is_first, is_last) def init_video_recorder(self): output_video_path = self.input_info.save_result_path @@ -136,13 +122,6 @@ def end_run_segment(self, segment_idx=None): self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True) with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"): self.model.infer(self.inputs) - - self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video - if self.is_live: - if self.video_recorder: - stream_video = wan_vae_to_comfy(self.gen_video) - self.video_recorder.pub_video(stream_video) - torch.cuda.empty_cache() @ProfilingContext4DebugL2("Run DiT") @@ -156,89 +135,48 @@ def run_main(self, total_steps=None): self.vae_decoder = self.load_vae_decoder() vae_decoder = AsyncVAEChunkDecoder.from_config(self.config, device=torch.device("cuda"), vae_decoder=self.vae_decoder) - with ( - no_sync_profiling(enabled=vae_decoder.is_async), - ProfilingContext4DebugL1( + with no_sync_profiling(enabled=vae_decoder.is_async): + with ProfilingContext4DebugL1( f"AR chunk total {self.video_segment_num} chunks", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, metrics_labels=["DefaultRunner"], - ), - ): - try: - for segment_idx in range(self.video_segment_num): - logger.info(f"start chunk {segment_idx + 1}/{self.video_segment_num}") - with ProfilingContext4DebugL1( - f"chunk end2end {segment_idx + 1}/{self.video_segment_num}", - recorder_mode=GET_RECORDER_MODE(), - metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, - metrics_labels=["DefaultRunner"], - ): - self.check_stop() - self.init_run_segment(segment_idx) - latents = self.run_segment(segment_idx) - - with ProfilingContext4DebugL1("step_pre_in_rerun"): - self.model.scheduler.step_pre( - seg_index=segment_idx, - step_index=self.model.scheduler.infer_steps - 1, - is_rerun=True, - ) - with ProfilingContext4DebugL1("infer_main_in_rerun"): - self.model.infer(self.inputs) - - vae_decoder.submit(self.decode_segment_latents, segment_idx, latents) - torch.cuda.empty_cache() - decoded_chunks = vae_decoder.finish() - finally: - if "vae_decoder" in locals(): - vae_decoder.finish() - if lazy_vae: - del self.vae_decoder - torch.cuda.empty_cache() - gc.collect() - - self.gen_video = torch.cat(decoded_chunks, dim=0) + ): + try: + for segment_idx in range(self.video_segment_num): + logger.info(f"start chunk {segment_idx + 1}/{self.video_segment_num}") + with ProfilingContext4DebugL1( + f"chunk end2end {segment_idx + 1}/{self.video_segment_num}", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, + metrics_labels=["DefaultRunner"], + ): + self.check_stop() + self.init_run_segment(segment_idx) + latents = self.run_segment(segment_idx) + + with ProfilingContext4DebugL1("step_pre_in_rerun"): + self.model.scheduler.step_pre( + seg_index=segment_idx, + step_index=self.model.scheduler.infer_steps - 1, + is_rerun=True, + ) + with ProfilingContext4DebugL1("infer_main_in_rerun"): + self.model.infer(self.inputs) + + vae_decoder.submit(self.decode_segment_latents, segment_idx, latents) + torch.cuda.empty_cache() + decoded_chunks = vae_decoder.finish() + finally: + if "vae_decoder" in locals(): + vae_decoder.finish() + if lazy_vae: + del self.vae_decoder + torch.cuda.empty_cache() + gc.collect() + + self.gen_video = torch.cat(decoded_chunks, dim=2) self.gen_video_final = self.gen_video gen_video_final = self.process_images_after_vae_decoder() self.end_run() return gen_video_final - - @ProfilingContext4DebugL2("Run DiT") - def run_main_live(self, total_steps=None): - try: - self.init_video_recorder() - logger.info(f"init video_recorder: {self.video_recorder}") - rank, world_size = get_rank_and_world_size() - if rank == world_size - 1: - assert self.video_recorder is not None, "video_recorder is required for stream audio input for rank 2" - self.video_recorder.start(self.width, self.height) - if world_size > 1: - dist.barrier() - self.init_run() - if self.config.get("compile", False): - self.model.select_graph_for_compile(self.input_info) - - for segment_idx in range(self.video_segment_num): - logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") - with ProfilingContext4DebugL1( - f"segment end2end {segment_idx + 1}/{self.video_segment_num}", - recorder_mode=GET_RECORDER_MODE(), - metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration, - metrics_labels=["DefaultRunner"], - ): - self.check_stop() - # 1. default do nothing - self.init_run_segment(segment_idx) - # 2. main inference loop - latents = self.run_segment(segment_idx) - # 3. vae decoder - self.gen_video = self.run_vae_decoder(latents) - # 4. default do nothing - self.end_run_segment(segment_idx) - finally: - if hasattr(self.model, "inputs"): - self.end_run() - if self.video_recorder: - self.video_recorder.stop() - self.video_recorder = None diff --git a/scripts/lingbot/run_lingbot_fast_i2v.sh b/scripts/lingbot/run_lingbot_fast_i2v.sh index a57c3d9b8..8f4b38962 100755 --- a/scripts/lingbot/run_lingbot_fast_i2v.sh +++ b/scripts/lingbot/run_lingbot_fast_i2v.sh @@ -13,7 +13,7 @@ python -m lightx2v.infer \ --model_cls lingbot_world_fast \ --task i2v \ --model_path $model_path \ ---config_json /data/nvme4/gushiqiao/new/LightX2V/configs/lingbot_fast/lingbot_fast_i2v_kv_turboquant.json \ +--config_json /data/nvme4/gushiqiao/new/LightX2V/configs/lingbot_fast/lingbot_fast_i2v.json \ --prompt "A serene lakeside scene with a lone tree standing in calm water, surrounded by distant snow-capped mountains under a bright blue sky with drifting white clouds — gentle ripples reflect the tree and sky, creating a tranquil, meditative atmosphere." \ --negative_prompt "" \ --image_path /data/nvme4/gushiqiao/lingbot-world/examples/03/image.jpg \ diff --git a/scripts/matrix_game2/run_matrix_game2_gta_drive.sh b/scripts/matrix_game2/run_matrix_game2_gta_drive.sh index 7f08656b7..adc13a1cb 100644 --- a/scripts/matrix_game2/run_matrix_game2_gta_drive.sh +++ b/scripts/matrix_game2/run_matrix_game2_gta_drive.sh @@ -1,8 +1,8 @@ #!/bin/bash # set path firstly -lightx2v_path=path to Lightx2v -model_path=path to Skywork/Matrix-Game-2.0 +lightx2v_path=/data/nvme4/gushiqiao/new/LightX2V +model_path=/data/nvme4/models/mgv2 export CUDA_VISIBLE_DEVICES=0 @@ -15,6 +15,6 @@ python -m lightx2v.infer \ --model_path $model_path \ --config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive.json \ --prompt '' \ ---image_path gta_drive/0003.png \ +--image_path /data/nvme4/gushiqiao/0003.png \ --save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive.mp4 \ --seed 42 diff --git a/scripts/matrix_game2/run_matrix_game2_universal.sh b/scripts/matrix_game2/run_matrix_game2_universal.sh index d7480efc3..a3aadcb30 100644 --- a/scripts/matrix_game2/run_matrix_game2_universal.sh +++ b/scripts/matrix_game2/run_matrix_game2_universal.sh @@ -2,8 +2,8 @@ #!/bin/bash # set path firstly -lightx2v_path=path to Lightx2v -model_path=path to Skywork/Matrix-Game-2.0 +lightx2v_path=/data/nvme4/gushiqiao/new/LightX2V +model_path=/data/nvme4/models/mgv2 export CUDA_VISIBLE_DEVICES=0 @@ -16,6 +16,6 @@ python -m lightx2v.infer \ --model_path $model_path \ --config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_universal.json \ --prompt '' \ ---image_path universal/0007.png \ +--image_path /data/nvme4/gushiqiao/0007.png \ --save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_universal.mp4 \ --seed 42 diff --git a/scripts/seko_talk/ar/run_seko_talk_ar_01_base.sh b/scripts/seko_talk/ar/run_seko_talk_ar_01_base.sh index 2f9f045ba..e598c631f 100644 --- a/scripts/seko_talk/ar/run_seko_talk_ar_01_base.sh +++ b/scripts/seko_talk/ar/run_seko_talk_ar_01_base.sh @@ -1,21 +1,21 @@ #!/bin/bash -lightx2v_path=/mnt/devsft_afs_1/gushiqiao/LightX2V -model_path=/models/seko_ar +lightx2v_path=/data/nvme4/gushiqiao/new/LightX2V +model_path=/data/nvme5/gushiqiao/models/SekoTalk-Distill-AR/ -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export CUDA_VISIBLE_DEVICES=0 # set environment variables source ${lightx2v_path}/scripts/base/base.sh -torchrun --nproc-per-node 8 -m lightx2v.infer \ +python -m lightx2v.infer \ --model_cls seko_talk_ar \ --task rs2v \ --model_path $model_path \ --config_json ${lightx2v_path}/configs/seko_talk/ar/seko_talk_ar_kv_dist.json \ --prompt "In a high-fidelity realistic lifestyle aesthetic, a young woman is captured lounging comfortably on a plush beige sectional sofa within a bright, minimalist interior defined by clean white walls and soft, diffused natural lighting. The subject has shoulder-length chestnut brown hair, striking blue eyes, and is dressed in a cozy, loose-fitting beige knit sweater paired with black pants. She wears a single white wireless earbud in her right ear and a ring on her left hand. Beside her on the sofa cushion lies a closed dark-colored laptop or tablet and a white earbud charging case. Throughout the scene, she leans back in a relaxed posture, her left elbow resting on the top of the sofa cushion with her hand gently supporting her head, while her right hand rests on her drawn-up knee. She is actively speaking, displaying a natural and engaging expression with rhythmic lip movements and subtle facial animations that suggest a casual conversation or vlog recording. Her movements are fluid and grounded; she occasionally shifts her weight slightly against the cushions and uses small, nuanced hand gestures with her right hand, lifting it briefly from her knee to emphasize her words before settling it back down. The camera maintains a fixed, static medium shot, framing her centrally to capture her upper body and the immediate cozy environment, creating an intimate and serene atmosphere without any camera movement or shifts in focus." \ --negative_prompt "low quality,blurry,pixelated,low resolution,noise,artifacts,poor lighting, overexposed, underexposed, distorted, unnatural, deformed, weird,scared,anatomy,mutated, wrong proportions, extra limbs,floating objects, disconnected, gravity-defying, impossible shadows,wrong lighting,non-existent reflections,inconsistent perspective, repetitive, monotonous, monotonous, generic, watermark, ugly, high contrast, bad photo, font, username, error, logo, words, letters, digits, autograph, trademark, name, twisted face, (poorly drawn hands, malformed hands, missing fingers, unnatural hand positions, blur hand, multiple fingers, multiple arms), static, naked, artifacts, oversaturated" \ ---image_path "/mnt/devsft_afs_1/gushiqiao/1_素材图.png" \ ---audio_path "/mnt/devsft_afs_1/gushiqiao/1_素材图.mp3" \ +--image_path "/data/nvme4/gushiqiao/new/example/1_素材图.png" \ +--audio_path "/data/nvme4/gushiqiao/new/example/1_素材图.mp3" \ --save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk_ar_sp2.mp4 \ --seed 0 diff --git a/scripts/self_forcing/run_wan_t2v_sf.sh b/scripts/self_forcing/run_wan_t2v_sf.sh index 39fe6047a..b31e35dc3 100755 --- a/scripts/self_forcing/run_wan_t2v_sf.sh +++ b/scripts/self_forcing/run_wan_t2v_sf.sh @@ -1,9 +1,9 @@ #!/bin/bash # set path firstly -lightx2v_path=path to Lightx2v -model_path=path to Wan2.1-T2V-1.3B -export CUDA_VISIBLE_DEVICES=0 +lightx2v_path=/data/nvme4/gushiqiao/new/LightX2V +model_path=/data/nvme0/gushiqiao/models/official_models/Wan2.1-T2V-1.3B/ +export CUDA_VISIBLE_DEVICES=7 # set environment variables source ${lightx2v_path}/scripts/base/base.sh