diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py index ad6a8b5b3a..b40251b8f8 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py @@ -1,6 +1,5 @@ # 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样 import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.infer_struct import InferStateInfo from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py deleted file mode 100644 index f51d611661..0000000000 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - -def token_decode_attention_flash_decoding( - q, - infer_state, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, - alloc_tensor_func=torch.empty, -): - BLOCK_SEQ = 256 - q_head_num, head_dim = q.shape[1], q.shape[2] - batch_size = infer_state.batch_size - max_kv_seq_len = infer_state.max_kv_seq_len - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" - ) - - light_ops.group8_int8kv_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py deleted file mode 100644 index b0a9b6245c..0000000000 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - -def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - q_head_num = q.shape[1] - head_dim = q.shape[2] - max_kv_seq_len = infer_state.max_kv_seq_len - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) - - light_ops.fp16_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_kv_seq_len, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 59d1f825a3..1c01cbd638 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -17,16 +17,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import torch from lightllm.utils.sgl_utils import sgl_ops -from lightllm.utils.light_utils import light_ops from typing import Callable, List, Optional, Tuple from lightllm.common.basemodel.triton_kernel.fused_moe.softmax_topk import softmax_topk from lightllm.common.triton_utils.autotuner import Autotuner -use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"] - def fused_topk( hidden_states: torch.Tensor, @@ -127,44 +123,6 @@ def biased_grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -# This is used by the Deepseek-V2 model -def cuda_grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - assert light_ops is not None, "lightllm_kernel is not installed." - - num_tokens = gating_output.shape[0] - topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32) - topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32) - token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32) - group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32) - if correction_bias is None: - correction_bias = torch.zeros_like(gating_output, dtype=torch.float32) - light_ops.grouped_topk( - topk_weights, - correction_bias, - topk_indices, - token_expert_indices, - gating_output.float(), - num_expert_group, - topk_group, - topk, - renormalize, - scoring_func, - group_scores, - ) - - return topk_weights, topk_indices - - def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -184,34 +142,22 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - if use_cuda_grouped_topk: - topk_weights, topk_ids = cuda_grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - correction_bias=correction_bias, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - ) - else: - group_score_topk_num = 1 - # for deepseek v3 - if topk_group == 4 and num_expert_group == 8 and top_k == 8: - group_score_topk_num = 2 - - topk_weights, topk_ids = triton_grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - correction_bias=correction_bias, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - group_score_used_topk_num=group_score_topk_num, - ) + group_score_topk_num = 1 + # for deepseek v3 + if topk_group == 4 and num_expert_group == 8 and top_k == 8: + group_score_topk_num = 2 + + topk_weights, topk_ids = triton_grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + group_score_used_topk_num=group_score_topk_num, + ) elif custom_routing_function is None: topk_weights, topk_ids = fused_topk( diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index b3d29b0527..65ec6cd145 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -8,19 +8,12 @@ from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from .quantize_method import WeightPack -if HAS_LIGHTLLM_KERNEL: - - def scaled_fp8_quant(tensor, *args, **kwargs): - return light_ops.per_token_quant_bf16_fp8(tensor) - -else: - if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant LIGHTLLM_USE_TRITON_FP8_SCALED_MM = os.getenv("LIGHTLLM_USE_TRITON_FP8_SCALED_MM", "False").upper() in [ "ON", diff --git a/lightllm/utils/light_utils.py b/lightllm/utils/light_utils.py deleted file mode 100644 index 944a0fe15f..0000000000 --- a/lightllm/utils/light_utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) -try: - # TODO: lightllm_kernel release - import lightllm_kernel - - light_ops = getattr(lightllm_kernel, "ops", lightllm_kernel) - HAS_LIGHTLLM_KERNEL = True -except: - light_ops = None - HAS_LIGHTLLM_KERNEL = False - logger.warning("lightllm_kernel is not installed, you can't use the api of it.") diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh index 9dc46d8807..7f7181ef06 100644 --- a/test/acc/test_qwen3.sh +++ b/test/acc/test_qwen3.sh @@ -4,6 +4,11 @@ LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --mo # second export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code +# test quant +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --quant_type vllm-fp8w8a8 + +export no_proxy="localhost,127.0.0.1,0.0.0.0,::1" HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code + LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_tpsp_mix_mode diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py index a01bbf32d8..3e2555b339 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py @@ -1,7 +1,6 @@ import pytest import torch -from lightllm.utils.light_utils import light_ops def alloc_tensor_func(shape, dtype, device): @@ -41,17 +40,17 @@ def __init__( # @pytest.mark.parametrize("shared_seq_len", [512]) @pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550]) @pytest.mark.parametrize("batch_size", list(range(6, 121, 6))) -def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len, batch_size): +def test_token_decode_attention_flash_decoding_diverse_matches_normal_decode(shared_seq_len, batch_size): """ - 测试 int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding - 与 ppl_int8kv_flash_decoding (baseline) 的对比。 + diverse 与 normal 均为仓库内 Triton 实现,应数值一致(无外部 CUDA extension)。 + diverse:int8kv_flash_decoding_diverse;对照:int8kv/normal token_decode_attention_flash_decoding。 """ from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) - from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( - token_decode_attention_flash_decoding as baseline_attention, + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.normal import ( + token_decode_attention_flash_decoding as normal_decode, ) num_heads = 32 @@ -89,7 +88,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le b_mark_shared_group = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") b_mark_shared_group[mark_shared_group_size - 1 :: mark_shared_group_size] = mark_shared_group_size - # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) + # 标准 int8 decode(单路径 Triton) baseline_infer_state = MockInferState( batch_size=batch_size, max_kv_seq_len=max_len_in_batch, @@ -98,7 +97,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le b_seq_len=b_seq_len, ) - # 创建 diverse 的 infer_state + # diverse:多流 + 共享前缀(Triton) diverse_infer_state = MockInferState( batch_size=batch_size, max_kv_seq_len=max_len_in_batch, @@ -110,7 +109,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le ) # 运行 baseline - baseline_out = baseline_attention( + normal_out = normal_decode( q=q.clone(), infer_state=baseline_infer_state, cache_k=cache_k, @@ -131,11 +130,10 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le ) print(f"\nshared_seq_len={shared_seq_len}\nbatch_size={batch_size}") - print(f"baseline_out: {baseline_out[0, 0, :4]}") + print(f"normal_out: {normal_out[0, 0, :4]}") print(f"diverse_out: {diverse_out[0, 0, :4]}") - print(f"max diff: {(baseline_out - diverse_out).abs().max()}") + print(f"max diff: {(normal_out - diverse_out).abs().max()}") - # 与 baseline 对比 assert torch.allclose( - baseline_out, diverse_out, atol=1e-2, rtol=1e-2 - ), f"Diverse attention output should match baseline for shared_seq_len={shared_seq_len}" + normal_out, diverse_out, atol=1e-2, rtol=1e-2 + ), f"diverse vs normal decode mismatch for shared_seq_len={shared_seq_len}" diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py index c7d4442543..985a33289f 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py @@ -140,9 +140,9 @@ def test_flash_decode_stage2_execution(shared_seq_len): if __name__ == "__main__": - import importlib + # 可选:对 Triton diverse stage2 做 cudagraph bench(仅本仓库内核,无外部 CUDA 扩展)。 + import triton - from lightllm.utils.light_utils import light_ops batch_sizes = [8, 16, 32, 64] seq_lens = [32, 64, 128, 256] @@ -150,7 +150,6 @@ def test_flash_decode_stage2_execution(shared_seq_len): results = [] for batch in batch_sizes: for seq in seq_lens: - # Clear GPU cache to reduce CUDA Graph capture failures. torch.cuda.empty_cache() setup_tensors = create_tensors( @@ -161,133 +160,33 @@ def test_flash_decode_stage2_execution(shared_seq_len): kv_len=seq, req_to_tokens_len=seq, ) - - # Outputs for CUDA implementation - mid_out_cuda = setup_tensors["mid_out"].clone() - mid_out_logsumexp_cuda = setup_tensors["mid_out_logsumexp"].clone() - - # Outputs for Triton implementation - mid_out_triton = setup_tensors["mid_out"].clone() - mid_out_logsumexp_triton = setup_tensors["mid_out_logsumexp"].clone() - - # Run CUDA to get reference - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - mid_out_cuda, - mid_out_logsumexp_cuda, - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - - # Run Triton - flash_decode_stage2( - q=setup_tensors["q"], - k=setup_tensors["k"], - k_scale=setup_tensors["k_scale"], - v=setup_tensors["v"], - v_scale=setup_tensors["v_scale"], - Req_to_tokens=setup_tensors["Req_to_tokens"], - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=setup_tensors["b_seq_len"], - b_shared_seq_len=setup_tensors["b_shared_seq_len"], - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=mid_out_triton, - mid_out_logsumexp=mid_out_logsumexp_triton, - block_seq=setup_tensors["block_seq"], - ) - - # Compare results - diff_mid_out = torch.abs(mid_out_cuda - mid_out_triton) - diff_logsumexp = torch.abs(mid_out_logsumexp_cuda - mid_out_logsumexp_triton) - max_diff_out = diff_mid_out.max().item() - max_diff_logsumexp = diff_logsumexp.max().item() - mean_diff_out = diff_mid_out.mean().item() - mean_diff_logsumexp = diff_logsumexp.mean().item() - - cos_sim_out = torch.nn.functional.cosine_similarity( - mid_out_cuda.flatten(), mid_out_triton.flatten(), dim=0 - ).item() - cos_sim_logsumexp = torch.nn.functional.cosine_similarity( - mid_out_logsumexp_cuda.flatten(), mid_out_logsumexp_triton.flatten(), dim=0 - ).item() - - print(f"\n[batch={batch}, seq={seq}] Consistency check:") - print(" mid_out:") - print(f" max_diff: {max_diff_out:.6f}, mean_diff: {mean_diff_out:.6f}, cosine_sim: {cos_sim_out:.8f}") - print(" logsumexp:") - print( - f" max_diff: {max_diff_logsumexp:.6f}, " - f"mean_diff: {mean_diff_logsumexp:.6f}, " - f"cosine_sim: {cos_sim_logsumexp:.8f}" - ) - - # Performance - fn_cuda = lambda: light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - setup_tensors["mid_out"], - setup_tensors["mid_out_logsumexp"], - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - ms_cuda = triton.testing.do_bench_cudagraph(fn_cuda, rep=100) - - fn_triton = lambda: flash_decode_stage2( - q=setup_tensors["q"], - k=setup_tensors["k"], - k_scale=setup_tensors["k_scale"], - v=setup_tensors["v"], - v_scale=setup_tensors["v_scale"], - Req_to_tokens=setup_tensors["Req_to_tokens"], - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=setup_tensors["b_seq_len"], - b_shared_seq_len=setup_tensors["b_shared_seq_len"], - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=setup_tensors["mid_out"], - mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], - block_seq=setup_tensors["block_seq"], - ) - ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) - - results.append( - { - "batch_size": batch, - "seq_len": seq, - "triton_ms": ms_triton, - "cuda_ms": ms_cuda, - } - ) + st = setup_tensors + + def bench_stage2(): + flash_decode_stage2( + q=st["q"], + k=st["k"], + k_scale=st["k_scale"], + v=st["v"], + v_scale=st["v_scale"], + Req_to_tokens=st["Req_to_tokens"], + B_req_idx=st["B_req_idx"], + B_Seqlen=st["b_seq_len"], + b_shared_seq_len=st["b_shared_seq_len"], + max_len_in_batch=st["max_len_in_batch"], + mid_out=st["mid_out"], + mid_out_logsumexp=st["mid_out_logsumexp"], + block_seq=st["block_seq"], + ) + + ms = triton.testing.do_bench_cudagraph(bench_stage2, rep=100) + results.append({"batch_size": batch, "seq_len": seq, "flash_decode_stage2_ms": ms}) print(results[-1]) - del setup_tensors print(f"\n{'='*80}") - print("SUMMARY - Performance Comparison") - print(f"{'='*80}") - print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12} {'cuda_ms':<12} {'vs cuda':<10}") + print(f"{'batch_size':<10} {'seq_len':<10} {'flash_decode_stage2_ms':<22}") print(f"{'-'*80}") for r in results: - vs_cuda = f"{r['cuda_ms']/r['triton_ms']:.2f}x" - emoji = "🎉" if r["triton_ms"] < r["cuda_ms"] else "" - print( - f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f} {r['cuda_ms']:<12.3f}" - f"{vs_cuda:<10} {emoji}" - ) + print(f"{r['batch_size']:<10} {r['seq_len']:<10} {r['flash_decode_stage2_ms']:<22.4f}") print(f"{'='*80}")