File tree Expand file tree Collapse file tree 4 files changed +19
-15
lines changed
Expand file tree Collapse file tree 4 files changed +19
-15
lines changed Original file line number Diff line number Diff line change @@ -56,6 +56,12 @@ void apply_rotary(RotaryParams& params) {
5656
5757void active (ActivationParams& params) {
5858#if defined(USE_MLU)
59+ // Note: Derivation from input is uncertain; using explicit parameter for
60+ // robustness.
61+ params.output = torch::empty (
62+ {params.input .sizes ()[0 ], params.intermediate_size / params.world_size },
63+ params.input .options ());
64+
5965 mlu::active (params.input ,
6066 params.output ,
6167 params.bias ,
@@ -65,20 +71,17 @@ void active(ActivationParams& params) {
6571 params.start_expert_id ,
6672 params.expert_size );
6773#elif defined(USE_CUDA)
74+ params.output = torch::empty (
75+ {params.input .sizes ()[0 ], params.intermediate_size / params.world_size },
76+ params.input .options ());
6877 cuda::act_and_mul (params.output , params.input , params.act_mode );
78+ #elif defined(USE_NPU)
79+ params.output = npu::active (params.input , params.act_mode );
6980#else
7081 LOG (FATAL) << " active not implemented" ;
7182#endif
7283}
7384
74- torch::Tensor active_tensor (ActivationParams& params) {
75- #if defined(USE_NPU)
76- return npu::active (params.input , params.act_mode );
77- #else
78- LOG (FATAL) << " active_tensor not implemented" ;
79- #endif
80- }
81-
8285void reshape_paged_cache (ReshapePagedCacheParams& params) {
8386#if defined(USE_MLU)
8487 mlu::reshape_paged_cache (params.key ,
Original file line number Diff line number Diff line change @@ -28,8 +28,6 @@ void apply_rotary(RotaryParams& params);
2828
2929void active (ActivationParams& params);
3030
31- torch::Tensor active_tensor (ActivationParams& params);
32-
3331void reshape_paged_cache (ReshapePagedCacheParams& params);
3432
3533void batch_prefill (AttentionParams& params);
Original file line number Diff line number Diff line change @@ -109,6 +109,11 @@ struct ActivationParams {
109109 // Expert size for MoE activation. Used when bias is provided.
110110 // Bias tensor shape must be [expert_size, in_channel].
111111 int64_t expert_size = 0 ;
112+
113+ // Note: Derivation from input is uncertain; using explicit parameter for
114+ // robustness.
115+ int64_t intermediate_size = 0 ;
116+ int64_t world_size = 0 ;
112117};
113118
114119// Reshape paged cache parameters
Original file line number Diff line number Diff line change @@ -89,17 +89,15 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) {
8989 // For w8a8 quantization, the active operation is fused with the down_proj
9090 return down_proj_->forward (gate_up);
9191 } else {
92- int64_t batch_size = gate_up.sizes ()[0 ];
93- auto output = torch::empty (
94- {batch_size,
95- intermediate_size_ / parallel_args_.tp_group_ ->world_size ()},
96- gate_up.options ());
92+ torch::Tensor output;
9793
9894 xllm::kernel::ActivationParams activation_params;
9995 activation_params.input = gate_up;
10096 activation_params.output = output;
10197 activation_params.act_mode = hidden_act_;
10298 activation_params.is_gated = is_gated_;
99+ activation_params.intermediate_size = intermediate_size_;
100+ activation_params.world_size = parallel_args_.tp_group_ ->world_size ();
103101 xllm::kernel::active (activation_params);
104102
105103 return down_proj_->forward (output);
You can’t perform that action at this time.
0 commit comments