File tree Expand file tree Collapse file tree 4 files changed +19
-15
lines changed
Expand file tree Collapse file tree 4 files changed +19
-15
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,12 @@ void apply_rotary(RotaryParams& params) {
6060
6161void active (ActivationParams& params) {
6262#if defined(USE_MLU)
63+ // Note: Derivation from input is uncertain; using explicit parameter for
64+ // robustness.
65+ params.output = torch::empty (
66+ {params.input .sizes ()[0 ], params.intermediate_size / params.world_size },
67+ params.input .options ());
68+
6369 mlu::active (params.input ,
6470 params.output ,
6571 params.bias ,
@@ -69,20 +75,17 @@ void active(ActivationParams& params) {
6975 params.start_expert_id ,
7076 params.expert_size );
7177#elif defined(USE_CUDA)
78+ params.output = torch::empty (
79+ {params.input .sizes ()[0 ], params.intermediate_size / params.world_size },
80+ params.input .options ());
7281 cuda::act_and_mul (params.output , params.input , params.act_mode );
82+ #elif defined(USE_NPU)
83+ params.output = npu::active (params.input , params.act_mode );
7384#else
7485 LOG (FATAL) << " active not implemented" ;
7586#endif
7687}
7788
78- torch::Tensor active_tensor (ActivationParams& params) {
79- #if defined(USE_NPU)
80- return npu::active (params.input , params.act_mode );
81- #else
82- LOG (FATAL) << " active_tensor not implemented" ;
83- #endif
84- }
85-
8689void reshape_paged_cache (ReshapePagedCacheParams& params) {
8790#if defined(USE_MLU)
8891 mlu::reshape_paged_cache (params.key ,
Original file line number Diff line number Diff line change @@ -28,8 +28,6 @@ void apply_rotary(RotaryParams& params);
2828
2929void active (ActivationParams& params);
3030
31- torch::Tensor active_tensor (ActivationParams& params);
32-
3331void reshape_paged_cache (ReshapePagedCacheParams& params);
3432
3533void batch_prefill (AttentionParams& params);
Original file line number Diff line number Diff line change @@ -109,6 +109,11 @@ struct ActivationParams {
109109 // Expert size for MoE activation. Used when bias is provided.
110110 // Bias tensor shape must be [expert_size, in_channel].
111111 int64_t expert_size = 0 ;
112+
113+ // Note: Derivation from input is uncertain; using explicit parameter for
114+ // robustness.
115+ int64_t intermediate_size = 0 ;
116+ int64_t world_size = 0 ;
112117};
113118
114119// Reshape paged cache parameters
Original file line number Diff line number Diff line change @@ -90,17 +90,15 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) {
9090 // For w8a8 quantization, the active operation is fused with the down_proj
9191 return down_proj_->forward (gate_up);
9292 } else {
93- int64_t batch_size = gate_up.sizes ()[0 ];
94- auto output = torch::empty (
95- {batch_size,
96- intermediate_size_ / parallel_args_.tp_group_ ->world_size ()},
97- gate_up.options ());
93+ torch::Tensor output;
9894
9995 xllm::kernel::ActivationParams activation_params;
10096 activation_params.input = gate_up;
10197 activation_params.output = output;
10298 activation_params.act_mode = hidden_act_;
10399 activation_params.is_gated = is_gated_;
100+ activation_params.intermediate_size = intermediate_size_;
101+ activation_params.world_size = parallel_args_.tp_group_ ->world_size ();
104102 xllm::kernel::active (activation_params);
105103
106104 return down_proj_->forward (output);
You can’t perform that action at this time.
0 commit comments