Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 49 additions & 88 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,13 @@ std::vector<paddle::Tensor> TextImageGatherScatter(
const bool is_scatter);

std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor& topk_ids,
int64_t num_experts,
bool compute_padded_cumsum = false);
void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids);
const paddle::Tensor& topk_ids, int64_t num_experts);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch);

std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
Expand Down Expand Up @@ -691,10 +691,6 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
bool renormalize,
float routed_scaling_factor);

std::vector<paddle::Tensor> FusedCastSigmoidBias(const paddle::Tensor& input,
const paddle::Tensor& bias,
std::string cast_type);

std::vector<paddle::Tensor> NoauxTcRedundant(
paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
Expand Down Expand Up @@ -785,11 +781,6 @@ std::vector<paddle::Tensor> BuildSamplingParams(
const int64_t token_num_output_cpu,
const int64_t increment_value);

std::vector<paddle::Tensor> BuildSamplingParamLogProb(
const paddle::Tensor& input_params,
const paddle::Tensor& token_num_per_batch,
int64_t token_num_output_cpu);

void SpecTokenPenaltyMultiScores(
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
Expand Down Expand Up @@ -884,22 +875,20 @@ void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& has_running_seqs,
const paddle::Tensor& step_input_ids,
const paddle::Tensor& adaptive_step_input_len,
const paddle::Tensor& step_output_ids,
const paddle::Tensor& step_output_len,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& is_paused,
const paddle::Tensor& mask_rollback,
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& end_tokens,
const paddle::Tensor& max_dec_len);

void NaiveUpdateModelStatus(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& next_tokens,
const paddle::Tensor& cu_seqlens_q_output);
const paddle::Tensor& max_dec_len,
const bool is_naive_mode,
const bool prefill_one_step_stop);

void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
Expand Down Expand Up @@ -982,17 +971,24 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& mask_rollback,
const paddle::Tensor& recompute_token_num,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& target_model_seq_lens_encoder,
const paddle::Tensor& target_model_seq_lens_decoder,
const paddle::Tensor& target_model_step_idx,
const paddle::Tensor& target_model_stop_flags,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& target_model_draft_tokens,
const int num_model_step,
const bool is_splitwise_prefill);
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);

void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
Expand Down Expand Up @@ -1025,17 +1021,7 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& last_seq_lens_this_time,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder);

std::vector<paddle::Tensor> EagleGatherHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& real_output_token_num);
const paddle::Tensor& step_idx);

void MTPStepPaddle(
const paddle::Tensor& base_model_stop_flags,
Expand Down Expand Up @@ -1149,16 +1135,13 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder);

void SpeculateGetAcceptTokensAndLogits(
const paddle::Tensor& token_ids,
const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num,
const paddle::Tensor& accept_tokens);
void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& logits,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& ori_cu_batch_token_offset,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num);

std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& ids_remove_padding,
Expand All @@ -1167,8 +1150,10 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states);
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback);

std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor& qkv,
Expand Down Expand Up @@ -1647,22 +1632,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");

/**
* cutlass_scaled_mm.cu
* cutlass_scaled_mm
* cutlass_scaled_mm_azp
*/
#ifdef ENABLE_SM75_EXT_OPS
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 兼容性 ENABLE_SM75_EXT_OPS 宏在 setup_ops.py 的 diff 中未见定义。

cutlass_scaled_mmcutlass_scaled_mm_azpstatic_scaled_fp8_quantdynamic_scaled_fp8_quantdynamic_per_token_scaled_fp8_quant 共 5 个算子的 pybind 注册被置于此宏守卫内。若 setup_ops.py 未追加 -DENABLE_SM75_EXT_OPS,这些算子将在所有 SM 层级上静默消失,调用方运行时抛出 AttributeError

请确认 setup_ops.py 中是否已存在(或本 PR 应补充)类似:

if cc >= 75:
    nvcc_compile_args += ["-DENABLE_SM75_EXT_OPS"]

注:architecture.md 当前仅记录 ENABLE_SM80_EXT_OPS(SM≥80),未见 SM75 对应宏。

