diff --git a/apps/llm/app/multimodal_llm/index.tsx b/apps/llm/app/multimodal_llm/index.tsx index 0de5004849..09670f47a8 100644 --- a/apps/llm/app/multimodal_llm/index.tsx +++ b/apps/llm/app/multimodal_llm/index.tsx @@ -12,6 +12,11 @@ import { View, } from 'react-native'; import { launchImageLibrary } from 'react-native-image-picker'; +import { + AudioManager, + AudioRecorder, + AudioContext, +} from 'react-native-audio-api'; import { useIsFocused } from '@react-navigation/native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; import { models, useLLM } from 'react-native-executorch'; @@ -29,6 +34,7 @@ const SUGGESTED_PROMPTS = [ 'Describe this scene in detail', 'What objects can you see?', 'What text appears in this image?', + 'Transcribe the audio?', ]; import { useLLMStats } from '../../hooks/useLLMStats'; import { StatsBar } from '../../components/StatsBar'; @@ -46,11 +52,19 @@ function MultimodalLLMScreen() { const textInputRef = useRef(null); const { setGlobalGenerating } = useContext(GeneratingContext); - // Added error state + const [audioBuffer, setAudioBuffer] = useState(null); + const [audioLabel, setAudioLabel] = useState(null); + const [audioUrl, setAudioUrl] = useState(''); + const [isFetchingAudio, setIsFetchingAudio] = useState(false); + const [isRecording, setIsRecording] = useState(false); + const [hasMicPermission, setHasMicPermission] = useState(false); + const recorder = useRef(new AudioRecorder()); + const recordChunks = useRef([]); + const [error, setError] = useState(null); const vlm = useLLM({ - model: models.llm.lfm2_5_vl_1_6b(), + model: models.llm.gemma4_e2b(), }); const tokenCount = vlm.isReady ? vlm.getGeneratedTokenCount() : 0; const { stats, onMessageSend } = useLLMStats( @@ -68,6 +82,87 @@ function MultimodalLLMScreen() { if (vlm.error) setError(String(vlm.error)); }, [vlm.error]); + useEffect(() => { + AudioManager.setAudioSessionOptions({ + iosCategory: 'playAndRecord', + iosMode: 'spokenAudio', + iosOptions: ['allowBluetoothHFP', 'defaultToSpeaker'], + }); + (async () => { + const status = await AudioManager.requestRecordingPermissions(); + setHasMicPermission(status === 'Granted'); + })(); + }, []); + + const loadAudioFromUrl = async () => { + const url = audioUrl.trim(); + if (!url) return; + setIsFetchingAudio(true); + try { + const ctx = new AudioContext({ sampleRate: 16000 }); + const decoded = await ctx.decodeAudioData(url); + const pcm = decoded.getChannelData(0); + const name = url.split('/').pop() || 'audio'; + setAudioBuffer(pcm); + setAudioLabel(`${name} ยท ${(pcm.length / 16000).toFixed(1)}s`); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } finally { + setIsFetchingAudio(false); + } + }; + + const startRecording = async () => { + if (!hasMicPermission) { + setError('Microphone permission denied. Please enable it in Settings.'); + return; + } + recordChunks.current = []; + const sampleRate = 16000; + recorder.current.onAudioReady( + { sampleRate, bufferLength: 0.1 * sampleRate, channelCount: 1 }, + ({ buffer }) => { + recordChunks.current.push(new Float32Array(buffer.getChannelData(0))); + } + ); + try { + const ok = await AudioManager.setAudioSessionActivity(true); + if (!ok) { + setError('Cannot start audio session'); + return; + } + const result = recorder.current.start(); + if (result.status === 'error') { + setError(`Recording problems: ${result.message}`); + return; + } + setIsRecording(true); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } + }; + + const stopRecording = () => { + recorder.current.stop(); + setIsRecording(false); + const total = recordChunks.current.reduce((n, c) => n + c.length, 0); + if (total === 0) return; + const pcm = new Float32Array(total); + let off = 0; + for (const c of recordChunks.current) { + pcm.set(c, off); + off += c.length; + } + recordChunks.current = []; + setAudioBuffer(pcm); + setAudioLabel(`Recording ยท ${(pcm.length / 16000).toFixed(1)}s`); + }; + + const clearAudio = () => { + setAudioBuffer(null); + setAudioLabel(null); + }; + const pickImage = async () => { try { const result = await launchImageLibrary({ mediaType: 'photo' }); @@ -81,19 +176,27 @@ function MultimodalLLMScreen() { }; const sendMessage = async () => { - if (!userInput.trim() || vlm.isGenerating) return; + if (!(imageUri || audioBuffer || userInput.trim()) || vlm.isGenerating) + return; onMessageSend(); const text = userInput.trim(); setUserInput(''); textInputRef.current?.clear(); Keyboard.dismiss(); const currentImageUri = imageUri; + const currentAudio = audioBuffer; setImageUri(null); + setAudioBuffer(null); + setAudioLabel(null); try { - await vlm.sendMessage( - text, - currentImageUri ? { imagePath: currentImageUri } : undefined - ); + const media = + currentImageUri || currentAudio + ? { + ...(currentImageUri ? { imagePath: currentImageUri } : {}), + ...(currentAudio ? { audioBuffer: currentAudio } : {}), + } + : undefined; + await vlm.sendMessage(text, media); } catch (e) { // Updated to set UI error instead of just console.error setError(e instanceof Error ? e.message : String(e)); @@ -159,6 +262,42 @@ function MultimodalLLMScreen() { )} + {/* Audio URL input */} + + + + + {isFetchingAudio ? 'โ€ฆ' : 'Load'} + + + + + {/* Audio attachment strip */} + {audioLabel && ( + + ๐ŸŽต {audioLabel} + + โœ• + + + )} + ๐Ÿ“ท + {/* Mic record / stop button */} + + + {isRecording ? 'โน๏ธ' : '๐ŸŽค'} + + + - {userInput.trim() && !vlm.isGenerating && ( - - - - )} + {(imageUri || audioBuffer || userInput.trim()) && + !vlm.isGenerating && ( + + + + )} {vlm.isGenerating && ( [] = [ + // Gemma4 + { label: 'Gemma4 e2b Quantized', value: GEMMA4_E2B_QUANTIZED }, // Llama 3.2 { label: 'Llama 3.2 1B', diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index f94ef918ac..18f7c10aec 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -223,6 +223,22 @@ inline std::vector getValue>(const jsi::Value &val, return getArrayAsVector(val, runtime); } +template <> +inline std::vector> +getValue>>(const jsi::Value &val, + jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + const size_t length = array.size(runtime); + std::vector> result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + auto span = + getTypedArrayAsSpan(array.getValueAtIndex(runtime, i), runtime); + result.emplace_back(span.begin(), span.end()); + } + return result; +} + template <> inline std::vector getValue>(const jsi::Value &val, jsi::Runtime &runtime) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 7e0fa4b26e..331624270a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -4,8 +4,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -21,7 +21,6 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, std::vector capabilities, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker, Module::LoadMode::Mmap) { - if (capabilities.empty()) { runner_ = std::make_unique(std::move(module_), tokenizerSource); @@ -31,6 +30,9 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, if (cap == "vision") { encoders[llm::MultimodalType::Image] = std::make_unique(*module_); + } else if (cap == "audio") { + encoders[llm::MultimodalType::Audio] = + std::make_unique(*module_); } } runner_ = std::make_unique( @@ -74,63 +76,68 @@ std::string LLM::generate(std::string input, return output; } -std::string LLM::generateMultimodal(std::string prompt, - std::vector imagePaths, - std::string imageToken, - std::shared_ptr callback) { +std::string LLM::generateMultimodal( + std::string prompt, std::shared_ptr callback, + std::vector imagePaths, std::string imageToken, + std::vector> audioWaveforms, std::string audioToken) { if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Runner is not loaded"); } if (!runner_->is_multimodal()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "This model does not support multimodal input. Use generate(prompt, " - "callback) for text-only generation."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "This model does not support multimodal input."); } - if (imageToken.empty()) { + if (imageToken.empty() && audioToken.empty()) { throw RnExecutorchError( RnExecutorchErrorCode::InvalidUserInput, - "imageToken must not be empty. Pass the model's image token (e.g. " - "from tokenizer_config.json)."); + "At least one of imageToken/audioToken must be non-empty"); } - const size_t kImageTokenLen = imageToken.size(); - + // Scan the prompt once, splitting at the earliest placeholder at each step + // so that image/audio placeholders can be freely interleaved in the prompt. std::vector inputs; - size_t imageIdx = 0; - size_t searchPos = 0; - - while (true) { - size_t found = prompt.find(imageToken, searchPos); - if (found == std::string::npos) { - if (searchPos < prompt.size()) { - inputs.push_back(llm::make_text_input(prompt.substr(searchPos))); - } + size_t imageIdx = 0, audioIdx = 0, pos = 0; + constexpr int32_t kAudioSampleRate = 16000; + while (pos < prompt.size()) { + size_t imgAt = + imageToken.empty() ? std::string::npos : prompt.find(imageToken, pos); + size_t audAt = + audioToken.empty() ? std::string::npos : prompt.find(audioToken, pos); + if (imgAt == std::string::npos && audAt == std::string::npos) { + inputs.push_back(llm::make_text_input(prompt.substr(pos))); break; } - // Text segment before this placeholder - if (found > searchPos) { - inputs.push_back( - llm::make_text_input(prompt.substr(searchPos, found - searchPos))); + const bool imageFirst = imgAt != std::string::npos && + (audAt == std::string::npos || imgAt < audAt); + size_t at = imageFirst ? imgAt : audAt; + if (at > pos) { + inputs.push_back(llm::make_text_input(prompt.substr(pos, at - pos))); } - // Image at this position - if (imageIdx >= imagePaths.size()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "More '" + imageToken + - "' placeholders in prompt than image paths provided"); + if (imageFirst) { + if (imageIdx >= imagePaths.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More '" + imageToken + + "' placeholders than image paths"); + } + inputs.push_back(llm::make_image_input(imagePaths[imageIdx++])); + pos = at + imageToken.size(); + } else { + if (audioIdx >= audioWaveforms.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More '" + audioToken + + "' placeholders than audio waveforms"); + } + inputs.push_back(llm::make_audio_input( + std::move(audioWaveforms[audioIdx++]), kAudioSampleRate)); + pos = at + audioToken.size(); } - inputs.push_back(llm::make_image_input(imagePaths[imageIdx++])); - searchPos = found + kImageTokenLen; } - - if (imageIdx < imagePaths.size()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "More image paths provided than '" + imageToken + - "' placeholders in prompt"); + if (imageIdx < imagePaths.size() || audioIdx < audioWaveforms.size()) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "More image/audio paths provided than placeholders in prompt"); } - if (inputs.empty()) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "No inputs to generate from"); @@ -150,7 +157,6 @@ std::string LLM::generateMultimodal(std::string prompt, if (error != Error::Ok) { throw RnExecutorchError(error, "Failed to generate multimodal response"); } - return output; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 222b5bc62f..bf1c44313d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -22,10 +22,16 @@ class LLM : public BaseModel { std::string generate(std::string prompt, std::shared_ptr callback); - std::string generateMultimodal(std::string prompt, - std::vector imagePaths, - std::string imageToken, - std::shared_ptr callback); + // Audio variant: `audioWaveforms` is a parallel vector of fp32 mono 16 kHz + // PCM buffers (decoded upstream, same contract as SpeechToText::transcribe). + // The prompt is scanned for `imageToken` and/or `audioToken` placeholders; + // each placeholder consumes the next entry from its respective vector in + // order. Either set of paths/waveforms/token may be empty. + std::string generateMultimodal( + std::string prompt, std::shared_ptr callback, + std::vector imagePaths = {}, std::string imageToken = "", + std::vector> audioWaveforms = {}, + std::string audioToken = ""); void interrupt(); void reset(); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 1f34b3a18e..5f9d7287a5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -293,6 +293,7 @@ add_rn_test(LLMTests integration/LLMTest.cpp ${COMMON_DIR}/runner/sampler.cpp ${COMMON_DIR}/runner/arange_util.cpp ${COMMON_DIR}/runner/encoders/vision_encoder.cpp + ${COMMON_DIR}/runner/encoders/audio_encoder.cpp ${IMAGE_UTILS_SOURCES} LIBS tokenizers_deps opencv_deps ) diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.cpp b/packages/react-native-executorch/common/runner/base_llm_runner.cpp index a021040807..7229d64f20 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.cpp +++ b/packages/react-native-executorch/common/runner/base_llm_runner.cpp @@ -56,11 +56,16 @@ Error BaseLLMRunner::load() { ? static_cast(metadata_.at(kMaxContextLen)) : static_cast(metadata_.at(kMaxSeqLen)); } - if (config_.max_new_tokens < 0) - config_.max_new_tokens = - std::min(config_.max_seq_len, config_.max_context_length); config_.enable_dynamic_shape = static_cast(metadata_.at(kEnableDynamicShape)); + if (config_.max_new_tokens < 0) { + // For dynamic-shape PTEs, max_seq_len is the per-call decoder chunk + // size, not the generation budget โ€” use max_context_length instead. + const int32_t seq_cap = config_.enable_dynamic_shape + ? config_.max_context_length + : config_.max_seq_len; + config_.max_new_tokens = std::min(seq_cap, config_.max_context_length); + } config_.enable_kv_cache = static_cast(metadata_.at(kUseKVCache)); eos_ids_ = std::make_unique>(); @@ -149,6 +154,8 @@ void BaseLLMRunner::set_repetition_penalty(float repetition_penalty) noexcept { config_.repetition_penalty = repetition_penalty; } +void BaseLLMRunner::set_topk(int32_t topk) noexcept { config_.topk = topk; } + void BaseLLMRunner::set_count_interval(size_t count_interval) { config_.output_token_batch_size = count_interval; } diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.h b/packages/react-native-executorch/common/runner/base_llm_runner.h index 9710f5ae70..82de49bea3 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.h +++ b/packages/react-native-executorch/common/runner/base_llm_runner.h @@ -55,6 +55,7 @@ class BaseLLMRunner { void set_topp(float topp) noexcept; void set_min_p(float min_p) noexcept; void set_repetition_penalty(float repetition_penalty) noexcept; + void set_topk(int32_t topk) noexcept; void set_count_interval(size_t count_interval); void set_time_interval(size_t time_interval); diff --git a/packages/react-native-executorch/common/runner/constants.h b/packages/react-native-executorch/common/runner/constants.h index f1fee23471..368371688a 100644 --- a/packages/react-native-executorch/common/runner/constants.h +++ b/packages/react-native-executorch/common/runner/constants.h @@ -23,8 +23,22 @@ inline constexpr auto kVisionEncoderMethod = "vision_encoder"; inline constexpr auto kAudioEncoderMethod = "audio_encoder"; inline constexpr auto kTokenEmbeddingMethod = "token_embedding"; inline constexpr auto kTextModelMethod = "text_decoder"; - inline constexpr auto numOfAddedBoSTokens = 0; inline constexpr auto numOfAddedEoSTokens = 0; +// Gemma4 +// PLE models only: token id that marks image placeholder slots in input_ids. +// token_embedding run on this id produces the per-layer PLE signal for image +// positions; the inputs_embeds output for those positions is discarded (the +// vision encoder output replaces it). +inline constexpr auto kImagePlaceholderId = "image_placeholder_id"; +// True iff the model exposes a per-layer-embedding (PLE) signal alongside +// inputs_embeds (Gemma4-style). When true, `token_embedding.execute()` +// returns the tuple (inputs_embeds, ple_tok) and the runner must thread +// ple_tok into text_decoder; when false (or absent), token_embedding returns +// inputs_embeds alone. Text-only PTEs that ship a single `forward` method +// omit this key entirely โ€” it is meaningful only for multimodal PTEs that +// expose a separate `token_embedding` method. +inline constexpr auto kHasPLE = "has_ple"; + } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp new file mode 100644 index 0000000000..52fa01ad66 --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp @@ -0,0 +1,126 @@ +// common/runner/encoders/audio_encoder.cpp +#include "audio_encoder.h" + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::aten::SizesType; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; + +namespace { +// Matches AUDIO_SAMPLES_PER_BLOCK in gemma_export/experiments_vulkan/ +// op_bisect/iter201_mm_4method_dynaudio_prefill2048_export.py. +// The PTE's audio_samples dim was exported as `7680 * audio_blocks`. +constexpr int32_t kSamplesPerBlock = 7680; +// k โˆˆ [kAudioBlockKMin, kAudioBlockKMax] from MODEL_INTERFACE.md ยง6. +// k=62 == 29.76 s @ 16 kHz is the SDPA mask + rel-shift bake point. +constexpr int64_t kAudioBlockKMin = 1; +constexpr int64_t kAudioBlockKMax = 62; +} // namespace + +AudioEncoder::AudioEncoder(::executorch::extension::Module &module) + : module_(&module) {} + +Error AudioEncoder::load() { + if (is_loaded()) { + return Error::Ok; + } + auto method_names_result = module_->method_names(); + if (!method_names_result.ok()) { + return method_names_result.error(); + } + if (method_names_result->count(kAudioEncoderMethod) == 0) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "Model does not support audio: 'audio_encoder' method not found. " + "Check that the .pte file matches the declared capabilities."); + } + return module_->load_method(kAudioEncoderMethod); +} + +bool AudioEncoder::is_loaded() const noexcept { + return module_->is_method_loaded(kAudioEncoderMethod); +} + +int32_t AudioEncoder::encoderTokenCount() const { return last_token_count_; } + +Result AudioEncoder::encode(const MultimodalInput &input) { + if (!is_loaded()) { + return Error::InvalidState; + } + if (!input.is_audio()) { + return Error::InvalidArgument; + } + + const auto &wav = input.get_audio(); + ET_CHECK_OR_RETURN_ERROR(!wav.samples.empty(), InvalidArgument, + "AudioEncoder: empty waveform"); + ET_CHECK_OR_RETURN_ERROR( + wav.sample_rate == 16000, InvalidArgument, + "AudioEncoder: expected 16000 Hz waveform, got %d Hz", wav.sample_rate); + + const int64_t n_valid = static_cast(wav.samples.size()); + const int64_t k_blocks = (n_valid + kSamplesPerBlock - 1) / kSamplesPerBlock; + ET_CHECK_OR_RETURN_ERROR( + k_blocks >= kAudioBlockKMin && k_blocks <= kAudioBlockKMax, + InvalidArgument, + "AudioEncoder: waveform of %lld samples needs k_blocks=%lld; " + "audio_encoder accepts k in [%lld, %lld] (block=%d samples; max %.2f s " + "@ 16 kHz)", + static_cast(n_valid), static_cast(k_blocks), + static_cast(kAudioBlockKMin), + static_cast(kAudioBlockKMax), + static_cast(kSamplesPerBlock), + static_cast(kSamplesPerBlock) * + static_cast(kAudioBlockKMax) / 16000.0); + const int64_t n_padded = k_blocks * kSamplesPerBlock; + + // Own the padded waveform for the lifetime of this call; from_blob below + // borrows without copying. The current export takes + // forward(waveform[1, 7680*k] fp32, num_blocks: int64 scalar) + // โ€” input 1 is a rank-0 Long telling the encoder how many of the K_MAX + // blocks contain real PCM. Passing a 2-d mask here trips "Attempted to + // change tensor rank: old=0, new=2". + padded_wav_.assign(static_cast(n_padded), 0.0f); + std::memcpy(padded_wav_.data(), wav.samples.data(), + static_cast(n_valid) * sizeof(float)); + + num_blocks_scalar_ = n_valid; + + auto wav_tensor = ::executorch::extension::from_blob( + padded_wav_.data(), {1, static_cast(n_padded)}, + ::executorch::aten::ScalarType::Float); + + auto num_blocks_tensor = ::executorch::extension::from_blob( + &num_blocks_scalar_, {}, ::executorch::aten::ScalarType::Long); + + std::vector args = {EValue(*wav_tensor), EValue(*num_blocks_tensor)}; + auto exec_result = ET_UNWRAP(module_->execute(kAudioEncoderMethod, args)); + ET_CHECK_OR_RETURN_ERROR(!exec_result.empty(), InvalidState, + "audio_encoder returned no outputs"); + auto audio_tensor = exec_result[0].toTensor(); + ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState, + "audio_encoder output rank=%zd, expected 3", + audio_tensor.dim()); + last_token_count_ = static_cast(audio_tensor.size(1)); + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info, + "AudioEncoder: valid_samples=", n_valid, + " padded_samples=", n_padded, " k_blocks=", k_blocks, + " audio_tokens=", last_token_count_); + return exec_result[0]; +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/audio_encoder.h b/packages/react-native-executorch/common/runner/encoders/audio_encoder.h new file mode 100644 index 0000000000..f6890e9e46 --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/audio_encoder.h @@ -0,0 +1,40 @@ +// common/runner/encoders/audio_encoder.h +#pragma once + +#include "iencoder.h" +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +// Runs the Gemma4 `audio_encoder` PTE method. +// +// Contract mirrors SpeechToText (Whisper): JS hands in fp32 mono 16 kHz PCM +// via `MultimodalInput::get_audio()`; the PTE owns the log-mel frontend so +// this class just wraps the samples in a `[1, N_samples]` Float tensor and +// executes. Resampling and WAV/MP3 decoding are the caller's responsibility +// (e.g. react-native-audio-api). +class AudioEncoder : public IEncoder { +public: + explicit AudioEncoder(::executorch::extension::Module &module); + + ::executorch::runtime::Error load() override; + bool is_loaded() const noexcept override; + ::executorch::runtime::Result<::executorch::runtime::EValue> + encode(const MultimodalInput &input) override; + // Number of audio embedding tokens produced per encode() call. 0 until first + // encode, since Gemma4's audio_encoder has a dynamic T dim. + int32_t encoderTokenCount() const override; + +private: + ::executorch::extension::Module *module_; + int32_t last_token_count_ = 0; + std::vector padded_wav_; + int64_t num_blocks_scalar_ = 0; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp index de3e196c1f..59fee53e11 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -2,7 +2,6 @@ #include "vision_encoder.h" #include -#include #include #include diff --git a/packages/react-native-executorch/common/runner/irunner.h b/packages/react-native-executorch/common/runner/irunner.h index 54b14c354f..4e5b14444a 100644 --- a/packages/react-native-executorch/common/runner/irunner.h +++ b/packages/react-native-executorch/common/runner/irunner.h @@ -73,6 +73,11 @@ struct GenerationConfig { size_t output_token_batch_size = 10; size_t batch_time_interval_ms = 120; + // Top-k sampling โ€“ keep only the k highest-logit tokens before softmax. + // 0 (default) disables top-k filtering. Stacks with topp: temperature -> + // top-k -> top-p -> softmax -> multinomial. + int32_t topk = 0; + // Enable dynamic input shapes (if implemented) or not // Impacts the prefill phase and causes TextPrefiller to pass all the tokens // at once if set to true. diff --git a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h index 071b193539..df6b0c7911 100644 --- a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h @@ -14,19 +14,53 @@ #include "text_decoder_runner.h" namespace executorch::extension::llm { +// Supports two PTE contracts, selected per-call from the kHasPLE metadata +// key (mirrors how kEnableDynamicShape etc. are read โ€” queried on demand, +// not cached in a member). Callers that need it multiple times in a hot +// path should snapshot into a local. +// +// * Legacy (has_ple == false): +// token_embedding(ids) -> inputs_embeds +// text_decoder(inputs_embeds, input_pos) +// +// * Gemma-style PLE (has_ple == true): +// token_embedding(ids) -> (inputs_embeds, ple_tok) +// text_decoder(inputs_embeds, ple_tok, input_pos) +// ple_tok carries Gemma4's per-layer PLE signal keyed on input_ids. It's +// computed once in token_embedding and threaded through every decoder call +// so PLE fires at every position (including multimodal placeholder slots). class MultimodalDecoderRunner : public TextDecoderRunner { public: explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager, const GenerationConfig &config) : TextDecoderRunner(module, io_manager, config) {} + // True iff the loaded PTE uses the Gemma-style PLE contract above. + // Reads the kHasPLE constant_method every call; cheap, but callers in + // hot loops should snapshot into a local. + bool has_ple() const { + auto r = module_->get(kHasPLE); + if (r.error() != ::executorch::runtime::Error::Ok) { + return false; + } + return r->toScalar().to(); + } + inline ::executorch::runtime::Result<::executorch::aten::Tensor> step(TensorPtr &tokens, int64_t start_pos) override { auto embed_result = module_->execute(kTokenEmbeddingMethod, tokens); if (!embed_result.ok()) { return embed_result.error(); } - return decode((*embed_result)[0], start_pos); + auto &embed_outputs = *embed_result; + if (has_ple()) { + ET_CHECK_MSG(embed_outputs.size() == 2, + "Expected 2 outputs (inputs_embeds, ple_tok) from " + "token_embedding, got %zu", + embed_outputs.size()); + return decode(embed_outputs[0], embed_outputs[1], start_pos); + } + return decode(embed_outputs[0], start_pos); } inline ::executorch::runtime::Result<::executorch::aten::Tensor> @@ -46,6 +80,24 @@ class MultimodalDecoderRunner : public TextDecoderRunner { return outputs[0].toTensor(); } + inline ::executorch::runtime::Result<::executorch::aten::Tensor> + decode(const ::executorch::runtime::EValue &embeddings, + const ::executorch::runtime::EValue &ple_tok, int64_t start_pos) { + auto start_pos_tensor = ::executorch::extension::from_blob( + &start_pos, {1}, ::executorch::aten::ScalarType::Long); + auto outputs_result = module_->execute( + kTextModelMethod, {embeddings, ple_tok, start_pos_tensor}); + if (!outputs_result.ok()) { + return outputs_result.error(); + } + auto &outputs = *outputs_result; + ET_CHECK_MSG(outputs.size() == 1, + "Expected 1 output from text_decoder, got %zu", + outputs.size()); + ET_CHECK_MSG(outputs[0].isTensor(), "text_decoder output is not a tensor"); + return outputs[0].toTensor(); + } + inline ::executorch::runtime::Error load() override { if (is_method_loaded()) { return ::executorch::runtime::Error::Ok; diff --git a/packages/react-native-executorch/common/runner/multimodal_input.h b/packages/react-native-executorch/common/runner/multimodal_input.h index 6b7de35014..b515c866d4 100644 --- a/packages/react-native-executorch/common/runner/multimodal_input.h +++ b/packages/react-native-executorch/common/runner/multimodal_input.h @@ -19,6 +19,15 @@ namespace executorch::extension::llm { struct ImagePath { std::string path; }; +// In-memory raw audio (fp32, mono). Pattern mirrors SpeechToText: the JS +// layer decodes WAV/MP3 via react-native-audio-api and passes Float32Array +// samples; the PTE has the log-mel frontend baked in, so the runner only +// needs the waveform itself. sample_rate is expected to match the PTE's +// mel-extractor (Gemma4: 16000 Hz). +struct AudioWaveform { + std::vector samples; + int32_t sample_rate; +}; class MultimodalInput { public: @@ -27,6 +36,7 @@ class MultimodalInput { : data_(std::move(tokens)) {} explicit MultimodalInput(ImagePath image_path) : data_(std::move(image_path)) {} + explicit MultimodalInput(AudioWaveform audio) : data_(std::move(audio)) {} MultimodalInput(const MultimodalInput &) = default; MultimodalInput &operator=(const MultimodalInput &) = default; @@ -42,6 +52,9 @@ class MultimodalInput { bool is_image() const noexcept { return std::holds_alternative(data_); } + bool is_audio() const noexcept { + return std::holds_alternative(data_); + } const std::string &get_text() const & { return std::get(data_); } const std::vector &get_tokens() const & { @@ -50,9 +63,13 @@ class MultimodalInput { const std::string &get_image_path() const & { return std::get(data_).path; } + const AudioWaveform &get_audio() const & { + return std::get(data_); + } private: - std::variant, ImagePath> data_; + std::variant, ImagePath, AudioWaveform> + data_; }; inline MultimodalInput make_text_input(const std::string &text) noexcept { @@ -64,5 +81,9 @@ inline MultimodalInput make_text_input(std::string &&text) noexcept { inline MultimodalInput make_image_input(std::string path) noexcept { return MultimodalInput(ImagePath{std::move(path)}); } +inline MultimodalInput make_audio_input(std::vector samples, + int32_t sample_rate = 16000) noexcept { + return MultimodalInput(AudioWaveform{std::move(samples), sample_rate}); +} } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp index 83a1a7f79c..57b5d0ac40 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp @@ -8,11 +8,21 @@ // Ported from executorch/extension/llm/runner/multimodal_prefiller.cpp // with our token-embedding padding fix and LFM2-VL adaptations. +// +// Supports two PTE shapes, selected from MultimodalDecoderRunner::has_ple() +// (auto-detected at load time): +// * Legacy : token_embedding -> inputs_embeds; +// text_decoder(inputs_embeds, cache_positions). +// * PLE : token_embedding -> (inputs_embeds, ple_tok); +// text_decoder(inputs_embeds, ple_tok, cache_positions). #include "multimodal_prefiller.h" #include "constants.h" #include "util.h" #include +#include +#include +#include namespace executorch::extension::llm { @@ -23,91 +33,433 @@ using ::executorch::runtime::Result; MultimodalPrefiller::MultimodalPrefiller( Module &module, MultimodalDecoderRunner &decoder_runner, - tokenizers::HFTokenizer &tokenizer, IEncoder *image_encoder) + tokenizers::HFTokenizer &tokenizer, IEncoder *image_encoder, + IEncoder *audio_encoder) : module_(&module), decoder_runner_(&decoder_runner), - tokenizer_(&tokenizer), image_encoder_(image_encoder) {} - -Result MultimodalPrefiller::prefill(const MultimodalInput &input, - int64_t &start_pos) { - EValue encoder_output; - std::vector padded_tokens_storage; - TensorPtr sliced_embed_storage; - - if (input.is_image()) { - ET_CHECK_OR_RETURN_ERROR(image_encoder_ != nullptr, InvalidState, - "No image encoder registered"); - auto encode_result = image_encoder_->encode(input); - ET_CHECK_OK_OR_RETURN_ERROR(encode_result.error(), "Image encoding failed"); - encoder_output = *encode_result; - - } else if (input.is_text() || input.is_tokens()) { - std::vector tokens; - if (input.is_text()) { - auto encode_result = tokenizer_->encode(input.get_text()); - if (!encode_result.ok()) { - ET_LOG(Error, "Tokenizer encode error %d", - static_cast(encode_result.error())); - return Error::InvalidArgument; + tokenizer_(&tokenizer), image_encoder_(image_encoder), + audio_encoder_(audio_encoder) {} + +Result +MultimodalPrefiller::prefill(const std::vector &inputs, + int64_t &start_pos) { + const bool has_ple = decoder_runner_->has_ple(); + const long t_prefill_begin = time_in_ms(); + + ET_CHECK_OR_RETURN_ERROR(!inputs.empty(), InvalidArgument, + "prefill: empty input list"); + + // ------------------------------------------------------------ + // Capacity & shape policy from PTE metadata. + // + // Three knobs drive prefill: + // * get_max_seq_len โ€” text_decoder S cap. In dynamic-shape PTEs this + // is the per-call chunk size (Gemma4 iter201 = + // 128); in static-shape PTEs (LFM2-VL) it is also + // the single-shot prefill cap. + // * get_max_context_len โ€” total KV budget (Gemma4 iter201 = 2048). Only + // materially used by the dynamic-shape path. + // * enable_dynamic_shape โ€” selects between chunked (true) and single-shot + // padded (false) prefill. + // ------------------------------------------------------------ + int64_t max_seq_len = -1; + { + auto r = module_->get(kMaxSeqLen); + if (r.error() == Error::Ok) { + max_seq_len = r->toScalar().to(); + } + } + + int64_t max_context_len = max_seq_len; + { + auto r = module_->get(kMaxContextLen); + if (r.error() == Error::Ok) { + max_context_len = r->toScalar().to(); + } + } + + bool enable_dynamic_shape = false; + { + auto r = module_->get(kEnableDynamicShape); + if (r.error() == Error::Ok) { + enable_dynamic_shape = r->toScalar().to(); + } + } + + const int64_t prefill_total_cap = + enable_dynamic_shape ? max_context_len : max_seq_len; + const int64_t decoder_chunk_size = max_seq_len; + + // ------------------------------------------------------------ + // Pass 1: build a fused input_ids buffer spanning all inputs. + // + // Mirrors gemma_export/experiments/infer_image.py::prefill_single_shot: + // llm_ids = prefix_ids + [0] * num_soft + suffix_ids + // Image positions use pad_token_id=0, matching HF modeling_gemma4.py:2190 + // (placeholder_id is rewritten to 0 before PLE lookup). The decoder embeds + // at those positions are then overwritten with the vision encoder output + // in pass 2. + // ------------------------------------------------------------ + struct ImageSlot { + const MultimodalInput *input; // non-owning, valid for duration of call + int64_t slot_start; + int64_t num_visual; + }; + // Audio tokens are dynamic per clip, so we encode first and remember a + // BYTE SNAPSHOT of the encoder output + count + dtype; pass 2 splices + // from the snapshot. + // + // We can NOT stash the EValue here. EValue holds an aten::Tensor which is + // just a TensorImpl*; `Method::get_output(i)` returns `const EValue&` to + // Method-internal storage and Module::execute copies that EValue into the + // returned vector. The copy shares the underlying TensorImpl, so a later + // execute() on the same method โ€” a second audio input in this prefill, + // a Vulkan backend output-buffer reuse across methods, or a load-time + // warm-up โ€” mutates `sizes()` in place under our feet. The original error + // ("audio encoder returned 96 tokens, expected 60") is exactly this: + // slot.num_audio was captured from the FIRST encode, slot.encoded.size(1) + // reflected the SECOND. Mirrors main_mm.cpp:604-675's copy-on-encode. + struct AudioSlot { + std::vector bytes; + ::executorch::aten::ScalarType dtype; + int64_t slot_start; + int64_t num_audio; + int64_t audio_hidden; + }; + + std::vector ids; + ids.reserve(static_cast(prefill_total_cap)); + std::vector image_slots; + std::vector audio_slots; + long audio_encode_ms = 0; + int audio_calls = 0; + + for (const auto &input : inputs) { + if (input.is_image()) { + ET_CHECK_OR_RETURN_ERROR(image_encoder_ != nullptr, InvalidState, + "No image encoder registered"); + const int32_t num_visual = image_encoder_->encoderTokenCount(); + ET_CHECK_OR_RETURN_ERROR(num_visual > 0, InvalidState, + "Image encoder reports 0 visual tokens"); + image_slots.push_back(ImageSlot{&input, static_cast(ids.size()), + static_cast(num_visual)}); + ids.insert(ids.end(), static_cast(num_visual), 0); + } else if (input.is_audio()) { + ET_CHECK_OR_RETURN_ERROR(audio_encoder_ != nullptr, InvalidState, + "No audio encoder registered"); + const long t_aud_begin = time_in_ms(); + auto enc = audio_encoder_->encode(input); + ET_CHECK_OK_OR_RETURN_ERROR(enc.error(), "Audio encoding failed"); + audio_encode_ms += time_in_ms() - t_aud_begin; + audio_calls += 1; + // Snapshot the encoder output NOW โ€” see AudioSlot comment above for + // why the returned EValue's tensor metadata can't survive past the + // next module_->execute(). num_audio and audio_hidden are read from + // the tensor directly rather than from encoderTokenCount() so they + // are guaranteed to reflect THIS encode call. + auto audio_tensor = enc->toTensor(); + ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState, + "audio_encoder output rank=%zd, expected 3", + audio_tensor.dim()); + const int64_t num_audio = static_cast(audio_tensor.size(1)); + const int64_t audio_hidden = static_cast(audio_tensor.size(2)); + ET_CHECK_OR_RETURN_ERROR(num_audio > 0, InvalidState, + "Audio encoder produced 0 tokens"); + std::vector bytes(audio_tensor.nbytes()); + std::memcpy(bytes.data(), audio_tensor.const_data_ptr(), + audio_tensor.nbytes()); + audio_slots.push_back( + AudioSlot{std::move(bytes), audio_tensor.scalar_type(), + static_cast(ids.size()), num_audio, audio_hidden}); + ids.insert(ids.end(), static_cast(num_audio), 0); + } else if (input.is_text() || input.is_tokens()) { + std::vector tokens; + if (input.is_text()) { + auto encode_result = tokenizer_->encode(input.get_text()); + if (!encode_result.ok()) { + ET_LOG(Error, "Tokenizer encode error %d", + static_cast(encode_result.error())); + return Error::InvalidArgument; + } + tokens = std::move(*encode_result); + } else { + tokens = input.get_tokens(); + } + for (auto t : tokens) { + ids.push_back(static_cast(t)); } - tokens = std::move(*encode_result); } else { - tokens = input.get_tokens(); + ET_LOG(Error, "Unsupported MultimodalInput type"); + return Error::NotSupported; } + } - const auto actual_seq_len = static_cast(tokens.size()); + const int64_t total_len = static_cast(ids.size()); + ET_CHECK_OR_RETURN_ERROR(total_len > 0, InvalidArgument, + "prefill produced zero tokens"); - // The token_embedding PTE has a fixed MAX_SEQ_LEN input buffer. - // Pad with zeros, run embedding, then slice output back to actual length. - int64_t max_seq_len = actual_seq_len; // fallback: no padding needed - auto max_seq_len_result = module_->get(kMaxSeqLen); - if (max_seq_len_result.error() == Error::Ok) { - max_seq_len = max_seq_len_result->toScalar().to(); - } + ET_CHECK_OR_RETURN_ERROR(total_len <= prefill_total_cap, InvalidArgument, + "Prefill length %lld exceeds %s (%lld)", + static_cast(total_len), + enable_dynamic_shape ? "get_max_context_len" + : "get_max_seq_len", + static_cast(prefill_total_cap)); + if (!enable_dynamic_shape) { + // Static-shape token_embedding needs fixed-length input; trailing pad + // zeros are inert because we copy only `total_len` rows out of the + // embedding output below. + ids.resize(static_cast(max_seq_len), 0); + } - padded_tokens_storage.assign(max_seq_len, 0); - std::ranges::copy(tokens, padded_tokens_storage.begin()); + // ------------------------------------------------------------ + // Single token_embedding call over the fused id buffer. + // ------------------------------------------------------------ + const int64_t tok_buf_len = static_cast(ids.size()); + auto token_tensor = ::executorch::extension::from_blob( + ids.data(), {1, static_cast(tok_buf_len)}, + ::executorch::aten::ScalarType::Long); - auto text_tensor = ::executorch::extension::from_blob( - padded_tokens_storage.data(), {1, static_cast(max_seq_len)}, - ::executorch::aten::ScalarType::Long); + const long t_tokembed_begin = time_in_ms(); + auto embed_result = module_->execute(kTokenEmbeddingMethod, token_tensor); + ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); + auto &embed_outputs = *embed_result; + const long t_tokembed_end = time_in_ms(); + + const size_t expected_outputs = has_ple ? 2u : 1u; + ET_CHECK_OR_RETURN_ERROR(embed_outputs.size() == expected_outputs, + InvalidState, + "Expected %zu output(s) from token_embedding, " + "got %zu", + expected_outputs, embed_outputs.size()); - auto embed_result = module_->execute(kTokenEmbeddingMethod, text_tensor); - ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); + auto full_embed = embed_outputs[0].toTensor(); + const auto hidden = static_cast(full_embed.size(2)); - auto full_embed = (*embed_result)[0].toTensor(); - const auto embed_dim = static_cast(full_embed.size(2)); - sliced_embed_storage = ::executorch::extension::from_blob( - full_embed.mutable_data_ptr(), {1, actual_seq_len, embed_dim}, - ::executorch::aten::ScalarType::Float); - encoder_output = EValue(*sliced_embed_storage); + // Own the embeds for the live prefix โ€” subsequent vision_encoder.execute + // calls may reuse the token_embedding output buffer in the runtime. + // Dtype is whatever the exporter chose (fp32 baseline, fp16 + // s16k_jitmask_fp16); copy bytes through nbytes/numel so we don't assume the + // scalar type. + const ::executorch::aten::ScalarType embeds_dtype = full_embed.scalar_type(); + const size_t embeds_total_numel = static_cast(full_embed.numel()); + ET_CHECK_OR_RETURN_ERROR(embeds_total_numel > 0, InvalidState, + "token_embedding returned zero elements"); + const size_t embeds_elem_size = full_embed.nbytes() / embeds_total_numel; + const size_t embeds_prefix_bytes = static_cast(total_len) * + static_cast(hidden) * + embeds_elem_size; + std::vector embeds_buf(embeds_prefix_bytes); + std::memcpy(embeds_buf.data(), full_embed.mutable_data_ptr(), + embeds_prefix_bytes); - } else { - ET_LOG(Error, "Unsupported MultimodalInput type"); - return Error::NotSupported; + // Own the ple_tok prefix similarly. Dtype is whatever the exporter chose + // (commonly bf16/int8); we copy bytes through nbytes/numel without + // assuming the scalar type. `ple_elem_size` is hoisted so the chunked + // text_decoder loop below can use it for byte-offset slicing. + std::vector ple_tok_buf; + SizesType num_layers = 0; + SizesType ple_dim = 0; + size_t ple_elem_size = 0; + ::executorch::aten::ScalarType ple_tok_dtype = + ::executorch::aten::ScalarType::Float; + if (has_ple) { + auto full_ple_tok = embed_outputs[1].toTensor(); + num_layers = static_cast(full_ple_tok.size(2)); + ple_dim = static_cast(full_ple_tok.size(3)); + ple_tok_dtype = full_ple_tok.scalar_type(); + const size_t total_numel = static_cast(full_ple_tok.numel()); + const size_t total_bytes = full_ple_tok.nbytes(); + ET_CHECK_OR_RETURN_ERROR(total_numel > 0, InvalidState, + "ple_tok has zero elements"); + ple_elem_size = total_bytes / total_numel; + const size_t prefix_bytes = static_cast(total_len) * + static_cast(num_layers) * + static_cast(ple_dim) * ple_elem_size; + ple_tok_buf.resize(prefix_bytes); + std::memcpy(ple_tok_buf.data(), full_ple_tok.mutable_data_ptr(), + prefix_bytes); } - // Run text_decoder for prefill. - int64_t seq_len = encoder_output.toTensor().size(1); - if (seq_len == 0) { - ET_LOG(Error, "Encoder returned empty output"); - return Error::InvalidState; + // ------------------------------------------------------------ + // Pass 2: encode images and splice their outputs into embeds_buf. + // ------------------------------------------------------------ + long vision_total_ms = 0; + int vision_calls = 0; + for (const auto &slot : image_slots) { + const long t_vis_begin = time_in_ms(); + auto encode_result = image_encoder_->encode(*slot.input); + ET_CHECK_OK_OR_RETURN_ERROR(encode_result.error(), "Image encoding failed"); + vision_total_ms += time_in_ms() - t_vis_begin; + vision_calls += 1; + auto encoder_output = *encode_result; + auto vision_tensor = encoder_output.toTensor(); + + const auto vision_dtype = vision_tensor.scalar_type(); + const size_t visual_elems = + static_cast(slot.num_visual) * static_cast(hidden); + uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * + static_cast(hidden) * + embeds_elem_size; + if (vision_dtype == embeds_dtype) { + const uint8_t *src = + static_cast(vision_tensor.const_data_ptr()); + std::memcpy(dst, src, visual_elems * embeds_elem_size); + } else if (vision_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::Half) { + const float *src = vision_tensor.const_data_ptr(); + auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_h[i] = ::executorch::aten::Half(src[i]); + } + } else if (vision_dtype == ::executorch::aten::ScalarType::Half && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = + vision_tensor.const_data_ptr<::executorch::aten::Half>(); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_f[i] = static_cast(src[i]); + } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidState, + "unsupported vision/text dtype pair: vision=%hhd text=%hhd", + static_cast(vision_dtype), static_cast(embeds_dtype)); + } + } + + // ------------------------------------------------------------ + // Pass 2b: splice encoded audio tokens into embeds_buf. Reads from the + // byte snapshot taken at encode time so post-encode execute() calls can't + // invalidate slot state. Same dtype-conversion matrix as vision. + // ------------------------------------------------------------ + for (auto &slot : audio_slots) { + ET_CHECK_OR_RETURN_ERROR( + slot.audio_hidden == static_cast(hidden), InvalidState, + "audio encoder hidden %lld != text_embed hidden %lld", + static_cast(slot.audio_hidden), + static_cast(hidden)); + + const auto audio_dtype = slot.dtype; + const size_t audio_elems = + static_cast(slot.num_audio) * static_cast(hidden); + const size_t audio_elem_size = + audio_elems > 0 ? slot.bytes.size() / audio_elems : 0; + ET_CHECK_OR_RETURN_ERROR( + audio_elem_size > 0 && + audio_elem_size * audio_elems == slot.bytes.size(), + InvalidState, + "audio slot bytes %zu inconsistent with num_audio=%lld hidden=%lld", + slot.bytes.size(), static_cast(slot.num_audio), + static_cast(hidden)); + + uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * + static_cast(hidden) * + embeds_elem_size; + + if (audio_dtype == embeds_dtype) { + std::memcpy(dst, slot.bytes.data(), audio_elems * embeds_elem_size); + } else if (audio_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::Half) { + const float *src = reinterpret_cast(slot.bytes.data()); + auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); + for (size_t i = 0; i < audio_elems; ++i) { + dst_h[i] = ::executorch::aten::Half(src[i]); + } + } else if (audio_dtype == ::executorch::aten::ScalarType::Half && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = + reinterpret_cast(slot.bytes.data()); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < audio_elems; ++i) { + dst_f[i] = static_cast(src[i]); + } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidState, + "unsupported audio/text dtype pair: audio=%hhd text=%hhd", + static_cast(audio_dtype), static_cast(embeds_dtype)); + } + } + + // ------------------------------------------------------------ + // Chunked text_decoder calls. + // + // Some PTEs (Gemma4 iter201) hard-cap text_decoder's S dim at + // get_max_seq_len (128) while the prefill budget extends to + // get_max_context_len (2048). KV cache state persists across calls via the + // absolute input_pos vector, so chunking is functionally transparent to + // the model. For single-shot static-shape PTEs (LFM2-VL) chunk_cap == + // total_len so the loop iterates exactly once โ€” preserving prior behavior. + // ------------------------------------------------------------ + const int64_t chunk_cap = + decoder_chunk_size > 0 ? decoder_chunk_size : total_len; + std::vector cache_positions(static_cast(total_len)); + for (int64_t i = 0; i < total_len; ++i) { + cache_positions[static_cast(i)] = start_pos + i; } - std::vector cache_positions; - auto cache_pos_result = populate_start_pos_or_cache_position( - module_, start_pos, cache_positions, seq_len, kTextModelMethod); - ET_CHECK_OK_OR_RETURN_ERROR(cache_pos_result.error()); + const long t_textdec_begin = time_in_ms(); + std::vector last_outs; + const int64_t num_chunks = (total_len + chunk_cap - 1) / chunk_cap; + for (int64_t ci = 0; ci < num_chunks; ++ci) { + const int64_t cs = ci * chunk_cap; + const int64_t ce = std::min(cs + chunk_cap, total_len); + const int64_t chunk_len = ce - cs; - auto prefill_result = - module_->execute(kTextModelMethod, {encoder_output, *cache_pos_result}); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error()); + uint8_t *embeds_chunk_ptr = + embeds_buf.data() + static_cast(cs) * + static_cast(hidden) * embeds_elem_size; + auto embeds_chunk = ::executorch::extension::from_blob( + embeds_chunk_ptr, {1, static_cast(chunk_len), hidden}, + embeds_dtype); - auto &prefill_outputs = *prefill_result; - ET_CHECK_OR_RETURN_ERROR(!prefill_outputs.empty(), InvalidState, + TensorPtr ple_chunk; + if (has_ple) { + uint8_t *ple_chunk_ptr = + ple_tok_buf.data() + static_cast(cs) * + static_cast(num_layers) * + static_cast(ple_dim) * ple_elem_size; + ple_chunk = ::executorch::extension::from_blob( + ple_chunk_ptr, + {1, static_cast(chunk_len), num_layers, ple_dim}, + ple_tok_dtype); + } + + auto pos_chunk = ::executorch::extension::from_blob( + cache_positions.data() + cs, {static_cast(chunk_len)}, + ::executorch::aten::ScalarType::Long); + + auto res = + has_ple ? module_->execute(kTextModelMethod, + {EValue(*embeds_chunk), EValue(*ple_chunk), + EValue(*pos_chunk)}) + : module_->execute(kTextModelMethod, + {EValue(*embeds_chunk), EValue(*pos_chunk)}); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + last_outs = std::move(*res); + } + const long t_textdec_end = time_in_ms(); + + ET_CHECK_OR_RETURN_ERROR(!last_outs.empty(), InvalidState, "text_decoder returned no outputs during prefill"); - auto logits = prefill_outputs[0].toTensor(); - start_pos += seq_len; + auto logits = last_outs[0].toTensor(); + const long t_logits_end = time_in_ms(); + start_pos += total_len; + + const long prefill_total = t_logits_end - t_prefill_begin; + const long tokembed_ms = t_tokembed_end - t_tokembed_begin; + const long textdec_ms = t_textdec_end - t_textdec_begin; + const long sample_ms = t_logits_end - t_textdec_end; + const long overhead_ms = + prefill_total - tokembed_ms - vision_total_ms - textdec_ms - sample_ms; + rnexecutorch::log( + rnexecutorch::LOG_LEVEL::Info, "prefill splits ms: total=", prefill_total, + " token_embed=", tokembed_ms, " vision(x", vision_calls, + ")=", vision_total_ms, " audio(x", audio_calls, ")=", audio_encode_ms, + " text_decoder=", textdec_ms, " logits->token=", sample_ms, + " overhead=", overhead_ms, " total_len=", total_len, + " chunks=", num_chunks, " chunk_cap=", chunk_cap, + " dynamic=", static_cast(enable_dynamic_shape)); return static_cast(decoder_runner_->logits_to_token(logits)); } @@ -127,6 +479,9 @@ Error MultimodalPrefiller::load() { if (methods.find(kVisionEncoderMethod) != methods.end()) { ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kVisionEncoderMethod)); } + if (methods.find(kAudioEncoderMethod) != methods.end()) { + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod)); + } return Error::Ok; } @@ -140,8 +495,13 @@ bool MultimodalPrefiller::is_method_loaded() { return false; } const auto &methods = *methods_res; - if (methods.find(kVisionEncoderMethod) != methods.end()) { - return module_->is_method_loaded(kVisionEncoderMethod); + if (methods.find(kVisionEncoderMethod) != methods.end() && + !module_->is_method_loaded(kVisionEncoderMethod)) { + return false; + } + if (methods.find(kAudioEncoderMethod) != methods.end() && + !module_->is_method_loaded(kAudioEncoderMethod)) { + return false; } return true; } diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.h b/packages/react-native-executorch/common/runner/multimodal_prefiller.h index d9b5a9bf5c..ee15a7d12a 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.h +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.h @@ -23,12 +23,13 @@ class MultimodalPrefiller { explicit MultimodalPrefiller(Module &module, MultimodalDecoderRunner &decoder_runner, tokenizers::HFTokenizer &tokenizer, - IEncoder *image_encoder = nullptr); + IEncoder *image_encoder = nullptr, + IEncoder *audio_encoder = nullptr); // Prefill one input segment. Updates start_pos in-place. // Returns the first predicted token after this segment. - ::executorch::runtime::Result prefill(const MultimodalInput &input, - int64_t &start_pos); + ::executorch::runtime::Result + prefill(const std::vector &inputs, int64_t &start_pos); ::executorch::runtime::Error load(); bool is_method_loaded(); @@ -38,6 +39,7 @@ class MultimodalPrefiller { MultimodalDecoderRunner *decoder_runner_; tokenizers::HFTokenizer *tokenizer_; IEncoder *image_encoder_; + IEncoder *audio_encoder_; }; } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.cpp b/packages/react-native-executorch/common/runner/multimodal_runner.cpp index 767fef9f38..3037c766ba 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_runner.cpp @@ -3,7 +3,6 @@ #include "constants.h" #include "util.h" #include -#include namespace executorch::extension::llm { @@ -54,8 +53,13 @@ Error MultimodalRunner::load_subcomponents() { if (enc_it != encoders_.end()) { image_encoder = enc_it->second.get(); } + IEncoder *audio_encoder = nullptr; + auto aud_it = encoders_.find(MultimodalType::Audio); + if (aud_it != encoders_.end()) { + audio_encoder = aud_it->second.get(); + } mm_prefiller_ = std::make_unique( - *module_, *mm_decoder_runner_, *tokenizer_, image_encoder); + *module_, *mm_decoder_runner_, *tokenizer_, image_encoder, audio_encoder); mm_token_generator_ = std::make_unique( tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true, std::move(eos_ids_), stats_ptr, config_); @@ -78,14 +82,10 @@ Error MultimodalRunner::generate_internal( } stats_.inference_start_ms = time_in_ms(); - - uint64_t prefill_next_token = 0; - for (const auto &input : inputs) { - auto prefill_result = mm_prefiller_->prefill(input, pos_); - if (!prefill_result.ok()) - return prefill_result.error(); - prefill_next_token = prefill_result.get(); - } + auto prefill_result = mm_prefiller_->prefill(inputs, pos_); + if (!prefill_result.ok()) + return prefill_result.error(); + uint64_t prefill_next_token = prefill_result.get(); stats_.first_token_ms = time_in_ms(); stats_.prompt_eval_end_ms = time_in_ms(); diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.h b/packages/react-native-executorch/common/runner/multimodal_runner.h index d24e0b40c2..c6180c54f0 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_runner.h @@ -10,7 +10,7 @@ namespace executorch::extension::llm { -enum class MultimodalType { Image }; +enum class MultimodalType { Image, Audio }; class MultimodalRunner : public BaseLLMRunner { public: diff --git a/packages/react-native-executorch/common/runner/sampler.cpp b/packages/react-native-executorch/common/runner/sampler.cpp index 26c75d4dd5..3484b8850f 100644 --- a/packages/react-native-executorch/common/runner/sampler.cpp +++ b/packages/react-native-executorch/common/runner/sampler.cpp @@ -35,6 +35,7 @@ #include "sampler.h" #include #include +#include #include namespace executorch { @@ -92,32 +93,52 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) { } } - auto compare = [](const ProbIndex &a, const ProbIndex &b) { - return a.prob > b.prob; - }; - std::sort(probindex.get(), probindex.get() + n0, compare); + std::sort(probindex.get(), probindex.get() + n0, + [](const ProbIndex &a, const ProbIndex &b) { + return a.prob > b.prob; + }); // truncate the list where cumulative probability exceeds topp T cumulative_prob = 0; - int last_idx = n0 - 1; // in case of rounding errors consider all elements + int last_idx = n0 - 1; for (int i = 0; i < n0; i++) { cumulative_prob += probindex[i].prob; - if (cumulative_prob > topp_) { + if (static_cast(cumulative_prob) > topp_) { last_idx = i; - break; // we've exceeded topp by including last_idx + break; } } // sample from the truncated list - const T &r = coin * cumulative_prob; + float r = coin * static_cast(cumulative_prob); T cdf = 0; for (int i = 0; i <= last_idx; i++) { cdf += probindex[i].prob; - if (r < cdf) { + if (r < static_cast(cdf)) { return probindex[i].index; } } - return probindex[last_idx].index; // in case of rounding errors + return probindex[last_idx].index; +} + +// Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary +// are kept (matches HuggingFace TopKLogitsWarper). +template void Sampler::mask_topk(T *logits) { + if (topk_ <= 0 || topk_ >= vocab_size_) { + return; + } + // Partial-select the (topk_-th largest) threshold using nth_element on a + // copy of logits; O(n) average. + std::vector scratch(logits, logits + vocab_size_); + std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1), + scratch.end(), std::greater()); + const T threshold = scratch[topk_ - 1]; + const T neg_inf = std::numeric_limits::lowest(); + for (int i = 0; i < vocab_size_; i++) { + if (logits[i] < threshold) { + logits[i] = neg_inf; + } + } } Sampler::Sampler(int32_t vocab_size, float temperature, float topp, @@ -126,10 +147,80 @@ Sampler::Sampler(int32_t vocab_size, float temperature, float topp, : vocab_size_(vocab_size), inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty), + topk_(0), rng_state_(rng_seed) {} + +// Mask logits whose softmax-prob falls outside the top-p nucleus to -inf. +// Keeps the token that crosses the threshold (HuggingFace convention). +template void Sampler::mask_topp(T *logits) { + if (topp_ <= 0.0f || topp_ >= 1.0f) { + return; + } + // Softmax into a scratch probs[] (do not mutate logits yet). + T max_val = logits[0]; + for (int i = 1; i < vocab_size_; i++) { + if (logits[i] > max_val) { + max_val = logits[i]; + } + } + std::unique_ptr[]> probindex = + std::make_unique[]>(vocab_size_); + T sum = 0; + for (int i = 0; i < vocab_size_; i++) { + T e = static_cast(expf(static_cast(logits[i] - max_val))); + probindex[i].prob = e; + probindex[i].index = i; + sum += e; + } + if (sum <= T(0)) { + return; + } + for (int i = 0; i < vocab_size_; i++) { + probindex[i].prob = probindex[i].prob / sum; + } + std::sort(probindex.get(), probindex.get() + vocab_size_, + [](const ProbIndex &a, const ProbIndex &b) { + return a.prob > b.prob; + }); + + // Find the smallest prefix whose cumulative probability >= topp_. + T cumulative = 0; + int last_idx = vocab_size_ - 1; + for (int i = 0; i < vocab_size_; i++) { + cumulative += probindex[i].prob; + if (static_cast(cumulative) >= topp_) { + last_idx = i; + break; + } + } + // Mark kept indices, then -inf the rest. + std::vector keep(vocab_size_, false); + for (int i = 0; i <= last_idx; i++) { + keep[probindex[i].index] = true; + } + const T neg_inf = std::numeric_limits::lowest(); + for (int i = 0; i < vocab_size_; i++) { + if (!keep[i]) { + logits[i] = neg_inf; + } + } +} + +Sampler::Sampler(int vocab_size, float temperature, float topp, int32_t topk, + unsigned long long rng_seed) + : vocab_size_(vocab_size), + inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), + topp_(topp), min_p_(0.0f), repetition_penalty_(1.0f), topk_(topk), rng_state_(rng_seed) {} +Sampler::Sampler(int vocab_size, float temperature, float topp, int32_t topk) + : Sampler(vocab_size, temperature, topp, topk, std::time(nullptr)) {} + +Sampler::Sampler(int vocab_size, float temperature, float topp, + unsigned long long rng_seed) + : Sampler(vocab_size, temperature, topp, 0, rng_seed) {} + Sampler::Sampler(int vocab_size, float temperature, float topp) - : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {} + : Sampler(vocab_size, temperature, topp, 0, std::time(nullptr)) {} template static void softmax(T *x, int size) { // find max value (for numerical stability) @@ -175,9 +266,11 @@ int32_t Sampler::sample(T *logits, const std::vector &recent_tokens) { apply_repetition_penalty(logits, vocab_size_, recent_tokens); // 2. apply the temperature to the logits apply_temperature(logits, vocab_size_); - // 3. apply softmax to the logits to get the probabilities for next token + // 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass) + mask_topk(logits); + // 4. apply softmax to the logits to get the probabilities for next token softmax(logits, vocab_size_); - // 4. apply min_p truncation + // 5. apply min_p truncation apply_min_p(logits, vocab_size_); // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&rng_state_); diff --git a/packages/react-native-executorch/common/runner/sampler.h b/packages/react-native-executorch/common/runner/sampler.h index 16811297ef..cd57e0524e 100644 --- a/packages/react-native-executorch/common/runner/sampler.h +++ b/packages/react-native-executorch/common/runner/sampler.h @@ -41,7 +41,18 @@ class Sampler { Sampler(int32_t vocab_size, float temperature, float topp, unsigned long long rng_seed, float min_p = 0.0f, float repetition_penalty = 1.0f); + // topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p. + // Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask + // -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses + // to greedy; pass topk = 0 to keep full-vocab temperature sampling. + Sampler(int32_t vocab_size, float temperature, float topp, int32_t topk, + unsigned long long rng_seed); + Sampler(int32_t vocab_size, float temperature, float topp, int32_t topk); + + // Back-compat overloads (topk = 0 => disabled). + Sampler(int32_t vocab_size, float temperature, float topp, + unsigned long long rng_seed); Sampler(int32_t vocab_size, float temperature, float topp); template int32_t sample(T *logits); @@ -53,6 +64,9 @@ class Sampler { template int32_t sample_topp(T *probabilities, float coin); template int32_t sample_mult(T *probabilities, float coin); template int32_t sample_argmax(T *probabilities); + // In-place logit warpers: set excluded indices to -inf. + template void mask_topk(T *logits); + template void mask_topp(T *logits); template inline void apply_temperature(T *logits, int32_t vocab_size) { @@ -110,6 +124,7 @@ class Sampler { float topp_; float min_p_; float repetition_penalty_; + int32_t topk_; unsigned long long rng_state_; }; diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp index e67d3e41fb..3258c78ec5 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp @@ -31,7 +31,6 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager, // outer loop (call site) is responsible for managing state. ::executorch::runtime::Result TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { - // ET_LOG(Info, "Input token %" PRIu64, input_token); auto method_meta_result = module_->method_meta("forward"); if (!method_meta_result.ok()) { return method_meta_result.error(); @@ -54,7 +53,7 @@ TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor); ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error()); inputs = inputs_res.get(); - auto outputs_res = module_->forward(inputs); + auto outputs_res = module_->execute("forward", inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); auto update_err = io_manager_->update_decode(outputs_res.get()); diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.h b/packages/react-native-executorch/common/runner/text_decoder_runner.h index bffc254bd6..d3aa229cd0 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.h @@ -10,6 +10,7 @@ #pragma once +#include "constants.h" #include "io_manager.h" #include "sampler.h" @@ -40,8 +41,8 @@ class TextDecoderRunner { step(TensorPtr &input, int64_t start_pos); /** - * Load the Module for text decode purpose. - * @return The error code. + * Load the Module for text decode purpose. Loads the dynamic-shape `forward` + * method used for both prefill and decode. */ virtual ::executorch::runtime::Error load() { return module_->load_method("forward"); diff --git a/packages/react-native-executorch/common/runner/text_prefiller.cpp b/packages/react-native-executorch/common/runner/text_prefiller.cpp index dc961158b7..dd6a1a67a8 100644 --- a/packages/react-native-executorch/common/runner/text_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/text_prefiller.cpp @@ -10,6 +10,7 @@ // LLM. #include "text_prefiller.h" +#include "rnexecutorch/Log.h" #include namespace executorch { @@ -33,15 +34,16 @@ TextPrefiller::prefill(std::vector &prompt_tokens, // Check if we need to chunk the prompt tokens int32_t num_prompt_tokens = prompt_tokens.size(); + const int32_t chunk_size = static_cast(max_seq_len_); - // If prompt tokens exceed max_seq_len_, we need to chunk them - if (num_prompt_tokens > max_seq_len_) { + // If prompt tokens exceed chunk_size, we need to chunk them + if (num_prompt_tokens > chunk_size) { uint64_t cur_token = 0; int num_tokens_to_process = 0; while (num_tokens_to_process < num_prompt_tokens) { - auto num_tokens_to_prefill_with = std::min( - num_prompt_tokens - num_tokens_to_process, max_seq_len_); + auto num_tokens_to_prefill_with = + std::min(num_prompt_tokens - num_tokens_to_process, chunk_size); std::vector prompt_tokens_to_process( num_tokens_to_prefill_with); @@ -75,7 +77,6 @@ TextPrefiller::prefill_chunk(std::vector &prompt_tokens, // store the token uint64_t cur_token; if (enable_parallel_prefill_ || !use_kv_cache_) { - // initialize tensor wrappers auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, executorch::aten::ScalarType::Long); diff --git a/packages/react-native-executorch/common/runner/text_runner.cpp b/packages/react-native-executorch/common/runner/text_runner.cpp index 5a75e00b4a..96df3b6c67 100644 --- a/packages/react-native-executorch/common/runner/text_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_runner.cpp @@ -16,9 +16,14 @@ TextRunner::TextRunner(std::unique_ptr module, : BaseLLMRunner(std::move(module), tokenizer_path, config) {} bool TextRunner::is_loaded() const { +#ifdef RNEX_BYPASS_TOKENIZER + return module_ && module_->is_loaded() && text_decoder_runner_ && + text_prefiller_ && text_token_generator_; +#else return module_ && module_->is_loaded() && tokenizer_ && tokenizer_->is_loaded() && text_decoder_runner_ && text_prefiller_ && text_token_generator_; +#endif } Error TextRunner::load_subcomponents() { @@ -65,6 +70,10 @@ Error TextRunner::generate_internal( stats_.inference_start_ms = time_in_ms(); + // Multi-turn: JS re-renders the full chat history each call, so reset KV + // position to 0 and re-prefill from scratch. + pos_ = 0; + int64_t context_len_left = static_cast(config_.max_context_length) - pos_; @@ -79,16 +88,23 @@ Error TextRunner::generate_internal( std::vector prompt_tokens = encodeResult.get(); int num_prompt_tokens = prompt_tokens.size(); + // For dynamic-shape PTEs (Gemma4 iter*), get_max_seq_len is the per-call + // decoder chunk size (e.g. 128) and the true generation budget lives in + // get_max_context_len. Static-shape PTEs set both equal, so this collapses + // to the old behavior. Mirrors multimodal_prefiller.cpp:96. + const int32_t seq_cap = config_.enable_dynamic_shape + ? config_.max_context_length + : config_.max_seq_len; + ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); - ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < config_.max_seq_len, - InvalidArgument, - "num_prompt_tokens %d >= max_seq_len %" PRId32, - num_prompt_tokens, config_.max_seq_len); + ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < seq_cap, InvalidArgument, + "num_prompt_tokens %d >= seq cap %" PRId32, + num_prompt_tokens, seq_cap); int32_t max_new_tokens = resolve_max_new_tokens( - num_prompt_tokens, config_.max_seq_len, - static_cast(context_len_left), config_.max_new_tokens); + num_prompt_tokens, seq_cap, static_cast(context_len_left), + config_.max_new_tokens); ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument, "Max new tokens %d is <= 0", max_new_tokens); diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index 7ecf6177a9..fc9f3dd3e6 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -100,13 +100,15 @@ class TextTokenGenerator { prev_token = cur_token; stats_->on_sampling_begin(); - cur_token = - text_decoder_runner_->logits_to_token(logits_tensor, generated_tokens); + cur_token = text_decoder_runner_->logits_to_token(logits_tensor, + generated_tokens); stats_->on_sampling_end(); pos++; generated_tokens.push_back(cur_token); + const bool eos_reached_now = eos_ids_->find(cur_token) != eos_ids_->end(); + if (use_kv_cache_) { // update the token tensor. token_data will not be empty. // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) @@ -118,6 +120,22 @@ class TextTokenGenerator { tokens_managed, {1, static_cast(token_data.size())})); } + // Don't include the terminal EOS/EOT token in the streamed text โ€” it + // would otherwise be appended to the assistant message stored in chat + // history and corrupt the next turn's chat-template rendering + // (e.g. duplicated ). + if (eos_reached_now) { + if (!token_cache.empty()) { + auto flush = tokenizer_->decode(token_cache, false); + if (flush.ok() && !flush.get().empty() && + !flush.get().ends_with("๏ฟฝ") && token_callback) { + token_callback(flush.get()); + } + token_cache.clear(); + } + break; + } + token_cache.push_back(static_cast(cur_token)); // print the token as string, decode it with the Tokenizer object @@ -142,8 +160,7 @@ class TextTokenGenerator { const auto eos_reached = eos_ids_->contains(cur_token); if (!cache_decoded.ends_with("๏ฟฝ") && - (countIntervalElapsed || timeIntervalElapsed || should_stop_ || - eos_reached)) { + (countIntervalElapsed || timeIntervalElapsed || should_stop_)) { token_callback(cache_decoded); token_cache.clear(); timestamp_ = std::chrono::high_resolution_clock::now(); @@ -152,13 +169,6 @@ class TextTokenGenerator { if (should_stop_) { break; } - - // data-dependent terminating condition: we have n_eos_ number of EOS - if (eos_ids_->find(cur_token) != eos_ids_->end()) { - printf("\n"); - ET_LOG(Info, "\nReached to the end of generation"); - break; - } } return pos - start_pos; } diff --git a/packages/react-native-executorch/src/constants/llmDefaults.ts b/packages/react-native-executorch/src/constants/llmDefaults.ts index a27a2f7a4f..77a60fe311 100644 --- a/packages/react-native-executorch/src/constants/llmDefaults.ts +++ b/packages/react-native-executorch/src/constants/llmDefaults.ts @@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy'; * @category Utilities - LLM */ export const DEFAULT_SYSTEM_PROMPT = - "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text."; + "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance"; /** * Generates a default structured output prompt based on the provided JSON schema. diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts index 9c9da9c420..bd04a37883 100644 --- a/packages/react-native-executorch/src/constants/modelRegistry.ts +++ b/packages/react-native-executorch/src/constants/modelRegistry.ts @@ -496,10 +496,12 @@ export const models = { M.LFM2_5_1_2B_INSTRUCT_QUANTIZED ), bielik_v3_0_1_5b: pair(M.BIELIK_V3_0_1_5B, M.BIELIK_V3_0_1_5B_QUANTIZED), + gemma4_e2b: base(M.GEMMA4_E2B), // Multimodal LLMs โ€” same hook/module as plain LLMs, listed here so users // pick a model by capability ("LLM") rather than by modality. lfm2_5_vl_1_6b: base(M.LFM2_5_VL_1_6B_QUANTIZED), lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED), + gemma4_e2b_multimodal: base(M.GEMMA4_E2B_MM), }, classification: { efficientnet_v2_s: variant(EFFICIENTNET_V2_S_VARIANTS), diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 17c523f881..29726f116b 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -125,6 +125,37 @@ export const QWEN3_0_6B_QUANTIZED = { generationConfig: QWEN3_GENERATION_CONFIG, } as const; +// GEMMA 4 +const GEMMA4_E2B_XNNPACK = `${URL_PREFIX}-gemma-4/${VERSION_TAG}/e2b/xnnpack/gemma_4_e2b_xnnpack_8da4w.pte`; +const GEMMA4_E2B_VULKAN = `${URL_PREFIX}-gemma-4/${VERSION_TAG}/e2b/vulkan/gemma_4_e2b_vulkan_8da4w.pte`; +const GEMMA4_E2B_XNNPACK_MM = `${URL_PREFIX}-gemma-4-multimodal/${VERSION_TAG}/e2b/xnnpack/gemma_4_e2b_xnnpack_8da4w.pte`; +const GEMMA4_E2B_VULKAN_MM = `${URL_PREFIX}-gemma-4-multimodal/${VERSION_TAG}/e2b/vulkan/gemma_4_e2b_vulkan_8da4w.pte`; +const GEMMA4_TOKENIZER = `${URL_PREFIX}-gemma-4/${VERSION_TAG}/e2b/xnnpack/tokenizer.json`; +const GEMMA4_TOKENIZER_CONFIG = `${URL_PREFIX}-gemma-4/${VERSION_TAG}/e2b/xnnpack/tokenizer_config.json`; + +/** + * @category Models - LLM + */ +export const GEMMA4_E2B = { + modelName: 'gemma4-e2b', + modelSource: + Platform.OS === `android` ? GEMMA4_E2B_VULKAN : GEMMA4_E2B_XNNPACK, + tokenizerSource: GEMMA4_TOKENIZER, + tokenizerConfigSource: GEMMA4_TOKENIZER_CONFIG, +} as const; + +/** + * @category Models - VLM + */ +export const GEMMA4_E2B_MM = { + modelName: 'gemma4-e2b-multimodal', + modelSource: + Platform.OS === `android` ? GEMMA4_E2B_VULKAN_MM : GEMMA4_E2B_XNNPACK_MM, + tokenizerSource: GEMMA4_TOKENIZER, + tokenizerConfigSource: GEMMA4_TOKENIZER_CONFIG, + capabilities: ['vision', 'audio'], +} as const; + /** * @category Models - LLM */ diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index bceca47a56..f2bfb401d1 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -16,6 +16,13 @@ import { Logger } from '../common/Logger'; import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; +// Audio soft-token expansion constants for Gemma4's audio_encoder. +// Mirrors AUDIO_SAMPLES_PER_BLOCK (kSamplesPerBlock=7680) and the per-block +// soft-token rate in audio_encoder.cpp; used to size the context budget so +// long audio doesn't silently overflow get_max_seq_len during prefill. +const AUDIO_SAMPLES_PER_BLOCK = 7680; +const AUDIO_TOKENS_PER_BLOCK = 12; + export class LLMController { private nativeModule: any; private chatConfig: ChatConfig = DEFAULT_CHAT_CONFIG; @@ -236,6 +243,17 @@ export class LLMController { return token; } + private getAudioToken(): string { + const token = this.tokenizerConfig.audio_token; + if (!token) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidConfig, + "Tokenizer config is missing 'audio_token'. Audio-capable models require tokenizerConfigSource with an 'audio_token' field." + ); + } + return token; + } + private filterSpecialTokens(text: string): string { let filtered = text; if ( @@ -244,6 +262,12 @@ export class LLMController { ) { filtered = filtered.replaceAll(this.tokenizerConfig.eos_token, ''); } + if ( + SPECIAL_TOKENS.EOT_TOKEN in this.tokenizerConfig && + this.tokenizerConfig.eot_token + ) { + filtered = filtered.replaceAll(this.tokenizerConfig.eot_token, ''); + } if ( SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && this.tokenizerConfig.pad_token @@ -269,25 +293,35 @@ export class LLMController { this.isGeneratingCallback(false); } - public async forward(input: string, imagePaths?: string[]): Promise { + public async forward( + input: string, + imagePaths?: string[], + audioWaveforms?: Float32Array[] + ): Promise { if (!this._isReady) { throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded); } if (this._isGenerating) { throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating); } + const hasImages = !!imagePaths && imagePaths.length > 0; + const hasAudio = !!audioWaveforms && audioWaveforms.length > 0; try { this.isGeneratingCallback(true); this.nativeModule.reset(); - const response = - imagePaths && imagePaths.length > 0 - ? await this.nativeModule.generateMultimodal( - input, - imagePaths.map(normalizeImagePath), - this.getImageToken(), - this.onToken - ) - : await this.nativeModule.generate(input, this.onToken); + let response: string; + if (hasImages || hasAudio) { + response = await this.nativeModule.generateMultimodal( + input, + this.onToken, + hasImages ? imagePaths!.map(normalizeImagePath) : [], + hasImages ? this.getImageToken() : '', + hasAudio ? audioWaveforms! : [], + hasAudio ? this.getAudioToken() : '' + ); + } else { + response = await this.nativeModule.generate(input, this.onToken); + } return this.filterSpecialTokens(response); } catch (e) { throw parseUnknownError(e); @@ -355,6 +389,9 @@ export class LLMController { const imagePaths = messages .filter((m) => m.mediaPath) .map((m) => m.mediaPath!); + const audioWaveforms = messages + .filter((m) => m.audioWaveform) + .map((m) => m.audioWaveform!); const renderedChat: string = this.applyChatTemplate( messages, @@ -365,19 +402,22 @@ export class LLMController { return await this.forward( renderedChat, - imagePaths.length > 0 ? imagePaths : undefined + imagePaths.length > 0 ? imagePaths : undefined, + audioWaveforms.length > 0 ? audioWaveforms : undefined ); } public async sendMessage( message: string, - media?: { imagePath?: string } + media?: { imagePath?: string; audioBuffer?: Float32Array } ): Promise { const mediaPath = media?.imagePath; + const audioBuffer = media?.audioBuffer; const newMessage: Message = { content: message, role: 'user', ...(mediaPath ? { mediaPath } : {}), + ...(audioBuffer ? { audioWaveform: audioBuffer } : {}), }; const updatedHistory = [...this._messageHistory, newMessage]; this.messageHistoryCallback(updatedHistory); @@ -392,7 +432,22 @@ export class LLMController { ); const textTokens = this.nativeModule.countTextTokens(rendered); const imageCount = messages.filter((m) => m.mediaPath).length; - return textTokens + imageCount * (visualTokenCount - 1); + // Audio soft-token expansion: Gemma4's audio_encoder pads samples to + // multiples of AUDIO_SAMPLES_PER_BLOCK (7680 @ 16 kHz) and emits + // AUDIO_TOKENS_PER_BLOCK (~12) soft tokens per padded block. The + // rendered template only contributes 1 token for the audio placeholder, + // so add (expansion - 1) per audio message to match prefill consumption. + const audioTokenExpansion = messages.reduce((acc, m) => { + if (!m.audioWaveform) return acc; + const kBlocks = Math.max( + 1, + Math.ceil(m.audioWaveform.length / AUDIO_SAMPLES_PER_BLOCK) + ); + return acc + (AUDIO_TOKENS_PER_BLOCK * kBlocks - 1); + }, 0); + return ( + textTokens + imageCount * (visualTokenCount - 1) + audioTokenExpansion + ); }; const maxContextLength = this.nativeModule.getMaxContextLength(); const messageHistoryWithPrompt = @@ -497,12 +552,15 @@ function normalizeImagePath(path: string): string { * @returns Messages with image-bearing turns rewritten to structured content. */ function messagesForChatTemplate(messages: Message[]): any[] { - return messages.map((m) => - m.mediaPath && typeof m.content === 'string' - ? { - ...m, - content: [{ type: 'image' }, { type: 'text', text: m.content }], - } - : m - ); + return messages.map((m) => { + if (typeof m.content !== 'string') return m; + const hasImage = !!m.mediaPath; + const hasAudio = !!m.audioWaveform; + if (!hasImage && !hasAudio) return m; + const parts: any[] = []; + if (hasImage) parts.push({ type: 'image' }); + if (hasAudio) parts.push({ type: 'audio' }); + parts.push({ type: 'text', text: m.content }); + return { ...m, content: parts }; + }); } diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 027e237997..a434011c9e 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -106,7 +106,10 @@ export function useLLM({ ); const sendMessage = useCallback( - (message: string, media?: { imagePath?: string }) => { + ( + message: string, + media?: { imagePath?: string; audioBuffer?: Float32Array } + ) => { setResponse(''); return controllerInstance.sendMessage(message, media); }, diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 6254775c15..1cdb2bc082 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -5,20 +5,23 @@ import { ResourceSource } from './common'; * Capabilities a multimodal LLM can have. * @category Types */ -export type LLMCapability = 'vision'; +export type LLMCapability = 'vision' | 'audio'; /** * Derives the media argument shape for `sendMessage` from a capabilities tuple. * @category Types */ export type MediaArg = - 'vision' extends C[number] ? { imagePath?: string } : object; + ('vision' extends C[number] ? { imagePath?: string } : object) & + ('audio' extends C[number] ? { audioBuffer?: Float32Array } : object); /** * Union of all built-in LLM model names. * @category Types */ export type LLMModelName = + | 'gemma4-e2b' + | 'gemma4-e2b-multimodal' | 'llama-3.2-3b' | 'llama-3.2-3b-qlora' | 'llama-3.2-3b-spinquant' @@ -289,6 +292,12 @@ export interface Message { * controller normalizes the path before passing it to native code. */ mediaPath?: string; + /** + * Optional fp32 mono 16 kHz PCM buffer. Only valid on `user` messages for + * models with the `'audio'` capability. The controller forwards it to the + * native `generateMultimodal` path. + */ + audioWaveform?: Float32Array; } /** @@ -386,6 +395,7 @@ export interface ContextStrategy { export const SPECIAL_TOKENS = { BOS_TOKEN: 'bos_token', EOS_TOKEN: 'eos_token', + EOT_TOKEN: 'eot_token', UNK_TOKEN: 'unk_token', SEP_TOKEN: 'sep_token', PAD_TOKEN: 'pad_token', diff --git a/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so b/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so index 8c65aa5d85..1e882b92fa 100644 Binary files a/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so and b/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so differ diff --git a/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so b/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so index a56a5d20ac..45efcf585d 100644 Binary files a/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so and b/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so differ