Skip to content

Commit 51ab721

Browse files
committed
refactor: standardize interface for active kernel execution.
1 parent c6974d1 commit 51ab721

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

xllm/core/kernels/ops_api.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ void apply_rotary(RotaryParams& params) {
6060

6161
void active(ActivationParams& params) {
6262
#if defined(USE_MLU)
63+
// Note: Derivation from input is uncertain; using explicit parameter for
64+
// robustness.
65+
params.output = torch::empty(
66+
{params.input.sizes()[0], params.intermediate_size / params.world_size},
67+
params.input.options());
68+
6369
mlu::active(params.input,
6470
params.output,
6571
params.bias,
@@ -69,20 +75,17 @@ void active(ActivationParams& params) {
6975
params.start_expert_id,
7076
params.expert_size);
7177
#elif defined(USE_CUDA)
78+
params.output = torch::empty(
79+
{params.input.sizes()[0], params.intermediate_size / params.world_size},
80+
params.input.options());
7281
cuda::act_and_mul(params.output, params.input, params.act_mode);
82+
#elif defined(USE_NPU)
83+
params.output = npu::active(params.input, params.act_mode);
7384
#else
7485
LOG(FATAL) << "active not implemented";
7586
#endif
7687
}
7788

78-
torch::Tensor active_tensor(ActivationParams& params) {
79-
#if defined(USE_NPU)
80-
return npu::active(params.input, params.act_mode);
81-
#else
82-
LOG(FATAL) << "active_tensor not implemented";
83-
#endif
84-
}
85-
8689
void reshape_paged_cache(ReshapePagedCacheParams& params) {
8790
#if defined(USE_MLU)
8891
mlu::reshape_paged_cache(params.key,

xllm/core/kernels/ops_api.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ void apply_rotary(RotaryParams& params);
2828

2929
void active(ActivationParams& params);
3030

31-
torch::Tensor active_tensor(ActivationParams& params);
32-
3331
void reshape_paged_cache(ReshapePagedCacheParams& params);
3432

3533
void batch_prefill(AttentionParams& params);

xllm/core/kernels/param.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ struct ActivationParams {
109109
// Expert size for MoE activation. Used when bias is provided.
110110
// Bias tensor shape must be [expert_size, in_channel].
111111
int64_t expert_size = 0;
112+
113+
// Note: Derivation from input is uncertain; using explicit parameter for
114+
// robustness.
115+
int64_t intermediate_size = 0;
116+
int64_t world_size = 0;
112117
};
113118

114119
// Reshape paged cache parameters

xllm/core/layers/common/dense_mlp.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,15 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) {
9090
// For w8a8 quantization, the active operation is fused with the down_proj
9191
return down_proj_->forward(gate_up);
9292
} else {
93-
int64_t batch_size = gate_up.sizes()[0];
94-
auto output = torch::empty(
95-
{batch_size,
96-
intermediate_size_ / parallel_args_.tp_group_->world_size()},
97-
gate_up.options());
93+
torch::Tensor output;
9894

9995
xllm::kernel::ActivationParams activation_params;
10096
activation_params.input = gate_up;
10197
activation_params.output = output;
10298
activation_params.act_mode = hidden_act_;
10399
activation_params.is_gated = is_gated_;
100+
activation_params.intermediate_size = intermediate_size_;
101+
activation_params.world_size = parallel_args_.tp_group_->world_size();
104102
xllm::kernel::active(activation_params);
105103

106104
return down_proj_->forward(output);

0 commit comments

Comments
 (0)