/* cutlass_scaled_mm.cu: cutlass_scaled_mm, cutlass_scaled_mm_azp */
m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function");
m.def("cutlass_scaled_mm_azp",
&CutlassScaledMmAzp,
"cutlass_scaled_mm_azp function");

/**
* quantization/common.cu
* static_scaled_fp8_quant
* dynamic_scaled_fp8_quant
* dynamic_per_token_scaled_fp8_quant
*/
/* quantization/common.cu: static/dynamic scaled fp8 quant ops */
m.def("static_scaled_fp8_quant",
&StaticScaledFp8Quant,
"static_scaled_fp8_quant function",
Expand All @@ -1684,6 +1661,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("input"),
py::arg("scales"),
py::arg("scale_ub"));
#endif
#ifdef ENABLE_SM80_EXT_OPS
m.def("decode_mla_write_cache",
&DecodeMLAWriteCacheKernel,
Expand All @@ -1706,13 +1684,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {

m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute");

m.def("fused_cast_sigmoid_bias",
&FusedCastSigmoidBias,
"Fused cast+sigmoid+bias for MoE gating scores",
py::arg("input"),
py::arg("bias"),
py::arg("cast_type") = std::string("float32"));

m.def("noaux_tc_redundant",
&NoauxTcRedundant,
"noaux_tc_redundant for MoE compute");
Expand Down Expand Up @@ -1788,10 +1759,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&BuildSamplingParams,
"build_sampling_params function");

m.def("build_sampling_params_logprob",
&BuildSamplingParamLogProb,
"build_sampling_params_logprob function");

m.def("speculate_get_token_penalty_multi_scores",
&SpecTokenPenaltyMultiScores,
"speculate_get_token_penalty_multi_scores function");
Expand All @@ -1811,10 +1778,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&UnifiedUpdateModelStatus,
"unified_update_model_status function");

m.def("naive_update_model_status",
&NaiveUpdateModelStatus,
"naive_update_model_status function");

m.def("speculate_set_value_by_flags_and_idx",
&SpeculateSetValueByFlagsAndIdx,
"speculate_set_value_by_flags_and_idx function");
Expand Down Expand Up @@ -1853,10 +1816,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&EagleGetSelfHiddenStates,
"eagle_get_self_hidden_states function");

m.def("eagle_gather_hidden_states",
&EagleGatherHiddenStates,
"eagle_gather_hidden_states function");

m.def("mtp_step_paddle", &MTPStepPaddle, "mtp_step_paddle function");

m.def("speculate_step_paddle",
Expand Down Expand Up @@ -1893,9 +1852,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&SpeculateInsertFirstToken,
"speculate_insert_first_token function");

m.def("speculate_get_accept_tokens_and_logits",
&SpeculateGetAcceptTokensAndLogits,
"speculate_get_accept_tokens_and_logits function");
m.def("speculate_get_target_logits",
&SpeculateGetTargetLogits,
"speculate_get_target_logits function");
#endif

m.def("update_attn_mask_offsets",
Expand All @@ -1919,22 +1878,23 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("custom_numpy_to_tensor",
&CustomNumpyToTensor,
"custom_numpy_to_tensor function");
#ifdef ENABLE_SM80_EXT_OPS
m.def("prefill_permute_to_masked_gemm",
&PrefillPermuteToMaskedGemm,
py::arg("x"),
py::arg("scale"),
py::arg("topk_ids"),
py::arg("num_local_experts"),
py::arg("max_token_num"),
"Prefill permute to masked GEMM for MoE");
"prefill_permute_to_masked_gemm");

m.def("depermute_prefill_combine",
&DepermutePrefillCombine,
py::arg("x"),
py::arg("indice_map"),
py::arg("topk_weights"),
py::arg("num_worst_tokens"),
"Depermute and combine expert outputs for MoE prefill");
"depermute_prefill_combine");

m.def("radix_topk_ragged_transform",
&RadixTopkRaggedTransform,
Expand All @@ -1953,4 +1913,5 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_group_fp8_quant",
&PerTokenGroupQuantFp8,
"per_token_group_quant_fp8");
#endif
}
Loading
Loading