From 421e0052eb93dbe6fb80aaf5a62a2276e58815a3 Mon Sep 17 00:00:00 2001 From: boby-cloudforge Date: Sun, 3 May 2026 15:45:02 +0200 Subject: [PATCH] =?UTF-8?q?[CI]=E3=80=90Hackathon=2010th=20Spring=20No.45-?= =?UTF-8?q?part2=E3=80=91Add=20SM75/SM80=20compile=20guards=20for=20cutlas?= =?UTF-8?q?s=20and=20MoE=20tail=20ops?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_ops/gpu_ops/cpp_extensions.cc | 137 ++++++++++----------------- custom_ops/setup_ops.py | 120 +++++++++++------------ 2 files changed, 105 insertions(+), 152 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 204ea33e50b..1a4d92c9a92 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -537,13 +537,13 @@ std::vector TextImageGatherScatter( const bool is_scatter); std::vector count_tokens_per_expert_func( - const paddle::Tensor& topk_ids, - int64_t num_experts, - bool compute_padded_cumsum = false); -void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids); + const paddle::Tensor& topk_ids, int64_t num_experts); +void GetPositionIdsAndMaskEncoderBatch( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids, + const paddle::Tensor& mask_encoder_batch); std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_nope, @@ -691,10 +691,6 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); -std::vector FusedCastSigmoidBias(const paddle::Tensor& input, - const paddle::Tensor& bias, - std::string cast_type); - std::vector NoauxTcRedundant( paddle::Tensor& scores, paddle::Tensor& scores_with_bias, @@ -785,11 +781,6 @@ std::vector BuildSamplingParams( const int64_t token_num_output_cpu, const int64_t increment_value); -std::vector BuildSamplingParamLogProb( - const paddle::Tensor& input_params, - const paddle::Tensor& token_num_per_batch, - int64_t token_num_output_cpu); - void SpecTokenPenaltyMultiScores( const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, @@ -884,22 +875,20 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& has_running_seqs, const paddle::Tensor& step_input_ids, + const paddle::Tensor& adaptive_step_input_len, const paddle::Tensor& step_output_ids, const paddle::Tensor& step_output_len, const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& is_paused, + const paddle::Tensor& mask_rollback, const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, const paddle::Tensor& step_idx, const paddle::Tensor& end_tokens, - const paddle::Tensor& max_dec_len); - -void NaiveUpdateModelStatus(const paddle::Tensor& accept_tokens, - const paddle::Tensor& accept_num, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& next_tokens, - const paddle::Tensor& cu_seqlens_q_output); + const paddle::Tensor& max_dec_len, + const bool is_naive_mode, + const bool prefill_one_step_stop); void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, @@ -982,17 +971,24 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, + const paddle::Tensor& batch_drop, const paddle::Tensor& pre_ids, + const paddle::Tensor& mask_rollback, + const paddle::Tensor& recompute_token_num, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, - const paddle::Tensor& target_model_seq_lens_encoder, - const paddle::Tensor& target_model_seq_lens_decoder, - const paddle::Tensor& target_model_step_idx, - const paddle::Tensor& target_model_stop_flags, - const paddle::Tensor& max_dec_len, - const paddle::Tensor& target_model_draft_tokens, - const int num_model_step, - const bool is_splitwise_prefill); + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int max_draft_token, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& draft_tokens, @@ -1025,17 +1021,7 @@ std::vector EagleGetSelfHiddenStates( const paddle::Tensor& input, const paddle::Tensor& last_seq_lens_this_time, const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder); - -std::vector EagleGatherHiddenStates( - const paddle::Tensor& input, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token_output, - const paddle::Tensor& cu_seqlens_q_output, - const paddle::Tensor& real_output_token_num); + const paddle::Tensor& step_idx); void MTPStepPaddle( const paddle::Tensor& base_model_stop_flags, @@ -1149,16 +1135,13 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder); -void SpeculateGetAcceptTokensAndLogits( - const paddle::Tensor& token_ids, - const paddle::Tensor& target_logits, - const paddle::Tensor& logits, - const paddle::Tensor& cu_batch_token_offset, - const paddle::Tensor& cu_seqlens_q_output, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& accept_num, - const paddle::Tensor& accept_tokens); +void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& ori_cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num); std::vector UpdateAttnMaskOffsets( const paddle::Tensor& ids_remove_padding, @@ -1167,8 +1150,10 @@ std::vector UpdateAttnMaskOffsets( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& attn_mask_offsets_full, + const paddle::Tensor& attn_mask_offsets_decoder, const paddle::Tensor& is_block_step, - const paddle::Tensor& decode_states); + const paddle::Tensor& decode_states, + const paddle::Tensor& mask_rollback); std::vector FusedNeoxRopeEmbedding( const paddle::Tensor& qkv, @@ -1647,22 +1632,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &GetPositionIdsAndMaskEncoderBatch, "get_position_ids_and_mask_encoder_batch function"); - /** - * cutlass_scaled_mm.cu - * cutlass_scaled_mm - * cutlass_scaled_mm_azp - */ +#ifdef ENABLE_SM75_EXT_OPS + /* cutlass_scaled_mm.cu: cutlass_scaled_mm, cutlass_scaled_mm_azp */ m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function"); m.def("cutlass_scaled_mm_azp", &CutlassScaledMmAzp, "cutlass_scaled_mm_azp function"); - /** - * quantization/common.cu - * static_scaled_fp8_quant - * dynamic_scaled_fp8_quant - * dynamic_per_token_scaled_fp8_quant - */ + /* quantization/common.cu: static/dynamic scaled fp8 quant ops */ m.def("static_scaled_fp8_quant", &StaticScaledFp8Quant, "static_scaled_fp8_quant function", @@ -1684,6 +1661,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("input"), py::arg("scales"), py::arg("scale_ub")); +#endif #ifdef ENABLE_SM80_EXT_OPS m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, @@ -1706,13 +1684,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); - m.def("fused_cast_sigmoid_bias", - &FusedCastSigmoidBias, - "Fused cast+sigmoid+bias for MoE gating scores", - py::arg("input"), - py::arg("bias"), - py::arg("cast_type") = std::string("float32")); - m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); @@ -1788,10 +1759,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &BuildSamplingParams, "build_sampling_params function"); - m.def("build_sampling_params_logprob", - &BuildSamplingParamLogProb, - "build_sampling_params_logprob function"); - m.def("speculate_get_token_penalty_multi_scores", &SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); @@ -1811,10 +1778,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &UnifiedUpdateModelStatus, "unified_update_model_status function"); - m.def("naive_update_model_status", - &NaiveUpdateModelStatus, - "naive_update_model_status function"); - m.def("speculate_set_value_by_flags_and_idx", &SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function"); @@ -1853,10 +1816,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &EagleGetSelfHiddenStates, "eagle_get_self_hidden_states function"); - m.def("eagle_gather_hidden_states", - &EagleGatherHiddenStates, - "eagle_gather_hidden_states function"); - m.def("mtp_step_paddle", &MTPStepPaddle, "mtp_step_paddle function"); m.def("speculate_step_paddle", @@ -1893,9 +1852,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &SpeculateInsertFirstToken, "speculate_insert_first_token function"); - m.def("speculate_get_accept_tokens_and_logits", - &SpeculateGetAcceptTokensAndLogits, - "speculate_get_accept_tokens_and_logits function"); + m.def("speculate_get_target_logits", + &SpeculateGetTargetLogits, + "speculate_get_target_logits function"); #endif m.def("update_attn_mask_offsets", @@ -1919,6 +1878,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("custom_numpy_to_tensor", &CustomNumpyToTensor, "custom_numpy_to_tensor function"); +#ifdef ENABLE_SM80_EXT_OPS m.def("prefill_permute_to_masked_gemm", &PrefillPermuteToMaskedGemm, py::arg("x"), @@ -1926,7 +1886,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("topk_ids"), py::arg("num_local_experts"), py::arg("max_token_num"), - "Prefill permute to masked GEMM for MoE"); + "prefill_permute_to_masked_gemm"); m.def("depermute_prefill_combine", &DepermutePrefillCombine, @@ -1934,7 +1894,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("indice_map"), py::arg("topk_weights"), py::arg("num_worst_tokens"), - "Depermute and combine expert outputs for MoE prefill"); + "depermute_prefill_combine"); m.def("radix_topk_ragged_transform", &RadixTopkRaggedTransform, @@ -1953,4 +1913,5 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("per_token_group_fp8_quant", &PerTokenGroupQuantFp8, "per_token_group_quant_fp8"); +#endif } diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 7b1bda32510..109028af6ec 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -315,7 +315,6 @@ def find_end_files(directory, end_str): "gpu_ops/swap_cache_batch.cu", "gpu_ops/swap_cache.cu", "gpu_ops/swap_cache_layout.cu", - "gpu_ops/swap_cache_optimized.cu", # 新增:优化的 KV cache 换入算子 "gpu_ops/step_system_cache.cu", "gpu_ops/cpp_extensions.cc", "gpu_ops/share_external_data.cu", @@ -331,7 +330,6 @@ def find_end_files(directory, end_str): "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", - "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", @@ -472,59 +470,58 @@ def find_end_files(directory, end_str): # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS. os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") - # Use non-exclusive checks against sm_versions so that building for - # multiple architectures (e.g. [80,90,100]) compiles kernels for ALL - # of them instead of only the highest one. - has_sm90 = 90 in sm_versions - has_sm100 = 100 in sm_versions and nvcc_version >= 12.9 - has_generic_fp8 = not has_sm90 and not has_sm100 # SM89 or other - - if has_sm90 or has_sm100: - nvcc_compile_args += [ - "-O3", - "-DNDEBUG", - ] - - if has_sm90: - print("SM90: Running SM90-specific FP8 kernel auto-generation.") - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py") - os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py") - - nvcc_compile_args += [ - "-DENABLE_SCALED_MM_SM90=1", - ] - sources += [ - "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", - "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", - ] - - if has_sm100: - print("SM100 (Blackwell): Applying SM100 configurations.") - # Placeholder for SM100-specific kernel auto-generation scripts - # These might be needed if Blackwell has new FP8 hardware features - # not covered by existing generic CUTLASS templates or SM90 scripts. - # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") - # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example - # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example - - # Add SM100 specific sources if any, e.g., for new hardware intrinsics - # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example - pass # No SM100 specific sources identified yet beyond what CUTLASS handles - - if has_generic_fp8: - # For SM89 (Ada) or other architectures without dedicated paths - print(f"SM{cc}: Running generic FP8 kernel auto-generation.") - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") - - if not has_sm90 and cc >= 90: - # When cc >= 90 but SM90 is not in the target list (e.g. only [80,100]), - # still run generic FP8 auto-generation for non-SM90 paths. - print(f"SM{cc}: Running generic FP8 kernel auto-generation (no SM90 target).") + if cc >= 90: # Hopper and newer + # SM90 (Hopper) specific auto-generation and flags + if cc == 90: # Only for SM90 + nvcc_compile_args += [ + # The gencode for 90a is added in get_gencode_flags now + # "-gencode", + # "arch=compute_90a,code=compute_90a", + "-O3", + "-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a + ] + print("SM90: Running SM90-specific FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py") + os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py") + + nvcc_compile_args += [ + "-DENABLE_SCALED_MM_SM90=1", + ] + sources += [ + "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", + "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", + ] + elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics + print("SM100 (Blackwell): Applying SM100 configurations.") + nvcc_compile_args += [ + # The gencode for 100a is added in get_gencode_flags + # "-gencode", + # "arch=compute_100a,code=compute_100a", + "-O3", # Common optimization flag + "-DNDEBUG", # Common debug flag + # Potentially add -DENABLE_SM100_FEATURES if specific macros are identified + ] + # Placeholder for SM100-specific kernel auto-generation scripts + # These might be needed if Blackwell has new FP8 hardware features + # not covered by existing generic CUTLASS templates or SM90 scripts. + # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") + # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example + # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example + + # Add SM100 specific sources if any, e.g., for new hardware intrinsics + # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example + pass # No SM100 specific sources identified yet beyond what CUTLASS handles + else: # For cc >= 89 but not 90 or 100 (e.g. SM89) + print(f"SM{cc}: Running generic FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") + + else: # For cc == 89 (Ada) + print("SM89: Running generic FP8 kernel auto-generation.") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") @@ -587,13 +584,14 @@ def find_end_files(directory, end_str): elif paddle.is_compiled_with_xpu(): assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops." elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): - _iluvatar_clang_cuda_flags = ["-Wno-non-pod-varargs", "-DPADDLE_DEV", "-DPADDLE_WITH_CUSTOM_DEVICE"] setup( name="fastdeploy_ops", ext_modules=CUDAExtension( extra_compile_args={ - "cxx": _iluvatar_clang_cuda_flags, - "nvcc": _iluvatar_clang_cuda_flags, + "nvcc": [ + "-DPADDLE_DEV", + "-DPADDLE_WITH_CUSTOM_DEVICE", + ] }, sources=[ "gpu_ops/save_with_output_msg.cc", @@ -619,7 +617,6 @@ def find_end_files(directory, end_str): "gpu_ops/get_img_boundaries.cc", "gpu_ops/fused_neox_rope_embedding.cu", "gpu_ops/get_output_ep.cc", - "gpu_ops/update_attn_mask_offsets.cu", "iluvatar_ops/moe_dispatch.cu", "iluvatar_ops/moe_reduce.cu", "iluvatar_ops/flash_attn_unpadded.cu", @@ -628,8 +625,6 @@ def find_end_files(directory, end_str): "iluvatar_ops/mixed_fused_attn.cu", "iluvatar_ops/w8a16_group_gemm.cu", "iluvatar_ops/w8a16_group_gemv.cu", - "iluvatar_ops/wi4a16_group_gemm.cu", - "iluvatar_ops/wi4a16_weight_quantize.cu", "iluvatar_ops/restore_tokens_per_expert.cu", "iluvatar_ops/runtime/iluvatar_context.cc", "iluvatar_ops/cpp_extensions.cc", @@ -687,7 +682,6 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", - "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", @@ -728,8 +722,6 @@ def find_end_files(directory, end_str): "-Igpu_ops", "-DPADDLE_DEV", "-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU", - "-Xcompiler", - "-Wno-non-pod-varargs", ], }