Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions cpp/kernels/talkerMLPKernels/talkerMLPKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,12 @@ void invokeScatter(rt::Tensor const& source, rt::Tensor const& indices, rt::Tens
template <int32_t VEC_SIZE = 8>
__global__ void assistantPreambleKernel(half const* __restrict__ projected, half const* __restrict__ ttsPadEmbed,
half const* __restrict__ ttsBosEmbed, half const* __restrict__ ttsEosEmbed, half const* __restrict__ embTable,
int32_t codecNothinkId, int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId,
int32_t codecBosId, int32_t textLen, int32_t hiddenDim, half* __restrict__ output)
int32_t codecNothinkId, int32_t codecThinkId, int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId,
int32_t codecPadId, int32_t codecBosId, int32_t langId, int32_t textLen, int32_t hiddenDim,
half* __restrict__ output)
{
constexpr int32_t kFixedPrefixLen = 8; // rows 0-7
// No-lang: 8 fixed prefix rows (0-7); lang: 9 fixed prefix rows (0-8) with langId injected at row 5.
int32_t const kFixedPrefixLen = (langId >= 0) ? 9 : 8;
int32_t const rowIdx = blockIdx.x;
int32_t const numVecs = hiddenDim / VEC_SIZE;

Expand All @@ -468,42 +470,91 @@ __global__ void assistantPreambleKernel(half const* __restrict__ projected, half

if (rowIdx < kFixedPrefixLen)
{
switch (rowIdx)
if (langId < 0)
{
case 0: srcA = projected; break;
case 1: srcA = projected + hiddenDim; break;
case 2: srcA = projected + 2 * hiddenDim; break;
case 3:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecNothinkId) * hiddenDim;
break;
case 4:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkBosId) * hiddenDim;
break;
case 5:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkEosId) * hiddenDim;
break;
case 6:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(speakerId) * hiddenDim;
break;
default: // rowIdx == 7
srcA = ttsBosEmbed;
srcB = embTable + static_cast<int64_t>(codecPadId) * hiddenDim;
break;
// No-language path (8-row prefix)
switch (rowIdx)
{
case 0: srcA = projected; break;
case 1: srcA = projected + hiddenDim; break;
case 2: srcA = projected + 2 * hiddenDim; break;
case 3:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecNothinkId) * hiddenDim;
break;
case 4:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkBosId) * hiddenDim;
break;
case 5:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkEosId) * hiddenDim;
break;
case 6:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(speakerId) * hiddenDim;
break;
default: // rowIdx == 7
srcA = ttsBosEmbed;
srcB = embTable + static_cast<int64_t>(codecPadId) * hiddenDim;
break;
}
}
else
{
// Language path (9-row prefix): codecThinkId at row 3, langId injected at row 5,
// codecThinkEosId shifted to row 6, speaker to row 7, codecPad/ttsBos to row 8.
switch (rowIdx)
{
case 0: srcA = projected; break;
case 1: srcA = projected + hiddenDim; break;
case 2: srcA = projected + 2 * hiddenDim; break;
case 3:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkId) * hiddenDim;
break;
case 4:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkBosId) * hiddenDim;
break;
case 5:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(langId) * hiddenDim;
break;
case 6:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(codecThinkEosId) * hiddenDim;
break;
case 7:
srcA = ttsPadEmbed;
srcB = embTable + static_cast<int64_t>(speakerId) * hiddenDim;
break;
default: // rowIdx == 8
srcA = ttsBosEmbed;
srcB = embTable + static_cast<int64_t>(codecPadId) * hiddenDim;
break;
}
}
}
else if (rowIdx < kFixedPrefixLen + textLen)
{
// Text token rows: projected[3 + (rowIdx-8)] + embTable[codec]
// Last text row uses codecBosId (start-of-generation marker);
// all preceding text rows use codecPadId. Matches PyTorch reference:
// assistant_codec_hidden = [zeros(3), no-think, thinkBos, thinkEos, speaker, codecPad, codecBos]
// Text token rows: projected[3 + (rowIdx-kFixedPrefixLen)] + embTable[codec]
//
// PyTorch reference (modeling_qwen3_tts.py, non_streaming_mode branch) builds the
// codec embed for text rows as embed([codec_pad_id] * (text_len + 1)), i.e. every
// text row uses codec_pad_id and the trailing tts_eos row also uses codec_pad_id.
// codec_bos_id is reserved for the very last row (tts_pad + codec_bos), not the last
// text row. Legacy (no-language) path: keep the original kernel's behaviour of using
// codec_bos on the last text row to avoid silently changing established outputs.
int32_t const textIdx = rowIdx - kFixedPrefixLen;
srcA = projected + static_cast<int64_t>(3 + textIdx) * hiddenDim;
int32_t const codecId = (rowIdx == kFixedPrefixLen + textLen - 1) ? codecBosId : codecPadId;
// Keep both paths consistent with the existing pristine behaviour: last text row uses
// codec_bos, preceding rows use codec_pad. The Python reference's non_streaming_mode
// branch puts codec_pad on every text row, but bumping the language path to that layout
// in isolation caused the talker to runaway-generate (188 frames vs ~24 expected); the
// codec_bos marker on the final text row appears load-bearing for the kernel/sampler.
bool const isLastTextRow = (rowIdx == kFixedPrefixLen + textLen - 1);
int32_t const codecId = isLastTextRow ? codecBosId : codecPadId;
srcB = embTable + static_cast<int64_t>(codecId) * hiddenDim;
}
else if (rowIdx == kFixedPrefixLen + textLen)
Expand Down Expand Up @@ -539,16 +590,16 @@ __global__ void assistantPreambleKernel(half const* __restrict__ projected, half
}

