Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/handleContextLogits.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ class HandleContextLogits : Algorithm

tr::SizeType32 operator()(RequestVector const& contextRequests,
std::vector<tr::SizeType32> const& numContextLogitsVec, tr::ITensor::SharedPtr const& logits,
DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
tensorrt_llm::runtime::CudaStream const& stream, OptionalRef<MedusaBuffers> medusaBuffers,
tr::SizeType32 vocabId = 0) const;
std::vector<std::shared_ptr<DecoderBuffers>>& decoderBuffers,
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
tensorrt_llm::runtime::CudaStream const& stream, OptionalRef<MedusaBuffers> medusaBuffers) const;
};

} // namespace tensorrt_llm::batch_manager
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class HandleGenerationLogits : Algorithm

HandleGenerationLogits() = default;

void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests, DecoderBuffers& decoderBuffers,
void operator()(tr::SizeType32 logitsIndex, RequestVector const& generationRequests,
std::vector<std::shared_ptr<DecoderBuffers>>& decoderBuffers,
tr::ModelConfig const& modelConfig, tr::BufferManager const& manager,
tensorrt_llm::runtime::CudaStream const& stream, tr::ITensor::SharedPtr const& logits,
OptionalRef<RuntimeBuffers> genRuntimeBuffers, tr::SizeType32 vocabId = 0) const;
tr::ITensor::SharedPtr const& logits,
OptionalRef<RuntimeBuffers> genRuntimeBuffers) const;
};

} // namespace tensorrt_llm::batch_manager
6 changes: 6 additions & 0 deletions cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class GptDecoderBatched : public IGptDecoderBatched
return mBufferManager;
}

[[nodiscard]] BufferManager const& getDecoderBufferManager() const
{
return mDecoderBufferManager;
}

private:
//! @brief Sets inputs for explicit draft tokens.
void setExplicitDraftTokensInputs(decoder_batch::Input const& input);
Expand All @@ -106,6 +111,7 @@ class GptDecoderBatched : public IGptDecoderBatched
CudaStreamPtr mRuntimeStream;
CudaStreamPtr mDecoderStream;
BufferManager mBufferManager;
BufferManager mDecoderBufferManager;

using GptDecoderPtr = std::unique_ptr<IGptDecoder>;
GptDecoderPtr mDecoder;
Expand Down
62 changes: 26 additions & 36 deletions cpp/tensorrt_llm/batch_manager/handleContextLogits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include "tensorrt_llm/kernels/cfgKernels.h"

namespace tru = tensorrt_llm::runtime::utils;

Expand Down Expand Up @@ -69,15 +68,17 @@ void setupMedusaLogits(std::vector<TensorPtr>& medusaLogitsHeads, TensorPtr cons
} // namespace

SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,
std::vector<SizeType32> const& numContextLogitsVec, TensorPtr const& logits, DecoderBuffers& decoderBuffers,
std::vector<SizeType32> const& numContextLogitsVec, TensorPtr const& logits,
std::vector<std::shared_ptr<DecoderBuffers>>& decoderBuffers,
tr::ModelConfig const& modelConfig, BufferManager const& manager, tensorrt_llm::runtime::CudaStream const& stream,
OptionalRef<MedusaBuffers> medusaBuffers, SizeType32 vocabId) const
OptionalRef<MedusaBuffers> medusaBuffers) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(HandleContextLogits);

SizeType32 batchIndex{0};
SizeType32 logitsIndex{0};
auto vocabSizes = modelConfig.getVocabSizes();
// Copy logits into decoderBuffers.logits
for (auto const& llmReq : contextRequests)
{
Expand All @@ -91,6 +92,7 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,

if (modelConfig.computeContextLogits())
{
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support context logits");
// Since the computational graph has been modified, only the last token is needed.
TLLM_CHECK_WITH_INFO(!modelConfig.getSpeculativeDecodingMode().isMedusa()
&& !modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding(),
Expand All @@ -117,28 +119,13 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,
auto const numDecoderLogits = 1 + draftLength;
TensorPtr logitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits);

// this is CFG support implementation, where we advance the logits index through the unconditional logits
if (llmReq->isCfg()) {
logitsIndex += numContextLogits + draftLength;
TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex - numDecoderLogits, numDecoderLogits);
// TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale)

float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0);
SizeType32 vocabOffset = 0;
auto vocabSizes = modelConfig.getVocabSizes();
for (SizeType32 i = 0; i < vocabId; ++i)
{
vocabOffset += vocabSizes[i];
}
tensorrt_llm::kernels::invokeCfg(stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]);
}

