Skip to content

Commit d0e2673

Browse files
committed
feat: add onerec worker impl.
1 parent ea71989 commit d0e2673

File tree

8 files changed

+489
-4
lines changed

8 files changed

+489
-4
lines changed

xllm/core/runtime/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ cc_library(
2222
dit_worker.h
2323
embed_worker_impl.h
2424
embed_vlm_worker_impl.h
25+
rec_worker_impl.h
26+
llmrec_worker_impl.h
27+
onerec_worker_impl.h
2528
worker_client.h
2629
xservice_client.h
2730
speculative_worker_impl.h
@@ -38,6 +41,9 @@ cc_library(
3841
dit_worker.cpp
3942
embed_worker_impl.cpp
4043
embed_vlm_worker_impl.cpp
44+
rec_worker_impl.cpp
45+
llmrec_worker_impl.cpp
46+
onerec_worker_impl.cpp
4147
worker_client.cpp
4248
xservice_client.cpp
4349
params_utils.cpp
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "llmrec_worker_impl.h"
17+
18+
#include <glog/logging.h>
19+
#include <torch/torch.h>
20+
21+
#include <algorithm>
22+
#include <vector>
23+
24+
#include "common/types.h"
25+
#include "core/layers/word_embedding.h"
26+
27+
namespace xllm {
28+
29+
LlmRecWorkerImpl::LlmRecWorkerImpl(const ParallelArgs& parallel_args,
30+
const torch::Device& device,
31+
const runtime::Options& options)
32+
: RecWorkerImpl(parallel_args, device, options) {}
33+
34+
void LlmRecWorkerImpl::prepare_work_before_execute(
35+
const ForwardInput& inputs,
36+
ForwardInput& processed_inputs) {
37+
WorkerImpl::prepare_work_before_execute(inputs, processed_inputs);
38+
39+
if (!inputs.input_params.mm_data.valid()) {
40+
return;
41+
}
42+
43+
torch::Tensor input_embedding;
44+
torch::Tensor input_tokens_tensor;
45+
torch::Tensor input_indices_tensor;
46+
47+
const auto& mm_data = inputs.input_params.mm_data;
48+
const auto& processed_mm_data = processed_inputs.input_params.mm_data;
49+
50+
if (auto res =
51+
processed_mm_data.get<torch::Tensor>(LLM_REC_INPUT_TOKENS)) {
52+
input_tokens_tensor = res.value();
53+
}
54+
55+
// input indices 需要在 Host 侧生成位置索引
56+
if (auto res = mm_data.get<torch::Tensor>(LLM_REC_INPUT_INDICES)) {
57+
input_indices_tensor = res.value();
58+
}
59+
60+
if (auto res =
61+
processed_mm_data.get<torch::Tensor>(LLM_REC_INPUT_EMBEDDING)) {
62+
input_embedding = res.value();
63+
}
64+
65+
if (input_embedding.defined()) {
66+
input_embedding = input_embedding.to(dtype());
67+
}
68+
69+
if (input_indices_tensor.defined()) {
70+
layer::WordEmbedding word_embedding = get_word_embedding();
71+
torch::Tensor input_tokens_embedding =
72+
word_embedding(input_tokens_tensor, 0);
73+
74+
if (input_embedding.defined()) {
75+
std::vector<int> input_indices(
76+
input_indices_tensor.data_ptr<int>(),
77+
input_indices_tensor.data_ptr<int>() + input_indices_tensor.numel());
78+
79+
processed_inputs.input_params.input_embedding =
80+
merge_embeddings_by_indices(
81+
input_tokens_embedding, input_embedding, input_indices);
82+
} else {
83+
processed_inputs.input_params.input_embedding = input_tokens_embedding;
84+
}
85+
} else if (input_embedding.defined()) {
86+
processed_inputs.input_params.input_embedding = input_embedding;
87+
}
88+
}
89+
90+
torch::Tensor LlmRecWorkerImpl::merge_embeddings_by_indices(
91+
const torch::Tensor& input_tokens_embedding,
92+
const torch::Tensor& input_embedding,
93+
const std::vector<int>& input_indices) {
94+
CHECK_EQ(input_embedding.dim(), 2);
95+
CHECK_EQ(input_tokens_embedding.dim(), 2);
96+
CHECK_EQ(input_tokens_embedding.size(1), input_embedding.size(1));
97+
CHECK_EQ(input_tokens_embedding.dtype(), input_embedding.dtype());
98+
CHECK_EQ(input_tokens_embedding.device(), input_embedding.device());
99+
100+
const int64_t total_rows =
101+
input_tokens_embedding.size(0) + input_embedding.size(0);
102+
const int64_t cols = input_embedding.size(1);
103+
104+
torch::Device device = input_embedding.device();
105+
torch::Tensor merged = torch::empty(
106+
{total_rows, cols}, torch::dtype(input_embedding.dtype()).device(device));
107+
108+
std::vector<int> input_embedding_indices;
109+
for (int i = 0; i < static_cast<int>(total_rows); ++i) {
110+
if (std::find(input_indices.begin(), input_indices.end(), i) ==
111+
input_indices.end()) {
112+
input_embedding_indices.push_back(i);
113+
}
114+
}
115+
116+
CHECK_EQ(input_embedding_indices.size(), input_embedding.size(0));
117+
118+
torch::Tensor input_embedding_indices_tensor =
119+
torch::tensor(input_embedding_indices, torch::kInt64).to(device);
120+
merged.index_put_({input_embedding_indices_tensor, torch::indexing::Ellipsis},
121+
input_embedding);
122+
123+
torch::Tensor input_indices_tensor =
124+
torch::tensor(input_indices, torch::kInt64).to(device);
125+
merged.index_put_({input_indices_tensor, torch::indexing::Ellipsis},
126+
input_tokens_embedding);
127+
128+
return merged;
129+
}
130+
131+
} // namespace xllm
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <torch/torch.h>
19+
20+
#include <vector>
21+
22+
#include "runtime/rec_worker_impl.h"
23+
24+
namespace xllm {
25+
26+
class LlmRecWorkerImpl final : public RecWorkerImpl {
27+
public:
28+
LlmRecWorkerImpl(const ParallelArgs& parallel_args,
29+
const torch::Device& device,
30+
const runtime::Options& options);
31+
32+
~LlmRecWorkerImpl() override = default;
33+
34+
void prepare_work_before_execute(const ForwardInput& inputs,
35+
ForwardInput& processed_inputs) override;
36+
37+
private:
38+
torch::Tensor merge_embeddings_by_indices(
39+
const torch::Tensor& input_tokens_embedding,
40+
const torch::Tensor& input_embedding,
41+
const std::vector<int>& input_indices);
42+
};
43+
44+
} // namespace xllm
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "onerec_worker_impl.h"
17+
18+
#include <glog/logging.h>
19+
20+
#include <optional>
21+
22+
#include "common/device_monitor.h"
23+
#include "common/metrics.h"
24+
#include "framework/model/model_input_params.h"
25+
#include "util/timer.h"
26+
27+
namespace xllm {
28+
29+
OneRecWorkerImpl::OneRecWorkerImpl(const ParallelArgs& parallel_args,
30+
const torch::Device& device,
31+
const runtime::Options& options)
32+
: RecWorkerImpl(parallel_args, device, options) {}
33+
34+
std::optional<ForwardOutput> OneRecWorkerImpl::step(const ForwardInput& input) {
35+
Timer timer;
36+
device_.set_device();
37+
38+
const auto& sampling_params = input.sampling_params;
39+
const auto& input_params = input.input_params;
40+
41+
if (!input_params.rec_params.has_value()) {
42+
LOG(ERROR) << "OneRecWorkerImpl requires rec_params.";
43+
return std::nullopt;
44+
}
45+
46+
const auto& rec_params = input_params.rec_params.value();
47+
48+
torch::Tensor hidden_states;
49+
if (rec_params.rec_stage == RecModelInputParams::RecStage::PREFILL) {
50+
if (!rec_params.is_first_prefill) {
51+
ModelInputParams decoder_params = input_params;
52+
decoder_params.rec_params->is_encoder_forward = false;
53+
hidden_states = model_executor_->forward(
54+
input.token_ids, input.positions, kv_caches_, decoder_params);
55+
} else {
56+
const bool has_sparse_embedding =
57+
rec_params.encoder_sparse_embedding.defined();
58+
const bool has_encoder_tokens = rec_params.encoder_token_ids.defined() &&
59+
rec_params.encoder_positions.defined();
60+
61+
if (!has_sparse_embedding && !has_encoder_tokens) {
62+
LOG(ERROR) << "OneRecWorkerImpl first prefill requires encoder inputs.";
63+
return std::nullopt;
64+
}
65+
66+
ModelInputParams encoder_params = input_params;
67+
encoder_params.rec_params->is_encoder_forward = true;
68+
69+
torch::Tensor encoder_tokens;
70+
if (has_sparse_embedding) {
71+
encoder_params.rec_params->is_hybrid_mode = true;
72+
encoder_tokens = rec_params.encoder_sparse_embedding;
73+
} else {
74+
encoder_tokens = rec_params.encoder_token_ids;
75+
}
76+
77+
model_executor_->forward(encoder_tokens,
78+
rec_params.encoder_positions,
79+
kv_caches_,
80+
encoder_params);
81+
82+
ModelInputParams decoder_params = input_params;
83+
decoder_params.rec_params->is_encoder_forward = false;
84+
hidden_states = model_executor_->forward(
85+
input.token_ids, input.positions, kv_caches_, decoder_params);
86+
}
87+
} else {
88+
ModelInputParams decoder_params = input_params;
89+
decoder_params.rec_params->is_encoder_forward = false;
90+
hidden_states = model_executor_->forward(
91+
input.token_ids, input.positions, kv_caches_, decoder_params);
92+
}
93+
94+
if (!hidden_states.defined()) {
95+
return std::nullopt;
96+
}
97+
98+
if (!enable_schedule_overlap() && !driver_ && !dp_driver_ &&
99+
!options_.enable_speculative_decode()) {
100+
device_.synchronize_default_stream();
101+
COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds());
102+
DeviceMonitor::get_instance().update_active_activation_memory(
103+
device_.index());
104+
return std::nullopt;
105+
}
106+
107+
torch::Tensor logits;
108+
if (sampling_params.selected_token_idxes.defined()) {
109+
logits =
110+
model_->logits(hidden_states, sampling_params.selected_token_idxes);
111+
}
112+
113+
ForwardOutput output;
114+
115+
if (sampling_params.selected_token_idxes.defined()) {
116+
auto sample_output = sampler_->forward(logits, sampling_params);
117+
output.logits = logits;
118+
output.sample_output = sample_output;
119+
output.do_sample = sampling_params.do_sample;
120+
output.logprobs = sampling_params.logprobs;
121+
output.max_top_logprobs = sampling_params.max_top_logprobs;
122+
}
123+
124+
device_.synchronize_default_stream();
125+
COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds());
126+
DeviceMonitor::get_instance().update_active_activation_memory(
127+
device_.index());
128+
129+
return output;
130+
}
131+
132+
} // namespace xllm
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <torch/torch.h>
19+
20+
#include <optional>
21+
22+
#include "runtime/rec_worker_impl.h"
23+
24+
namespace xllm {
25+
26+
class OneRecWorkerImpl final : public RecWorkerImpl {
27+
public:
28+
OneRecWorkerImpl(const ParallelArgs& parallel_args,
29+
const torch::Device& device,
30+
const runtime::Options& options);
31+
32+
~OneRecWorkerImpl() override = default;
33+
34+
std::optional<ForwardOutput> step(const ForwardInput& input) override;
35+
};
36+
37+
} // namespace xllm

0 commit comments

Comments
 (0)