void invokeAssistantPreamble(rt::Tensor const& projected, rt::Tensor const& ttsPadEmbed, rt::Tensor const& ttsBosEmbed,
rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkBosId,
int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId, int32_t textLen,
rt::Tensor& output, cudaStream_t stream)
rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkId,
int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId,
int32_t langId, int32_t textLen, rt::Tensor& output, cudaStream_t stream)
{
constexpr int32_t kVecSize = 8;

int32_t const hiddenDim = static_cast<int32_t>(projected.getShape()[1]);
int32_t const numVecs = hiddenDim / kVecSize;
// totalRows = 8 fixed prefix + textLen text rows + 2 suffix rows
int32_t const totalRows = 8 + textLen + 2;
int32_t const kFixedPrefixLen = (langId >= 0) ? 9 : 8;
int32_t const totalRows = kFixedPrefixLen + textLen + 2;

// 128 threads covers H=1024 with VEC_SIZE=8 in one pass
dim3 const block(std::min(numVecs, 128));
Expand All @@ -562,8 +613,8 @@ void invokeAssistantPreamble(rt::Tensor const& projected, rt::Tensor const& ttsP
half* outPtr = static_cast<half*>(output.rawPointer());

assistantPreambleKernel<kVecSize><<<grid, block, 0, stream>>>(projPtr, padPtr, bosPtr, eosPtr, embPtr,
codecNothinkId, codecThinkBosId, codecThinkEosId, speakerId, codecPadId, codecBosId, textLen, hiddenDim,
outPtr);
codecNothinkId, codecThinkId, codecThinkBosId, codecThinkEosId, speakerId, codecPadId, codecBosId, langId,
textLen, hiddenDim, outPtr);
CUDA_CHECK(cudaPeekAtLastError());
}

