From 2f7f6980316c8580b6401f809732b9bd89b0dd69 Mon Sep 17 00:00:00 2001 From: magicheng0816 Date: Mon, 10 Nov 2025 21:17:30 +0800 Subject: [PATCH 1/4] feat: add constrained decoding for generative recommendation. --- xllm/core/framework/sampling/CMakeLists.txt | 2 + .../framework/sampling/constrained_decoding.h | 36 ++++ .../sampling/rec_constrained_decoding.cpp | 186 ++++++++++++++++++ .../sampling/rec_constrained_decoding.h | 54 +++++ 4 files changed, 278 insertions(+) create mode 100644 xllm/core/framework/sampling/constrained_decoding.h create mode 100644 xllm/core/framework/sampling/rec_constrained_decoding.cpp create mode 100644 xllm/core/framework/sampling/rec_constrained_decoding.h diff --git a/xllm/core/framework/sampling/CMakeLists.txt b/xllm/core/framework/sampling/CMakeLists.txt index 764157d04..c2d113892 100644 --- a/xllm/core/framework/sampling/CMakeLists.txt +++ b/xllm/core/framework/sampling/CMakeLists.txt @@ -10,12 +10,14 @@ cc_library( rejection_sampler.h sampler.h beam_searcher.h + rec_constrained_decoding.h SRCS sampling_params.cpp logits_utils.cpp rejection_sampler.cpp sampler.cpp beam_searcher.cpp + rec_constrained_decoding.cpp DEPS :common glog::glog diff --git a/xllm/core/framework/sampling/constrained_decoding.h b/xllm/core/framework/sampling/constrained_decoding.h new file mode 100644 index 000000000..f1202287e --- /dev/null +++ b/xllm/core/framework/sampling/constrained_decoding.h @@ -0,0 +1,36 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include +#include +#include + +namespace xllm { + +// constrained decoding is used to ensure that the generated content +// conforms to specific formats or rules. +class ConstrainedDecoding { + public: + virtual ~ConstrainedDecoding(); + + virtual bool build_mask_cache(); + + // input generated_token_list: [sequence_num][generated_token_ids] + // output: mask tensor[sequence_num,vocab_size] + virtual torch::Tensor generate_mask( + const std::vector>& generated_token_list); +}; +} // namespace xllm diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.cpp b/xllm/core/framework/sampling/rec_constrained_decoding.cpp new file mode 100644 index 000000000..bd570c9bd --- /dev/null +++ b/xllm/core/framework/sampling/rec_constrained_decoding.cpp @@ -0,0 +1,186 @@ +#include "rec_constrained_decoding.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "common/global_flags.h" +#include "common/version_singleton.h" +#include "framework/state_dict/rec_vocab_dict.h" +#include "util/slice.h" +#include "util/tensor_helper.h" + +namespace xllm { + +constexpr float PRE_MASK_FACTOR = -10000.0f; +constexpr int GEN_MASK_THREAD_NUM = 16; + +RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device, + bool use_gen_threadpool) + : model_version_(model_version), + vocab_size_(vocab_size), + dtype_(dtype), + device_(device), + use_gen_threadpool_(use_gen_threadpool) { + if (use_gen_threadpool_) { + gen_threadpool_ = std::make_unique(GEN_MASK_THREAD_NUM); + } + + build_mask_cache_ = false; +} + +bool RecConstrainedDecoding::build_mask_cache() { + first_token_mask_ = torch::full({vocab_size_}, PRE_MASK_FACTOR, dtype_); + + std::vector empty_token_ids; + Slice prefix_token_ids = {empty_token_ids.data(), + empty_token_ids.size()}; + + const std::set& first_token_ids = + VersionSingleton::GetInstance( + std::to_string(model_version_)) + ->get_next_tokens_by_prefix_tokens(prefix_token_ids); + + for (auto token_id : first_token_ids) { + first_token_mask_[token_id] = 0; + } + + first_token_mask_ = safe_to(first_token_mask_, device_, true); + + build_mask_cache_ = true; + + LOG(INFO) << "build mask cache, first token ids size:" + << first_token_ids.size(); + + return true; +} + +torch::Tensor RecConstrainedDecoding::generate_mask( + const std::vector>& generated_token_list) { + if (!build_mask_cache_ || 0 == generated_token_list.size()) { + return torch::Tensor(); + } + + size_t token_size = generated_token_list[0].size(); + + // generate mask for first token + if (0 == token_size) { + size_t sequence_num = generated_token_list.size(); + auto mask = first_token_mask_.unsqueeze(0); + return mask.repeat({sequence_num, 1}); + } + + // generate mask for non-first token + return generate_decode_mask(generated_token_list); +} + +torch::Tensor RecConstrainedDecoding::generate_decode_mask( + const std::vector>& generated_token_list) { + size_t sequence_num = generated_token_list.size(); + torch::TensorOptions options = torch::dtype(dtype_).device(device_); + auto mask = + torch::full({sequence_num, vocab_size_}, PRE_MASK_FACTOR, options); + + std::mutex global_batch_mutex; + std::vector global_batch_token_indices; + std::vector global_batch_vocab_indices; + + int max_index_num_per_token = 8192; + global_batch_token_indices.reserve(max_index_num_per_token * sequence_num); + global_batch_vocab_indices.reserve(max_index_num_per_token * sequence_num); + + auto update_mask = [&](size_t start_idx, size_t end_idx) { + std::vector local_token_indices; + std::vector local_vocab_indices; + local_token_indices.reserve(max_index_num_per_token * + (end_idx - start_idx)); + local_vocab_indices.reserve(max_index_num_per_token * + (end_idx - start_idx)); + + for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) { + Slice tokens_slice(generated_token_list[token_idx]); + + const std::set& next_token_ids = + VersionSingleton::GetInstance( + std::to_string(model_version_)) + ->get_next_tokens_by_prefix_tokens(tokens_slice); + + if (next_token_ids.size() > 0) { + for (int32_t vocab_idx : next_token_ids) { + local_token_indices.push_back(static_cast(token_idx)); + local_vocab_indices.push_back(static_cast(vocab_idx)); + } + } else { + LOG(ERROR) << "fail to generate mask for tokens:" + << generated_token_list[token_idx]; + } + } + + // merge local results to global batch (thread-safe) + if (!local_token_indices.empty()) { + std::lock_guard lock(global_batch_mutex); + global_batch_token_indices.insert(global_batch_token_indices.end(), + local_token_indices.begin(), + local_token_indices.end()); + global_batch_vocab_indices.insert(global_batch_vocab_indices.end(), + local_vocab_indices.begin(), + local_vocab_indices.end()); + } + }; + + if (use_gen_threadpool_) { + const size_t batch_size = std::max( + 1UL, (sequence_num + GEN_MASK_THREAD_NUM - 1) / GEN_MASK_THREAD_NUM); + const size_t num_batches = (sequence_num + batch_size - 1) / batch_size; + + std::vector> futures; + std::vector>> promises; + + promises.reserve(num_batches); + futures.reserve(num_batches); + + for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) { + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + size_t start_idx = batch_idx * batch_size; + size_t end_idx = std::min(start_idx + batch_size, sequence_num); + + gen_threadpool_->schedule( + [update_mask, start_idx, end_idx, promise]() mutable { + update_mask(start_idx, end_idx); + promise->set_value(); + }); + } + + for (auto& future : futures) { + future.get(); + } + } else { + update_mask(0, sequence_num); + } + + if (!global_batch_token_indices.empty()) { + auto token_indices = + torch::tensor(global_batch_token_indices, torch::kInt64); + auto vocab_indices = + torch::tensor(global_batch_vocab_indices, torch::kInt64); + token_indices = safe_to(token_indices, device_, true); + vocab_indices = safe_to(vocab_indices, device_, true); + mask.index_put_({token_indices, vocab_indices}, 0.0f); + } + + return mask; +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.h b/xllm/core/framework/sampling/rec_constrained_decoding.h new file mode 100644 index 000000000..555214cb3 --- /dev/null +++ b/xllm/core/framework/sampling/rec_constrained_decoding.h @@ -0,0 +1,54 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include +#include + +#include "constrained_decoding.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecConstrainedDecoding : public ConstrainedDecoding { + public: + RecConstrainedDecoding(uint64_t model_version, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device, + bool use_gen_threadpool_ = true); + virtual ~RecConstrainedDecoding() = default; + + bool build_mask_cache() override; + + torch::Tensor generate_mask( + const std::vector>& generated_token_list) override; + + private: + torch::Tensor generate_decode_mask( + const std::vector>& generated_token_list); + + private: + bool build_mask_cache_; + bool use_gen_threadpool_; + int32_t vocab_size_; + uint64_t model_version_; + torch::Device device_; + torch::ScalarType dtype_; + torch::Tensor first_token_mask_; + std::unique_ptr gen_threadpool_; +}; + +} // namespace xllm From 50e36f97c0fc655ad1e1ca234555df44e2f3460f Mon Sep 17 00:00:00 2001 From: magicheng0816 Date: Wed, 3 Dec 2025 15:36:05 +0800 Subject: [PATCH 2/4] feat: fix log style,etc. --- xllm/core/framework/sampling/constrained_decoding.h | 6 +++--- .../framework/sampling/rec_constrained_decoding.cpp | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/xllm/core/framework/sampling/constrained_decoding.h b/xllm/core/framework/sampling/constrained_decoding.h index f1202287e..2a6415a2b 100644 --- a/xllm/core/framework/sampling/constrained_decoding.h +++ b/xllm/core/framework/sampling/constrained_decoding.h @@ -20,7 +20,7 @@ limitations under the License. namespace xllm { -// constrained decoding is used to ensure that the generated content +// Constrained decoding is used to ensure that the generated content // conforms to specific formats or rules. class ConstrainedDecoding { public: @@ -28,8 +28,8 @@ class ConstrainedDecoding { virtual bool build_mask_cache(); - // input generated_token_list: [sequence_num][generated_token_ids] - // output: mask tensor[sequence_num,vocab_size] + // Input generated_token_list: [sequence_num][generated_token_ids] + // Output: mask tensor[sequence_num,vocab_size] virtual torch::Tensor generate_mask( const std::vector>& generated_token_list); }; diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.cpp b/xllm/core/framework/sampling/rec_constrained_decoding.cpp index bd570c9bd..be1f71859 100644 --- a/xllm/core/framework/sampling/rec_constrained_decoding.cpp +++ b/xllm/core/framework/sampling/rec_constrained_decoding.cpp @@ -59,7 +59,7 @@ bool RecConstrainedDecoding::build_mask_cache() { build_mask_cache_ = true; - LOG(INFO) << "build mask cache, first token ids size:" + LOG(INFO) << "Build mask cache, first token ids size:" << first_token_ids.size(); return true; @@ -73,14 +73,14 @@ torch::Tensor RecConstrainedDecoding::generate_mask( size_t token_size = generated_token_list[0].size(); - // generate mask for first token + // Generate mask for first token if (0 == token_size) { size_t sequence_num = generated_token_list.size(); auto mask = first_token_mask_.unsqueeze(0); return mask.repeat({sequence_num, 1}); } - // generate mask for non-first token + // Generate mask for non-first token return generate_decode_mask(generated_token_list); } @@ -121,12 +121,12 @@ torch::Tensor RecConstrainedDecoding::generate_decode_mask( local_vocab_indices.push_back(static_cast(vocab_idx)); } } else { - LOG(ERROR) << "fail to generate mask for tokens:" + LOG(ERROR) << "Fail to generate mask for tokens:" << generated_token_list[token_idx]; } } - // merge local results to global batch (thread-safe) + // Merge local results to global batch (thread-safe) if (!local_token_indices.empty()) { std::lock_guard lock(global_batch_mutex); global_batch_token_indices.insert(global_batch_token_indices.end(), From 5a322eb1a19d1ba44268afbc5f2adc7391e3f752 Mon Sep 17 00:00:00 2001 From: magicheng0816 Date: Wed, 3 Dec 2025 21:21:03 +0800 Subject: [PATCH 3/4] feat: add comments, xllm header. --- .../framework/sampling/constrained_decoding.h | 15 +++++++++++++-- .../sampling/rec_constrained_decoding.cpp | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/xllm/core/framework/sampling/constrained_decoding.h b/xllm/core/framework/sampling/constrained_decoding.h index 2a6415a2b..2d505a21b 100644 --- a/xllm/core/framework/sampling/constrained_decoding.h +++ b/xllm/core/framework/sampling/constrained_decoding.h @@ -26,10 +26,21 @@ class ConstrainedDecoding { public: virtual ~ConstrainedDecoding(); + // Precompute and cache fixed constraint masks (e.g., static vocabulary + // whitelists) to avoid redundant calculations during token generation. + // Returns: true if cache built successfully, false otherwise virtual bool build_mask_cache(); - // Input generated_token_list: [sequence_num][generated_token_ids] - // Output: mask tensor[sequence_num,vocab_size] + // Generate dynamic constraint mask based on already generated token + // sequences. This mask will be applied to filter invalid tokens. + // + // Input: generated_token_list - 2D vector of token IDs, where each inner + // vector represents the generated tokens for a single sequence in the batch + // (format:[sequence_num][token_ids]) + // Output: tensor of shape [sequence_num, vocab_size], where 0.0f + // indicates allowed tokens and a large negative number indicates forbidden + // tokens for each sequence, the usage is to filter invalid tokens by adding + // the mask to the model logits. virtual torch::Tensor generate_mask( const std::vector>& generated_token_list); }; diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.cpp b/xllm/core/framework/sampling/rec_constrained_decoding.cpp index be1f71859..0778b3ad4 100644 --- a/xllm/core/framework/sampling/rec_constrained_decoding.cpp +++ b/xllm/core/framework/sampling/rec_constrained_decoding.cpp @@ -1,3 +1,18 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "rec_constrained_decoding.h" #include From bf356b0c173b8486e2af211ca415dd4c59d0ae20 Mon Sep 17 00:00:00 2001 From: magicheng0816 Date: Tue, 9 Dec 2025 10:47:24 +0800 Subject: [PATCH 4/4] feat: standardize some C++ implementations. --- xllm/core/framework/sampling/constrained_decoding.h | 6 +++--- .../framework/sampling/rec_constrained_decoding.cpp | 10 +++------- .../core/framework/sampling/rec_constrained_decoding.h | 4 ++++ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xllm/core/framework/sampling/constrained_decoding.h b/xllm/core/framework/sampling/constrained_decoding.h index 2d505a21b..5b842a6ef 100644 --- a/xllm/core/framework/sampling/constrained_decoding.h +++ b/xllm/core/framework/sampling/constrained_decoding.h @@ -24,12 +24,12 @@ namespace xllm { // conforms to specific formats or rules. class ConstrainedDecoding { public: - virtual ~ConstrainedDecoding(); + virtual ~ConstrainedDecoding() = default; // Precompute and cache fixed constraint masks (e.g., static vocabulary // whitelists) to avoid redundant calculations during token generation. // Returns: true if cache built successfully, false otherwise - virtual bool build_mask_cache(); + virtual bool build_mask_cache() = 0; // Generate dynamic constraint mask based on already generated token // sequences. This mask will be applied to filter invalid tokens. @@ -42,6 +42,6 @@ class ConstrainedDecoding { // tokens for each sequence, the usage is to filter invalid tokens by adding // the mask to the model logits. virtual torch::Tensor generate_mask( - const std::vector>& generated_token_list); + const std::vector>& generated_token_list) = 0; }; } // namespace xllm diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.cpp b/xllm/core/framework/sampling/rec_constrained_decoding.cpp index 0778b3ad4..117f63bf0 100644 --- a/xllm/core/framework/sampling/rec_constrained_decoding.cpp +++ b/xllm/core/framework/sampling/rec_constrained_decoding.cpp @@ -33,20 +33,16 @@ limitations under the License. #include "util/tensor_helper.h" namespace xllm { - -constexpr float PRE_MASK_FACTOR = -10000.0f; -constexpr int GEN_MASK_THREAD_NUM = 16; - RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version, const int32_t vocab_size, torch::ScalarType dtype, torch::Device device, bool use_gen_threadpool) - : model_version_(model_version), + : use_gen_threadpool_(use_gen_threadpool), vocab_size_(vocab_size), - dtype_(dtype), + model_version_(model_version), device_(device), - use_gen_threadpool_(use_gen_threadpool) { + dtype_(dtype) { if (use_gen_threadpool_) { gen_threadpool_ = std::make_unique(GEN_MASK_THREAD_NUM); } diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.h b/xllm/core/framework/sampling/rec_constrained_decoding.h index 555214cb3..7cf049286 100644 --- a/xllm/core/framework/sampling/rec_constrained_decoding.h +++ b/xllm/core/framework/sampling/rec_constrained_decoding.h @@ -40,6 +40,10 @@ class RecConstrainedDecoding : public ConstrainedDecoding { torch::Tensor generate_decode_mask( const std::vector>& generated_token_list); + private: + constexpr static float PRE_MASK_FACTOR = -10000.0f; + constexpr static int GEN_MASK_THREAD_NUM = 16; + private: bool build_mask_cache_; bool use_gen_threadpool_;