diff --git a/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h index 1a7c0966caf..cd3c4809873 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h @@ -49,9 +49,9 @@ class HandleContextLogits : Algorithm tr::SizeType32 operator()(RequestVector const& contextRequests, std::vector const& numContextLogitsVec, tr::ITensor::SharedPtr const& logits, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - tensorrt_llm::runtime::CudaStream const& stream, OptionalRef medusaBuffers, - tr::SizeType32 vocabId = 0) const; + std::vector>& decoderBuffers, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + tensorrt_llm::runtime::CudaStream const& stream, OptionalRef medusaBuffers) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h index a585aa6e491..d55f8c929c3 100644 --- a/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h +++ b/cpp/include/tensorrt_llm/batch_manager/handleGenerationLogits.h @@ -45,10 +45,11 @@ class HandleGenerationLogits : Algorithm HandleGenerationLogits() = default; - void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers, + void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests, + std::vector>& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - tensorrt_llm::runtime::CudaStream const& stream, tr::ITensor::SharedPtr const& logits, - OptionalRef genRuntimeBuffers, tr::SizeType32 vocabId = 0) const; + tr::ITensor::SharedPtr const& logits, + OptionalRef genRuntimeBuffers) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index 92560641ef8..b1cb85a28ce 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -89,6 +89,11 @@ class GptDecoderBatched : public IGptDecoderBatched return mBufferManager; } + [[nodiscard]] BufferManager const& getDecoderBufferManager() const + { + return mDecoderBufferManager; + } + private: //! @brief Sets inputs for explicit draft tokens. void setExplicitDraftTokensInputs(decoder_batch::Input const& input); @@ -106,6 +111,7 @@ class GptDecoderBatched : public IGptDecoderBatched CudaStreamPtr mRuntimeStream; CudaStreamPtr mDecoderStream; BufferManager mBufferManager; + BufferManager mDecoderBufferManager; using GptDecoderPtr = std::unique_ptr; GptDecoderPtr mDecoder; diff --git a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp index 51e1ab5753a..f8c8fbaa989 100644 --- a/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp @@ -25,7 +25,6 @@ #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" -#include "tensorrt_llm/kernels/cfgKernels.h" namespace tru = tensorrt_llm::runtime::utils; @@ -69,15 +68,17 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons } // namespace SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, - std::vector const& numContextLogitsVec, TensorPtr const& logits, DecoderBuffers& decoderBuffers, + std::vector const& numContextLogitsVec, TensorPtr const& logits, + std::vector>& decoderBuffers, tr::ModelConfig const& modelConfig, BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, - OptionalRef medusaBuffers, SizeType32 vocabId) const + OptionalRef medusaBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(HandleContextLogits); SizeType32 batchIndex{0}; SizeType32 logitsIndex{0}; + auto vocabSizes = modelConfig.getVocabSizes(); // Copy logits into decoderBuffers.logits for (auto const& llmReq : contextRequests) { @@ -91,6 +92,7 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, if (modelConfig.computeContextLogits()) { + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support context logits"); // Since the computational graph has been modified, only the last token is needed. TLLM_CHECK_WITH_INFO(!modelConfig.getSpeculativeDecodingMode().isMedusa() && !modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding(), @@ -117,28 +119,13 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, auto const numDecoderLogits = 1 + draftLength; TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); - // this is CFG support implementation, where we advance the logits index through the unconditional logits - if (llmReq->isCfg()) { - logitsIndex += numContextLogits + draftLength; - TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits); - // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) - - float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); - SizeType32 vocabOffset = 0; - auto vocabSizes = modelConfig.getVocabSizes(); - for (SizeType32 i = 0; i < vocabId; ++i) - { - vocabOffset += vocabSizes[i]; - } - tensorrt_llm::kernels::invokeCfg(stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); - } - auto const seqSlot = llmReq->mSeqSlots.at(0); if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support speculative decoding"); TLLM_CHECK(medusaBuffers); // speculative decoding is not supported for numVocabs > 1 - auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot); + auto& medusaLogitsHeads = decoderBuffers.front()->draftBuffers.predictedDraftLogits.at(seqSlot); setupMedusaLogits(medusaLogitsHeads, medusaBuffers->medusaLogitsDevice, modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex - numDecoderLogits, numDecoderLogits); @@ -148,45 +135,48 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests, // save the accepted token logits from target model if (llmReq->getReturnGenerationLogits()) { + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support returning generation logits"); copyLastContextLogits(logitsView, *llmReq, manager); } TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid(*logitsView, manager, "logits") == false, "Found invalid number (NaN or Inf) in logits"); - auto& decoderLogits = decoderBuffers.logits.at(seqSlot); - if (reqBeamWidth > 1) { // Tile logits of context requests + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support beam search"); auto const logitsShape = logitsView->getShape(); auto const logitsType = logitsView->getDataType(); + auto& decoderLogits = decoderBuffers.front()->logits.at(seqSlot); decoderLogits = manager.gpu(ITensor::makeShape({reqBeamWidth, logitsShape.d[1]}), logitsType); tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, stream); decoderLogits->unsqueeze(0); } else { - auto curVocablogitsView = logitsView; + // iterate through vocabs and assign the logits to the decoder buffers + SizeType32 offset = 0; auto const logitsViewShape = logitsView->getShape(); - if (logitsViewShape.d[0] == 1) // if current nTok is 1, could have multiple vocabs - { - SizeType32 offset = 0; - auto vocabSizes = modelConfig.getVocabSizes(); - for (SizeType32 i = 0; i < vocabId; ++i) - { - offset += vocabSizes[i]; - } - curVocablogitsView = ITensor::slice(logitsView, {0, offset}, vocabSizes[vocabId]); // [vocabSize,] - curVocablogitsView = ITensor::view(curVocablogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); + TLLM_CHECK_WITH_INFO(logitsViewShape.d[0] == 1, "nTok should be 1 for multi-vocab"); + for (SizeType32 vocabId = 0; vocabId < (SizeType32)vocabSizes.size(); ++vocabId) { + // thats where we store a slice of logits to + auto& decoderLogits = decoderBuffers[vocabId]->logits.at(seqSlot); + auto curVocabLogitsView = logitsView; + curVocabLogitsView = ITensor::slice(logitsView, {0, offset}, vocabSizes[vocabId]); // [vocabSize,] + // not sure why + curVocabLogitsView = ITensor::view(curVocabLogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); + auto const updateLogitsViewShape = curVocabLogitsView->getShape(); + // assign + decoderLogits = ITensor::view( + curVocabLogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); + offset += (SizeType32)vocabSizes[vocabId]; } - auto const updateLogitsViewShape = curVocablogitsView->getShape(); - decoderLogits = ITensor::view( - curVocablogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); } ++batchIndex; if (llmReq->isCfg()) { ++batchIndex; + logitsIndex += numContextLogits + draftLength; } } diff --git a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp index 2338ace6478..6533a55d0c7 100644 --- a/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp +++ b/cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp @@ -25,7 +25,6 @@ #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" -#include "tensorrt_llm/kernels/cfgKernels.h" namespace tru = tensorrt_llm::runtime::utils; @@ -75,21 +74,14 @@ void setupMedusaLogits(std::vector& medusaLogitsHeads, TensorPtr cons } // namespace void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector const& generationRequests, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, BufferManager const& manager, - tensorrt_llm::runtime::CudaStream const& stream, TensorPtr const& logits, OptionalRef genRuntimeBuffers, - SizeType32 vocabId) const + std::vector>& decoderBuffers, + tr::ModelConfig const& modelConfig, BufferManager const& manager, + TensorPtr const& logits, OptionalRef genRuntimeBuffers) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(HandleGenerationLogits); - // compute where logits start for the given `vocabId` auto vocabSizes = modelConfig.getVocabSizes(); - SizeType32 vocabOffset = 0; - for (SizeType32 i = 0; i < vocabId; ++i) - { - vocabOffset += vocabSizes[i]; - } - for (auto const& llmReq : generationRequests) { auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth; @@ -110,38 +102,39 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid(*logitsView, manager, "logits") == false, "Found invalid number (NaN or Inf) in logits"); - // CFG implementation: get unconditional logits and add them to logitsView if (llmReq->isCfg()) { + // skip unconditional logits logitsIndex += numLogits; - TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex, numLogits); - // TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale) - float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0); - tensorrt_llm::kernels::invokeCfg(stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]); } - auto& decoderLogits = decoderBuffers.logits.at(seqSlot); + auto const logitsViewShape = logitsView->getShape(); if (reqBeamWidth > 1) { + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support beam search"); + auto& decoderLogits = decoderBuffers.front()->logits.at(seqSlot); decoderLogits = logitsView; decoderLogits->unsqueeze(0); } else { - auto curVocablogitsView = logitsView; - if (logitsViewShape.d[0] == 1) // if current nTok is 1, could have multiple vocabs - { - curVocablogitsView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSizes[vocabId]); // [vocabSize,] - curVocablogitsView = ITensor::view( - curVocablogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); // [numLogits == 1, vocabSize] + SizeType32 vocabOffset = 0; + for (SizeType32 vocabId = 0; vocabId < (SizeType32)vocabSizes.size(); ++vocabId) { + auto& decoderLogits = decoderBuffers[vocabId]->logits.at(seqSlot); + TLLM_CHECK_WITH_INFO(logitsViewShape.d[0] == 1, "Multi-vocab requires nTok to be 1"); + auto curVocabLogitsView = logitsView; + curVocabLogitsView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSizes[vocabId]); // [vocabSize,] + curVocabLogitsView = ITensor::view(curVocabLogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); // [numLogits == 1, vocabSize] + auto const updateLogitsViewShape = curVocabLogitsView->getShape(); + decoderLogits = ITensor::view( + curVocabLogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); + vocabOffset += (SizeType32)vocabSizes[vocabId]; } - auto const updateLogitsViewShape = curVocablogitsView->getShape(); - decoderLogits = ITensor::view( - curVocablogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]})); } if (llmReq->getReturnGenerationLogits()) { + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support returning generation logits"); TLLM_CHECK_WITH_INFO(modelConfig.getSpeculativeDecodingMode().isNone() || modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal(), "Only speculative decoding with external draft tokens supports returning generation logits"); @@ -164,9 +157,10 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co } if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { + TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support speculative decoding"); TLLM_CHECK(genRuntimeBuffers); // speculative decoding is not supported for numVocabs > 1 - auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot); + auto& medusaLogitsHeads = decoderBuffers.front()->draftBuffers.predictedDraftLogits.at(seqSlot); setupMedusaLogits(medusaLogitsHeads, genRuntimeBuffers->medusaBuffers->medusaLogitsDevice, modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex, draftLength); } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 470d025d7be..4ed65e1bd64 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -52,6 +52,7 @@ #include "tensorrt_llm/common/nvtxUtils.h" #include "tensorrt_llm/common/timestampUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/kernels/cfgKernels.h" #include "tensorrt_llm/layers/defaultDecodingParams.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/gptDecoderBatched.h" @@ -1918,25 +1919,58 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(decoderStepAsync); - runtime::CudaEvent decoderFinishEvent; + // first run handling of context and generation logits, + // this forwards the logit slices into decoder buffers, so decoders can be run + auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); + auto& contextRuntimeBuffers = mBuffers.at(contextBufferId); + auto const logitsIndex = (*mHandleContextLogits)(scheduledRequests.contextRequests, + contextRuntimeBuffers->numContextLogits, contextRuntimeBuffers->logits, mDecoderBuffers, mModelConfig, + mRuntime->getBufferManager(), mRuntime->getStream(), contextRuntimeBuffers->medusaBuffers); + auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0; + auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId(); + auto& genRuntimeBuffers = mBuffers.at(genBufferId); + (*mHandleGenerationLogits)(genLogitsIndex, scheduledRequests.generationRequests, mDecoderBuffers, + mModelConfig, mRuntime->getBufferManager(), genRuntimeBuffers->logits, *genRuntimeBuffers); + + // copy cache indirection only for the first decoder, the rest is not used + copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId, 0); + + // apply cfg for generation logits + if (scheduledRequests.generationRequests.size() > 0) { + // check that all generation requests are cfg + bool allCfg = true; + bool hasCfg = false; + for (auto const& llmReq : scheduledRequests.generationRequests) { + allCfg &= llmReq->isCfg(); + hasCfg |= llmReq->isCfg(); + } + if (hasCfg) { + TLLM_CHECK_WITH_INFO(allCfg, "A kernel which allows to enable CFG per request is to be implemented"); + // slice generation logits, assumes nTok=1 per request + TensorPtr genLogitsView = ITensor::slice( + contextRuntimeBuffers->logits, + genLogitsIndex, + scheduledRequests.generationRequests.size() + ); + // compute total vocab size + SizeType32 totalVocabSize = 0; + for (auto const vs : mModelConfig.getVocabSizes()) { + totalVocabSize += vs; + } + float cfgScale = scheduledRequests.generationRequests.front()->mSamplingConfig.cfgScale->at(0); + tensorrt_llm::kernels::invokeCfg( + mRuntime->getStream(), genLogitsView, + scheduledRequests.generationRequests.size(), // numRequests + totalVocabSize, + cfgScale + ); + } + } + + // this is operations that we hopefully do no apply in multi-vocab sampling + // keep them separate just in case, because they are on the runtime stream for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) { - auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); - auto& contextRuntimeBuffers = mBuffers.at(contextBufferId); - auto const logitsIndex = (*mHandleContextLogits)(scheduledRequests.contextRequests, - contextRuntimeBuffers->numContextLogits, contextRuntimeBuffers->logits, *mDecoderBuffers[vid], mModelConfig, - mRuntime->getBufferManager(), mRuntime->getStream(), contextRuntimeBuffers->medusaBuffers, vid); - - auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0; - auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId(); - auto& genRuntimeBuffers = mBuffers.at(genBufferId); - (*mHandleGenerationLogits)(genLogitsIndex, scheduledRequests.generationRequests, *mDecoderBuffers[vid], - mModelConfig, mRuntime->getBufferManager(), mRuntime->getStream(), genRuntimeBuffers->logits, *genRuntimeBuffers, vid); - - // Copy indirection output into input - // TODO: Could we avoid this by modifying batchDecoder to take a vector of tensors instead? - copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId, vid); - mLogitsPostProcessorIsApplied = (*mLogitsPostProcessor)(scheduledRequests.contextRequests, scheduledRequests.generationRequests, mReplicateLogitsPostProcessor, *mDecoderBuffers[vid], mWorldConfig, *mRuntime, mLogitsPostProcessorBatched); @@ -1945,31 +1979,42 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques { mGuidedDecoder->execute(scheduledRequests, mRuntime->getBufferManager(), mDecoderBuffers[vid]->logits); } + } - auto const fusedBufferId = getFusedBufferId(); - auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId); + // finally run actual decoding + // each decoder runs on a dedicated stream + auto const fusedBufferId = getFusedBufferId(); + auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId); + std::vector finishedEvents; + for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) + { + if (vid == 0) { + // wait for runtime stream to finish + auto event = CudaEvent{}; + mRuntime->getStream().record(event); + mDecoders[vid]->getDecoderStream()->wait(event.get()); + } + auto& decodingInput = mDecodingInputs[vid].at(mMicroBatchId); std::tie(decodingInput, mDecodingOutput[vid]) = (*mMakeDecodingBatchInputOutput)(scheduledRequests.contextRequests, scheduledRequests.generationRequests, *mDecoderBuffers[vid], mDecoderInputBuffers.at(fusedBufferId), mDecoders[vid]->getDecoderState(), - mModelConfig, getMaxNumSequences(), mOperatingBeamWidth, mRuntime->getBufferManager(), - mRuntime->getStream(), *fusedRuntimeBuffers); - - runtime::CudaEvent finishedEvent = mDecoders[vid]->forwardAsync(*mDecodingOutput[vid], *decodingInput); + mModelConfig, getMaxNumSequences(), mOperatingBeamWidth, mDecoders[vid]->getDecoderBufferManager(), + *mDecoders[vid]->getDecoderStream(), *fusedRuntimeBuffers); + [[maybe_unused]] auto unusedEvent = mDecoders[vid]->forwardAsync(*mDecodingOutput[vid], *decodingInput); auto const returnLogProbs = batchReturnLogProbs(scheduledRequests); - finishedEvent = updateDecoderBuffers(returnLogProbs, std::move(finishedEvent), vid); - if (vid == getNumVocabs() - 1) - { - decoderFinishEvent = std::move(finishedEvent); - // All decoders use the same cuda stream for now. The last finished event indicates all decoder finish - // decoding - } + TLLM_CHECK_WITH_INFO(!returnLogProbs || getNumVocabs() == 1, "Can't return log probs for multi-vocab"); + finishedEvents.push_back(updateDecoderBuffers(returnLogProbs, mDecoders[vid]->getDecoderBufferManager(), vid)); + } + + for (auto const& event : finishedEvents) { + mRuntime->getStream().wait(event.get()); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return decoderFinishEvent; + return std::move(finishedEvents.back()); } void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( @@ -2031,29 +2076,29 @@ void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs( } runtime::CudaEvent TrtGptModelInflightBatching::updateDecoderBuffers( - bool returnLogProbs, runtime::CudaEvent decoderFinishEvent, SizeType32 vocabId) + bool returnLogProbs, runtime::BufferManager const& manager, SizeType32 vocabId) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(updateDecoderBuffers); - // Chain copy after decoder event, using a different stream - mCopyBufferManager.getStream().wait(decoderFinishEvent); - - mCopyBufferManager.copy( + manager.copy( *mDecoders[vocabId]->getDecoderState().getAllNewTokens(), *mDecoderBuffers[vocabId]->newOutputTokensHost); - mCopyBufferManager.copy(*mDecoders[vocabId]->getDecoderState().getJointDecodingOutput().lengths, + manager.copy(*mDecoders[vocabId]->getDecoderState().getJointDecodingOutput().lengths, *mDecoderBuffers[vocabId]->sequenceLengthsHost); - auto const finishedSumDevice = mDecoders[vocabId]->getDecoderState().getFinishedSum(); - mCopyBufferManager.copy(*finishedSumDevice, *mDecoderBuffers[vocabId]->finishedSumHost); - auto const finishReasonsDevice = mDecoders[vocabId]->getDecoderState().getFinishReasons(); - mCopyBufferManager.copy(*finishReasonsDevice, *mDecoderBuffers[vocabId]->finishReasonsHost); + if (vocabId == 0) { + // no need to copy those for all the vocabs + auto const finishedSumDevice = mDecoders[vocabId]->getDecoderState().getFinishedSum(); + manager.copy(*finishedSumDevice, *mDecoderBuffers[vocabId]->finishedSumHost); + auto const finishReasonsDevice = mDecoders[vocabId]->getDecoderState().getFinishReasons(); + manager.copy(*finishReasonsDevice, *mDecoderBuffers[vocabId]->finishReasonsHost); + } if (returnLogProbs) { - mCopyBufferManager.copy( + manager.copy( *mDecoders[vocabId]->getDecoderState().getCumLogProbs(), *mDecoderBuffers[vocabId]->cumLogProbsHost); - mCopyBufferManager.copy( + manager.copy( *mDecoders[vocabId]->getDecoderState().getLogProbs(), *mDecoderBuffers[vocabId]->logProbsHost); } @@ -2062,7 +2107,7 @@ runtime::CudaEvent TrtGptModelInflightBatching::updateDecoderBuffers( // TODO: keep data on device for next iteration mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensDevice = mDecoders[vocabId]->getDecoderState().getNextDraftTokens(); - mCopyBufferManager.copy(*mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensDevice, + manager.copy(*mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensDevice, *mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensHost); if (mModelConfig.getSpeculativeDecodingMode().variableDraftLength()) @@ -2071,9 +2116,9 @@ runtime::CudaEvent TrtGptModelInflightBatching::updateDecoderBuffers( = mDecoders[vocabId]->getDecoderState().getNextDraftTokensLengths(); mDecoderBuffers[vocabId]->draftBuffers.prevDraftTokensLengthsDevice = mDecoders[vocabId]->getDecoderState().getPrevDraftTokensLengths(); - mCopyBufferManager.copy(*mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensLengthsDevice, + manager.copy(*mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensLengthsDevice, *mDecoderBuffers[vocabId]->draftBuffers.nextDraftTokensLengthsHost); - mCopyBufferManager.copy(*mDecoderBuffers[vocabId]->draftBuffers.prevDraftTokensLengthsDevice, + manager.copy(*mDecoderBuffers[vocabId]->draftBuffers.prevDraftTokensLengthsDevice, *mDecoderBuffers[vocabId]->draftBuffers.prevDraftTokensLengthsHost); } } @@ -2087,7 +2132,7 @@ runtime::CudaEvent TrtGptModelInflightBatching::updateDecoderBuffers( } runtime::CudaEvent copyEvent{}; - mCopyBufferManager.getStream().record(copyEvent); + manager.getStream().record(copyEvent); // Store the event for later sync. Sync stream before calling next decoder. Sync host before updating requests. TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return copyEvent; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 084a441d575..0c20ab8ac02 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -286,7 +286,7 @@ class TrtGptModelInflightBatching : public TrtGptModel ScheduledRequests const& scheduledRequests, std::optional const& decoderFinishEvent); runtime::CudaEvent updateDecoderBuffers( - bool returnLogProbs, runtime::CudaEvent decoderFinishEvent, SizeType32 vocabId = 0); + bool returnLogProbs, runtime::BufferManager const& manager, SizeType32 vocabId = 0); std::vector> communicateDecoderBuffers(bool returnLogProbs); void updateRequests(ScheduledRequests const& scheduledRequests); diff --git a/cpp/tensorrt_llm/kernels/cfgKernels.cu b/cpp/tensorrt_llm/kernels/cfgKernels.cu new file mode 100644 index 00000000000..f83d2371388 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cfgKernels.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include "tensorrt_llm/kernels/cfgKernels.h" + + namespace tensorrt_llm::kernels + { + + namespace + { + + // The __global__ kernel should be defined in the .cu file and is best placed + // in an anonymous namespace to limit its visibility to this translation unit. + template + __global__ void applyCfgKernel( + T* logits, int const numRequests, int const vocabSize, float const cfgScale) + { + // Each block processes one or more requests using a grid-stride loop. + // This makes the kernel robust to any number of requests. + for (int reqIdx = blockIdx.x; reqIdx < numRequests; reqIdx += gridDim.x) + { + // Each thread in the block processes one or more vocab entries using a block-stride loop. + // This ensures all vocab entries are processed regardless of vocabSize. + for (int vocabIdx = threadIdx.x; vocabIdx < vocabSize; vocabIdx += blockDim.x) + { + // The input tensor is conceptually [numRequests, 2, vocabSize] but laid out as + // a contiguous [numRequests * 2, vocabSize] tensor. + // We access the conditional logits at [reqIdx * 2 * vocabSize + vocabIdx] and + // unconditional logits at [reqIdx * 2 * vocabSize + vocabSize + vocabIdx]. + T* condLogitPtr = logits + reqIdx * 2 * vocabSize + vocabIdx; + T const* uncondLogitPtr = logits + reqIdx * 2 * vocabSize + vocabSize + vocabIdx; + + // Perform calculations in float for precision. + float condLogitFloat = static_cast(*condLogitPtr); + float uncondLogitFloat = static_cast(*uncondLogitPtr); + + // Apply the CFG formula: guidance * cond + (1 - guidance) * uncond + float result = cfgScale * condLogitFloat + (1.0f - cfgScale) * uncondLogitFloat; + + // Store the result back in place of the conditional logit. + *condLogitPtr = static_cast(result); + } + } + } + + } // anonymous namespace + + + // Definition of the function that launches the kernel. + // This function must be defined in the .cu file because it uses <<<...>>> syntax. + template + void invokeCfgKernel(T* logits, int const numRequests, int const vocabSize, float const cfgScale, cudaStream_t stream) + { + // A block size of 512 is a good general-purpose choice for memory-bound kernels. + dim3 block(512); + // Use a grid-stride loop, so we don't need a grid size equal to numRequests. + // Capping the grid size can improve efficiency by avoiding launching an excessive + // number of small blocks. 256 is a safe heuristic. + dim3 grid(std::min(numRequests, 256)); + + // Launch the kernel. + applyCfgKernel<<>>(logits, numRequests, vocabSize, cfgScale); + + // It is critical to check for errors after launching a kernel. + TLLM_CUDA_CHECK(cudaGetLastError()); + } + + // Explicitly instantiate the templates for the supported data types. + // This is required because the definition is in a .cu file and would not + // otherwise be visible to other parts of the program that include the header. + template void invokeCfgKernel(float* logits, int const numRequests, int const vocabSize, float const cfgScale, cudaStream_t stream); + template void invokeCfgKernel(half* logits, int const numRequests, int const vocabSize, float const cfgScale, cudaStream_t stream); + + } // namespace tensorrt_llm::kernels + \ No newline at end of file diff --git a/cpp/tensorrt_llm/kernels/cfgKernels.h b/cpp/tensorrt_llm/kernels/cfgKernels.h index 6c28dc91ffd..69b8032f8d3 100644 --- a/cpp/tensorrt_llm/kernels/cfgKernels.h +++ b/cpp/tensorrt_llm/kernels/cfgKernels.h @@ -1,60 +1,63 @@ #pragma once -#include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/runtime/iTensor.h" -#include -#include +#include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" namespace tensorrt_llm::kernels { -//! Apply classifier-free guidance (CFG) on GPU in-place using cuBLAS. -//! It overwrites `logitsView` with: logits = cfgScale * logits + (1 - cfgScale) * uncondLogits -//! Only the slice [vocabOffset, vocabOffset + vocabSize) is modified. +/** + * @brief Forward declaration for the templated kernel launcher. + * + * The full definition of this function resides in the .cu file and is compiled by NVCC. + * This declaration makes it visible to the inline `invokeCfg` function below. + */ +template +void invokeCfgKernel(T* logits, int const numRequests, int const vocabSize, float const cfgScale, cudaStream_t stream); + +/** + * @brief Applies classifier-free guidance (CFG) to the logits tensor in-place on the GPU. + * + * This function is the main entry point for the CFG operation. It is a type-dispatcher + * that calls the appropriate templated kernel launcher based on the data type of the logits tensor. + * + * @details The formula applied is: + * `logits_cond = cfgScale * logits_cond + (1 - cfgScale) * logits_uncond` + * + * The input logits tensor is expected to have a shape that can be interpreted as + * [numRequests, 2, vocabSize], where logits[i, 0, :] are the conditional logits and + * logits[i, 1, :] are the unconditional logits. The result is written back into the + * conditional logits' location. For efficiency, the implementation treats the tensor + * as having the shape [numRequests * 2, vocabSize]. + * + * @param stream The CUDA stream to execute the kernel on. + * @param logitsView A shared pointer to the tensor containing both conditional and + * unconditional logits. The tensor is modified in-place. + * @param numRequests The number of requests in the batch. + * @param vocabSize The size of the vocabulary. + * @param cfgScale The guidance scale factor. A value of 1.0 effectively disables CFG. + */ inline void invokeCfg(tensorrt_llm::runtime::CudaStream const& stream, - runtime::ITensor::SharedPtr logitsView, runtime::ITensor::SharedPtr uncondLogitsView, - float cfgScale, runtime::SizeType32 vocabOffset, runtime::SizeType32 vocabSize) + runtime::ITensor::SharedPtr logitsView, int numRequests, int vocabSize, float cfgScale) { - using TensorPtr = runtime::ITensor::SharedPtr; - - // Restrict to current vocabulary segment. - TensorPtr logitsVocabView = runtime::ITensor::slice(logitsView, {0, vocabOffset}, vocabSize); - TensorPtr uncondLogitsVocabView = runtime::ITensor::slice(uncondLogitsView, {0, vocabOffset}, vocabSize); - - void* condPtr = logitsVocabView->data(); - void const* uncondPtr = uncondLogitsVocabView->data(); + auto const& logitsDataType = logitsView->getDataType(); - cudaDataType_t dataType{}; - switch (logitsVocabView->getDataType()) + if (logitsDataType == nvinfer1::DataType::kFLOAT) { - case nvinfer1::DataType::kFLOAT: dataType = CUDA_R_32F; break; - case nvinfer1::DataType::kHALF: dataType = CUDA_R_16F; break; - default: TLLM_THROW("Unsupported data type for CFG"); + invokeCfgKernel(runtime::bufferCast(*logitsView), + numRequests, vocabSize, cfgScale, stream.get()); + } + else if (logitsDataType == nvinfer1::DataType::kHALF) + { + invokeCfgKernel(runtime::bufferCast(*logitsView), + numRequests, vocabSize, cfgScale, stream.get()); + } + else + { + TLLM_THROW("Unsupported data type for CFG. Only float and half are supported."); } - - auto handlePtr = getCublasHandle(); - auto& handle = *handlePtr; - tensorrt_llm::common::check_cuda_error(cublasSetStream(handle, stream.get())); - - int n = static_cast(vocabSize); - int inc = 1; - - // Use float for the scaling factors and always accumulate in FP32 to - // satisfy cuBLAS requirements (FP16 vectors must use FP32 compute/alpha). - float alphaF = cfgScale; // Scaling factor in FP32 - float axpyF = 1.0f - cfgScale; // (1 - cfgScale) in FP32 - - tensorrt_llm::common::check_cuda_error( - cublasScalEx(handle, n, &alphaF, CUDA_R_32F, // alpha - condPtr, dataType, // x and its type - inc, CUDA_R_32F)); // increments + compute type - - tensorrt_llm::common::check_cuda_error( - cublasAxpyEx(handle, n, &axpyF, CUDA_R_32F, // alpha - uncondPtr, dataType, inc, // x - condPtr, dataType, inc, // y - CUDA_R_32F)); // compute type } } // namespace tensorrt_llm::kernels \ No newline at end of file diff --git a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp index 2eadcf29446..67f3b7ddb5b 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp @@ -104,7 +104,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod "__call__", [](HandleContextLogits const& self, RequestVector const& contextRequests, std::vector const& numContextLogitsVec, at::Tensor const& logits, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, + std::vector>& decoderBuffers, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream, OptionalRef medusaBuffers = std::nullopt) { @@ -121,15 +122,15 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod .def( "__call__", [](HandleGenerationLogits const& self, tr::SizeType32 logitsIndex, RequestVector const& generationRequests, - DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, - tensorrt_llm::runtime::CudaStream const& stream, + std::vector>& decoderBuffers, + tr::ModelConfig const& modelConfig, tr::BufferManager const& manager, at::Tensor const& logits, OptionalRef genRuntimeBuffers = std::nullopt) { - self(logitsIndex, generationRequests, decoderBuffers, modelConfig, manager, stream, + self(logitsIndex, generationRequests, decoderBuffers, modelConfig, manager, tr::TorchView::of(logits), genRuntimeBuffers); }, py::arg("logits_index"), py::arg("generation_requests"), py::arg("decoder_buffers"), - py::arg("model_config"), py::arg("buffer_manager"), py::arg("stream"), py::arg("logits"), + py::arg("model_config"), py::arg("buffer_manager"), py::arg("logits"), py::arg("gen_runtime_buffers") = std::nullopt) .def("name", [](HandleGenerationLogits const&) { return HandleGenerationLogits::name; }); diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 72cdbe5ebcf..073160e449c 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -37,7 +37,9 @@ using namespace tensorrt_llm::runtime; GptDecoderBatched::GptDecoderBatched(GptDecoderBatched::CudaStreamPtr stream, SpeculativeDecodingMode const& speculativeDecodingMode, nvinfer1::DataType dtype) : mRuntimeStream{std::move(stream)} + , mDecoderStream{std::make_shared()} , mBufferManager{mRuntimeStream} + , mDecoderBufferManager{mDecoderStream} { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -108,7 +110,6 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max } auto const device = mRuntimeStream->getDevice(); - mDecoderStream = std::make_shared(); TLLM_CHECK(mDecoderStream->getDevice() == device); if (vocabSize == 0) @@ -189,20 +190,12 @@ CudaEvent GptDecoderBatched::forwardAsync(decoder_batch::Output& output, decoder { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto eventStart = CudaEvent{}; - mRuntimeStream->record(eventStart); - mDecoderStream->wait(eventStart.get()); - forwardDispatch(output, input); CudaEvent event{}; mDecoderStream->record(event); - mRuntimeStream->wait(event); - - CudaEvent eventStop{}; - mRuntimeStream->record(eventStop); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return eventStop; + return event; } // TODO: produce new input and output