From e07142178b35a3f7de5a889673695423538ebd21 Mon Sep 17 00:00:00 2001 From: suharvest Date: Tue, 26 May 2026 23:07:34 -0400 Subject: [PATCH] feat https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/87: add CustomVoice language conditioning support for Qwen3-TTS --- .../talkerMLPKernels/talkerMLPKernels.cu | 129 ++++++++++++------ .../talkerMLPKernels/talkerMLPKernels.h | 59 +++++--- cpp/runtime/qwen3OmniTTSRuntime.cpp | 73 ++++++++-- cpp/runtime/qwen3OmniTTSRuntime.h | 12 +- examples/omni/qwen3_tts_inference.cpp | 7 + experimental/llm_loader/export_all_cli.py | 6 + 6 files changed, 212 insertions(+), 74 deletions(-) diff --git a/cpp/kernels/talkerMLPKernels/talkerMLPKernels.cu b/cpp/kernels/talkerMLPKernels/talkerMLPKernels.cu index 60056ae..ed268c5 100644 --- a/cpp/kernels/talkerMLPKernels/talkerMLPKernels.cu +++ b/cpp/kernels/talkerMLPKernels/talkerMLPKernels.cu @@ -453,10 +453,12 @@ void invokeScatter(rt::Tensor const& source, rt::Tensor const& indices, rt::Tens template __global__ void assistantPreambleKernel(half const* __restrict__ projected, half const* __restrict__ ttsPadEmbed, half const* __restrict__ ttsBosEmbed, half const* __restrict__ ttsEosEmbed, half const* __restrict__ embTable, - int32_t codecNothinkId, int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, - int32_t codecBosId, int32_t textLen, int32_t hiddenDim, half* __restrict__ output) + int32_t codecNothinkId, int32_t codecThinkId, int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, + int32_t codecPadId, int32_t codecBosId, int32_t langId, int32_t textLen, int32_t hiddenDim, + half* __restrict__ output) { - constexpr int32_t kFixedPrefixLen = 8; // rows 0-7 + // No-lang: 8 fixed prefix rows (0-7); lang: 9 fixed prefix rows (0-8) with langId injected at row 5. + int32_t const kFixedPrefixLen = (langId >= 0) ? 9 : 8; int32_t const rowIdx = blockIdx.x; int32_t const numVecs = hiddenDim / VEC_SIZE; @@ -468,42 +470,91 @@ __global__ void assistantPreambleKernel(half const* __restrict__ projected, half if (rowIdx < kFixedPrefixLen) { - switch (rowIdx) + if (langId < 0) { - case 0: srcA = projected; break; - case 1: srcA = projected + hiddenDim; break; - case 2: srcA = projected + 2 * hiddenDim; break; - case 3: - srcA = ttsPadEmbed; - srcB = embTable + static_cast(codecNothinkId) * hiddenDim; - break; - case 4: - srcA = ttsPadEmbed; - srcB = embTable + static_cast(codecThinkBosId) * hiddenDim; - break; - case 5: - srcA = ttsPadEmbed; - srcB = embTable + static_cast(codecThinkEosId) * hiddenDim; - break; - case 6: - srcA = ttsPadEmbed; - srcB = embTable + static_cast(speakerId) * hiddenDim; - break; - default: // rowIdx == 7 - srcA = ttsBosEmbed; - srcB = embTable + static_cast(codecPadId) * hiddenDim; - break; + // No-language path (8-row prefix) + switch (rowIdx) + { + case 0: srcA = projected; break; + case 1: srcA = projected + hiddenDim; break; + case 2: srcA = projected + 2 * hiddenDim; break; + case 3: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(codecNothinkId) * hiddenDim; + break; + case 4: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(codecThinkBosId) * hiddenDim; + break; + case 5: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(codecThinkEosId) * hiddenDim; + break; + case 6: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(speakerId) * hiddenDim; + break; + default: // rowIdx == 7 + srcA = ttsBosEmbed; + srcB = embTable + static_cast(codecPadId) * hiddenDim; + break; + } + } + else + { + // Language path (9-row prefix): codecThinkId at row 3, langId injected at row 5, + // codecThinkEosId shifted to row 6, speaker to row 7, codecPad/ttsBos to row 8. + switch (rowIdx) + { + case 0: srcA = projected; break; + case 1: srcA = projected + hiddenDim; break; + case 2: srcA = projected + 2 * hiddenDim; break; + case 3: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(codecThinkId) * hiddenDim; + break; + case 4: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(codecThinkBosId) * hiddenDim; + break; + case 5: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(langId) * hiddenDim; + break; + case 6: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(codecThinkEosId) * hiddenDim; + break; + case 7: + srcA = ttsPadEmbed; + srcB = embTable + static_cast(speakerId) * hiddenDim; + break; + default: // rowIdx == 8 + srcA = ttsBosEmbed; + srcB = embTable + static_cast(codecPadId) * hiddenDim; + break; + } } } else if (rowIdx < kFixedPrefixLen + textLen) { - // Text token rows: projected[3 + (rowIdx-8)] + embTable[codec] - // Last text row uses codecBosId (start-of-generation marker); - // all preceding text rows use codecPadId. Matches PyTorch reference: - // assistant_codec_hidden = [zeros(3), no-think, thinkBos, thinkEos, speaker, codecPad, codecBos] + // Text token rows: projected[3 + (rowIdx-kFixedPrefixLen)] + embTable[codec] + // + // PyTorch reference (modeling_qwen3_tts.py, non_streaming_mode branch) builds the + // codec embed for text rows as embed([codec_pad_id] * (text_len + 1)), i.e. every + // text row uses codec_pad_id and the trailing tts_eos row also uses codec_pad_id. + // codec_bos_id is reserved for the very last row (tts_pad + codec_bos), not the last + // text row. Legacy (no-language) path: keep the original kernel's behaviour of using + // codec_bos on the last text row to avoid silently changing established outputs. int32_t const textIdx = rowIdx - kFixedPrefixLen; srcA = projected + static_cast(3 + textIdx) * hiddenDim; - int32_t const codecId = (rowIdx == kFixedPrefixLen + textLen - 1) ? codecBosId : codecPadId; + // Keep both paths consistent with the existing pristine behaviour: last text row uses + // codec_bos, preceding rows use codec_pad. The Python reference's non_streaming_mode + // branch puts codec_pad on every text row, but bumping the language path to that layout + // in isolation caused the talker to runaway-generate (188 frames vs ~24 expected); the + // codec_bos marker on the final text row appears load-bearing for the kernel/sampler. + bool const isLastTextRow = (rowIdx == kFixedPrefixLen + textLen - 1); + int32_t const codecId = isLastTextRow ? codecBosId : codecPadId; srcB = embTable + static_cast(codecId) * hiddenDim; } else if (rowIdx == kFixedPrefixLen + textLen) @@ -539,16 +590,16 @@ __global__ void assistantPreambleKernel(half const* __restrict__ projected, half } void invokeAssistantPreamble(rt::Tensor const& projected, rt::Tensor const& ttsPadEmbed, rt::Tensor const& ttsBosEmbed, - rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkBosId, - int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId, int32_t textLen, - rt::Tensor& output, cudaStream_t stream) + rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkId, + int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId, + int32_t langId, int32_t textLen, rt::Tensor& output, cudaStream_t stream) { constexpr int32_t kVecSize = 8; int32_t const hiddenDim = static_cast(projected.getShape()[1]); int32_t const numVecs = hiddenDim / kVecSize; - // totalRows = 8 fixed prefix + textLen text rows + 2 suffix rows - int32_t const totalRows = 8 + textLen + 2; + int32_t const kFixedPrefixLen = (langId >= 0) ? 9 : 8; + int32_t const totalRows = kFixedPrefixLen + textLen + 2; // 128 threads covers H=1024 with VEC_SIZE=8 in one pass dim3 const block(std::min(numVecs, 128)); @@ -562,8 +613,8 @@ void invokeAssistantPreamble(rt::Tensor const& projected, rt::Tensor const& ttsP half* outPtr = static_cast(output.rawPointer()); assistantPreambleKernel<<>>(projPtr, padPtr, bosPtr, eosPtr, embPtr, - codecNothinkId, codecThinkBosId, codecThinkEosId, speakerId, codecPadId, codecBosId, textLen, hiddenDim, - outPtr); + codecNothinkId, codecThinkId, codecThinkBosId, codecThinkEosId, speakerId, codecPadId, codecBosId, langId, + textLen, hiddenDim, outPtr); CUDA_CHECK(cudaPeekAtLastError()); } diff --git a/cpp/kernels/talkerMLPKernels/talkerMLPKernels.h b/cpp/kernels/talkerMLPKernels/talkerMLPKernels.h index 08b6cdf..05f16ea 100644 --- a/cpp/kernels/talkerMLPKernels/talkerMLPKernels.h +++ b/cpp/kernels/talkerMLPKernels/talkerMLPKernels.h @@ -94,32 +94,49 @@ void invokeScatter(rt::Tensor const& source, rt::Tensor const& indices, rt::Tens //! \brief Fused non-streaming assistant preamble construction for TTS input projection //! //! Builds the complete non-streaming prefill buffer in one pass. -//! Total rows written = 8 + textLen + 2 (= seqLen + 2). -//! -//! Row layout (written at outputOffset): -//! [0-2]: projected[0-2] (role tokens) -//! [3]: ttsPadEmbed + talkerEmbTable[codecNothinkId] -//! [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId] -//! [5]: ttsPadEmbed + talkerEmbTable[codecThinkEosId] -//! [6]: ttsPadEmbed + talkerEmbTable[speakerId] -//! [7]: ttsBosEmbed + talkerEmbTable[codecPadId] -//! [8..8+N-2]: projected[3+i] + talkerEmbTable[codecPadId] (text tokens, N=textLen) -//! [8+N-1]: projected[3+N-1] + talkerEmbTable[codecBosId] (last text = start-of-generation) -//! [8+N]: ttsEosEmbed + talkerEmbTable[codecPadId] -//! [8+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId] +//! Two layouts based on whether a language conditioning ID is provided: +//! +//! No-language path (langId < 0): +//! Total rows = 8 + textLen + 2 (= seqLen + 2). Uses codecNothinkId at row 3. +//! [0-2]: projected[0-2] +//! [3]: ttsPadEmbed + talkerEmbTable[codecNothinkId] +//! [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId] +//! [5]: ttsPadEmbed + talkerEmbTable[codecThinkEosId] +//! [6]: ttsPadEmbed + talkerEmbTable[speakerId] +//! [7]: ttsBosEmbed + talkerEmbTable[codecPadId] +//! [8..8+N-1]: projected[3+i] + talkerEmbTable[codecPad/codecBos] (last row uses codecBosId) +//! [8+N]: ttsEosEmbed + talkerEmbTable[codecPadId] +//! [8+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId] +//! +//! Language path (langId >= 0, CustomVoice with language conditioning): +//! Total rows = 9 + textLen + 2 (= seqLen + 3 = original-seqLen + 2, since N is also +1 upstream). +//! Uses codecThinkId at row 3 and injects langId at row 5. +//! [0-2]: projected[0-2] +//! [3]: ttsPadEmbed + talkerEmbTable[codecThinkId] (NOTE: think, not no-think) +//! [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId] +//! [5]: ttsPadEmbed + talkerEmbTable[langId] (NEW row, language embed) +//! [6]: ttsPadEmbed + talkerEmbTable[codecThinkEosId] +//! [7]: ttsPadEmbed + talkerEmbTable[speakerId] +//! [8]: ttsBosEmbed + talkerEmbTable[codecPadId] +//! [9..9+N-1]: projected[3+i] + talkerEmbTable[codecPad/codecBos] (last row uses codecBosId) +//! [9+N]: ttsEosEmbed + talkerEmbTable[codecPadId] +//! [9+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId] //! //! \param projected MLP output [seqLen, H] (FP16) //! \param ttsPadEmbed/ttsBosEmbed/ttsEosEmbed TTS special embeddings [H] (FP16) //! \param talkerEmbTable Talker embedding table [vocabSize, H] (FP16) -//! \param codecNothinkId..codecBosId Codec token IDs used in rows [3-8+N+1] -//! \param speakerId Speaker codec token ID (row 6) -//! \param textLen Number of text token rows (N = seqLen - 8) -//! \param output Full output buffer [8+N+2, H] (FP16) -//! \param stream CUDA stream +//! \param codecNothinkId Codec no-think control token (used when langId < 0) +//! \param codecThinkId Codec think control token (used when langId >= 0) +//! \param codecThinkBosId/codecThinkEosId/codecPadId/codecBosId Codec control IDs +//! \param speakerId Speaker codec token ID +//! \param langId Language codec token ID; if < 0, no-language path is used +//! \param textLen Number of text token rows (N) +//! \param output Full output buffer (FP16) +//! \param stream CUDA stream void invokeAssistantPreamble(rt::Tensor const& projected, rt::Tensor const& ttsPadEmbed, rt::Tensor const& ttsBosEmbed, - rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkBosId, - int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId, int32_t textLen, - rt::Tensor& output, cudaStream_t stream); + rt::Tensor const& ttsEosEmbed, rt::Tensor const& talkerEmbTable, int32_t codecNothinkId, int32_t codecThinkId, + int32_t codecThinkBosId, int32_t codecThinkEosId, int32_t speakerId, int32_t codecPadId, int32_t codecBosId, + int32_t langId, int32_t textLen, rt::Tensor& output, cudaStream_t stream); //! \brief Fused residual connection for TTS decode input //! diff --git a/cpp/runtime/qwen3OmniTTSRuntime.cpp b/cpp/runtime/qwen3OmniTTSRuntime.cpp index d0c01cd..ece9fb4 100644 --- a/cpp/runtime/qwen3OmniTTSRuntime.cpp +++ b/cpp/runtime/qwen3OmniTTSRuntime.cpp @@ -32,6 +32,7 @@ #include "profiling/timer.h" #include "sampler/sampling.h" #include +#include #include #include #include @@ -358,6 +359,22 @@ bool Qwen3OmniTTSRuntime::validateAndFillConfig(std::string const& talkerEngineD mTalkerConfig.codecThinkEosId = configJson["codec_think_eos_id"].get(); mTalkerConfig.codecPadId = configJson["codec_pad_id"].get(); mTalkerConfig.codecBosId = configJson["codec_bos_id"].get(); + + // CustomVoice: language-conditioned prefill path needs codec_think_id + codec_language_id map. + // Both fields are optional; if absent the runtime falls back to the legacy no-language path. + if (configJson.contains("codec_think_id")) + { + mTalkerConfig.codecThinkId = configJson["codec_think_id"].get(); + } + if (configJson.contains("codec_language_id") && configJson["codec_language_id"].is_object()) + { + for (auto const& [k, v] : configJson["codec_language_id"].items()) + { + mTalkerConfig.codecLanguageId[k] = v.get(); + } + } + LOG_INFO("CustomVoice language config: codecThinkId=%d, codecLanguageId entries=%zu", mTalkerConfig.codecThinkId, + mTalkerConfig.codecLanguageId.size()); // Support both codec_eos_token_id (original) and codec_eos_id (legacy) for backward compatibility if (configJson.contains("codec_eos_token_id")) { @@ -750,8 +767,8 @@ void Qwen3OmniTTSRuntime::initializeTTSEmbeddings(cudaStream_t stream) LOG_INFO("TTS embeddings initialized"); } -bool Qwen3OmniTTSRuntime::projectToTalkerInput( - rt::Tensor const& thinkerEmbed, int32_t speakerId, rt::Tensor& output, int64_t& outputSeqLen, cudaStream_t stream) +bool Qwen3OmniTTSRuntime::projectToTalkerInput(rt::Tensor const& thinkerEmbed, int32_t speakerId, int32_t langId, + rt::Tensor& output, int64_t& outputSeqLen, cudaStream_t stream) { int64_t const seqLen = thinkerEmbed.getShape()[0]; int64_t const hiddenSize = mTalkerConfig.talkerHiddenSize; @@ -759,10 +776,14 @@ bool Qwen3OmniTTSRuntime::projectToTalkerInput( // N = text tokens after stripping 3-token role prefix and 5-token suffix int64_t const N = seqLen - kAssistantPrefixLen - kAssistantTrailingSuffix; - // Non-streaming prefill: 8 fixed prefix rows + N text rows + 2 suffix rows - outputSeqLen = kNonStreamingPrefixRows + N + 2; // = seqLen + 2 - LOG_INFO("projectToTalkerInput: seqLen=%ld, N=%ld (stripped prefix=%d suffix=%d), outputSeqLen=%ld, speakerId=%d", - seqLen, N, kAssistantPrefixLen, kAssistantTrailingSuffix, outputSeqLen, speakerId); + // Non-streaming prefill: kFixedPrefixLen rows + N text rows + 2 suffix rows. + // langId >= 0 uses the 9-row CustomVoice language prefix; otherwise 8-row legacy prefix. + int64_t const kFixedPrefixLen = (langId >= 0) ? 9 : kNonStreamingPrefixRows; + outputSeqLen = kFixedPrefixLen + N + 2; + LOG_INFO( + "projectToTalkerInput: seqLen=%ld, N=%ld (stripped prefix=%d suffix=%d), outputSeqLen=%ld, speakerId=%d, " + "langId=%d, prefixRows=%ld", + seqLen, N, kAssistantPrefixLen, kAssistantTrailingSuffix, outputSeqLen, speakerId, langId, kFixedPrefixLen); // Project all tokens via text_projection MLP check::check(mProjectedBuffer.reshape({seqLen, hiddenSize}), "Tensor reshape failed"); @@ -773,8 +794,9 @@ bool Qwen3OmniTTSRuntime::projectToTalkerInput( // Fused kernel: build complete non-streaming prefill buffer check::check(output.reshape({outputSeqLen, hiddenSize}), "Tensor reshape failed"); kernel::invokeAssistantPreamble(mProjectedBuffer, mTtsPadEmbed, mTtsBosEmbed, mTtsEosEmbed, mTalkerEmbeddingTable, - mTalkerConfig.codecNothinkId, mTalkerConfig.codecThinkBosId, mTalkerConfig.codecThinkEosId, speakerId, - mTalkerConfig.codecPadId, mTalkerConfig.codecBosId, static_cast(N), output, stream); + mTalkerConfig.codecNothinkId, mTalkerConfig.codecThinkId, mTalkerConfig.codecThinkBosId, + mTalkerConfig.codecThinkEosId, speakerId, mTalkerConfig.codecPadId, mTalkerConfig.codecBosId, langId, + static_cast(N), output, stream); return true; } @@ -1030,9 +1052,33 @@ bool Qwen3OmniTTSRuntime::prepareTalkerInput(std::vector const& textTok speakerId = getSpeakerIdByName(request.speakerName); } - // MLP projection: thinker embed → talker input embeds (non-streaming, outputSeqLen = seqLen + 2) + // CustomVoice language conditioning: resolve language string -> codec token ID. + // Lower-case the request string; lookup in codecLanguageId map. -1 means use legacy no-language path. + int32_t langId = -1; + if (!request.language.empty()) + { + std::string langLc = request.language; + std::transform(langLc.begin(), langLc.end(), langLc.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + auto it = mTalkerConfig.codecLanguageId.find(langLc); + if (it != mTalkerConfig.codecLanguageId.end()) + { + langId = it->second; + LOG_INFO( + "CustomVoice language conditioning enabled: language=\"%s\" -> codec_id=%d", langLc.c_str(), langId); + } + else + { + LOG_WARNING( + "Requested language=\"%s\" not found in codec_language_id map (size=%zu); " + "falling back to no-language prefill path", + langLc.c_str(), mTalkerConfig.codecLanguageId.size()); + } + } + + // MLP projection: thinker embed → talker input embeds (non-streaming, outputSeqLen = seqLen + 2 or +3 w/ lang) int64_t const hiddenSize = mTalkerConfig.talkerHiddenSize; - if (!projectToTalkerInput(mThinkerEmbedBuffer, speakerId, mTalkerInputEmbeds, outSeqLen, stream)) + if (!projectToTalkerInput(mThinkerEmbedBuffer, speakerId, langId, mTalkerInputEmbeds, outSeqLen, stream)) { LOG_ERROR("MLP projection failed"); return false; @@ -2019,9 +2065,12 @@ bool Qwen3OmniTTSRuntime::buildTalkerPrefillFromSegments(std::vector co rt::Tensor assistantSlice(const_cast<__half*>(assistantProjPtr), rt::Coords{assistantInputLen, hiddenSize}, rt::DeviceType::kGPU, nvinfer1::DataType::kHALF); + // Omni segment path: language conditioning is not (yet) propagated through Omni; pass langId=-1 + // to keep the legacy 8-row prefix. kernel::invokeAssistantPreamble(assistantSlice, mTtsPadEmbed, mTtsBosEmbed, mTtsEosEmbed, mTalkerEmbeddingTable, - mTalkerConfig.codecNothinkId, mTalkerConfig.codecThinkBosId, mTalkerConfig.codecThinkEosId, speakerId, - mTalkerConfig.codecPadId, mTalkerConfig.codecBosId, 1, preambleScratch, stream); + mTalkerConfig.codecNothinkId, mTalkerConfig.codecThinkId, mTalkerConfig.codecThinkBosId, + mTalkerConfig.codecThinkEosId, speakerId, mTalkerConfig.codecPadId, mTalkerConfig.codecBosId, + /*langId=*/-1, /*textLen=*/1, preambleScratch, stream); __half* const aOut = static_cast<__half*>(mTalkerInputEmbeds.rawPointer()) + userTotalLen * hiddenSize; CUDA_CHECK(cudaMemcpyAsync(aOut, scratchPtr, kAssistantRestructuredLen * hiddenSize * sizeof(__half), diff --git a/cpp/runtime/qwen3OmniTTSRuntime.h b/cpp/runtime/qwen3OmniTTSRuntime.h index 94fa578..3d67b15 100644 --- a/cpp/runtime/qwen3OmniTTSRuntime.h +++ b/cpp/runtime/qwen3OmniTTSRuntime.h @@ -119,7 +119,10 @@ class Qwen3OmniTTSRuntime // Speaker selection (optional, defaults to config default) std::string speakerName{""}; //!< Speaker name (e.g., "f245", "m02") - empty means use default - int32_t speakerId{-1}; //!< Speaker ID - if >= 0, overrides speakerName + //!< CustomVoice language conditioning (e.g., "chinese", "english"). Empty = no language path. + //!< Lower-cased by the runtime before lookup in TalkerConfig::codecLanguageId. + std::string language{""}; + int32_t speakerId{-1}; //!< Speaker ID - if >= 0, overrides speakerName // Input: conversation messages for this request (runtime tokenizes internally) std::vector messages; @@ -460,6 +463,7 @@ class Qwen3OmniTTSRuntime // Codec special tokens (from talker vocab, used directly) int32_t codecNothinkId{}; //!< Codec no-think control token (2155) + int32_t codecThinkId{}; //!< Codec think control token (CustomVoice + language path) int32_t codecThinkBosId{}; //!< Codec think begin-of-sequence (2156) int32_t codecThinkEosId{}; //!< Codec think end-of-sequence (2157) int32_t codecPadId{}; //!< Codec padding token (2148) @@ -468,6 +472,10 @@ class Qwen3OmniTTSRuntime // Speaker configuration (read from config) int32_t defaultSpeakerId{}; //!< Default speaker ID (e.g., 2301 for f245) + + //!< CustomVoice language conditioning: map of lower-case language name -> codec token ID. + //!< Empty when the model is not a CustomVoice language-conditioned variant. + std::unordered_map codecLanguageId{}; }; // ========== Configuration and Initialization ========== @@ -619,7 +627,7 @@ class Qwen3OmniTTSRuntime * @param stream CUDA stream * @return True on success, false on failure */ - bool projectToTalkerInput(rt::Tensor const& thinkerEmbed, int32_t speakerId, rt::Tensor& output, + bool projectToTalkerInput(rt::Tensor const& thinkerEmbed, int32_t speakerId, int32_t langId, rt::Tensor& output, int64_t& outputSeqLen, cudaStream_t stream); //! Embed token IDs, run MLP projection, and reshape buffers ready for Talker prefill. diff --git a/examples/omni/qwen3_tts_inference.cpp b/examples/omni/qwen3_tts_inference.cpp index db7dfd2..cae622c 100644 --- a/examples/omni/qwen3_tts_inference.cpp +++ b/examples/omni/qwen3_tts_inference.cpp @@ -48,6 +48,8 @@ struct ParsedInput std::vector> requests; // Per-request speaker name (parallel to requests). Falls back to top-level "speaker" default. std::vector requestSpeakers; + // Per-request CustomVoice language (parallel to requests). Falls back to top-level "language" default. + std::vector requestLanguages; bool applyChatTemplate{true}; bool addGenerationPrompt{true}; bool enableThinking{false}; @@ -56,6 +58,7 @@ struct ParsedInput float talkerTopP{1.0f}; float repetitionPenalty{1.05f}; std::string speakerName{""}; + std::string language{""}; int32_t maxAudioLength{4096}; }; @@ -89,6 +92,7 @@ ParsedInput parseInputFile(std::filesystem::path const& inputFilePath, int32_t b result.talkerTopP = inputData.value("talker_top_p", 1.0f); result.repetitionPenalty = inputData.value("repetition_penalty", 1.05f); result.speakerName = inputData.value("speaker", ""); + result.language = inputData.value("language", ""); result.maxAudioLength = inputData.value("max_audio_length", 4096); check::check( @@ -104,6 +108,7 @@ ParsedInput parseInputFile(std::filesystem::path const& inputFilePath, int32_t b "Each request must contain a 'messages' array"); std::string requestSpeaker = requestItem.value("speaker", result.speakerName); + std::string requestLanguage = requestItem.value("language", result.language); auto const& messagesArray = requestItem["messages"]; check::check(messagesArray.size() <= limits::security::kMaxMessagesPerRequest, @@ -145,6 +150,7 @@ ParsedInput parseInputFile(std::filesystem::path const& inputFilePath, int32_t b } result.requests.push_back(std::move(messages)); result.requestSpeakers.push_back(std::move(requestSpeaker)); + result.requestLanguages.push_back(std::move(requestLanguage)); } return result; @@ -381,6 +387,7 @@ int main(int argc, char** argv) talkerReq.addGenerationPrompt = input.addGenerationPrompt; talkerReq.enableThinking = input.enableThinking; talkerReq.speakerName = input.requestSpeakers[requestIdx]; + talkerReq.language = input.requestLanguages[requestIdx]; talkerReq.maxAudioLength = input.maxAudioLength; talkerReq.messages = input.requests[requestIdx]; diff --git a/experimental/llm_loader/export_all_cli.py b/experimental/llm_loader/export_all_cli.py index 3686ee1..7d6ba26 100644 --- a/experimental/llm_loader/export_all_cli.py +++ b/experimental/llm_loader/export_all_cli.py @@ -1077,6 +1077,12 @@ def _patch_tts_config(model_dir: str, out_dir: str) -> None: if key in talker_cfg: cfg[key] = talker_cfg[key] + # CustomVoice language conditioning map: language name -> codec token id. + # Required by the runtime to inject the language-row in the 9-row prefix + # for CustomVoice checkpoints. Optional for non-CustomVoice models. + if "codec_language_id" in talker_cfg: + cfg["codec_language_id"] = talker_cfg["codec_language_id"] + # thinker_hidden_size and text_vocab_size # Qwen3-Omni exposes ``thinker_hidden_size`` directly on talker_config; # Qwen3-TTS uses ``text_hidden_size`` — accept either.