From 7b94adf6bbb2efbb9acd247dc30c9ea71eea7e62 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Fri, 27 Jun 2025 10:16:42 +0000 Subject: [PATCH 1/9] [DRAFT] Initial implementation of CFG (classifier free guidance) for LLMs Signed-off-by: Viacheslav Klimkov --- .../batch_manager/handleGenerationLogits.h | 3 +- .../tensorrt_llm/batch_manager/llmRequest.h | 37 +- cpp/include/tensorrt_llm/executor/executor.h | 10 +- .../layers/defaultDecodingParams.h | 5 + .../tensorrt_llm/runtime/samplingConfig.h | 8 + .../batch_manager/allocateKvCache.cpp | 40 +- .../batch_manager/assignReqSeqSlots.cpp | 14 +- .../batch_manager/encoderBuffers.cpp | 102 +++-- .../batch_manager/generateRequestOptions.cpp | 2 +- .../batch_manager/guidedDecoder.cpp | 4 +- .../batch_manager/handleContextLogits.cpp | 12 +- .../batch_manager/handleGenerationLogits.cpp | 56 ++- .../batch_manager/kvCacheManager.cpp | 36 +- cpp/tensorrt_llm/batch_manager/llmRequest.cpp | 2 +- .../batch_manager/logitsPostProcessor.cpp | 4 +- .../makeDecodingBatchInputOutput.cpp | 4 +- .../batch_manager/rnnStateBuffers.cpp | 4 +- .../batch_manager/runtimeBuffers.cpp | 431 ++++++++++-------- .../batch_manager/transformerBuffers.cpp | 272 +++++------ .../trtGptModelInflightBatching.cpp | 24 +- cpp/tensorrt_llm/executor/samplingConfig.cpp | 21 +- .../pybind/batch_manager/algorithms.cpp | 7 +- .../pybind/batch_manager/bindings.cpp | 2 +- cpp/tensorrt_llm/pybind/bindings.cpp | 4 +- cpp/tensorrt_llm/pybind/executor/request.cpp | 12 +- cpp/tensorrt_llm/runtime/decoderState.cpp | 4 +- .../runtime/gptDecoderBatched.cpp | 2 +- cpp/tests/batch_manager/guidedDecoderTest.cpp | 4 +- .../batch_manager/llmRequestTest.cpp | 4 +- tensorrt_llm/runtime/model_runner_cpp.py | 1 + 30 files changed, 691 insertions(+), 440 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h index 33a255a91d1..a585aa6e491 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h @@ -46,7 +46,8 @@ class HandleGenerationLogits : Algorithm HandleGenerationLogits() = default; void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, - tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, tr::ITensor::SharedPtr const& logits, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + tensorrt_llm::runtime::CudaStream const& stream, tr::ITensor::SharedPtr const& logits, OptionalRef genRuntimeBuffers, tr::SizeType32 vocabId = 0) const; }; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 478cc392675..32f56f03301 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -34,6 +34,7 @@ #include #include #include +#include namespace tensorrt_llm::batch_manager { @@ -520,6 +521,38 @@ class GenericLlmRequest return mNumReturnSequences; } + [[nodiscard]] bool isCfg() const + { + return mSamplingConfig.cfgScale.has_value() && mSamplingConfig.cfgScale->at(0) != 1.0f; + } + + [[nodiscard]] SizeType32 getNumSequences() const + { + if (isCfg()) { + TLLM_CHECK_WITH_INFO(mSamplingConfig.beamWidth == 1, "cfgScale is only supported for beamWidth = 1"); + return 2; + } + return 1; + } + + [[nodiscard]] SizeType32 getSeqSlot(int idx) const + { + TLLM_CHECK_WITH_INFO(idx >= 0 && idx < getNumSequences(), "seq slot idx is out of range"); + return mSeqSlots[idx]; + } + + [[nodiscard]] uint64_t getSeqSlotId(int idx = 0) const + { + if (idx == 0) { + return mRequestId; + } + if (isCfg() && idx == 1) { + return std::numeric_limits::max() - mRequestId; + } + TLLM_CHECK_WITH_INFO(false, "Sequence slot id is implemented for CFG only"); + return 0; + } + /// @brief Get the number of subrequests, the expected number of responses under non-streaming mode. In sampling /// mode, it will be equal to mSamplingConfig.numReturnSequences, while it will be equal to 1 in beam search. /// @return The number of subrequests in total request size. @@ -823,7 +856,7 @@ class GenericLlmRequest : LlmRequestState::kCONTEXT_INIT; mContextCurrentPosition = 0; mContextChunkSize = mPromptLen; - mSeqSlot.reset(); + mSeqSlots.clear(); } /// @brief Get the maximum length of tokens returned to the client. Use to ensure we don't return to @@ -1826,7 +1859,7 @@ class GenericLlmRequest runtime::SamplingConfig mSamplingConfig; std::optional mEndId{std::nullopt}; std::optional mPadId{std::nullopt}; - std::optional mSeqSlot{std::nullopt}; + std::vector mSeqSlots{}; std::optional mLogitsPostProcessor{std::nullopt}; bool mApplyLogitsPostProcessorBatched{false}; std::optional mClientId{std::nullopt}; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 40c278d1b89..f28a0d2882c 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -76,7 +76,8 @@ class SamplingConfig std::optional const& noRepeatNgramSize = std::nullopt, std::optional const& numReturnSequences = std::nullopt, std::optional const& minP = std::nullopt, - std::optional> const& beamWidthArray = std::nullopt); + std::optional> const& beamWidthArray = std::nullopt, + std::optional const& cfgScale = std::nullopt); bool operator==(SamplingConfig const& other) const; @@ -102,6 +103,7 @@ class SamplingConfig [[nodiscard]] std::optional getNumReturnSequences() const; [[nodiscard]] std::optional getMinP() const; [[nodiscard]] std::optional> getBeamWidthArray() const; + [[nodiscard]] std::optional getCfgScale() const; void setBeamWidth(SizeType32 beamWidth); void setTopK(std::optional const& topK); @@ -124,6 +126,7 @@ class SamplingConfig void setNumReturnSequences(std::optional const& numReturnSequences); void setMinP(std::optional const& minP); void setBeamWidthArray(std::optional> const& beamWidthArray); + void setCfgScale(std::optional const& cfgScale); private: static SizeType32 checkBeamWidth(SizeType32 beamWidth); @@ -145,6 +148,9 @@ class SamplingConfig static std::optional const& checkMinP(std::optional const& minP); static std::optional> const& checkBeamWidthArray( std::optional> const& beamWidthArray, std::optional const beamWidth); + static std::optional const& checkCfgScale(std::optional const& cfgScale); + + void updateNumReturnBeams(); friend class Serialization; @@ -196,6 +202,8 @@ class SamplingConfig std::optional mMinP; /// @brief Controls the beam width for each step for Variable-Beam-Width-Search. std::optional> mBeamWidthArray; + /// @brief Controls the cfg scale for sampling. + std::optional mCfgScale; }; /// @brief Additional output that should be gathered. diff --git a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h index d35080d5588..0fc4b1a3d53 100644 --- a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h +++ b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h @@ -133,6 +133,11 @@ class DefaultDecodingParams { return std::vector{1}; } + + [[nodiscard]] __host__ __device__ static constexpr float getCfgScale() + { + return 1.0f; + } }; } // namespace layers } // namespace tensorrt_llm diff --git a/cpp/include/tensorrt_llm/runtime/samplingConfig.h b/cpp/include/tensorrt_llm/runtime/samplingConfig.h index 00d7d0ae256..45771836df1 100644 --- a/cpp/include/tensorrt_llm/runtime/samplingConfig.h +++ b/cpp/include/tensorrt_llm/runtime/samplingConfig.h @@ -141,6 +141,8 @@ class SamplingConfig configs, [&configs](size_t ci) { return configs[ci].topK; }, layers::DefaultDecodingParams::getTopK()); topP = fuseValues( configs, [&configs](size_t ci) { return configs[ci].topP; }, layers::DefaultDecodingParams::getTopP()); + cfgScale = fuseValues( + configs, [&configs](size_t ci) { return configs[ci].cfgScale; }, layers::DefaultDecodingParams::getCfgScale()); // Generate a random seed for each samplingConfig with randomSeed == std::nullopt randomSeed = std::vector(configs.size()); @@ -230,6 +232,7 @@ class SamplingConfig SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32) SET_FROM_OPTIONAL(minP, MinP, FloatType) SET_FROM_OPTIONAL(beamWidthArray, BeamWidthArray, std::vector) + SET_FROM_OPTIONAL(cfgScale, CfgScale, FloatType) #undef SET_FROM_OPTIONAL } @@ -279,6 +282,8 @@ class SamplingConfig // valid &= validateVec("lengthPenalty", lengthPenalty, 0.f); valid &= validateVec("noRepeatNgramSize", noRepeatNgramSize, 0); valid &= validateVec("minP", minP, -fltEpsilon, {1.f}); + // TODO: validation of cfgScale? + valid &= validateVec("cfgScale", cfgScale, -10.0f); // TODO: check `beamWidthArray` // Detect greedy sampling and overwrite params. @@ -372,9 +377,12 @@ class SamplingConfig std::optional normalizeLogProbs; + OptVec cfgScale; // [1] or [batchSize] + bool operator==(SamplingConfig const& other) const { return beamWidth == other.beamWidth && numReturnSequences == other.numReturnSequences + && cfgScale == other.cfgScale && temperature == other.temperature && originalTemperature == other.originalTemperature && minLength == other.minLength && repetitionPenalty == other.repetitionPenalty && presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty diff --git a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp index c0482deb554..9041c9664a1 100644 --- a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp +++ b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp @@ -30,14 +30,9 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager { if (llmReq->isFirstContextChunk()) { - auto const requestId = llmReq->mRequestId; auto const promptLen = llmReq->mPromptLen; auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; auto draftLength = llmReq->getNumDraftTokens(); - - // Allocate/Reuse KV cache - kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq); - // EagleNet will increment kv cache up to maxPathLen to account for accepted tokens. // Then up to maxDecodingDraftTokens will be used to generate next draft tokens. if (modelConfig.getSpeculativeDecodingMode().isEagle()) @@ -45,26 +40,31 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager draftLength = modelConfig.getSpeculativeDecodingModule().getMaxPathLen() + modelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens(); } + for (int i = 0; i < llmReq->getNumSequences(); i++) { + auto const requestId = llmReq->getSeqSlotId(i); - // Allocate more KV cache for speculative decoding - if (draftLength > 0) - { - for (SizeType32 di = 0; di < draftLength; ++di) + // Allocate/Reuse KV cache + kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq); + + // Allocate more KV cache for speculative decoding + if (draftLength > 0) { - kvCacheManager.addToken(requestId); + for (SizeType32 di = 0; di < draftLength; ++di) + { + kvCacheManager.addToken(requestId); + } } - } - if (crossKvCacheManager) - { - crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq); + if (crossKvCacheManager) + { + crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq); + } } } } for (auto const& llmReq : generationRequests) { - auto const requestId = llmReq->mRequestId; auto decodingTokens = llmReq->getNumDraftTokens() + 1; // EagleNet will increment kv cache up to maxPathLen to account for accepted tokens. @@ -74,10 +74,12 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager decodingTokens = modelConfig.getSpeculativeDecodingModule().getMaxPathLen() + modelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens(); } - - for (SizeType32 di = 0; di < decodingTokens; ++di) - { - kvCacheManager.addToken(requestId); + for (int i = 0; i < llmReq->getNumSequences(); i++) { + auto const requestId = llmReq->getSeqSlotId(i); + for (SizeType32 di = 0; di < decodingTokens; ++di) + { + kvCacheManager.addToken(requestId); + } } } diff --git a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp index b38dd66bd40..8ca7917fdba 100644 --- a/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp +++ b/cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp @@ -30,15 +30,21 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana { for (auto const& llmReq : requests) { - auto const isReqNew = (llmReq->isContextInitState() && !llmReq->mSeqSlot) + auto const isReqNew = (llmReq->isContextInitState() && llmReq->mSeqSlots.empty()) || (llmReq->isDisaggGenerationTransmissionComplete()); if (isReqNew && llmReq->getReturnPerfMetrics()) { llmReq->setFirstScheduledTime(std::chrono::steady_clock::now()); } - auto const reqSeqSlot = seqSlotManager.getSequenceSlot(isReqNew, llmReq->mRequestId); - TLLM_CHECK_WITH_INFO(reqSeqSlot, "Unable to get batch slot for reqId"); - llmReq->mSeqSlot = reqSeqSlot; + for (int i = 0; i < llmReq->getNumSequences(); i++) { + auto const reqSeqSlot = seqSlotManager.getSequenceSlot(isReqNew, llmReq->getSeqSlotId(i)); + TLLM_CHECK_WITH_INFO(reqSeqSlot, "Unable to get batch slot for reqId"); + if ((int)llmReq->mSeqSlots.size() >= i + 1) { + llmReq->mSeqSlots[i] = reqSeqSlot.value(); + } else { + llmReq->mSeqSlots.push_back(reqSeqSlot.value()); + } + } } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp b/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp index 5bb11b1bcdc..6eb1a529dd9 100644 --- a/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/encoderBuffers.cpp @@ -121,6 +121,10 @@ void EncoderBuffers::updateBufferSizes(RequestVector const& requests, ModelConfi { encoderInputLen += req->getEncoderInputLen(); encoderOutputLen += req->getEncoderOutputLen(); + if (req->isCfg()) { + // for CFG, repeat the encoder output twice + encoderOutputLen += req->getEncoderOutputLen(); + } maxInputLengthInBatch = std::max(maxInputLengthInBatch, req->getEncoderInputLen()); // Decoder input is encoder output } @@ -222,40 +226,53 @@ void EncoderBuffers::setFromInputs(RequestVector const& requests, ModelConfig co { SizeType32 const inputLength = llmReq->getEncoderInputLen(); SizeType32 const outputLength = llmReq->getEncoderOutputLen(); - if (llmReq->getEncoderInputFeatures()) - { - auto const& reqFeatures - = llmReq - ->getEncoderInputFeatures(); // whisper: [length, featureDim]; Vision: [batch_size, channel, W, H] - TLLM_LOG_DEBUG("EncoderBuffers::setFromInputs - request id = %d, input features length = %d", - llmReq->mRequestId, inputLength); - manager.copy(*reqFeatures, *ITensor::slice(inputFeatures, offset, inputLength)); - offset += inputLength; - } - else - { - auto const& reqTokens = *llmReq->getEncoderTokens().value(); - inputIdsAll.insert(inputIdsAll.end(), reqTokens.begin(), reqTokens.end()); - if (tokenTypeIds) + for (int s = 0; s < llmReq->getNumSequences(); s++) { + if (llmReq->getEncoderInputFeatures()) { - tokenTypeIdsAll.insert( - tokenTypeIdsAll.end(), tokenTypeIdsReserved.begin(), tokenTypeIdsReserved.begin() + inputLength); + if (s == 0) { + // copy input features from request to the buffer for conditional generation + auto const& reqFeatures + = llmReq + ->getEncoderInputFeatures(); // whisper: [length, featureDim]; Vision: [batch_size, channel, W, H] + TLLM_LOG_DEBUG("EncoderBuffers::setFromInputs - request id = %d, input features length = %d", + llmReq->mRequestId, inputLength); + manager.copy(*reqFeatures, *ITensor::slice(inputFeatures, offset, inputLength)); + offset += inputLength; + } else if (s == 1) { + // need to add dummy input of zeros for CFG + auto uncondFeatures = ITensor::slice(inputFeatures, offset, inputLength); + manager.setMem(*uncondFeatures, 0); + offset += inputLength; + } else { + TLLM_CHECK_WITH_INFO(false, "Unexpected sequence index for llmReq [%ld]: %d", llmReq->mRequestId, s); + } } + else + { + // TODO: CFG support for encoder that proceses tokens is not implemented yet + auto const& reqTokens = *llmReq->getEncoderTokens().value(); + inputIdsAll.insert(inputIdsAll.end(), reqTokens.begin(), reqTokens.end()); + if (tokenTypeIds) + { + tokenTypeIdsAll.insert( + tokenTypeIdsAll.end(), tokenTypeIdsReserved.begin(), tokenTypeIdsReserved.begin() + inputLength); + } + } + if (positionIds) + { + SizeType32 const length = modelConfig.isWhisper() ? outputLength : inputLength; + positionIdsAll.insert( + positionIdsAll.end(), positionIdsReserved.begin(), positionIdsReserved.begin() + length); + } + if (modelConfig.useLanguageAdapter()) + { + auto const languageAdapterRouting + = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), inputLength); + languageAdapterRoutingAll.insert( + languageAdapterRoutingAll.end(), std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + } + inputLengthsAll.push_back(inputLength); } - if (positionIds) - { - SizeType32 const length = modelConfig.isWhisper() ? outputLength : inputLength; - positionIdsAll.insert( - positionIdsAll.end(), positionIdsReserved.begin(), positionIdsReserved.begin() + length); - } - if (modelConfig.useLanguageAdapter()) - { - auto const languageAdapterRouting - = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), inputLength); - languageAdapterRoutingAll.insert( - languageAdapterRoutingAll.end(), std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); - } - inputLengthsAll.push_back(inputLength); } // copy inputs from host to device @@ -396,6 +413,9 @@ void EncoderBuffers::rearrangeOutputs(RequestVector const& requests, ModelConfig } } offset += size; + if (req->isCfg()) { + offset += size; + } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -478,17 +498,23 @@ void EncoderBuffers::setBufferSizes(RequestVector const& contextRequests, Reques for (auto const& llmReq : contextRequests) { - numRequests += 1; - encoderInputLen += llmReq->getEncoderInputLen(); + numRequests += llmReq->getNumSequences(); encoderOutputLen += llmReq->getEncoderOutputLen(); + if (llmReq->isCfg()) { + encoderInputLen += llmReq->getEncoderInputLen(); + encoderOutputLen += llmReq->getEncoderOutputLen(); + } maxInputLengthInBatch = std::max(maxInputLengthInBatch, llmReq->getEncoderInputLen()); } for (auto const& llmReq : genRequests) { encoderOutputLen += llmReq->getEncoderOutputLen(); + if (llmReq->isCfg()) { + encoderOutputLen += llmReq->getEncoderOutputLen(); + } auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; - numRequests += reqBeamWidth; // tile by beam width + numRequests += reqBeamWidth * llmReq->getNumSequences(); // tile by beam width maxInputLengthInBatch = std::max(maxInputLengthInBatch, llmReq->getEncoderInputLen()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -522,10 +548,16 @@ void EncoderBuffers::fill( // copy encoder output to encoder output buffer for both ctx and gen requests, // disable freeing enc buffer in llm request for it size = llmReq->getEncoderOutputLen(); - auto const encoderOutputSlice = runtime::ITensor::slice(encoderOutput, offset, size); + auto encoderOutputSlice = runtime::ITensor::slice(encoderOutput, offset, size); manager.copy(*llmReq->getEncoderOutput(), *encoderOutputSlice); offset += size; inputLengthsAll.emplace_back(size); + if (llmReq->isCfg()) { + auto encoderOutputSlice = runtime::ITensor::slice(encoderOutput, offset, size); + manager.setMem(*encoderOutputSlice, 0); + offset += size; + inputLengthsAll.emplace_back(size); + } } } manager.copy(inputLengthsAll.data(), *inputLengths); diff --git a/cpp/tensorrt_llm/batch_manager/generateRequestOptions.cpp b/cpp/tensorrt_llm/batch_manager/generateRequestOptions.cpp index 159442cbead..cab2a69d465 100644 --- a/cpp/tensorrt_llm/batch_manager/generateRequestOptions.cpp +++ b/cpp/tensorrt_llm/batch_manager/generateRequestOptions.cpp @@ -179,7 +179,7 @@ GenerateRequestOptions::operator()(tr::ModelConfig const& modelConfig, tr::World decoderRequest.stopWordsList->squeeze(0); } // TODO: is this correct? - batchSlotsRange[batchIdx] = llmReq->mSeqSlot.value(); + batchSlotsRange[batchIdx] = llmReq->mSeqSlots.at(0); decoderRequests.push_back(decoderRequest); samplingConfigs.push_back(llmReq->mSamplingConfig); diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index edf8917d27a..9c8668369ae 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -79,7 +79,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests) { continue; } - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); if (llmReq->isContextInitState() && llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen()) { @@ -162,7 +162,7 @@ void GuidedDecoder::execute(ScheduledRequests const& scheduledRequests, BufferMa auto const& guidedDecodingParams = llmReq->getGuidedDecodingParams(); if (guidedDecodingParams.has_value()) { - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const& logits = decoderBuffersLogits.at(seqSlot); auto const logitsBitmask = ITensor::at(mLogitsBitmask, {seqSlot}); diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index 38e90efb975..9c6ba9bcd41 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -114,9 +114,16 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, // Get the logits from the last context token and draft tokens auto const numDecoderLogits = 1 + draftLength; - auto const seqSlot = llmReq->mSeqSlot.value(); TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); + // this is CFG support implementation, where we advance the logits index through the unconditional logits + if (llmReq->isCfg()) { + logitsIndex += numContextLogits + draftLength; + TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); + // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) + } + + auto const seqSlot = llmReq->mSeqSlots.at(0); if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { TLLM_CHECK(medusaBuffers); @@ -168,6 +175,9 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, curVocablogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); } ++batchIndex; + if (llmReq->isCfg()) { + ++batchIndex; + } } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp index 27429d780ea..18a7f0b1a91 100644 --- a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp @@ -39,6 +39,48 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; namespace { +template +static void addAndScale(ITensor* cond, ITensor* uncond, SizeType32 size, float cfgScale) { + auto* condPtr = tensorrt_llm::runtime::bufferCast(*cond); + auto* uncondPtr = tensorrt_llm::runtime::bufferCast(*uncond); + for (SizeType32 i = 0; i < size; i++) { + condPtr[i] = condPtr[i] * (T)cfgScale + uncondPtr[i] * (T)(1 - cfgScale); + } +} + +static void applyCfgCpu(BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, + TensorPtr logitsView, TensorPtr uncondLogitsView, + float cfgScale, SizeType32 vocabOffset, SizeType32 vocabSize) +{ + // this is a temporary testing implementation where CFG is applied on CPU. + // it needs to become a kernel implemented with cublas + auto logitsVocabView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] + auto uncondLogitsVocabView = ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] + + auto logitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), logitsVocabView->getDataType()); + auto uncondLogitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), uncondLogitsVocabView->getDataType()); + ITensor* logitsCpuPtr = logitsCpu.get(); + ITensor* uncondLogitsCpuPtr = uncondLogitsCpu.get(); + + logitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); + manager.copy(*logitsVocabView, *logitsCpuPtr); + uncondLogitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); + manager.copy(*uncondLogitsVocabView, *uncondLogitsCpuPtr); + stream.synchronize(); + + if (logitsVocabView->getDataType() == nvinfer1::DataType::kFLOAT) + { + addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); + } + else if (logitsVocabView->getDataType() == nvinfer1::DataType::kHALF) + { + addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); + } + manager.copy(*logitsCpuPtr, *logitsVocabView); + stream.synchronize(); +} + + //! @brief Copy logits from generation phase under streaming mode. void copyStreamingGenerationLogits(BufferManager const& bufferManager, LlmRequest& llmReq) { @@ -75,7 +117,8 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, BufferManager const& manager, - TensorPtr const& logits, OptionalRef genRuntimeBuffers, SizeType32 vocabId) const + tensorrt_llm::runtime::CudaStream const& stream, TensorPtr const& logits, OptionalRef genRuntimeBuffers, + SizeType32 vocabId) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(HandleGenerationLogits); @@ -91,7 +134,7 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co for (auto const& llmReq : generationRequests) { auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const draftLength = llmReq->getNumDraftTokens(); auto const numLogits = draftLength + reqBeamWidth; @@ -108,6 +151,15 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid(*logitsView, manager, "logits") == false, "Found invalid number (NaN or Inf) in logits"); + // CFG implementation: get unconditional logits and add them to logitsView + if (llmReq->isCfg()) { + logitsIndex += numLogits; + TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex, numLogits); + // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) + float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); + applyCfgCpu(manager, stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); + } + auto& decoderLogits = decoderBuffers.logits.at(seqSlot); auto const logitsViewShape = logitsView->getShape(); if (reqBeamWidth > 1) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b688d0ef4ff..c33c77cbe14 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -915,7 +915,7 @@ void BlockManager::addSequence( mReusedTokens += static_cast(prepopulatedPromptLen); mTotalInputTokens += static_cast(uniqueTokens.size()); llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock()); - TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.mRequestId, + TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.getSeqSlotId(), inputLength, prepopulatedPromptLen); } @@ -1422,7 +1422,8 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep(LlmRequest const& req, bool tw auto const numNextBlocks = tc::ceilDiv(numNextTokens, getTokensPerBlock()); numRequiredBlocks = (numNextBlocks - numPastBlocks) * req.mSamplingConfig.beamWidth; } - return numRequiredBlocks; + // we need more blocks if there are multiple sequences in this request + return numRequiredBlocks * req.getNumSequences(); } SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req) const @@ -1431,17 +1432,17 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req) { if (req.isContextInitState() && req.getContextCurrentPosition() == 0) { - return tc::ceilDiv(req.getEncoderOutputLen(), getTokensPerBlock()); + return tc::ceilDiv(req.getEncoderOutputLen(), getTokensPerBlock()) * req.getNumSequences(); } return 0; // cross KV cache doesn't grow after the initial context phase } SizeType32 const numContextBlocks - = (std::min(req.mPromptLen, mMaxAttentionWindow + mTemporaryAttentionWindow) + mSinkBubbleLength) + = req.getNumSequences() * (std::min(req.mPromptLen, mMaxAttentionWindow + mTemporaryAttentionWindow) + mSinkBubbleLength) / getTokensPerBlock(); SizeType32 const numTotalBlocksPerBeam - = tc::ceilDiv(std::min(req.mPromptLen + req.mMaxNewTokens, mMaxAttentionWindow + mTemporaryAttentionWindow) + = req.getNumSequences() * tc::ceilDiv(std::min(req.mPromptLen + req.mMaxNewTokens, mMaxAttentionWindow + mTemporaryAttentionWindow) + mSinkBubbleLength, getTokensPerBlock()); @@ -1450,11 +1451,14 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req) SizeType32 numAllocBlocksPerBeam = 0; { std::scoped_lock lck(mSequencesMtx); - auto const seqIt = mSequences.find(req.mRequestId); - if (seqIt != mSequences.end()) - { - auto const& seq = seqIt->second; - numAllocBlocksPerBeam = seq.getCacheBlockIds().at(0).size(); + for (int i = 0; i < req.getNumSequences(); i++) { + auto const requestId = req.getSeqSlotId(i); + auto const seqIt = mSequences.find(requestId); + if (seqIt != mSequences.end()) + { + auto const& seq = seqIt->second; + numAllocBlocksPerBeam += seq.getCacheBlockIds().at(0).size(); + } } } @@ -1689,11 +1693,13 @@ void KVCacheManager::addSequence( void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) { - auto const requestId = llmRequest.mRequestId; - auto& sequence = getSequence(requestId); - if (mEnableBlockReuse && !sequence.isCyclic()) - { - mBlockManager.storeContextBlocks(sequence, llmRequest); + for (int i = 0; i < llmRequest.getNumSequences(); i++) { + auto const requestId = llmRequest.getSeqSlotId(i); + auto& sequence = getSequence(requestId); + if (mEnableBlockReuse && !sequence.isCyclic()) + { + mBlockManager.storeContextBlocks(sequence, llmRequest); + } } } diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index 311d85145f2..da81b075c63 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -265,7 +265,7 @@ std::shared_ptr LlmRequest::createChildRequest(RequestIdType request childReq->mSequenceIndex = mChildRequests.size() + 1; childReq->mParentRequestId = this->mRequestId; childReq->mSequenceFinalVec = this->mSequenceFinalVec; - childReq->mSeqSlot.reset(); + childReq->mSeqSlots.clear(); // To ensure different randomness across children, assign a unique random seed to each child // by adding its sequence index to the base seed. If no seed is provided, the parent's seed defaults to 0. diff --git a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp index 888de98e7bf..716c4d6e28b 100644 --- a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp +++ b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp @@ -59,7 +59,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque logitsPostProcessorIsApplied = true; if (replicateLogitsPostProcessor || worldConfig.isFirstTensorParallelRank()) { - auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlot.value()); + auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlots.at(0)); (*llmReq->mLogitsPostProcessor)( llmReq->mRequestId, logits, llmReq->getTokens(), runtime.getStreamPtr(), llmReq->mClientId); } @@ -68,7 +68,7 @@ bool LogitsPostProcessor::operator()(RequestVector const& contextRequests, Reque { reqIdsVec.push_back(llmReq->mRequestId); - auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlot.value()); + auto& logits = decoderBuffers.logits.at(llmReq->mSeqSlots.at(0)); logitsVec.push_back(logits); beamTokensVec.emplace_back(llmReq->getTokens()); diff --git a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp index 9823a653759..9b940041a0b 100644 --- a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp +++ b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp @@ -103,7 +103,7 @@ std::vector getActiveSlots(RequestVector const& contextRequests, Req { if (llmReq->isGenerationInProgressState() || llmReq->isLastContextChunk()) { - activeSlots.push_back(llmReq->mSeqSlot.value()); + activeSlots.push_back(llmReq->mSeqSlots.at(0)); } } } @@ -132,7 +132,7 @@ void copySequenceLengths(RequestVector const& contextRequests, RequestVector con { auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens(); // Get position of the current sequence in the decoder - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); batchSlotsRange[batchIdx] = seqSlot; fillValuesRange[batchIdx] = currentSequenceLen; ++batchIdx; diff --git a/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp b/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp index cedc52c80b7..18112212170 100644 --- a/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/rnnStateBuffers.cpp @@ -52,7 +52,9 @@ void RnnStateBuffers::fillSlotMappings( SizeType32 batchIdx{0}; for (auto const& llmReq : contextRequests) { - auto const seqSlot = llmReq->mSeqSlot.value(); + // TODO: rnn state does not support CFG yet + TLLM_CHECK_WITH_INFO(!llmReq->isCfg(), "rnn state buffers do not support CFG yet"); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; rnnStateManager->fillSlotMapping(*slotMappingHost, batchIdx, seqSlot, reqBeamWidth); ++batchIdx; diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 1a3576d7e91..1e685c28777 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -389,17 +389,20 @@ void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, Reques NVTX3_SCOPED_RANGE(runtimeBuffersSetBufferSizes); // set context sizes - numContextRequests = static_cast(contextRequests.size()); + numContextRequests = 0; + for (auto const& llmReq : contextRequests) { + numContextRequests += llmReq->getNumSequences(); + } auto numContextLogits = numContextRequests; numContextTokens = 0; maxContextLength = 0; for (auto const& llmReq : contextRequests) { auto const draftLength = llmReq->isLastContextChunk() ? llmReq->getNumDraftTokens() : 0; - numContextLogits += draftLength; + numContextLogits += draftLength * llmReq->getNumSequences(); auto const contextChunkSize = llmReq->getContextChunkSize(); - numContextTokens += contextChunkSize + draftLength; + numContextTokens += (contextChunkSize + draftLength) * llmReq->getNumSequences(); if (maxContextLength < llmReq->mPromptLen) { maxContextLength = llmReq->mPromptLen; @@ -407,15 +410,18 @@ void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, Reques } // set generation sizes - numGenRequests = static_cast(genRequests.size()); + numGenRequests = 0; + for (auto const& llmReq : genRequests) { + numGenRequests += llmReq->getNumSequences(); + } numGenSequences = 0; numGenTokens = 0; for (auto const& llmReq : genRequests) { - auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; - numGenSequences += reqBeamWidth; + auto reqBeamWidth = llmReq->mSamplingConfig.beamWidth; + numGenSequences += reqBeamWidth * llmReq->getNumSequences(); auto const draftLen = llmReq->getNumDraftTokens(); - numGenTokens += draftLen + reqBeamWidth; + numGenTokens += (draftLen + reqBeamWidth) * llmReq->getNumSequences(); } numLogits = numContextLogits + numGenTokens; @@ -425,6 +431,7 @@ void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, Reques encoderBuffers->setBufferSizes(contextRequests, genRequests); } + TLLM_LOG_WARNING(">>>>>setBufferSizes: numContextRequests %d; numGenRequests %d", numContextRequests, numGenRequests); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -500,21 +507,27 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request for (auto const& llmReq : requests) { // Get position of the current sequence in the decoder - auto const seqSlot = llmReq->mSeqSlot.value(); - seqSlotIndices[batchIdx] = seqSlot; - ++batchIdx; + for (const auto& seqSlot : llmReq->mSeqSlots) { + seqSlotIndices[batchIdx] = seqSlot; + ++batchIdx; + } } } - TLLM_CHECK(seqSlots->getSize() == static_cast(batchIdx)); manager.copy(*seqSlots, *seqSlotsDevice); } + // count number of context sequences + SizeType32 contextRequestsSize = 0; + for (auto const& llmReq : contextRequests) { + contextRequestsSize += llmReq->getNumSequences(); + } + // context preparation loop - if (!contextRequests.empty()) + if (contextRequestsSize > 0) { NVTX3_SCOPED_RANGE(contextPrepareLoop); - numContextLogits.resize(contextRequests.size()); + numContextLogits.resize(contextRequestsSize); SizeType32 batchIdx{0}; for (auto const& llmReq : contextRequests) @@ -524,102 +537,118 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request TLLM_CHECK_WITH_INFO( llmReq->getMaxNumGeneratedTokens() == 0, "Context request should not have generated tokens."); - auto const& reqTokens = llmReq->getTokens(0); - auto const& draftTokens = llmReq->getDraftTokens(); - auto const draftLength = llmReq->getNumDraftTokens(); - auto const& positionIds = llmReq->getPositionIds(); + // for CFG requests, add the inputs to the buffer twice + std::vector is_conditional_vec{true}; + if (llmReq->isCfg()) { + is_conditional_vec.push_back(false); + } + for (const auto& is_conditional : is_conditional_vec) { + auto const& origTokens = llmReq->getTokens(0); + std::vector dummyTokens; + if (!is_conditional) { + // TODO: need a special token for unconditional input, which is expanded + // to all zeros + dummyTokens.assign(origTokens.size(), 0); + } - auto const contextChunkSize = llmReq->getContextChunkSize(); - auto const beginCompute = llmReq->getContextCurrentPosition(); - auto const endCompute = beginCompute + contextChunkSize; - inputHost.insert(inputHost.end(), reqTokens.begin() + beginCompute, - reqTokens.begin() + beginCompute + contextChunkSize * llmReq->getNumVocabs()); + auto const& reqTokens = is_conditional ? origTokens : dummyTokens; + auto const& draftTokens = llmReq->getDraftTokens(); + auto const draftLength = llmReq->getNumDraftTokens(); + auto const& positionIds = llmReq->getPositionIds(); - logitsIdsHostPtr[totalNumLogits++] = contextChunkSize; - numContextLogits.at(batchIdx) = modelConfig.computeContextLogits() ? contextChunkSize : 1; + auto const contextChunkSize = llmReq->getContextChunkSize(); + auto const beginCompute = llmReq->getContextCurrentPosition(); + auto const endCompute = beginCompute + contextChunkSize; + inputHost.insert(inputHost.end(), reqTokens.begin() + beginCompute, + reqTokens.begin() + beginCompute + contextChunkSize * llmReq->getNumVocabs()); - if (llmReq->isLastContextChunk()) - { - inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); - std::fill_n(logitsIdsHostPtr + totalNumLogits, draftLength, 1); - totalNumLogits += draftLength; - } - auto const inputLength = contextChunkSize + (llmReq->isLastContextChunk() ? draftLength : 0); - contextLengthsHostPtr[batchIdx] = inputLength; - auto const sequenceLen = inputLength + llmReq->getContextCurrentPosition(); - sequenceLengthsHostPtr[batchIdx] = sequenceLen; + logitsIdsHostPtr[totalNumLogits++] = contextChunkSize; + numContextLogits.at(batchIdx) = modelConfig.computeContextLogits() ? contextChunkSize : 1; - if (static_cast(pastKeyValueLengthsPtr)) - { - pastKeyValueLengthsPtr[batchIdx] = beginCompute + inputLength; - } - - if (positionIds.has_value()) - { - TLLM_CHECK_WITH_INFO(!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); - positionIdsHost.insert(positionIdsHost.end(), positionIds.value()->begin() + beginCompute, - positionIds.value()->begin() + endCompute); - } - else - { - if (isChatGlm) + if (llmReq->isLastContextChunk()) { - // Specialize for ChatGLM-6B with 2D-Position-Embedding - positionIdsHost.resize(totalInputSize + inputLength); - std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); - positionIdsHost.back() = positionIdsHost.back() - 1; + inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); + std::fill_n(logitsIdsHostPtr + totalNumLogits, draftLength, 1); + totalNumLogits += draftLength; + } + auto const inputLength = contextChunkSize + (llmReq->isLastContextChunk() ? draftLength : 0); + contextLengthsHostPtr[batchIdx] = inputLength; + auto const sequenceLen = inputLength + llmReq->getContextCurrentPosition(); + sequenceLengthsHostPtr[batchIdx] = sequenceLen; - positionIdsHostRow2.resize(totalInputSize + inputLength); - positionIdsHostRow2.back() = 1; + if (static_cast(pastKeyValueLengthsPtr)) + { + pastKeyValueLengthsPtr[batchIdx] = beginCompute + inputLength; } - else if (isGlm) + + if (positionIds.has_value()) { - // Specialize for GLM-10B with 2D-Position-Embedding and special value of the mask id position - auto start = inputHost.begin() + totalInputSize; - auto end = start + inputLength; - auto it = std::find_if( - start, end, [](SizeType32 id) { return id == 50260 || id == 50263 || id == 50264; }); - llmReq->mMaskPosition = (it != end) ? std::distance(start, it) : maxContextLength; - - positionIdsHost.resize(totalInputSize + inputLength); - std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); - positionIdsHost.back() = llmReq->mMaskPosition; - - positionIdsHostRow2.resize(totalInputSize + inputLength); - positionIdsHostRow2.back() = 1; + TLLM_CHECK_WITH_INFO(!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); + positionIdsHost.insert(positionIdsHost.end(), positionIds.value()->begin() + beginCompute, + positionIds.value()->begin() + endCompute); } else { - // Other models - positionIdsHost.resize(totalInputSize + inputLength); - std::iota(std::begin(positionIdsHost) + totalInputSize, - std::begin(positionIdsHost) + totalInputSize + inputLength, beginCompute); + if (isChatGlm) + { + // Specialize for ChatGLM-6B with 2D-Position-Embedding + positionIdsHost.resize(totalInputSize + inputLength); + std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); + positionIdsHost.back() = positionIdsHost.back() - 1; + + positionIdsHostRow2.resize(totalInputSize + inputLength); + positionIdsHostRow2.back() = 1; + } + else if (isGlm) + { + // Specialize for GLM-10B with 2D-Position-Embedding and special value of the mask id position + auto start = inputHost.begin() + totalInputSize; + auto end = start + inputLength; + auto it = std::find_if( + start, end, [](SizeType32 id) { return id == 50260 || id == 50263 || id == 50264; }); + llmReq->mMaskPosition = (it != end) ? std::distance(start, it) : maxContextLength; + + positionIdsHost.resize(totalInputSize + inputLength); + std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0); + positionIdsHost.back() = llmReq->mMaskPosition; + + positionIdsHostRow2.resize(totalInputSize + inputLength); + positionIdsHostRow2.back() = 1; + } + else + { + // Other models + positionIdsHost.resize(totalInputSize + inputLength); + std::iota(std::begin(positionIdsHost) + totalInputSize, + std::begin(positionIdsHost) + totalInputSize + inputLength, beginCompute); + } } - } - if (modelConfig.useMrope()) - { - auto optMropeRotaryCosSin = llmReq->getMropeRotaryCosSin().value(); - TLLM_CHECK_WITH_INFO(optMropeRotaryCosSin->getShape().d[0] == mropeRotaryCosSinSize, - "Provided MropeRotarySinCos is %ld and expected is %d.\n", optMropeRotaryCosSin->getShape().d[0], - int(mropeRotaryCosSinSize)); + if (modelConfig.useMrope()) + { + auto optMropeRotaryCosSin = llmReq->getMropeRotaryCosSin().value(); + TLLM_CHECK_WITH_INFO(optMropeRotaryCosSin->getShape().d[0] == mropeRotaryCosSinSize, + "Provided MropeRotarySinCos is %ld and expected is %d.\n", optMropeRotaryCosSin->getShape().d[0], + int(mropeRotaryCosSinSize)); - auto const mropeRotaryCosSinCtx = ITensor::slice(mropeRotaryCosSin, batchIdx, 1); - manager.copy(*optMropeRotaryCosSin, *mropeRotaryCosSinCtx); - } + auto const mropeRotaryCosSinCtx = ITensor::slice(mropeRotaryCosSin, batchIdx, 1); + manager.copy(*optMropeRotaryCosSin, *mropeRotaryCosSinCtx); + } - if (modelConfig.useLanguageAdapter()) - { - auto const languageAdapterRouting = llmReq->getLanguageAdapterRouting( - modelConfig.getNumLanguages().value(), endCompute - beginCompute); - languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), - std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + if (modelConfig.useLanguageAdapter()) + { + auto const languageAdapterRouting = llmReq->getLanguageAdapterRouting( + modelConfig.getNumLanguages().value(), endCompute - beginCompute); + languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), + std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + } + totalInputSize += inputLength; + ++batchIdx; } - totalInputSize += inputLength; - ++batchIdx; } if (rnnStateBuffers) - { + { + // TODO: dont implement CFG for rnn state buffers for now rnnStateBuffers->fillSlotMappings(contextRequests, rnnStateManagerPtr); } @@ -635,108 +664,112 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request { NVTX3_SCOPED_RANGE(genPrepareLoop); - auto const numContextRequests = static_cast(contextRequests.size()); - auto numSequences = numContextRequests; + auto numSequences = contextRequestsSize; for (auto const& llmReq : genRequests) { - auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; - - auto const draftLength = llmReq->getNumDraftTokens(); - auto const& draftTokens = llmReq->getDraftTokens(); - auto const numLogits = draftLength + reqBeamWidth; - TLLM_CHECK(draftLength == 0 || reqBeamWidth == 1); - - auto const promptLen = llmReq->mPromptLen; - auto const sequenceLen = promptLen + llmReq->getMaxNumGeneratedTokens(); - auto const& positionIds = llmReq->getPositionIds(); - for (int beam = 0; beam < reqBeamWidth; ++beam) - { - auto const lastToken = llmReq->getLastTokens(beam); - auto const numTokens = llmReq->getNumTokens(beam); - if (llmReq->getNumVocabs() > 1) - { - auto const& beamTokens = llmReq->getTokens(beam); - TLLM_CHECK_WITH_INFO( - beamTokens.size() % llmReq->getNumVocabs() == 0, - "Number of tokens needs to be a multiple of number of vocabs!" - ); - inputHost.insert(inputHost.end(), beamTokens.cend() - llmReq->getNumVocabs(), beamTokens.cend()); - } - else + for (int s = 0; s < llmReq->getNumSequences(); s++) { + auto reqBeamWidth = llmReq->mSamplingConfig.beamWidth; + + auto const draftLength = llmReq->getNumDraftTokens(); + auto const& draftTokens = llmReq->getDraftTokens(); + auto const numLogits = draftLength + reqBeamWidth; + TLLM_CHECK(draftLength == 0 || reqBeamWidth == 1); + + auto const promptLen = llmReq->mPromptLen; + auto const sequenceLen = promptLen + llmReq->getMaxNumGeneratedTokens(); + auto const& positionIds = llmReq->getPositionIds(); + for (int reqBeam = 0; reqBeam < reqBeamWidth; ++reqBeam) { - inputHost.push_back(lastToken); + // for CFG, simply use tokens from the 0th beam during generation + int beam = llmReq->isCfg() ? 0 : reqBeam; + auto const numTokens = llmReq->getNumTokens(beam); + if (llmReq->getNumVocabs() > 1) + { + auto const& beamTokens = llmReq->getTokens(beam); + TLLM_CHECK_WITH_INFO( + beamTokens.size() % llmReq->getNumVocabs() == 0, + "Number of tokens needs to be a multiple of number of vocabs!" + ); + inputHost.insert(inputHost.end(), beamTokens.cend() - llmReq->getNumVocabs(), beamTokens.cend()); + } + else + { + auto const lastToken = llmReq->getLastTokens(beam); + inputHost.push_back(lastToken); + } + + // If model updates generation position ids do not append them here. + if (!modelConfig.getSpeculativeDecodingMode().updatesPositionIds()) + { + if (positionIds.has_value()) + { + TLLM_CHECK_WITH_INFO( + !(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); + auto last_context_position_id = positionIds.value()->back(); + positionIdsHost.push_back( + static_cast(last_context_position_id + sequenceLen - promptLen)); + } + else + { + if (isChatGlm) // ChatGLM-6B + { + positionIdsHost.push_back(static_cast(promptLen - 2)); + positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); + } + else if (isGlm) + { + positionIdsHost.push_back(llmReq->mMaskPosition); + positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); + } + else // GPT / ChatGLM2-6B / ChatGLM3-6B / BART + { + // positionIds is just the size of tokens -1 + positionIdsHost.push_back(numTokens - 1); + } + } + } + + if (draftLength > 0) + { + inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); + } + + if (modelConfig.useMrope()) + { + auto optMropePositionDeltas = llmReq->getMropePositionDeltas().value(); + mropePositionDeltasHost.push_back(optMropePositionDeltas); + } + + if (modelConfig.useLanguageAdapter()) + { + // Generation requests only have one token per sequence + auto const languageAdapterRouting + = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), 1); + languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), + std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + } } - // If model updates generation position ids do not append them here. - if (!modelConfig.getSpeculativeDecodingMode().updatesPositionIds()) + if (static_cast(pastKeyValueLengthsPtr)) { - if (positionIds.has_value()) - { - TLLM_CHECK_WITH_INFO( - !(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization"); - auto last_context_position_id = positionIds.value()->back(); - positionIdsHost.push_back( - static_cast(last_context_position_id + sequenceLen - promptLen)); - } - else - { - if (isChatGlm) // ChatGLM-6B - { - positionIdsHost.push_back(static_cast(promptLen - 2)); - positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); - } - else if (isGlm) - { - positionIdsHost.push_back(llmReq->mMaskPosition); - positionIdsHostRow2.push_back(static_cast(sequenceLen - promptLen + 1)); - } - else // GPT / ChatGLM2-6B / ChatGLM3-6B / BART - { - // positionIds is just the size of tokens -1 - positionIdsHost.push_back(numTokens - 1); - } - } + SizeType32 pastKeyValueLength = sequenceLen - 1; + std::fill_n(pastKeyValueLengthsPtr + numSequences, reqBeamWidth, pastKeyValueLength); } + totalInputSize += numLogits; - if (draftLength > 0) - { - inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end()); - } + std::fill_n(logitsIdsHostPtr + totalNumLogits, numLogits, 1); - if (modelConfig.useMrope()) - { - auto optMropePositionDeltas = llmReq->getMropePositionDeltas().value(); - mropePositionDeltasHost.push_back(optMropePositionDeltas); - } + totalNumLogits += numLogits; - if (modelConfig.useLanguageAdapter()) + if (rnnStateBuffers) { - // Generation requests only have one token per sequence - auto const languageAdapterRouting - = llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), 1); - languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(), - std::begin(languageAdapterRouting), std::end(languageAdapterRouting)); + TLLM_CHECK_WITH_INFO(!llmReq->isCfg(), "CFG is not supported for rnn state buffers"); + auto const seqSlot = llmReq->mSeqSlots[0]; + auto& rnnStateManager = *rnnStateManagerPtr; + rnnStateManager.fillSlotMapping(*rnnStateBuffers->slotMappingHost, numSequences, seqSlot, reqBeamWidth); } + numSequences += reqBeamWidth; } - - if (static_cast(pastKeyValueLengthsPtr)) - { - SizeType32 pastKeyValueLength = sequenceLen - 1; - std::fill_n(pastKeyValueLengthsPtr + numSequences, reqBeamWidth, pastKeyValueLength); - } - totalInputSize += numLogits; - - std::fill_n(logitsIdsHostPtr + totalNumLogits, numLogits, 1); - - totalNumLogits += numLogits; - - if (rnnStateBuffers) - { - auto const seqSlot = llmReq->mSeqSlot.value(); - auto& rnnStateManager = *rnnStateManagerPtr; - rnnStateManager.fillSlotMapping(*rnnStateBuffers->slotMappingHost, numSequences, seqSlot, reqBeamWidth); - } - numSequences += reqBeamWidth; } if (transformerBuffers && maxBeamWidth > 1) @@ -744,19 +777,20 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request transformerBuffers->copyCacheIndirection(genRequests, decoderBuffers.cacheIndirectionOutput, stream); } - numSequences = numContextRequests; + numSequences = contextRequestsSize; for (auto const& llmReq : genRequests) { - auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; - - auto const draftLength = llmReq->getNumDraftTokens(); + for (int s = 0; s < llmReq->getNumSequences(); s++) { + auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; + auto const draftLength = llmReq->getNumDraftTokens(); - auto const contextQLength = llmReq->mPromptLen + draftLength; - auto const sequenceLen = contextQLength + llmReq->getMaxNumGeneratedTokens(); + auto const contextQLength = llmReq->mPromptLen + draftLength; + auto const sequenceLen = contextQLength + llmReq->getMaxNumGeneratedTokens(); - std::fill_n(contextLengthsHostPtr + numSequences, reqBeamWidth, contextQLength); - std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen); - numSequences += reqBeamWidth; + std::fill_n(contextLengthsHostPtr + numSequences, reqBeamWidth, contextQLength); + std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen); + numSequences += reqBeamWidth; + } } if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind()) { @@ -990,10 +1024,9 @@ void RuntimeBuffers::setAttentionPriorIdx( SizeType32 totalContextEncoderOutputLen = 0; for (auto const& llmReq : contextRequests) { totalContextEncoderOutputLen += llmReq->getEncoderOutputLen(); - } - SizeType32 totalEncoderOutputLen = totalContextEncoderOutputLen; - for (auto const& llmReq : genRequests) { - totalEncoderOutputLen += llmReq->getEncoderOutputLen(); + if (llmReq->isCfg()) { + totalContextEncoderOutputLen += llmReq->getEncoderOutputLen(); + } } // create a cpu buffer for scores to find max score in @@ -1016,8 +1049,14 @@ void RuntimeBuffers::setAttentionPriorIdx( llmReq->setAttentionPriorIdx(maxIdx); } kvOffset += encoderOutputLen; + if (llmReq->isCfg()) { + kvOffset += encoderOutputLen; + } } qOffset += contextRequests[i]->getContextChunkSize(); + if (contextRequests[i]->isCfg()) { + qOffset += contextRequests[i]->getContextChunkSize(); + } } // for generation requests, there is no context, @@ -1045,8 +1084,14 @@ void RuntimeBuffers::setAttentionPriorIdx( llmReq->setAttentionPriorIdx(prevPriorIdx + maxIdx); } kvOffset += encoderOutputLen; + if (llmReq->isCfg( )) { + kvOffset += encoderOutputLen; + } } qOffset += 1; + if (genRequests[i]->isCfg()) { + qOffset += 1; + } } } diff --git a/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp b/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp index 8de1ae2b628..b98696b7f72 100644 --- a/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp @@ -359,11 +359,15 @@ void TransformerBuffers::resetCacheIndirection(RequestVector const& contextReque NVTX3_SCOPED_RANGE(resetCacheIndirection); auto const& stream = manager.getStream(); - auto const numContextRequests = contextRequests.size(); - + std::vector slots; + for (auto const& llmReq : contextRequests) { + for (int i = 0; i < llmReq->getNumSequences(); i++) { + slots.push_back(llmReq->getSeqSlot(i)); + } + } + auto const numContextRequests = slots.size(); std::fill_n(bufferCast(*fillValuesAlt), numContextRequests, 0); - std::transform(contextRequests.begin(), contextRequests.end(), bufferCast(*seqSlotsAlt), - [](auto const& llmReq) { return llmReq->mSeqSlot.value(); }); + std::copy(slots.begin(), slots.end(), bufferCast(*seqSlotsAlt)); auto const seqSlotsHostView = ITensor::slice(seqSlotsAlt, 0, numContextRequests); auto seqSlotsDeviceView = ITensor::slice(seqSlotsAltDevice, 0, numContextRequests); @@ -394,19 +398,21 @@ void TransformerBuffers::copyKvBlockOffsets(RequestVector const& contextRequests { for (auto const& llmReq : requests) { - auto const requestId = llmReq->mRequestId; - auto const isContextRequest = llmReq->isContextInitState(); - auto const beamWidth = isContextRequest ? contextBeamWidth : llmReq->mSamplingConfig.beamWidth; - auto const maxBeamBlockCount - = kvCacheManager->copyBlockOffsets(*kvCacheBlockOffsetsHost, numSequences, requestId); - maxBlockCount = std::max(maxBlockCount, maxBeamBlockCount); - if (crossKvCacheBlockOffsetsHost) - { - auto const maxCrossBeamBlockCount - = crossKvCacheManager->copyBlockOffsets(*crossKvCacheBlockOffsetsHost, numSequences, requestId); - maxCrossBlockCount = std::max(maxCrossBlockCount, maxCrossBeamBlockCount); + for (int i = 0; i < llmReq->getNumSequences(); i++) { + auto const requestId = llmReq->getSeqSlotId(i); + auto const isContextRequest = llmReq->isContextInitState(); + auto const beamWidth = isContextRequest ? contextBeamWidth : llmReq->mSamplingConfig.beamWidth; + auto const maxBeamBlockCount + = kvCacheManager->copyBlockOffsets(*kvCacheBlockOffsetsHost, numSequences, requestId); + maxBlockCount = std::max(maxBlockCount, maxBeamBlockCount); + if (crossKvCacheBlockOffsetsHost) + { + auto const maxCrossBeamBlockCount + = crossKvCacheManager->copyBlockOffsets(*crossKvCacheBlockOffsetsHost, numSequences, requestId); + maxCrossBlockCount = std::max(maxCrossBlockCount, maxCrossBeamBlockCount); + } + numSequences += beamWidth; } - numSequences += beamWidth; } } @@ -442,7 +448,13 @@ void TransformerBuffers::copyCacheIndirection( TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(copyCacheIndirection); - auto const numGenerationRequests = genRequests.size(); + std::vector slots; + for (auto const& llmReq : genRequests) { + for (int i = 0; i < llmReq->getNumSequences(); i++) { + slots.push_back(llmReq->getSeqSlot(i)); + } + } + auto const numGenerationRequests = slots.size(); auto batchedCopySrcOffsets = BufferRange(*cacheIndirBatchedCopySrcOffsets); auto batchedCopyDstOffsets = BufferRange(*cacheIndirBatchedCopyDstOffsets); @@ -454,8 +466,8 @@ void TransformerBuffers::copyCacheIndirection( cacheIndirShape.d[0] = 1; auto const copySize = static_cast(ITensor::volume(cacheIndirShape)); - std::transform(genRequests.begin(), genRequests.end(), batchedCopySrcOffsets.begin(), - [copySize](auto const& llmReq) { return llmReq->mSeqSlot.value() * copySize; }); + std::transform(slots.begin(), slots.end(), batchedCopySrcOffsets.begin(), + [copySize](auto const& slot) { return slot * copySize; }); std::generate_n( batchedCopyDstOffsets.begin(), numGenerationRequests, [copySize, i = 0]() mutable { return (i++) * copySize; }); std::fill_n(batchedCopySizes.begin(), numGenerationRequests, copySize); @@ -542,136 +554,140 @@ void TransformerBuffers::copyCrossAttentionMasks(RequestVector const& contextReq bool* pinnedMemPtr = bufferCastOrNull(crossAttentionMaskPinnedHost); for (auto const& llmReq : contextRequests) { - auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); - auto const position = llmReq->getContextCurrentPosition(); - auto const size = llmReq->getContextChunkSize(); - if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr) - { - auto memType = crossAttentionMaskRequest->getMemoryType(); - auto const crossAttentionMaskRequestDim0 - = static_cast(crossAttentionMaskRequest->getShape().d[0]); - auto const crossAttentionMaskRequestDim1 - = static_cast(crossAttentionMaskRequest->getShape().d[1]); - TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from contextRequests position %d chunkSize %d", - crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, position, size); - if ((position + size - 1) >= crossAttentionMaskRequestDim0) + for (int s = 0; s < llmReq->getNumSequences(); s++) { + auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); + auto const position = llmReq->getContextCurrentPosition(); + auto const size = llmReq->getContextChunkSize(); + if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr && s == 0) + { + auto memType = crossAttentionMaskRequest->getMemoryType(); + auto const crossAttentionMaskRequestDim0 + = static_cast(crossAttentionMaskRequest->getShape().d[0]); + auto const crossAttentionMaskRequestDim1 + = static_cast(crossAttentionMaskRequest->getShape().d[1]); + TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from contextRequests position %d chunkSize %d", + crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, position, size); + if ((position + size - 1) >= crossAttentionMaskRequestDim0) + { + TLLM_LOG_WARNING( + "The provided crossAttentionMask input is not complete for context phases, the last row " + "will be " + "used by default."); + } + // copy it to pinned memory if it is a cpu tensor. + if (memType == MemoryType::kCPU) + { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); + auto const copiedPosition + = std::min(crossAttentionMaskRequestDim0 - 1, static_cast(position)); + auto const copiedSize + = std::min(crossAttentionMaskRequestDim0 - copiedPosition, static_cast(size)); + SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); + SizeType64 inputMaskSize = (copiedSize * crossAttentionMaskRequestDim1); + std::memcpy( + pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, inputMaskSize); + pinnedMemPtr += inputMaskSize; + for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + { + SizeType64 tokenIdInPinnedMem + = std::min(copiedSize - 1, static_cast(tokenId - position)); + batchedCopySrcOffsets.begin()[numCopiedTokens] + = (pinnedMemPtr - primarySrcPtr) + tokenIdInPinnedMem * crossAttentionMaskRequestDim1; + batchedCopyDstOffsets.begin()[numCopiedTokens] + = numTokens * static_cast(maxEncoderInputLengthInBatch); + batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numCopiedTokens++; + numTokens++; + } + } + else + { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); + for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + { + batchedCopySrcOffsets.begin()[numCopiedTokens] + = static_cast(bufferCastOrNull(crossAttentionMaskRequest) - primarySrcPtr) + + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(tokenId)) + * crossAttentionMaskRequestDim1; + batchedCopyDstOffsets.begin()[numCopiedTokens] + = numTokens * static_cast(maxEncoderInputLengthInBatch); + batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numCopiedTokens++; + numTokens++; + } + } + } + else { + numTokens += size; TLLM_LOG_WARNING( - "The provided crossAttentionMask input is not complete for context phases, the last row " - "will be " - "used by default."); + "CrossAttentionMask is not provided for the request. Default padding attention mask will be " + "created."); } - // copy it to pinned memory if it is a cpu tensor. - if (memType == MemoryType::kCPU) + } + } + sync_check_cuda_error(stream.get()); + + for (auto const& llmReq : genRequests) + { + for (int s = 0; s < llmReq->getNumSequences(); s++) { + auto const promptLen = llmReq->mPromptLen; + auto const decodingIter = llmReq->getDecodingIter(); + auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); + if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr && s == 0) { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); - auto const copiedPosition - = std::min(crossAttentionMaskRequestDim0 - 1, static_cast(position)); - auto const copiedSize - = std::min(crossAttentionMaskRequestDim0 - copiedPosition, static_cast(size)); - SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); - SizeType64 inputMaskSize = (copiedSize * crossAttentionMaskRequestDim1); - std::memcpy( - pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, inputMaskSize); - pinnedMemPtr += inputMaskSize; - for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + auto const memType = crossAttentionMaskRequest->getMemoryType(); + auto const crossAttentionMaskRequestDim0 + = static_cast(crossAttentionMaskRequest->getShape().d[0]); + auto const crossAttentionMaskRequestDim1 + = static_cast(crossAttentionMaskRequest->getShape().d[1]); + TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from genRequests decodingIter %d", + crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, decodingIter); + if (promptLen + decodingIter - 1 >= crossAttentionMaskRequestDim0) { - SizeType64 tokenIdInPinnedMem - = std::min(copiedSize - 1, static_cast(tokenId - position)); - batchedCopySrcOffsets.begin()[numCopiedTokens] - = (pinnedMemPtr - primarySrcPtr) + tokenIdInPinnedMem * crossAttentionMaskRequestDim1; + TLLM_LOG_WARNING( + "The provided crossAttentionMask input is not complete for generation phases, the last row " + "will be " + "used by default."); + } + // copy it to pinned memory if it is a cpu tensor. + if (memType == MemoryType::kCPU) + { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); + SizeType64 copiedPosition = std::min( + crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)); + SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); + SizeType64 inputMaskSize = crossAttentionMaskRequestDim1; + std::memcpy( + pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, inputMaskSize); + pinnedMemPtr += inputMaskSize; + batchedCopySrcOffsets.begin()[numCopiedTokens] = static_cast(pinnedMemPtr - primarySrcPtr); batchedCopyDstOffsets.begin()[numCopiedTokens] = numTokens * static_cast(maxEncoderInputLengthInBatch); batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; - numCopiedTokens++; - numTokens++; } - } - else - { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); - for (SizeType32 tokenId = position; tokenId < position + size; tokenId++) + else { + TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); batchedCopySrcOffsets.begin()[numCopiedTokens] = static_cast(bufferCastOrNull(crossAttentionMaskRequest) - primarySrcPtr) - + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(tokenId)) + + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)) * crossAttentionMaskRequestDim1; batchedCopyDstOffsets.begin()[numCopiedTokens] = numTokens * static_cast(maxEncoderInputLengthInBatch); batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; - numCopiedTokens++; - numTokens++; } - } - } - else - { - numTokens += size; - TLLM_LOG_WARNING( - "CrossAttentionMask is not provided for the request. Default padding attention mask will be " - "created."); - } - } - sync_check_cuda_error(stream.get()); - - for (auto const& llmReq : genRequests) - { - auto const promptLen = llmReq->mPromptLen; - auto const decodingIter = llmReq->getDecodingIter(); - auto const& crossAttentionMaskRequest = llmReq->getCrossAttentionMask(); - if (bufferCastOrNull(crossAttentionMaskRequest) != nullptr) - { - auto const memType = crossAttentionMaskRequest->getMemoryType(); - auto const crossAttentionMaskRequestDim0 - = static_cast(crossAttentionMaskRequest->getShape().d[0]); - auto const crossAttentionMaskRequestDim1 - = static_cast(crossAttentionMaskRequest->getShape().d[1]); - TLLM_LOG_DEBUG("copyCrossAttentionMasks (shape [%d, %d]) from genRequests decodingIter %d", - crossAttentionMaskRequestDim0, crossAttentionMaskRequestDim1, decodingIter); - if (promptLen + decodingIter - 1 >= crossAttentionMaskRequestDim0) - { - TLLM_LOG_WARNING( - "The provided crossAttentionMask input is not complete for generation phases, the last row " - "will be " - "used by default."); - } - // copy it to pinned memory if it is a cpu tensor. - if (memType == MemoryType::kCPU) - { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on CPU."); - SizeType64 copiedPosition = std::min( - crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)); - SizeType64 inputMaskOffset = (copiedPosition * crossAttentionMaskRequestDim1); - SizeType64 inputMaskSize = crossAttentionMaskRequestDim1; - std::memcpy( - pinnedMemPtr, bufferCastOrNull(crossAttentionMaskRequest) + inputMaskOffset, inputMaskSize); - pinnedMemPtr += inputMaskSize; - batchedCopySrcOffsets.begin()[numCopiedTokens] = static_cast(pinnedMemPtr - primarySrcPtr); - batchedCopyDstOffsets.begin()[numCopiedTokens] - = numTokens * static_cast(maxEncoderInputLengthInBatch); - batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numCopiedTokens++; + numTokens++; } else { - TLLM_LOG_DEBUG("CrossAttentionMask tensor is on GPU."); - batchedCopySrcOffsets.begin()[numCopiedTokens] - = static_cast(bufferCastOrNull(crossAttentionMaskRequest) - primarySrcPtr) - + std::min(crossAttentionMaskRequestDim0 - 1, static_cast(promptLen + decodingIter - 1)) - * crossAttentionMaskRequestDim1; - batchedCopyDstOffsets.begin()[numCopiedTokens] - = numTokens * static_cast(maxEncoderInputLengthInBatch); - batchedCopySizes.begin()[numCopiedTokens] = crossAttentionMaskRequestDim1; + numTokens++; + TLLM_LOG_WARNING( + "CrossAttentionMask is not provided for the generation request. Full valid attentionMask will " + "be used " + "by default."); } - numCopiedTokens++; - numTokens++; - } - else - { - numTokens++; - TLLM_LOG_WARNING( - "CrossAttentionMask is not provided for the generation request. Full valid attentionMask will " - "be used " - "by default."); } } sync_check_cuda_error(stream.get()); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 2c7a8c5705d..de230cbf749 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -1730,7 +1730,7 @@ void TrtGptModelInflightBatching::postProcessRequest( LlmRequest& llmReq, std::vector const& numDroppedTokens) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto const seqSlot = llmReq.mSeqSlot.value(); + auto const seqSlot = llmReq.mSeqSlots.at(0); auto const reqBeamWidth = llmReq.mSamplingConfig.beamWidth; auto const& bufferManager = getBufferManager(); @@ -1923,7 +1923,7 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId(); auto& genRuntimeBuffers = mBuffers.at(genBufferId); (*mHandleGenerationLogits)(genLogitsIndex, scheduledRequests.generationRequests, *mDecoderBuffers[vid], - mModelConfig, mRuntime->getBufferManager(), genRuntimeBuffers->logits, *genRuntimeBuffers, vid); + mModelConfig, mRuntime->getBufferManager(), mRuntime->getStream(), genRuntimeBuffers->logits, *genRuntimeBuffers, vid); // Copy indirection output into input // TODO: Could we avoid this by modifying batchDecoder to take a vector of tensors instead? @@ -1985,14 +1985,16 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( { for (auto const& llmReq : requests) { - auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; - auto const seqSlot = llmReq->mSeqSlot.value(); - auto const copySize = static_cast(cacheIndirShape.d[2]) * reqBeamWidth; - srcOffsetsPtr[batchIdx] = seqSlot * copySize; - dstOffsetsPtr[batchIdx] = seqSlot * copySize; - copySizesPtr[batchIdx] = copySize; - maxCopySize = std::max(maxCopySize, copySize); - batchIdx++; + for (int s = 0; s < llmReq->getNumSequences(); s++) { + auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; + auto const seqSlot = llmReq->getSeqSlot(s); + auto const copySize = static_cast(cacheIndirShape.d[2]) * reqBeamWidth; + srcOffsetsPtr[batchIdx] = seqSlot * copySize; + dstOffsetsPtr[batchIdx] = seqSlot * copySize; + copySizesPtr[batchIdx] = copySize; + maxCopySize = std::max(maxCopySize, copySize); + batchIdx++; + } } } if (batchIdx != 0) @@ -2162,7 +2164,7 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu { continue; } - auto const seqSlot = llmReq->mSeqSlot.value(); + auto const seqSlot = llmReq->mSeqSlots.at(0); auto const numGeneratedTokens = llmReq->getNumDraftTokens() + 1; auto const currentNumOfTokens = llmReq->getMaxBeamNumTokens(); diff --git a/cpp/tensorrt_llm/executor/samplingConfig.cpp b/cpp/tensorrt_llm/executor/samplingConfig.cpp index e838815d7c8..c31f26bf7e2 100644 --- a/cpp/tensorrt_llm/executor/samplingConfig.cpp +++ b/cpp/tensorrt_llm/executor/samplingConfig.cpp @@ -36,7 +36,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty, OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP, - OptVec const& beamWidthArray) + OptVec const& beamWidthArray, OptFloat const& cfgScale) : mBeamWidth(checkBeamWidth(beamWidth)) , mTopK(checkTopK(topK)) , mTopP(checkTopP(topP)) @@ -56,6 +56,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF , mNumReturnSequences(checkNumReturnSequences(numReturnSequences, beamWidth)) , mMinP(checkMinP(minP)) , mBeamWidthArray(checkBeamWidthArray(beamWidthArray, beamWidth)) + , mCfgScale(checkCfgScale(cfgScale)) { updateNumReturnBeams(); } @@ -69,7 +70,7 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const && mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty && mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences - && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray; + && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray && mCfgScale == other.mCfgScale; } // Getters @@ -175,6 +176,11 @@ OptSize32 SamplingConfig::getNumReturnSequences() const return mNumReturnSequences; } +OptFloat SamplingConfig::getCfgScale() const +{ + return mCfgScale; +} + std::optional SamplingConfig::getMinP() const { return mMinP; @@ -295,6 +301,11 @@ void SamplingConfig::setBeamWidthArray(OptVec const& beamWidthArray) mBeamWidthArray = checkBeamWidthArray(beamWidthArray, std::nullopt); } +void SamplingConfig::setCfgScale(std::optional const& cfgScale) +{ + mCfgScale = checkCfgScale(cfgScale); +} + // Checkers SizeType32 SamplingConfig::checkBeamWidth(SizeType32 beamWidth) { @@ -311,6 +322,12 @@ OptFloat const& SamplingConfig::checkTopK(OptFloat const& topK) return topK; } +OptFloat const& SamplingConfig::checkCfgScale(OptFloat const& cfgScale) +{ + // TODO: implement checking the cfg scale + return cfgScale; +} + OptFloat const& SamplingConfig::checkTopP(OptFloat const& topP) { if (topP.has_value()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index 795ba065d97..2eadcf29446 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -122,13 +122,14 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod "__call__", [](HandleGenerationLogits const& self, tr::SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + tensorrt_llm::runtime::CudaStream const& stream, at::Tensor const& logits, OptionalRef genRuntimeBuffers = std::nullopt) { - self(logitsIndex, generationRequests, decoderBuffers, modelConfig, manager, tr::TorchView::of(logits), - genRuntimeBuffers); + self(logitsIndex, generationRequests, decoderBuffers, modelConfig, manager, stream, + tr::TorchView::of(logits), genRuntimeBuffers); }, py::arg("logits_index"), py::arg("generation_requests"), py::arg("decoder_buffers"), - py::arg("model_config"), py::arg("buffer_manager"), py::arg("logits"), + py::arg("model_config"), py::arg("buffer_manager"), py::arg("stream"), py::arg("logits"), py::arg("gen_runtime_buffers") = std::nullopt) .def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; }); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 9362095412b..ecd8eeac955 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -220,7 +220,7 @@ void initBindings(pybind11::module_& m) .def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming) .def_readwrite("end_id", &GenLlmReq::mEndId) .def_readwrite("pad_id", &GenLlmReq::mPadId) - .def_readwrite("seq_slot", &GenLlmReq::mSeqSlot) + .def_readwrite("seq_slots", &GenLlmReq::mSeqSlots) .def_property_readonly("return_log_probs", &GenLlmReq::returnLogProbs) .def_property_readonly("return_context_logits", &GenLlmReq::getReturnContextLogits) .def_property_readonly("return_generation_logits", &GenLlmReq::getReturnGenerationLogits) diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 933c66564f8..86d673f9889 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -401,7 +401,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) }; auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig { - assert(t.size() == 19); + assert(t.size() == 20); tr::SamplingConfig config; config.beamWidth = t[0].cast(); @@ -423,6 +423,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) config.numReturnSequences = t[16].cast(); config.minP = t[17].cast>(); config.beamWidthArray = t[18].cast>>(); + config.cfgScale = t[19].cast>(); return config; }; @@ -450,6 +451,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("num_return_sequences", &tr::SamplingConfig::numReturnSequences) .def_readwrite("min_p", &tr::SamplingConfig::minP) .def_readwrite("beam_width_array", &tr::SamplingConfig::beamWidthArray) + .def_readwrite("cfg_scale", &tr::SamplingConfig::cfgScale) .def(py::pickle(SamplingConfigGetState, SamplingConfigSetState)) .def("__eq__", &tr::SamplingConfig::operator==); diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index df5f0788813..8263e9a6077 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -72,7 +72,7 @@ void initRequestBindings(pybind11::module_& m) }; auto samplingConfigSetstate = [](py::tuple const& state) { - if (state.size() != 19) + if (state.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -94,7 +94,8 @@ void initRequestBindings(pybind11::module_& m) state[15].cast>(), // NoRepeatNgramSize state[16].cast>(), // NumReturnSequences state[17].cast>(), // MinP - state[18].cast>>() // BeamWidthArray + state[18].cast>>(), // BeamWidthArray + state[19].cast>() // CfgScale ); }; py::class_(m, "SamplingConfig") @@ -116,7 +117,7 @@ void initRequestBindings(pybind11::module_& m) std::optional const& earlyStopping, std::optional const& noRepeatNgramSize, std::optional const& numReturnSequences, std::optional const& minP, - std::optional> const& beamWidthArray) + std::optional> const& beamWidthArray, std::optional const& cfgScale) { if (randomSeed.has_value()) { @@ -137,7 +138,7 @@ void initRequestBindings(pybind11::module_& m) return std::make_unique(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, seed, temperature, minTokens, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping, noRepeatNgramSize, - numReturnSequences, minP, beamWidthArray); + numReturnSequences, minP, beamWidthArray, cfgScale); }), py::arg("beam_width") = 1, py::kw_only(), py::arg("top_k") = py::none(), py::arg("top_p") = py::none(), py::arg("top_p_min") = py::none(), py::arg("top_p_reset_ids") = py::none(), @@ -147,7 +148,7 @@ void initRequestBindings(pybind11::module_& m) py::arg("presence_penalty") = py::none(), py::arg("frequency_penalty") = py::none(), py::arg("length_penalty") = py::none(), py::arg("early_stopping") = py::none(), py::arg("no_repeat_ngram_size") = py::none(), py::arg("num_return_sequences") = py::none(), - py::arg("min_p") = py::none(), py::arg("beam_width_array") = py::none()) + py::arg("min_p") = py::none(), py::arg("beam_width_array") = py::none(), py::arg("cfg_scale") = py::none()) .def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) .def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) .def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) @@ -176,6 +177,7 @@ void initRequestBindings(pybind11::module_& m) .def_property("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) .def_property( "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) + .def_property("cfg_scale", &tle::SamplingConfig::getCfgScale, &tle::SamplingConfig::setCfgScale) .def(py::pickle(samplingConfigGetstate, samplingConfigSetstate)); auto additionalModelOutputGetstate diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index b6854b0d5e6..6db7f884f66 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -407,9 +407,9 @@ void DecoderState::disableLookahead(RequestVector const& genRequests) for (auto const& llmReq : genRequests) { - if (llmReq->mSeqSlot) + if (!llmReq->mSeqSlots.empty()) { - setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1); + setNumDecodingEngineTokens(llmReq->mSeqSlots.at(0), 1); } } diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 81f3705b42e..72cdbe5ebcf 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -65,7 +65,7 @@ void GptDecoderBatched::disableLookahead(RequestVector const& genRequests, Tenso for (auto const& llmReq : genRequests) { samplingConfigs.push_back(llmReq->mSamplingConfig); - batchSlotsRange[batchIdx] = llmReq->mSeqSlot.value(); + batchSlotsRange[batchIdx] = llmReq->mSeqSlots.at(0); batchIdx += 1; } auto const batchSize = batchIdx; diff --git a/cpp/tests/batch_manager/guidedDecoderTest.cpp b/cpp/tests/batch_manager/guidedDecoderTest.cpp index 13e368b6331..1d2b41d32f4 100644 --- a/cpp/tests/batch_manager/guidedDecoderTest.cpp +++ b/cpp/tests/batch_manager/guidedDecoderTest.cpp @@ -120,10 +120,10 @@ class GuidedDecoderTest : public ::testing::Test auto llmReq1 = std::make_shared(1, 100, std::make_shared(10), SamplingConfig(), false); texec::GuidedDecodingParams guidedDecodingParams(texec::GuidedDecodingParams::GuideType::kJSON); llmReq1->setGuidedDecodingParams(guidedDecodingParams); - llmReq1->mSeqSlot = 1; + llmReq1->mSeqSlots.push_back(1); auto llmReq2 = std::make_shared(1, 100, std::make_shared(10), SamplingConfig(), false); - llmReq2->mSeqSlot = 2; + llmReq2->mSeqSlots.push_back(2); RequestVector contextRequests{llmReq1, llmReq2}; RequestVector generationRequests{}; diff --git a/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp b/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp index a7b10c66256..35c9c326045 100644 --- a/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp @@ -54,7 +54,7 @@ TEST_F(LlmRequestTest, fromExecutorRequest) EXPECT_EQ(llmReq.getOrigPromptLen(), inputTokens.size()); EXPECT_EQ(llmReq.getMaxSentTokenLen(), inputTokens.size()); EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT); - EXPECT_FALSE(llmReq.mSeqSlot); + EXPECT_TRUE(llmReq.mSeqSlots.empty()); // No speculative decoding config, draft tokens should be empty EXPECT_EQ(llmReq.getDraftTokens()->size(), 0); EXPECT_FALSE(llmReq.getEmbeddingBias().has_value()); @@ -488,7 +488,7 @@ TEST_F(LlmRequestTest, testCreateRequests) EXPECT_EQ(childReq1->getState(), llmReq.getState()); EXPECT_EQ(childReq1->mSamplingConfig.randomSeed.value(), std::vector{8}); EXPECT_EQ(llmReq2.mSamplingConfig.randomSeed.value(), std::vector{7}); - EXPECT_FALSE(childReq1->mSeqSlot); + EXPECT_TRUE(childReq1->mSeqSlots.empty()); } { diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 678a86441e9..2077dff0d5f 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -639,6 +639,7 @@ def generate( "num_return_sequences", "min_p", "beam_width_array", + "cfg_scale", ] rename_params = {"num_beams": "beam_width", "random_seed": "seed"} sampling_params = { From feba85fca599be562e62f826bf0456f07d695646 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Mon, 30 Jun 2025 12:57:35 +0000 Subject: [PATCH 2/9] cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp: apply CFG during context phase too (first frame of logits) --- .../batch_manager/handleContextLogits.cpp | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index 9c6ba9bcd41..0b13537c587 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -39,6 +39,48 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; namespace { +template +static void addAndScale(ITensor* cond, ITensor* uncond, SizeType32 size, float cfgScale) { + auto* condPtr = tensorrt_llm::runtime::bufferCast(*cond); + auto* uncondPtr = tensorrt_llm::runtime::bufferCast(*uncond); + for (SizeType32 i = 0; i < size; i++) { + condPtr[i] = condPtr[i] * (T)cfgScale + uncondPtr[i] * (T)(1 - cfgScale); + } +} + +static void applyCfgCpu(BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, + TensorPtr logitsView, TensorPtr uncondLogitsView, + float cfgScale, SizeType32 vocabOffset, SizeType32 vocabSize) +{ + // this is a temporary testing implementation where CFG is applied on CPU. + // it needs to become a kernel implemented with cublas + auto logitsVocabView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] + auto uncondLogitsVocabView = ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] + + auto logitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), logitsVocabView->getDataType()); + auto uncondLogitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), uncondLogitsVocabView->getDataType()); + ITensor* logitsCpuPtr = logitsCpu.get(); + ITensor* uncondLogitsCpuPtr = uncondLogitsCpu.get(); + + logitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); + manager.copy(*logitsVocabView, *logitsCpuPtr); + uncondLogitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); + manager.copy(*uncondLogitsVocabView, *uncondLogitsCpuPtr); + stream.synchronize(); + + if (logitsVocabView->getDataType() == nvinfer1::DataType::kFLOAT) + { + addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); + } + else if (logitsVocabView->getDataType() == nvinfer1::DataType::kHALF) + { + addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); + } + manager.copy(*logitsCpuPtr, *logitsVocabView); + stream.synchronize(); +} + + //! @brief Copy logits from context phase to beginning of generation logits. //! @details Usually, this concerns logits of 1 token. In speculative decoding this concerns draftLen + 1 tokens. void copyLastContextLogits(TensorPtr const& contextLogits, LlmRequest& llmReq, BufferManager const& bufferManager) @@ -121,6 +163,15 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, logitsIndex += numContextLogits + draftLength; TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) + + float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); + SizeType32 vocabOffset = 0; + auto vocabSizes = modelConfig.getVocabSizes(); + for (SizeType32 i = 0; i < vocabId; ++i) + { + vocabOffset += vocabSizes[i]; + } + applyCfgCpu(manager, stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); } auto const seqSlot = llmReq->mSeqSlots.at(0); From 38486306fcb69e1ad7efdbda57dec04bd7204643 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Mon, 30 Jun 2025 17:26:48 +0000 Subject: [PATCH 3/9] CFG in T5TTS: introduce acoustic token which expands to zeros for unconditional generation Signed-off-by: Viacheslav Klimkov --- cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp | 6 +++--- examples/models/contrib/t5tts/convert_checkpoint.py | 13 ++++++------- tensorrt_llm/models/t5tts/model.py | 4 +++- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 1e685c28777..452eab660ef 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -546,9 +546,9 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request auto const& origTokens = llmReq->getTokens(0); std::vector dummyTokens; if (!is_conditional) { - // TODO: need a special token for unconditional input, which is expanded - // to all zeros - dummyTokens.assign(origTokens.size(), 0); + // TODO: that is special token added in "convert_checkpoint", + // that is expanded to all zeros. Should be configurable. + dummyTokens.assign(origTokens.size(), 16192); } auto const& reqTokens = is_conditional ? origTokens : dummyTokens; diff --git a/examples/models/contrib/t5tts/convert_checkpoint.py b/examples/models/contrib/t5tts/convert_checkpoint.py index dc133d26ca6..abd48a5d431 100644 --- a/examples/models/contrib/t5tts/convert_checkpoint.py +++ b/examples/models/contrib/t5tts/convert_checkpoint.py @@ -277,17 +277,16 @@ def convert_t5tts_decoder( prefix: str = "decoder", ): weights = {} - #weights['embedding.vocab_embedding.weight'] = model_dict['final_proj.weight'].clone().contiguous() - - weights['lm_head.weight'] = model_dict['final_proj.weight'].clone( - ).contiguous() + weights['lm_head.weight'] = model_dict['final_proj.weight'].clone().contiguous() weights['embedding.position_embedding.weight'] = model_dict[ f'{prefix}.position_embeddings.weight'].contiguous() - weights[f'embedding.vocab_embedding.weight'] = torch.cat( - [model_dict[f'audio_embeddings.{i}.weight'] for i in range(len(config.vocab_sizes))], dim=0 - ).contiguous() + embs = [model_dict[f'audio_embeddings.{i}.weight'] for i in range(len(config.vocab_sizes))] + embs.append(torch.zeros(1, 768, dtype=embs[0].dtype, device=embs[0].device)) + # embeddings have shape (2024 x 768) * 8, pad them adding extra entry in vocab which expands to zeros + # we dont change the config, instead we change usage of the embedding dim in the model definition + weights[f'embedding.vocab_embedding.weight'] = torch.cat(embs, dim=0).contiguous() num_layers = config.n_layer for i in range(num_layers): diff --git a/tensorrt_llm/models/t5tts/model.py b/tensorrt_llm/models/t5tts/model.py index ce10896a91d..7dbc0cf5d73 100644 --- a/tensorrt_llm/models/t5tts/model.py +++ b/tensorrt_llm/models/t5tts/model.py @@ -1019,7 +1019,9 @@ def __init__(self, config: PretrainedConfig): if self.mapping.is_first_pp_rank(): self.embedding = EncoderDecoderEmbedding( - self.config.vocab_size, + # TODO: vocab is expanded to incorporate service token used for unconditional generation + # during CFG + self.config.vocab_size + 1, self.num_vocabs, self.config.hidden_size, max_position_embeddings=self.config.max_position_embeddings, From d9dc161a24bc9e0b825ba3e1168e7df23b05d096 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 1 Jul 2025 10:17:36 +0000 Subject: [PATCH 4/9] runtimeBuffers: remove print, replace hardcode with vocab size from config Signed-off-by: Viacheslav Klimkov --- cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index 452eab660ef..f076b7c82a5 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -431,7 +431,6 @@ void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, Reques encoderBuffers->setBufferSizes(contextRequests, genRequests); } - TLLM_LOG_WARNING(">>>>>setBufferSizes: numContextRequests %d; numGenRequests %d", numContextRequests, numGenRequests); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -546,9 +545,9 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request auto const& origTokens = llmReq->getTokens(0); std::vector dummyTokens; if (!is_conditional) { - // TODO: that is special token added in "convert_checkpoint", - // that is expanded to all zeros. Should be configurable. - dummyTokens.assign(origTokens.size(), 16192); + // that is special token added in "convert_checkpoint", + // that is expanded to all zeros + dummyTokens.assign(origTokens.size(), modelConfig.getVocabSize()); } auto const& reqTokens = is_conditional ? origTokens : dummyTokens; From 0ff0aad7063b87c1f617b19b98f4a9af9829b2a6 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 1 Jul 2025 10:18:21 +0000 Subject: [PATCH 5/9] switch to using cublas to compute logits after cfg Signed-off-by: Viacheslav Klimkov --- .../batch_manager/handleContextLogits.cpp | 45 +------------- .../batch_manager/handleGenerationLogits.cpp | 45 +------------- cpp/tensorrt_llm/kernels/cfgKernels.h | 60 +++++++++++++++++++ 3 files changed, 64 insertions(+), 86 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/cfgKernels.h diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index 0b13537c587..51e1ab5753a 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -25,6 +25,7 @@ #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" +#include "tensorrt_llm/kernels/cfgKernels.h" namespace tru = tensorrt_llm::runtime::utils; @@ -39,48 +40,6 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; namespace { -template -static void addAndScale(ITensor* cond, ITensor* uncond, SizeType32 size, float cfgScale) { - auto* condPtr = tensorrt_llm::runtime::bufferCast(*cond); - auto* uncondPtr = tensorrt_llm::runtime::bufferCast(*uncond); - for (SizeType32 i = 0; i < size; i++) { - condPtr[i] = condPtr[i] * (T)cfgScale + uncondPtr[i] * (T)(1 - cfgScale); - } -} - -static void applyCfgCpu(BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, - TensorPtr logitsView, TensorPtr uncondLogitsView, - float cfgScale, SizeType32 vocabOffset, SizeType32 vocabSize) -{ - // this is a temporary testing implementation where CFG is applied on CPU. - // it needs to become a kernel implemented with cublas - auto logitsVocabView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] - auto uncondLogitsVocabView = ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] - - auto logitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), logitsVocabView->getDataType()); - auto uncondLogitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), uncondLogitsVocabView->getDataType()); - ITensor* logitsCpuPtr = logitsCpu.get(); - ITensor* uncondLogitsCpuPtr = uncondLogitsCpu.get(); - - logitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); - manager.copy(*logitsVocabView, *logitsCpuPtr); - uncondLogitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); - manager.copy(*uncondLogitsVocabView, *uncondLogitsCpuPtr); - stream.synchronize(); - - if (logitsVocabView->getDataType() == nvinfer1::DataType::kFLOAT) - { - addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); - } - else if (logitsVocabView->getDataType() == nvinfer1::DataType::kHALF) - { - addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); - } - manager.copy(*logitsCpuPtr, *logitsVocabView); - stream.synchronize(); -} - - //! @brief Copy logits from context phase to beginning of generation logits. //! @details Usually, this concerns logits of 1 token. In speculative decoding this concerns draftLen + 1 tokens. void copyLastContextLogits(TensorPtr const& contextLogits, LlmRequest& llmReq, BufferManager const& bufferManager) @@ -171,7 +130,7 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, { vocabOffset += vocabSizes[i]; } - applyCfgCpu(manager, stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); + tensorrt_llm::kernels::invokeCfg(stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); } auto const seqSlot = llmReq->mSeqSlots.at(0); diff --git a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp index 18a7f0b1a91..2338ace6478 100644 --- a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp @@ -25,6 +25,7 @@ #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" +#include "tensorrt_llm/kernels/cfgKernels.h" namespace tru = tensorrt_llm::runtime::utils; @@ -39,48 +40,6 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; namespace { -template -static void addAndScale(ITensor* cond, ITensor* uncond, SizeType32 size, float cfgScale) { - auto* condPtr = tensorrt_llm::runtime::bufferCast(*cond); - auto* uncondPtr = tensorrt_llm::runtime::bufferCast(*uncond); - for (SizeType32 i = 0; i < size; i++) { - condPtr[i] = condPtr[i] * (T)cfgScale + uncondPtr[i] * (T)(1 - cfgScale); - } -} - -static void applyCfgCpu(BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, - TensorPtr logitsView, TensorPtr uncondLogitsView, - float cfgScale, SizeType32 vocabOffset, SizeType32 vocabSize) -{ - // this is a temporary testing implementation where CFG is applied on CPU. - // it needs to become a kernel implemented with cublas - auto logitsVocabView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] - auto uncondLogitsVocabView = ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); // [vocabSize,] - - auto logitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), logitsVocabView->getDataType()); - auto uncondLogitsCpu = manager.cpu(ITensor::makeShape({vocabSize}), uncondLogitsVocabView->getDataType()); - ITensor* logitsCpuPtr = logitsCpu.get(); - ITensor* uncondLogitsCpuPtr = uncondLogitsCpu.get(); - - logitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); - manager.copy(*logitsVocabView, *logitsCpuPtr); - uncondLogitsCpuPtr->reshape(ITensor::makeShape({vocabSize})); - manager.copy(*uncondLogitsVocabView, *uncondLogitsCpuPtr); - stream.synchronize(); - - if (logitsVocabView->getDataType() == nvinfer1::DataType::kFLOAT) - { - addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); - } - else if (logitsVocabView->getDataType() == nvinfer1::DataType::kHALF) - { - addAndScale(logitsCpuPtr, uncondLogitsCpuPtr, vocabSize, cfgScale); - } - manager.copy(*logitsCpuPtr, *logitsVocabView); - stream.synchronize(); -} - - //! @brief Copy logits from generation phase under streaming mode. void copyStreamingGenerationLogits(BufferManager const& bufferManager, LlmRequest& llmReq) { @@ -157,7 +116,7 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex, numLogits); // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); - applyCfgCpu(manager, stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); + tensorrt_llm::kernels::invokeCfg(stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); } auto& decoderLogits = decoderBuffers.logits.at(seqSlot); diff --git a/cpp/tensorrt_llm/kernels/cfgKernels.h b/cpp/tensorrt_llm/kernels/cfgKernels.h new file mode 100644 index 00000000000..6c28dc91ffd --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cfgKernels.h @@ -0,0 +1,60 @@ +#pragma once + +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include +#include + +namespace tensorrt_llm::kernels +{ + +//! Apply classifier-free guidance (CFG) on GPU in-place using cuBLAS. +//! It overwrites `logitsView` with: logits = cfgScale * logits + (1 - cfgScale) * uncondLogits +//! Only the slice [vocabOffset, vocabOffset + vocabSize) is modified. +inline void invokeCfg(tensorrt_llm::runtime::CudaStream const& stream, + runtime::ITensor::SharedPtr logitsView, runtime::ITensor::SharedPtr uncondLogitsView, + float cfgScale, runtime::SizeType32 vocabOffset, runtime::SizeType32 vocabSize) +{ + using TensorPtr = runtime::ITensor::SharedPtr; + + // Restrict to current vocabulary segment. + TensorPtr logitsVocabView = runtime::ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); + TensorPtr uncondLogitsVocabView = runtime::ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); + + void* condPtr = logitsVocabView->data(); + void const* uncondPtr = uncondLogitsVocabView->data(); + + cudaDataType_t dataType{}; + switch (logitsVocabView->getDataType()) + { + case nvinfer1::DataType::kFLOAT: dataType = CUDA_R_32F; break; + case nvinfer1::DataType::kHALF: dataType = CUDA_R_16F; break; + default: TLLM_THROW("Unsupported data type for CFG"); + } + + auto handlePtr = getCublasHandle(); + auto& handle = *handlePtr; + tensorrt_llm::common::check_cuda_error(cublasSetStream(handle, stream.get())); + + int n = static_cast(vocabSize); + int inc = 1; + + // Use float for the scaling factors and always accumulate in FP32 to + // satisfy cuBLAS requirements (FP16 vectors must use FP32 compute/alpha). + float alphaF = cfgScale; // Scaling factor in FP32 + float axpyF = 1.0f - cfgScale; // (1 - cfgScale) in FP32 + + tensorrt_llm::common::check_cuda_error( + cublasScalEx(handle, n, &alphaF, CUDA_R_32F, // alpha + condPtr, dataType, // x and its type + inc, CUDA_R_32F)); // increments + compute type + + tensorrt_llm::common::check_cuda_error( + cublasAxpyEx(handle, n, &axpyF, CUDA_R_32F, // alpha + uncondPtr, dataType, inc, // x + condPtr, dataType, inc, // y + CUDA_R_32F)); // compute type +} + +} // namespace tensorrt_llm::kernels \ No newline at end of file From 3829d855b752719ffdb4e8e2a97aea9b332bd1b6 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Thu, 3 Jul 2025 14:41:04 +0000 Subject: [PATCH 6/9] Fixes to freeing cache and slots for case when llm request contains multiple sequences Signed-off-by: Viacheslav Klimkov --- .../tensorrt_llm/batch_manager/kvCacheUtils.h | 13 +++++++ .../batch_manager/cacheFormatter.cpp | 2 +- .../batch_manager/capacityScheduler.cpp | 7 +++- .../batch_manager/dataTransceiverImpl.cpp | 9 ++++- .../trtGptModelInflightBatching.cpp | 9 ++++- .../utils/inflightBatchingUtils.cpp | 38 ++++++++++--------- 6 files changed, 53 insertions(+), 25 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 476dce90779..c6dfc476276 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -39,6 +39,19 @@ class BlockRange { } + BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest const& llmRequest, SizeType32 beam, SizeType32 poolIdx = 0) + : mManager(&cacheManager) + , mPool(cacheManager.getBlockManager().getPrimaryPool(poolIdx)) + { + std::vector blockIds; + for (int i = 0; i < llmRequest.getNumSequences(); i++) { + auto const requestId = llmRequest.getSeqSlotId(i); + const auto& thisBlockIds = cacheManager.getSequence(requestId).getCacheBlockIds().at(beam); + blockIds.insert(blockIds.end(), thisBlockIds.begin(), thisBlockIds.end()); + } + mBlockIds = std::move(blockIds); + } + BlockRange(BaseKVCacheManager const& cacheManager, std::vector blockIds, SizeType32 poolIdx = 0) : mManager(&cacheManager) , mPool(cacheManager.getBlockManager().getPrimaryPool(poolIdx)) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index eb787936e50..2386056d746 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -49,7 +49,7 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest, constexpr SizeType32 beam{0}; auto& blockManager = mCacheManager->getBlockManager(); size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size(); - auto blockRange = BlockRange(*mCacheManager, llmRequest.mRequestId, beam); + auto blockRange = BlockRange(*mCacheManager, llmRequest, beam); if (requestBlockNum < blockRange.size() && requestBlockNum > 0) { // handle block reuse, the prefix blocks are reused diff --git a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp index 023a7c30e18..59f8b709595 100644 --- a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp @@ -390,9 +390,12 @@ std::tuple MaxUtilizationScheduler::operator()( // If we can't allocate a started request, we need to start freeing started requests // from the end of the vector and try again // Here we simulate freeing the kvCache blocks associated with that sequence - kvCacheManager.schedulingRemoveSequence((*lastStartedReqIt)->mRequestId); + for (int i = 0; i < (*lastStartedReqIt)->getNumSequences(); i++) { + auto const requestId = (*lastStartedReqIt)->getSeqSlotId(i); + kvCacheManager.schedulingRemoveSequence(requestId); + TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> pause", requestId); + } pausedRequests.emplace_back(*lastStartedReqIt); - TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> pause", (*lastStartedReqIt)->mRequestId); reqItEnd = std::next(lastStartedReqIt).base(); } else diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index 3597626c84f..0b049796311 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -142,8 +142,13 @@ void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) if (cacheFormatter != nullptr) { auto* cacheManager = cacheFormatter->getCacheManager(); - auto blockRange = kv_cache_manager::BlockRange( - *cacheManager, cacheManager->getNewlyAllocatedBlockIds(llmRequest.mRequestId)); + std::vector blockIds; + for (int i = 0; i < llmRequest.getNumSequences(); i++) { + auto const requestId = llmRequest.getSeqSlotId(i); + const auto& thisBlockIds = cacheManager->getNewlyAllocatedBlockIds(requestId); + blockIds.insert(blockIds.end(), thisBlockIds.begin(), thisBlockIds.end()); + } + auto blockRange = kv_cache_manager::BlockRange(*cacheManager, blockIds); requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index de230cbf749..051e980be12 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -806,7 +806,9 @@ void TrtGptModelInflightBatching::forwardSync() mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration."); mCacheTransceiver->respondAndSendAsync(llmReq.get()); } - mSeqSlotManager->freeSequenceSlot(llmReq->mRequestId); + for (int i = 0; i < llmReq->getNumSequences(); i++) { + mSeqSlotManager->freeSequenceSlot(llmReq->getSeqSlotId(i)); + } } } } @@ -2283,7 +2285,10 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu // At this point, KV cache rows are already gathered and moved to the right location. // We can safely rewind (draft - accepted) tokens - mKvCacheManager->rewindKVCache(llmReq->mRequestId, rewindLength); + for (int i = 0; i < llmReq->getNumSequences(); i++) { + auto const requestId = llmReq->getSeqSlotId(i); + mKvCacheManager->rewindKVCache(requestId, rewindLength); + } } } diff --git a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp index 6aa7d891357..43dfc8cff5a 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp @@ -220,24 +220,26 @@ void terminateRequest(SequenceSlotManager& seqSlotManager, LlmRequest& llmReq, S { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // If a sequence slot is associated with this request id, free it - seqSlotManager.freeSequenceSlot(llmReq.mRequestId); - // Remove the sequence from kvCacheManager - auto const requestId = llmReq.mRequestId; - if (kvCacheManager) - { - kvCacheManager->removeSequence(requestId, llmReq); - } - if (crossKvCacheManager) - { - crossKvCacheManager->removeSequence(requestId, llmReq); - } - if (pause && !llmReq.isGenerationCompleteState()) - { - llmReq.pause(maxInputLen); - } - else - { - TLLM_LOG_DEBUG("terminated: request ID %lu, paused: %d", requestId, pause); + for (int i = 0; i < llmReq.getNumSequences(); i++) { + auto const requestId = llmReq.getSeqSlotId(i); + seqSlotManager.freeSequenceSlot(requestId); + // Remove the sequence from kvCacheManager + if (kvCacheManager) + { + kvCacheManager->removeSequence(requestId, llmReq); + } + if (crossKvCacheManager) + { + crossKvCacheManager->removeSequence(requestId, llmReq); + } + if (pause && !llmReq.isGenerationCompleteState()) + { + llmReq.pause(maxInputLen); + } + else + { + TLLM_LOG_DEBUG("terminated: request ID %lu, paused: %d", requestId, pause); + } } if (peftCacheManager) From 3100c283730bcf81cecceb24adb577f8abd69ce5 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Thu, 3 Jul 2025 15:08:49 +0000 Subject: [PATCH 7/9] Adjust gptManagerBenchmark to work with decoder models with cross attention, where encoder features are provided explicitely Signed-off-by: Viacheslav Klimkov --- benchmarks/cpp/gptManagerBenchmark.cpp | 152 +++++++++++++++--- benchmarks/cpp/utils/utils.cpp | 30 +++- benchmarks/cpp/utils/utils.h | 5 + .../models/contrib/t5tts/prepare_benchmark.py | 68 +++++--- 4 files changed, 206 insertions(+), 49 deletions(-) diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index bc8a2eb066f..7d1c5f022c1 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -40,6 +40,8 @@ #include #include #include +#include +#include using namespace tensorrt_llm::batch_manager; using namespace tensorrt_llm::runtime; @@ -53,6 +55,88 @@ namespace using TensorPtr = ITensor::SharedPtr; +// Add function to read model dtype from config +std::string getModelDtype(std::optional const& engineDir, texec::ModelType modelType) +{ + if (!engineDir.has_value()) + { + return "float16"; // default fallback + } + + std::filesystem::path configPath = engineDir.value() / "config.json"; + if (!std::filesystem::exists(configPath)) + { + TLLM_LOG_WARNING("Config file not found at %s, using default dtype float16", configPath.string().c_str()); + return "float32"; + } + + try + { + std::ifstream configFile(configPath); + nlohmann::json config; + configFile >> config; + + std::string dtype = "float32"; // default + + // Check if this is an engine config or checkpoint config + if (config.contains("pretrained_config")) + { + // Engine format + if (config["pretrained_config"].contains("dtype")) + { + dtype = config["pretrained_config"]["dtype"].get(); + } + } + TLLM_LOG_INFO("Detected model dtype: %s", dtype.c_str()); + return dtype; + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING("Failed to read dtype from config: %s, using default float16", e.what()); + return "float32"; + } +} + +// Add function to cast inputFeat tensor from fp32 to fp16 on CPU +texec::Tensor castInputFeatHalf(texec::Tensor const& inputFeat, std::string const& modelDtype) +{ + auto currentDtype = inputFeat.getDataType(); + auto memoryType = inputFeat.getMemoryType(); + + // Only cast CPU tensors from fp32 to fp16 + if (currentDtype != texec::DataType::kFP32 || memoryType != texec::MemoryType::kCPU) + { + TLLM_LOG_WARNING("InputFeat tensor is not fp32 or not on CPU, skipping cast"); + return inputFeat; + } + + try + { + auto shape = inputFeat.getShape(); + auto numElements = inputFeat.getSize(); + + // Create new fp16 tensor on CPU + texec::Tensor castedTensor = texec::Tensor::cpu(texec::DataType::kFP16, shape); + + // Cast data from fp32 to fp16 + auto const* sourceData = static_cast(inputFeat.getData()); + auto* targetData = static_cast(castedTensor.getData()); + + for (size_t i = 0; i < numElements; ++i) + { + targetData[i] = static_cast(sourceData[i]); + } + + TLLM_LOG_DEBUG("Casted inputFeat tensor from fp32 to fp16"); + return castedTensor; + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING("Failed to cast inputFeat tensor to fp16: %s, keeping original", e.what()); + return inputFeat; + } +} + class LoraLib { public: @@ -810,16 +894,19 @@ class ExecutorServer namespace { -texec::Request makeExecutorRequest(texec::VecTokens &inputTokenIds, int32_t outputLen, SizeType32 const& beamWidth, +texec::Request makeExecutorRequest(texec::VecTokens inputTokenIds, SizeType32 outputLen, SizeType32 const& beamWidth, std::optional const& eosId, std::optional const& padId, SizeType32 num_vocabs = 1, bool streaming = false, bool const& returnContextLogits = false, bool const& returnGenerationLogits = false, std::optional const& loraConfig = std::nullopt, std::optional const& lookaheadConfig = std::nullopt, std::optional encoderInputTokenIds = std::nullopt, - std::optional temperature = std::nullopt) + std::optional encoderFeatures = std::nullopt, + std::optional encoderOutLen = std::nullopt, + std::optional temperature = std::nullopt, std::optional cfgScale = std::nullopt) { auto samplingConfig = texec::SamplingConfig{beamWidth}; samplingConfig.setTemperature(temperature); + samplingConfig.setCfgScale(cfgScale); auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; auto request = texec::Request(inputTokenIds, outputLen, streaming, samplingConfig, outputConfig, eosId, padId, std::nullopt, // positionIds @@ -834,7 +921,14 @@ texec::Request makeExecutorRequest(texec::VecTokens &inputTokenIds, int32_t outp std::nullopt, // kvCacheRetentionConfig std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor - encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt); + encoderInputTokenIds.has_value() && encoderInputTokenIds.value().size() > 0 ? encoderInputTokenIds : std::nullopt, + std::nullopt, // client id + false, // returnAllGeneratedTokens + tensorrt_llm::executor::Request::kDefaultPriority, // priority + tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION, // type + std::nullopt, // ContextPhaseParams + encoderFeatures.has_value() && encoderFeatures.value().getSize() > 0 ? encoderFeatures : std::nullopt, + encoderOutLen); if (num_vocabs > 1) { request.setNumVocabs(num_vocabs); } @@ -853,10 +947,25 @@ void benchmarkExecutor(std::optional const& decoderEngine auto const& world = tensorrt_llm::mpi::MpiComm::world(); auto worldRank = world.getRank(); + // Determine model dtype from config + std::string modelDtype = "float32"; // default + if (decoderEngineDir.has_value()) { + modelDtype = getModelDtype(decoderEngineDir, executorModelType); + } + // Load dataset auto samples = parseWorkloadJson(datasetPath, maxNumSamples, maxPromptLen); auto const numSamples = samples.size(); + // Cast inputFeat tensors from fp32 to fp16 if model is fp16 + for (auto& sample : samples) + { + if (sample.inputFeat.getSize() > 0) + { + sample.inputFeat = castInputFeatHalf(sample.inputFeat, modelDtype); + } + } + auto recorder = std::make_shared(opCsvFile, benchmarkParams.streaming, beamWidth, responsesJsonFile); int32_t decoderStartTokenId = 0; std::shared_ptr executorServer; @@ -955,20 +1064,24 @@ void benchmarkExecutor(std::optional const& decoderEngine std::vector requests; for (auto i = 0; i < warmUp; ++i) { - if (executorModelType == texec::ModelType::kENCODER_DECODER) + if (executorModelType == texec::ModelType::kENCODER_DECODER || samples[0].inputIds.empty()) { if (samples[0].contextIds.empty()) { samples[0].contextIds.push_back(decoderStartTokenId); } requests.emplace_back(makeExecutorRequest(samples[0].contextIds, samples[0].outputLen, beamWidth, eosId, padId, num_vocabs, benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, std::nullopt, - benchmarkParams.requestLookaheadConfig, samples[0].inputIds)); + benchmarkParams.requestLookaheadConfig, samples[0].inputIds, samples[0].inputFeat, samples[0].inputLen, + benchmarkParams.temperature, benchmarkParams.cfgScale)); } else { requests.emplace_back(makeExecutorRequest(samples[0].inputIds, samples[0].outputLen, beamWidth, eosId, padId, num_vocabs, benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, std::nullopt, - benchmarkParams.requestLookaheadConfig, std::nullopt, benchmarkParams.temperature)); + benchmarkParams.requestLookaheadConfig, std::nullopt, + std::nullopt, + std::nullopt, + benchmarkParams.temperature, benchmarkParams.cfgScale)); } } executorServer->enqueue(std::move(requests), true); @@ -990,20 +1103,22 @@ void benchmarkExecutor(std::optional const& decoderEngine { loraConfig = texec::LoraConfig(samples[i].taskId); } - if (executorModelType == texec::ModelType::kENCODER_DECODER) + if (executorModelType == texec::ModelType::kENCODER_DECODER || samples[i].inputIds.empty()) { if (samples[i].contextIds.empty()) { samples[i].contextIds.push_back(decoderStartTokenId); } requests.emplace_back(makeExecutorRequest(samples[i].contextIds, samples[i].outputLen, beamWidth, eosId, padId, num_vocabs, benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig, - benchmarkParams.requestLookaheadConfig, samples[i].inputIds)); + benchmarkParams.requestLookaheadConfig, samples[i].inputIds, samples[i].inputFeat, samples[i].inputLen, + benchmarkParams.temperature, benchmarkParams.cfgScale)); } else { requests.emplace_back(makeExecutorRequest(samples[i].inputIds, samples[i].outputLen, beamWidth, eosId, padId, num_vocabs, benchmarkParams.streaming, returnContextLogits, returnGenerationLogits, loraConfig, - benchmarkParams.requestLookaheadConfig, std::nullopt, benchmarkParams.temperature)); + benchmarkParams.requestLookaheadConfig, std::nullopt, std::nullopt, std::nullopt, + benchmarkParams.temperature, benchmarkParams.cfgScale)); } } @@ -1160,6 +1275,7 @@ int main(int argc, char* argv[]) "Minimum token probability threshold for typical acceptance. Enables typical acceptance in Eagle", cxxopts::value()); options.add_options()("temperature", "Sampling temperature for each request", cxxopts::value()); + options.add_options()("cfg_scale", "Scale of classifier-free guidance (CFG) for each request", cxxopts::value()); options.add_options()( "eagle_use_dynamic_tree", "Whether to use Eagle-2", cxxopts::value()->default_value("false")); options.add_options()("eagle_dynamic_tree_max_top_k", @@ -1280,18 +1396,10 @@ int main(int argc, char* argv[]) { benchmarkParams.freeGpuMemoryFraction = result["kv_cache_free_gpu_mem_fraction"].as(); } - // Argument: K-V Cache Cross Attention Fraction. Only applicable to enc-dec models. - if (result.count("encoder_engine_dir") && result.count("decoder_engine_dir")) + // Argument: K-V Cache Cross Attention Fraction. Only applicable to models with xattn + if (result.count("cross_kv_cache_fraction")) { - if (result.count("cross_kv_cache_fraction")) - { - benchmarkParams.crossKvCacheFraction = result["cross_kv_cache_fraction"].as(); - } - else - { - benchmarkParams.crossKvCacheFraction - = 0.5f; // default value if not set. but non enc-dec should not even have this param set - } + benchmarkParams.crossKvCacheFraction = result["cross_kv_cache_fraction"].as(); } // Argument: Enable dynamic tuning of batch size @@ -1407,6 +1515,10 @@ int main(int argc, char* argv[]) { benchmarkParams.temperature = result["temperature"].as(); } + if (result.count("cfg_scale")) + { + benchmarkParams.cfgScale = result["cfg_scale"].as(); + } if (result.count("executor_lookahead_config")) { diff --git a/benchmarks/cpp/utils/utils.cpp b/benchmarks/cpp/utils/utils.cpp index c1cdab384c2..00839124111 100644 --- a/benchmarks/cpp/utils/utils.cpp +++ b/benchmarks/cpp/utils/utils.cpp @@ -22,6 +22,9 @@ #include #include +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include namespace tensorrt_llm::benchmark { @@ -78,16 +81,19 @@ Samples parseWorkloadJson( auto constexpr ignoreComments = true; TLLM_CHECK_WITH_INFO(std::filesystem::exists(datasetPath), "File does not exist: %s", datasetPath.c_str()); std::ifstream jsonStream(datasetPath); - auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ignoreComments); - Samples samples; + auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ignoreComments); for (auto const& sample : json["samples"]) { if (samples.size() >= maxNumSamples) break; int32_t taskId = sample.count("task_id") ? sample["task_id"].template get() : -1; - auto input_ids(sample["input_ids"].template get>()); + int32_t inputLen = sample["input_len"]; + std::vector input_ids; + if (sample.count("input_ids")) { + input_ids = sample["input_ids"].template get>(); + } if (maxPromptLen && (input_ids.size() > maxPromptLen.value())) { input_ids.resize(maxPromptLen.value()); @@ -97,7 +103,23 @@ Samples parseWorkloadJson( { context_ids = sample["context_ids"].template get>(); } - samples.emplace_back(Sample{std::move(input_ids), std::move(context_ids), sample["output_len"], taskId}); + texec::Tensor inputFeat; + if (sample.count("input_feat")) + { + auto inputFeatVec = sample["input_feat"].template get>(); + + if (!inputFeatVec.empty()) + { + TLLM_CHECK_WITH_INFO(inputFeatVec.size() % static_cast(inputLen) == 0, + "input_feat size %zu is not divisible by input_len %d", inputFeatVec.size(), inputLen); + + int32_t hiddenDim = static_cast(inputFeatVec.size() / inputLen); + inputFeat = texec::Tensor::of(inputFeatVec.data(), {inputLen, hiddenDim}); + } + } + + samples.emplace_back(Sample{std::move(input_ids), std::move(context_ids), inputFeat, inputLen, + sample["output_len"], taskId}); } if (samples.size() < maxNumSamples) diff --git a/benchmarks/cpp/utils/utils.h b/benchmarks/cpp/utils/utils.h index ecc0fb77a93..8f0409db76b 100644 --- a/benchmarks/cpp/utils/utils.h +++ b/benchmarks/cpp/utils/utils.h @@ -17,6 +17,8 @@ */ #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/tensor.h" +#include "tensorrt_llm/runtime/iTensor.h" #include #include @@ -87,6 +89,7 @@ struct BenchmarkParams std::optional eagleConfig; std::optional temperature; + std::optional cfgScale; std::optional executorLookaheadConfig; std::optional requestLookaheadConfig; @@ -225,6 +228,8 @@ struct Sample { std::vector inputIds; std::vector contextIds; + texec::Tensor inputFeat; + int32_t inputLen; int32_t outputLen; int32_t taskId; }; diff --git a/examples/models/contrib/t5tts/prepare_benchmark.py b/examples/models/contrib/t5tts/prepare_benchmark.py index 2fc6fd088d3..265a0035b8d 100644 --- a/examples/models/contrib/t5tts/prepare_benchmark.py +++ b/examples/models/contrib/t5tts/prepare_benchmark.py @@ -5,13 +5,20 @@ from pathlib import Path def generate_samples( - num_samples, text_vocab_size, audio_vocab_size, - input_mean, input_std, input_min, input_max, - context_mean, context_std, context_min, context_max, - output_mean, output_std, output_min, output_max, + num_samples, audio_vocab_size, + text_stats, + context_stats, + output_stats, output_file, + text_vocab_size=None, + text_emb_dim=None, num_vocabs=8): + + input_mean, input_std, input_min, input_max = text_stats + context_mean, context_std, context_min, context_max = context_stats + output_mean, output_std, output_min, output_max = output_stats + # Create metadata metadata = { "workload_type": "token-norm-dist", @@ -34,9 +41,27 @@ def generate_samples( context_len = min(max(context_min, int(np.random.normal(context_mean, context_std))), context_max) output_len = min(max(output_min, int(np.random.normal(output_mean, output_std))), output_max) + + sample = { + "input_len": input_len, + # no need to multiply by num_vocabs, + # this defines number of decoder iterations + "output_len": output_len, + "task_id": -1 # As in your example + } + # Generate input_ids: random ints in range (0, 2048) - input_ids = [random.randint(0, text_vocab_size - 1) for _ in range(input_len)] - + input_ids = None + input_emb = None + if text_vocab_size is not None: + input_ids = [random.randint(0, text_vocab_size - 1) for _ in range(input_len)] + sample["input_ids"] = input_ids + elif text_emb_dim is not None: + input_emb = np.random.randn(input_len, text_emb_dim).flatten().tolist() + sample["input_feat"] = input_emb + else: + raise ValueError("Either text_vocab_size or text_emb_dim must be provided") + # Generate context_ids as specified context_matrix = np.random.randint(0, audio_vocab_size, size=(context_len, num_vocabs)) # Set first row to zeros @@ -48,18 +73,8 @@ def generate_samples( # Flatten to 1D array context_ids = context_matrix.flatten().tolist() - - # Create sample - sample = { - "input_len": input_len, - "input_ids": input_ids, - "context_ids": context_ids, - # no need to multiply by num_vocabs, - # this defines number of decoder iterations - "output_len": output_len, - "task_id": -1 # As in your example - } - + sample["context_ids"] = context_ids + samples.append(sample) # Create the full JSON structure @@ -87,18 +102,21 @@ def main(): default=[3 * 75, 0, 3 * 75, 3 * 75], help='Context length parameters: mean, std, max') parser.add_argument('--output_len', type=int, nargs=4, metavar=('MEAN', 'STD', 'MIN', 'MAX'), default=[5 * 75, 0, 5 * 75, 5 * 75], help='Output length parameters: mean, std, max') - parser.add_argument('--text_vocab_size', type=int, default=98, help='Text vocabulary size') + parser.add_argument('--text_vocab_size', type=int, default=None, help='Text vocabulary size') # 98 + parser.add_argument('--text_emb_dim', type=int, default=None, help='Text embedding dimension') # 768 parser.add_argument('--audio_vocab_size', type=int, default=2048, help='Audio vocabulary size') args = parser.parse_args() - + generate_samples( - args.samples, args.text_vocab_size, args.audio_vocab_size, - args.input_len[0], args.input_len[1], args.input_len[2], args.input_len[3], - args.context_len[0], args.context_len[1], args.context_len[2], args.context_len[3], - args.output_len[0], args.output_len[1], args.output_len[2], args.output_len[3], + args.samples, args.audio_vocab_size, + args.input_len, + args.context_len, + args.output_len, args.output, - args.num_vocabs + text_vocab_size=args.text_vocab_size, + text_emb_dim=args.text_emb_dim, + num_vocabs=args.num_vocabs, ) if __name__ == "__main__": From b626eb71f3fcea88e73cb4ecdbdbcf433685abab Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Fri, 4 Jul 2025 12:15:54 +0000 Subject: [PATCH 8/9] gptManagerBenchmark.cpp: pre-copy encoder features to GPU in benchmark Signed-off-by: Viacheslav Klimkov --- benchmarks/cpp/gptManagerBenchmark.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index 7d1c5f022c1..32fab824e5d 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -126,9 +126,12 @@ texec::Tensor castInputFeatHalf(texec::Tensor const& inputFeat, std::string cons { targetData[i] = static_cast(sourceData[i]); } - TLLM_LOG_DEBUG("Casted inputFeat tensor from fp32 to fp16"); - return castedTensor; + + auto stream = std::make_shared(); + texec::Tensor castedTensorGpu = castedTensor.copyToGpu(stream); + + return castedTensorGpu; } catch (std::exception const& e) { From 43daa79104826661a24e838f0f19a521d2a97789 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 8 Jul 2025 08:54:35 +0000 Subject: [PATCH 9/9] Enable latest grpo checkpoint, apply attention prior in leader thread Signed-off-by: Viacheslav Klimkov --- cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp | 2 +- .../decoderMaskedMultiheadAttentionTemplate.h | 2 +- tensorrt_llm/models/t5tts/model.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp index f076b7c82a5..05a30f4f3c9 100644 --- a/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp @@ -1029,7 +1029,7 @@ void RuntimeBuffers::setAttentionPriorIdx( } // create a cpu buffer for scores to find max score in - SizeType32 searchLength = 10; + SizeType32 searchLength = 5; auto const& manager = runtime.getBufferManager(); auto const& stream = runtime.getStream(); auto scoresHost = manager.cpu(ITensor::makeShape({searchLength}), scores->getDataType()); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index c3efde9e2f6..98438e2a411 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -2093,7 +2093,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } } - if (is_active && has_attention_mask && DO_CROSS_ATTENTION) { + if (is_active && has_attention_mask && DO_CROSS_ATTENTION && is_leader) { // TODO: This is a fix to take into account custom attention mask during cross attention. // It is implicitely excludes EOS token from encoder sequence, this is to be checked. // It penalizes masked tokens with -1e9, this can be adjusted to implement attention prior floor. diff --git a/tensorrt_llm/models/t5tts/model.py b/tensorrt_llm/models/t5tts/model.py index 7dbc0cf5d73..5a193c8e1e5 100644 --- a/tensorrt_llm/models/t5tts/model.py +++ b/tensorrt_llm/models/t5tts/model.py @@ -51,8 +51,8 @@ MLPType.FusedGatedMLP: FusedGatedMLP, } -COMPUTE_SCORES_FROM_LAYERS = [4, 6, 10] -APPLY_PRIOR_TO_LAYERS = [4, 6, 10] +COMPUTE_SCORES_FROM_LAYERS = [4,6,10] +APPLY_PRIOR_TO_LAYERS = [4,5,6,7,8,9,10,11] class PositionwiseConvFF(Module):