auto const seqSlot = llmReq->mSeqSlots.at(0);
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
{
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support speculative decoding");
TLLM_CHECK(medusaBuffers);
// speculative decoding is not supported for numVocabs > 1
auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot);
auto& medusaLogitsHeads = decoderBuffers.front()->draftBuffers.predictedDraftLogits.at(seqSlot);
setupMedusaLogits(medusaLogitsHeads, medusaBuffers->medusaLogitsDevice,
modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex - numDecoderLogits,
numDecoderLogits);
Expand All @@ -148,45 +135,48 @@ SizeType32 HandleContextLogits::operator()(RequestVector const& contextRequests,
// save the accepted token logits from target model
if (llmReq->getReturnGenerationLogits())
{
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support returning generation logits");
copyLastContextLogits(logitsView, *llmReq, manager);
}

TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid<float>(*logitsView, manager, "logits") == false,
"Found invalid number (NaN or Inf) in logits");

auto& decoderLogits = decoderBuffers.logits.at(seqSlot);

if (reqBeamWidth > 1)
{
// Tile logits of context requests
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support beam search");
auto const logitsShape = logitsView->getShape();
auto const logitsType = logitsView->getDataType();
auto& decoderLogits = decoderBuffers.front()->logits.at(seqSlot);
decoderLogits = manager.gpu(ITensor::makeShape({reqBeamWidth, logitsShape.d[1]}), logitsType);
tensorrt_llm::runtime::kernels::tileTensor(*decoderLogits, *logitsView, reqBeamWidth, stream);
decoderLogits->unsqueeze(0);
}
else
{
auto curVocablogitsView = logitsView;
// iterate through vocabs and assign the logits to the decoder buffers
SizeType32 offset = 0;
auto const logitsViewShape = logitsView->getShape();
if (logitsViewShape.d[0] == 1) // if current nTok is 1, could have multiple vocabs
{
SizeType32 offset = 0;
auto vocabSizes = modelConfig.getVocabSizes();
for (SizeType32 i = 0; i < vocabId; ++i)
{
offset += vocabSizes[i];
}
curVocablogitsView = ITensor::slice(logitsView, {0, offset}, vocabSizes[vocabId]); // [vocabSize,]
curVocablogitsView = ITensor::view(curVocablogitsView, ITensor::makeShape({1, vocabSizes[vocabId]}));
TLLM_CHECK_WITH_INFO(logitsViewShape.d[0] == 1, "nTok should be 1 for multi-vocab");
for (SizeType32 vocabId = 0; vocabId < (SizeType32)vocabSizes.size(); ++vocabId) {
// thats where we store a slice of logits to
auto& decoderLogits = decoderBuffers[vocabId]->logits.at(seqSlot);
auto curVocabLogitsView = logitsView;
curVocabLogitsView = ITensor::slice(logitsView, {0, offset}, vocabSizes[vocabId]); // [vocabSize,]
// not sure why
curVocabLogitsView = ITensor::view(curVocabLogitsView, ITensor::makeShape({1, vocabSizes[vocabId]}));
auto const updateLogitsViewShape = curVocabLogitsView->getShape();
// assign
decoderLogits = ITensor::view(
curVocabLogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]}));
offset += (SizeType32)vocabSizes[vocabId];
}
auto const updateLogitsViewShape = curVocablogitsView->getShape();
decoderLogits = ITensor::view(
curVocablogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]}));
}
++batchIndex;
if (llmReq->isCfg()) {
++batchIndex;
logitsIndex += numContextLogits + draftLength;
}
}

Expand Down
48 changes: 21 additions & 27 deletions cpp/tensorrt_llm/batch_manager/handleGenerationLogits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include "tensorrt_llm/kernels/cfgKernels.h"

namespace tru = tensorrt_llm::runtime::utils;

Expand Down Expand Up @@ -75,21 +74,14 @@ void setupMedusaLogits(std::vector<TensorPtr>& medusaLogitsHeads, TensorPtr cons
} // namespace

void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector const& generationRequests,
DecoderBuffers& decoderBuffers, tr::ModelConfig const& modelConfig, BufferManager const& manager,
tensorrt_llm::runtime::CudaStream const& stream, TensorPtr const& logits, OptionalRef<RuntimeBuffers> genRuntimeBuffers,
SizeType32 vocabId) const
std::vector<std::shared_ptr<DecoderBuffers>>& decoderBuffers,
tr::ModelConfig const& modelConfig, BufferManager const& manager,
TensorPtr const& logits, OptionalRef<RuntimeBuffers> genRuntimeBuffers) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(HandleGenerationLogits);