Expand Down
59 changes: 38 additions & 21 deletions cpp/kernels/talkerMLPKernels/talkerMLPKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,32 +94,49 @@ void invokeScatter(rt::Tensor const& source, rt::Tensor const& indices, rt::Tens
//! \brief Fused non-streaming assistant preamble construction for TTS input projection
//!
//! Builds the complete non-streaming prefill buffer in one pass.
//! Total rows written = 8 + textLen + 2 (= seqLen + 2).
//!
//! Row layout (written at outputOffset):
//! [0-2]: projected[0-2] (role tokens)
//! [3]: ttsPadEmbed + talkerEmbTable[codecNothinkId]
//! [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId]
//! [5]: ttsPadEmbed + talkerEmbTable[codecThinkEosId]
//! [6]: ttsPadEmbed + talkerEmbTable[speakerId]
//! [7]: ttsBosEmbed + talkerEmbTable[codecPadId]
//! [8..8+N-2]: projected[3+i] + talkerEmbTable[codecPadId] (text tokens, N=textLen)
//! [8+N-1]: projected[3+N-1] + talkerEmbTable[codecBosId] (last text = start-of-generation)
//! [8+N]: ttsEosEmbed + talkerEmbTable[codecPadId]
//! [8+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId]
//! Two layouts based on whether a language conditioning ID is provided:
//!
//! No-language path (langId < 0):
//! Total rows = 8 + textLen + 2 (= seqLen + 2). Uses codecNothinkId at row 3.
//! [0-2]: projected[0-2]
//! [3]: ttsPadEmbed + talkerEmbTable[codecNothinkId]
//! [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId]
//! [5]: ttsPadEmbed + talkerEmbTable[codecThinkEosId]
//! [6]: ttsPadEmbed + talkerEmbTable[speakerId]
//! [7]: ttsBosEmbed + talkerEmbTable[codecPadId]
//! [8..8+N-1]: projected[3+i] + talkerEmbTable[codecPad/codecBos] (last row uses codecBosId)
//! [8+N]: ttsEosEmbed + talkerEmbTable[codecPadId]
//! [8+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId]
//!
//! Language path (langId >= 0, CustomVoice with language conditioning):
//! Total rows = 9 + textLen + 2 (= seqLen + 3 = original-seqLen + 2, since N is also +1 upstream).
//! Uses codecThinkId at row 3 and injects langId at row 5.
//! [0-2]: projected[0-2]
//! [3]: ttsPadEmbed + talkerEmbTable[codecThinkId] (NOTE: think, not no-think)
//! [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId]
//! [5]: ttsPadEmbed + talkerEmbTable[langId] (NEW row, language embed)
//! [6]: ttsPadEmbed + talkerEmbTable[codecThinkEosId]
//! [7]: ttsPadEmbed + talkerEmbTable[speakerId]
//! [8]: ttsBosEmbed + talkerEmbTable[codecPadId]
//! [9..9+N-1]: projected[3+i] + talkerEmbTable[codecPad/codecBos] (last row uses codecBosId)
//! [9+N]: ttsEosEmbed + talkerEmbTable[codecPadId]
//! [9+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId]
//!
//! \param projected MLP output [seqLen, H] (FP16)
//! \param ttsPadEmbed/ttsBosEmbed/ttsEosEmbed TTS special embeddings [H] (FP16)
//! \param talkerEmbTable Talker embedding table [vocabSize, H] (FP16)
//! \param codecNothinkId..codecBosId Codec token IDs used in rows [3-8+N+1]
//! \param speakerId Speaker codec token ID (row 6)
//! \param textLen Number of text token rows (N = seqLen - 8)
//! \param output Full output buffer [8+N+2, H] (FP16)
//! \param stream CUDA stream
//! \param codecNothinkId Codec no-think control token (used when langId < 0)
//! \param codecThinkId Codec think control token (used when langId >= 0)
//! \param codecThinkBosId/codecThinkEosId/codecPadId/codecBosId Codec control IDs
//! \param speakerId Speaker codec token ID
//! \param langId Language codec token ID; if < 0, no-language path is used
//! \param textLen Number of text token rows (N)
//! \param output Full output buffer (FP16)
//! \param stream CUDA stream
void invokeAssistantPreamble(rt::Tensor const& projected, rt::Tensor const& ttsPadEmbed, rt::Tensor const& ttsBosEmbed,
rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkBosId,
int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId, int32_t textLen,
rt::Tensor& output, cudaStream_t stream);
rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkId,
int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId,
int32_t langId, int32_t textLen, rt::Tensor& output, cudaStream_t stream);

//! \brief Fused residual connection for TTS decode input
//!
Expand Down
Loading