Skip to content

Commit baf3c90

Browse files
committed
refactor: standardize interface for active kernel execution.
1 parent 1da759f commit baf3c90

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

xllm/core/kernels/ops_api.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ void apply_rotary(RotaryParams& params) {
5656

5757
void active(ActivationParams& params) {
5858
#if defined(USE_MLU)
59+
// Note: Derivation from input is uncertain; using explicit parameter for
60+
// robustness.
61+
params.output = torch::empty(
62+
{params.input.sizes()[0], params.intermediate_size / params.world_size},
63+
params.input.options());
64+
5965
mlu::active(params.input,
6066
params.output,
6167
params.bias,
@@ -65,20 +71,17 @@ void active(ActivationParams& params) {
6571
params.start_expert_id,
6672
params.expert_size);
6773
#elif defined(USE_CUDA)
74+
params.output = torch::empty(
75+
{params.input.sizes()[0], params.intermediate_size / params.world_size},
76+
params.input.options());
6877
cuda::act_and_mul(params.output, params.input, params.act_mode);
78+
#elif defined(USE_NPU)
79+
params.output = npu::active(params.input, params.act_mode);
6980
#else
7081
LOG(FATAL) << "active not implemented";
7182
#endif
7283
}
7384

74-
torch::Tensor active_tensor(ActivationParams& params) {
75-
#if defined(USE_NPU)
76-
return npu::active(params.input, params.act_mode);
77-
#else
78-
LOG(FATAL) << "active_tensor not implemented";
79-
#endif
80-
}
81-
8285
void reshape_paged_cache(ReshapePagedCacheParams& params) {
8386
#if defined(USE_MLU)
8487
mlu::reshape_paged_cache(params.key,

xllm/core/kernels/ops_api.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ void apply_rotary(RotaryParams& params);
2828

2929
void active(ActivationParams& params);
3030

31-
torch::Tensor active_tensor(ActivationParams& params);
32-
3331
void reshape_paged_cache(ReshapePagedCacheParams& params);
3432

3533
void batch_prefill(AttentionParams& params);

xllm/core/kernels/param.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ struct ActivationParams {
109109
// Expert size for MoE activation. Used when bias is provided.
110110
// Bias tensor shape must be [expert_size, in_channel].
111111
int64_t expert_size = 0;
112+
113+
// Note: Derivation from input is uncertain; using explicit parameter for
114+
// robustness.
115+
int64_t intermediate_size = 0;
116+
int64_t world_size = 0;
112117
};
113118

114119
// Reshape paged cache parameters

xllm/core/layers/common/dense_mlp.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,15 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) {
8989
// For w8a8 quantization, the active operation is fused with the down_proj
9090
return down_proj_->forward(gate_up);
9191
} else {
92-
int64_t batch_size = gate_up.sizes()[0];
93-
auto output = torch::empty(
94-
{batch_size,
95-
intermediate_size_ / parallel_args_.tp_group_->world_size()},
96-
gate_up.options());
92+
torch::Tensor output;
9793

9894
xllm::kernel::ActivationParams activation_params;
9995
activation_params.input = gate_up;
10096
activation_params.output = output;
10197
activation_params.act_mode = hidden_act_;
10298
activation_params.is_gated = is_gated_;
99+
activation_params.intermediate_size = intermediate_size_;
100+
activation_params.world_size = parallel_args_.tp_group_->world_size();
103101
xllm::kernel::active(activation_params);
104102

105103
return down_proj_->forward(output);

0 commit comments

Comments
 (0)