diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt index f95e796b83b..11b59dac64d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt @@ -711,13 +711,14 @@ private constructor( ): Int /** - * Prefill the KV cache with the given preprocessed audio input. + * Prefill the KV cache with pre-processed audio data (uint8 mel spectrogram). * - * @param audio Input preprocessed audio as a byte array - * @param batchSize Input batch size - * @param nBins Input number of bins - * @param nFrames Input number of frames + * @param audio Input audio as a byte array (uint8 values) + * @param batchSize Input batch size (must be positive) + * @param nBins Input number of frequency bins (must be positive) + * @param nFrames Input number of time frames (must be positive) * @throws ExecutorchRuntimeException if the prefill failed + * @throws IllegalArgumentException if dimensions are non-positive or array is too small */ @Experimental fun prefillAudio(audio: ByteArray, batchSize: Int, nBins: Int, nFrames: Int) { @@ -725,6 +726,19 @@ private constructor( try { checkNotReentrant() checkNotDestroyed() + require(batchSize > 0 && nBins > 0 && nFrames > 0) { + "batchSize, nBins, and nFrames must all be positive" + } + val expected: Long + try { + val partial = Math.multiplyExact(batchSize.toLong(), nBins.toLong()) + expected = Math.multiplyExact(partial, nFrames.toLong()) + } catch (ex: ArithmeticException) { + throw IllegalArgumentException("batchSize*nBins*nFrames overflows", ex) + } + require(audio.size.toLong() == expected) { + "audio.size (${audio.size}) must equal batchSize*nBins*nFrames ($expected)" + } val nativeResult = prefillAudioInput(audio, batchSize, nBins, nFrames) if (nativeResult != 0) { throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") @@ -742,13 +756,14 @@ private constructor( ): Int /** - * Prefill the KV cache with the given preprocessed audio input. + * Prefill the KV cache with pre-processed audio data (float32 mel spectrogram). * - * @param audio Input preprocessed audio as a float array - * @param batchSize Input batch size - * @param nBins Input number of bins - * @param nFrames Input number of frames + * @param audio Input audio as a float array (float32 values) + * @param batchSize Input batch size (must be positive) + * @param nBins Input number of frequency bins (must be positive) + * @param nFrames Input number of time frames (must be positive) * @throws ExecutorchRuntimeException if the prefill failed + * @throws IllegalArgumentException if dimensions are non-positive or array is too small */ @Experimental fun prefillAudio(audio: FloatArray, batchSize: Int, nBins: Int, nFrames: Int) { @@ -756,6 +771,19 @@ private constructor( try { checkNotReentrant() checkNotDestroyed() + require(batchSize > 0 && nBins > 0 && nFrames > 0) { + "batchSize, nBins, and nFrames must all be positive" + } + val expected: Long + try { + val partial = Math.multiplyExact(batchSize.toLong(), nBins.toLong()) + expected = Math.multiplyExact(partial, nFrames.toLong()) + } catch (ex: ArithmeticException) { + throw IllegalArgumentException("batchSize*nBins*nFrames overflows", ex) + } + require(audio.size.toLong() == expected) { + "audio.size (${audio.size}) must equal batchSize*nBins*nFrames ($expected)" + } val nativeResult = prefillAudioInputFloat(audio, batchSize, nBins, nFrames) if (nativeResult != 0) { throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") @@ -773,13 +801,14 @@ private constructor( ): Int /** - * Prefill the KV cache with the given raw audio input. + * Prefill the KV cache with raw audio data. * * @param audio Input raw audio as a byte array - * @param batchSize Input batch size - * @param nChannels Input number of channels - * @param nSamples Input number of samples + * @param batchSize Input batch size (must be positive) + * @param nChannels Input number of channels (must be positive) + * @param nSamples Input number of samples (must be positive) * @throws ExecutorchRuntimeException if the prefill failed + * @throws IllegalArgumentException if dimensions are non-positive or array is too small */ @Experimental fun prefillRawAudio(audio: ByteArray, batchSize: Int, nChannels: Int, nSamples: Int) { @@ -787,6 +816,19 @@ private constructor( try { checkNotReentrant() checkNotDestroyed() + require(batchSize > 0 && nChannels > 0 && nSamples > 0) { + "batchSize, nChannels, and nSamples must all be positive" + } + val expected: Long + try { + val partial = Math.multiplyExact(batchSize.toLong(), nChannels.toLong()) + expected = Math.multiplyExact(partial, nSamples.toLong()) + } catch (ex: ArithmeticException) { + throw IllegalArgumentException("batchSize*nChannels*nSamples overflows", ex) + } + require(audio.size.toLong() == expected) { + "audio.size (${audio.size}) must equal batchSize*nChannels*nSamples ($expected)" + } val nativeResult = prefillRawAudioInput(audio, batchSize, nChannels, nSamples) if (nativeResult != 0) { throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b9215f978bc..9189fff17a3 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -474,22 +474,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { if (!runner_) { return static_cast(Error::InvalidState); } - if (data == nullptr) { + if (data == nullptr || batch_size <= 0 || n_bins <= 0 || n_frames <= 0) { return static_cast(Error::InvalidArgument); } auto data_size = data->size(); - if (data_size == 0) { - return 0; + size_t expected = static_cast(batch_size) * + static_cast(n_bins) * static_cast(n_frames); + if (static_cast(data_size) != expected) { + return static_cast(Error::InvalidArgument); } - std::vector data_jbyte(data_size); std::vector data_u8(data_size); - data->getRegion(0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = data_jbyte[i]; - } + data->getRegion(0, data_size, reinterpret_cast(data_u8.data())); llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; std::vector inputs; - inputs.emplace_back(llm::MultimodalInput{std::move(audio)}); + inputs.emplace_back(std::move(audio)); int32_t bos = needs_bos_ ? num_bos_ : 0; auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); if (!result.ok()) { @@ -499,7 +497,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - // Returns status_code jint prefill_audio_input_float( facebook::jni::alias_ref data, jint batch_size, @@ -508,22 +505,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { if (!runner_) { return static_cast(Error::InvalidState); } - if (data == nullptr) { + if (data == nullptr || batch_size <= 0 || n_bins <= 0 || n_frames <= 0) { return static_cast(Error::InvalidArgument); } auto data_size = data->size(); - if (data_size == 0) { - return 0; + size_t expected = static_cast(batch_size) * + static_cast(n_bins) * static_cast(n_frames); + if (static_cast(data_size) != expected) { + return static_cast(Error::InvalidArgument); } - std::vector data_jfloat(data_size); std::vector data_f(data_size); - data->getRegion(0, data_size, data_jfloat.data()); - for (int i = 0; i < data_size; i++) { - data_f[i] = data_jfloat[i]; - } + data->getRegion(0, data_size, data_f.data()); llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; std::vector inputs; - inputs.emplace_back(llm::MultimodalInput{std::move(audio)}); + inputs.emplace_back(std::move(audio)); int32_t bos = needs_bos_ ? num_bos_ : 0; auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); if (!result.ok()) { @@ -533,7 +528,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - // Returns status_code jint prefill_raw_audio_input( facebook::jni::alias_ref data, jint batch_size, @@ -542,22 +536,21 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { if (!runner_) { return static_cast(Error::InvalidState); } - if (data == nullptr) { + if (data == nullptr || batch_size <= 0 || n_channels <= 0 || + n_samples <= 0) { return static_cast(Error::InvalidArgument); } auto data_size = data->size(); - if (data_size == 0) { - return 0; + size_t expected = static_cast(batch_size) * + static_cast(n_channels) * static_cast(n_samples); + if (static_cast(data_size) != expected) { + return static_cast(Error::InvalidArgument); } - std::vector data_jbyte(data_size); std::vector data_u8(data_size); - data->getRegion(0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = data_jbyte[i]; - } + data->getRegion(0, data_size, reinterpret_cast(data_u8.data())); llm::RawAudio audio{std::move(data_u8), batch_size, n_channels, n_samples}; std::vector inputs; - inputs.emplace_back(llm::MultimodalInput{std::move(audio)}); + inputs.emplace_back(std::move(audio)); int32_t bos = needs_bos_ ? num_bos_ : 0; auto result = runner_->prefill(inputs, bos, /*num_eos=*/0); if (!result.ok()) {