// compute where logits start for the given `vocabId`
auto vocabSizes = modelConfig.getVocabSizes();
SizeType32 vocabOffset = 0;
for (SizeType32 i = 0; i < vocabId; ++i)
{
vocabOffset += vocabSizes[i];
}

for (auto const& llmReq : generationRequests)
{
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
Expand All @@ -110,38 +102,39 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co
TLLM_CHECK_DEBUG_WITH_INFO(tru::tensorHasInvalid<float>(*logitsView, manager, "logits") == false,
"Found invalid number (NaN or Inf) in logits");

// CFG implementation: get unconditional logits and add them to logitsView
if (llmReq->isCfg()) {
// skip unconditional logits
logitsIndex += numLogits;
TensorPtr uncondLogitsView = ITensor::slice(logits, logitsIndex, numLogits);
// TODO: implement CFG, apply logitsView = logitsView * cfgScale + uncondLogitsView * (1 - cfgScale)
float cfgScale = llmReq->mSamplingConfig.cfgScale->at(0);
tensorrt_llm::kernels::invokeCfg(stream, logitsView, uncondLogitsView, cfgScale, vocabOffset, vocabSizes[vocabId]);
}

auto& decoderLogits = decoderBuffers.logits.at(seqSlot);

auto const logitsViewShape = logitsView->getShape();
if (reqBeamWidth > 1)
{
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support beam search");
auto& decoderLogits = decoderBuffers.front()->logits.at(seqSlot);
decoderLogits = logitsView;
decoderLogits->unsqueeze(0);
}
else
{
auto curVocablogitsView = logitsView;
if (logitsViewShape.d[0] == 1) // if current nTok is 1, could have multiple vocabs
{
curVocablogitsView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSizes[vocabId]); // [vocabSize,]
curVocablogitsView = ITensor::view(
curVocablogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); // [numLogits == 1, vocabSize]
SizeType32 vocabOffset = 0;
for (SizeType32 vocabId = 0; vocabId < (SizeType32)vocabSizes.size(); ++vocabId) {
auto& decoderLogits = decoderBuffers[vocabId]->logits.at(seqSlot);
TLLM_CHECK_WITH_INFO(logitsViewShape.d[0] == 1, "Multi-vocab requires nTok to be 1");
auto curVocabLogitsView = logitsView;
curVocabLogitsView = ITensor::slice(logitsView, {0, vocabOffset}, vocabSizes[vocabId]); // [vocabSize,]
curVocabLogitsView = ITensor::view(curVocabLogitsView, ITensor::makeShape({1, vocabSizes[vocabId]})); // [numLogits == 1, vocabSize]
auto const updateLogitsViewShape = curVocabLogitsView->getShape();
decoderLogits = ITensor::view(
curVocabLogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]}));
vocabOffset += (SizeType32)vocabSizes[vocabId];
}
auto const updateLogitsViewShape = curVocablogitsView->getShape();
decoderLogits = ITensor::view(
curVocablogitsView, ITensor::makeShape({updateLogitsViewShape.d[0], 1, updateLogitsViewShape.d[1]}));
}

if (llmReq->getReturnGenerationLogits())
{
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support returning generation logits");
TLLM_CHECK_WITH_INFO(modelConfig.getSpeculativeDecodingMode().isNone()
|| modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal(),
"Only speculative decoding with external draft tokens supports returning generation logits");
Expand All @@ -164,9 +157,10 @@ void HandleGenerationLogits::operator()(SizeType32 logitsIndex, RequestVector co
}
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
{
TLLM_CHECK_WITH_INFO(vocabSizes.size() == 1, "Multi-vocab does not support speculative decoding");
TLLM_CHECK(genRuntimeBuffers);
// speculative decoding is not supported for numVocabs > 1
auto& medusaLogitsHeads = decoderBuffers.draftBuffers.predictedDraftLogits.at(seqSlot);
auto& medusaLogitsHeads = decoderBuffers.front()->draftBuffers.predictedDraftLogits.at(seqSlot);
setupMedusaLogits(medusaLogitsHeads, genRuntimeBuffers->medusaBuffers->medusaLogitsDevice,
modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(), logitsIndex, draftLength);
}
Expand Down
Loading