diff --git a/xllm/core/framework/sampling/CMakeLists.txt b/xllm/core/framework/sampling/CMakeLists.txt index 764157d04..c2d113892 100644 --- a/xllm/core/framework/sampling/CMakeLists.txt +++ b/xllm/core/framework/sampling/CMakeLists.txt @@ -10,12 +10,14 @@ cc_library( rejection_sampler.h sampler.h beam_searcher.h + rec_constrained_decoding.h SRCS sampling_params.cpp logits_utils.cpp rejection_sampler.cpp sampler.cpp beam_searcher.cpp + rec_constrained_decoding.cpp DEPS :common glog::glog diff --git a/xllm/core/framework/sampling/constrained_decoding.h b/xllm/core/framework/sampling/constrained_decoding.h new file mode 100644 index 000000000..5b842a6ef --- /dev/null +++ b/xllm/core/framework/sampling/constrained_decoding.h @@ -0,0 +1,47 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include +#include +#include + +namespace xllm { + +// Constrained decoding is used to ensure that the generated content +// conforms to specific formats or rules. +class ConstrainedDecoding { + public: + virtual ~ConstrainedDecoding() = default; + + // Precompute and cache fixed constraint masks (e.g., static vocabulary + // whitelists) to avoid redundant calculations during token generation. + // Returns: true if cache built successfully, false otherwise + virtual bool build_mask_cache() = 0; + + // Generate dynamic constraint mask based on already generated token + // sequences. This mask will be applied to filter invalid tokens. + // + // Input: generated_token_list - 2D vector of token IDs, where each inner + // vector represents the generated tokens for a single sequence in the batch + // (format:[sequence_num][token_ids]) + // Output: tensor of shape [sequence_num, vocab_size], where 0.0f + // indicates allowed tokens and a large negative number indicates forbidden + // tokens for each sequence, the usage is to filter invalid tokens by adding + // the mask to the model logits. + virtual torch::Tensor generate_mask( + const std::vector>& generated_token_list) = 0; +}; +} // namespace xllm diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.cpp b/xllm/core/framework/sampling/rec_constrained_decoding.cpp new file mode 100644 index 000000000..117f63bf0 --- /dev/null +++ b/xllm/core/framework/sampling/rec_constrained_decoding.cpp @@ -0,0 +1,197 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_constrained_decoding.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "common/global_flags.h" +#include "common/version_singleton.h" +#include "framework/state_dict/rec_vocab_dict.h" +#include "util/slice.h" +#include "util/tensor_helper.h" + +namespace xllm { +RecConstrainedDecoding::RecConstrainedDecoding(uint64_t model_version, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device, + bool use_gen_threadpool) + : use_gen_threadpool_(use_gen_threadpool), + vocab_size_(vocab_size), + model_version_(model_version), + device_(device), + dtype_(dtype) { + if (use_gen_threadpool_) { + gen_threadpool_ = std::make_unique(GEN_MASK_THREAD_NUM); + } + + build_mask_cache_ = false; +} + +bool RecConstrainedDecoding::build_mask_cache() { + first_token_mask_ = torch::full({vocab_size_}, PRE_MASK_FACTOR, dtype_); + + std::vector empty_token_ids; + Slice prefix_token_ids = {empty_token_ids.data(), + empty_token_ids.size()}; + + const std::set& first_token_ids = + VersionSingleton::GetInstance( + std::to_string(model_version_)) + ->get_next_tokens_by_prefix_tokens(prefix_token_ids); + + for (auto token_id : first_token_ids) { + first_token_mask_[token_id] = 0; + } + + first_token_mask_ = safe_to(first_token_mask_, device_, true); + + build_mask_cache_ = true; + + LOG(INFO) << "Build mask cache, first token ids size:" + << first_token_ids.size(); + + return true; +} + +torch::Tensor RecConstrainedDecoding::generate_mask( + const std::vector>& generated_token_list) { + if (!build_mask_cache_ || 0 == generated_token_list.size()) { + return torch::Tensor(); + } + + size_t token_size = generated_token_list[0].size(); + + // Generate mask for first token + if (0 == token_size) { + size_t sequence_num = generated_token_list.size(); + auto mask = first_token_mask_.unsqueeze(0); + return mask.repeat({sequence_num, 1}); + } + + // Generate mask for non-first token + return generate_decode_mask(generated_token_list); +} + +torch::Tensor RecConstrainedDecoding::generate_decode_mask( + const std::vector>& generated_token_list) { + size_t sequence_num = generated_token_list.size(); + torch::TensorOptions options = torch::dtype(dtype_).device(device_); + auto mask = + torch::full({sequence_num, vocab_size_}, PRE_MASK_FACTOR, options); + + std::mutex global_batch_mutex; + std::vector global_batch_token_indices; + std::vector global_batch_vocab_indices; + + int max_index_num_per_token = 8192; + global_batch_token_indices.reserve(max_index_num_per_token * sequence_num); + global_batch_vocab_indices.reserve(max_index_num_per_token * sequence_num); + + auto update_mask = [&](size_t start_idx, size_t end_idx) { + std::vector local_token_indices; + std::vector local_vocab_indices; + local_token_indices.reserve(max_index_num_per_token * + (end_idx - start_idx)); + local_vocab_indices.reserve(max_index_num_per_token * + (end_idx - start_idx)); + + for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) { + Slice tokens_slice(generated_token_list[token_idx]); + + const std::set& next_token_ids = + VersionSingleton::GetInstance( + std::to_string(model_version_)) + ->get_next_tokens_by_prefix_tokens(tokens_slice); + + if (next_token_ids.size() > 0) { + for (int32_t vocab_idx : next_token_ids) { + local_token_indices.push_back(static_cast(token_idx)); + local_vocab_indices.push_back(static_cast(vocab_idx)); + } + } else { + LOG(ERROR) << "Fail to generate mask for tokens:" + << generated_token_list[token_idx]; + } + } + + // Merge local results to global batch (thread-safe) + if (!local_token_indices.empty()) { + std::lock_guard lock(global_batch_mutex); + global_batch_token_indices.insert(global_batch_token_indices.end(), + local_token_indices.begin(), + local_token_indices.end()); + global_batch_vocab_indices.insert(global_batch_vocab_indices.end(), + local_vocab_indices.begin(), + local_vocab_indices.end()); + } + }; + + if (use_gen_threadpool_) { + const size_t batch_size = std::max( + 1UL, (sequence_num + GEN_MASK_THREAD_NUM - 1) / GEN_MASK_THREAD_NUM); + const size_t num_batches = (sequence_num + batch_size - 1) / batch_size; + + std::vector> futures; + std::vector>> promises; + + promises.reserve(num_batches); + futures.reserve(num_batches); + + for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) { + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + size_t start_idx = batch_idx * batch_size; + size_t end_idx = std::min(start_idx + batch_size, sequence_num); + + gen_threadpool_->schedule( + [update_mask, start_idx, end_idx, promise]() mutable { + update_mask(start_idx, end_idx); + promise->set_value(); + }); + } + + for (auto& future : futures) { + future.get(); + } + } else { + update_mask(0, sequence_num); + } + + if (!global_batch_token_indices.empty()) { + auto token_indices = + torch::tensor(global_batch_token_indices, torch::kInt64); + auto vocab_indices = + torch::tensor(global_batch_vocab_indices, torch::kInt64); + token_indices = safe_to(token_indices, device_, true); + vocab_indices = safe_to(vocab_indices, device_, true); + mask.index_put_({token_indices, vocab_indices}, 0.0f); + } + + return mask; +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/sampling/rec_constrained_decoding.h b/xllm/core/framework/sampling/rec_constrained_decoding.h new file mode 100644 index 000000000..7cf049286 --- /dev/null +++ b/xllm/core/framework/sampling/rec_constrained_decoding.h @@ -0,0 +1,58 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include +#include + +#include "constrained_decoding.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecConstrainedDecoding : public ConstrainedDecoding { + public: + RecConstrainedDecoding(uint64_t model_version, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device, + bool use_gen_threadpool_ = true); + virtual ~RecConstrainedDecoding() = default; + + bool build_mask_cache() override; + + torch::Tensor generate_mask( + const std::vector>& generated_token_list) override; + + private: + torch::Tensor generate_decode_mask( + const std::vector>& generated_token_list); + + private: + constexpr static float PRE_MASK_FACTOR = -10000.0f; + constexpr static int GEN_MASK_THREAD_NUM = 16; + + private: + bool build_mask_cache_; + bool use_gen_threadpool_; + int32_t vocab_size_; + uint64_t model_version_; + torch::Device device_; + torch::ScalarType dtype_; + torch::Tensor first_token_mask_; + std::unique_ptr gen_threadpool_; +}; + +} // namespace xllm