@@ -29,6 +29,7 @@ limitations under the License.
2929#include " framework/model/model_input_params.h"
3030#include " framework/request/sequence.h"
3131#include " framework/sampling/sampling_params.h"
32+ #include " rec_batch_input_builder.h"
3233#include " runtime/params_utils.h"
3334#include " util/slice.h"
3435#include " util/tensor_helper.h"
@@ -96,6 +97,10 @@ void Batch::add(const std::vector<Sequence*>& sequences) {
9697ForwardInput Batch::prepare_forward_input (uint32_t num_decoding_tokens,
9798 uint32_t min_decoding_batch_size,
9899 const ModelArgs& args) {
100+ if (sequences_.empty () && !sequence_groups_.empty ()) {
101+ return prepare_rec_forward_input (
102+ num_decoding_tokens, min_decoding_batch_size, args);
103+ }
99104 BatchInputBuilder builder (sequences_,
100105 allowed_max_tokens_,
101106 input_embeddings_vec_,
@@ -108,6 +113,43 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
108113 min_decoding_batch_size);
109114}
110115
116+ ForwardInput Batch::prepare_rec_forward_input (uint32_t num_decoding_tokens,
117+ uint32_t min_decoding_batch_size,
118+ const ModelArgs& args,
119+ ThreadPool* thread_pool) {
120+ RecType rec_type = RecType::kNone ;
121+ if (!sequence_groups_.empty () && !sequence_groups_[0 ]->sequences ().empty ()) {
122+ rec_type = sequence_groups_[0 ]->sequences ()[0 ]->rec_type ();
123+ }
124+
125+ auto builder = RecBatchInputBuilder::Create (rec_type,
126+ sequence_groups_,
127+ allowed_max_tokens_,
128+ input_embeddings_vec_,
129+ mm_data_vec_,
130+ swap_block_transfer_infos_,
131+ batch_id_,
132+ &args,
133+ thread_pool);
134+ return builder->build_rec_forward_input (num_decoding_tokens,
135+ min_decoding_batch_size);
136+ }
137+
138+ std::vector<Sequence*> Batch::get_sequences () const {
139+ if (!sequences_.empty ()) {
140+ return sequences_;
141+ }
142+
143+ std::vector<Sequence*> result;
144+ for (const auto * seq_group : sequence_groups_) {
145+ const auto & sequences = seq_group->sequences ();
146+ for (const auto & seq_ptr : sequences) {
147+ result.push_back (seq_ptr.get ());
148+ }
149+ }
150+ return result;
151+ }
152+
111153void Batch::dp_balance_shuffle_seqs () {
112154 // this shuffle operation is mainly used for npu with 24 cores
113155 // and specific mla op implementation
@@ -217,7 +259,8 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output,
217259 // this means all sequences are in prefill stage status.
218260 const int64_t num_seqs = raw_output.outputs .size ();
219261 int64_t output_idx = 0 ;
220- for (auto * seq : sequences_) {
262+ const auto sequences = get_sequences ();
263+ for (auto * seq : sequences) {
221264 if (seq->finished ()) {
222265 output_idx++;
223266 continue ;
@@ -264,7 +307,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
264307 if (sample_output.embeddings .defined ()) {
265308 const int64_t num_seqs = sample_output.embeddings .size (0 );
266309 int64_t output_idx = 0 ;
267- for (auto * seq : sequences_) {
310+ const auto sequences = get_sequences ();
311+ for (auto * seq : sequences) {
268312 CHECK_LT (output_idx, num_seqs);
269313 auto cur_seq_embed =
270314 safe_to (sample_output.embeddings [output_idx++], torch::kFloat32 );
@@ -277,7 +321,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output,
277321 // this means all sequences are in prefill stage status.
278322 const int64_t num_seqs = sample_output.next_tokens .size (0 );
279323 int64_t output_idx = 0 ;
280- for (auto * seq : sequences_) {
324+ const auto sequences = get_sequences ();
325+ for (auto * seq : sequences) {
281326 if (seq->finished ()) {
282327 output_idx++;
283328 continue ;
@@ -352,7 +397,8 @@ void Batch::process_embedding_output(const torch::Tensor& output_embedding) {
352397 Token token (0 );
353398 if (output_embedding.defined ()) {
354399 int32_t slice_img_index = 0 ;
355- for (auto * seq : sequences_) { // TODO
400+ const auto sequences = get_sequences ();
401+ for (auto * seq : sequences) {
356402 const auto & mm_data = seq->get_mm_data ();
357403
358404 auto pixel_values = mm_data.get_tensor_vec (" pixel_values" );
0 commit comments