Harden existing audio prefill APIs (#20136)#20136
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20136
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 69dba93 with merge base a9d5674 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D107929913. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR aims to harden the Android audio prefill APIs by adding input validation and reducing unnecessary JNI copies while keeping API signatures unchanged.
Changes:
- Add dimension/size validation for
prefillAudio/prefillRawAudioin Kotlin and JNI. - Replace JNI double-allocation + per-element conversions with direct
getRegioninto native vectors. - Update Kotlin KDoc to clarify data types and thrown exceptions.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| extension/android/jni/jni_layer_llama.cpp | Adds validation and refactors audio/raw-audio JNI array handling to reduce copies (but currently has a correctness issue when arrays are larger than expected). |
| extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt | Adds require(...) validation for audio prefills and improves documentation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| auto data_size = data->size(); | ||
| if (data_size == 0) { | ||
| return 0; | ||
| size_t expected = | ||
| static_cast<size_t>(batch_size) * static_cast<size_t>(n_bins) * | ||
| static_cast<size_t>(n_frames); | ||
| if (static_cast<size_t>(data_size) < expected) { | ||
| return static_cast<jint>(Error::InvalidArgument); | ||
| } | ||
| std::vector<jbyte> data_jbyte(data_size); | ||
| std::vector<uint8_t> data_u8(data_size); | ||
| data->getRegion(0, data_size, data_jbyte.data()); | ||
| for (int i = 0; i < data_size; i++) { | ||
| data_u8[i] = data_jbyte[i]; | ||
| } | ||
| data->getRegion( | ||
| 0, data_size, reinterpret_cast<jbyte*>(data_u8.data())); | ||
| llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; |
| auto data_size = data->size(); | ||
| if (data_size == 0) { | ||
| return 0; | ||
| size_t expected = | ||
| static_cast<size_t>(batch_size) * static_cast<size_t>(n_bins) * | ||
| static_cast<size_t>(n_frames); | ||
| if (static_cast<size_t>(data_size) < expected) { | ||
| return static_cast<jint>(Error::InvalidArgument); | ||
| } | ||
| std::vector<jfloat> data_jfloat(data_size); | ||
| std::vector<float> data_f(data_size); | ||
| data->getRegion(0, data_size, data_jfloat.data()); | ||
| for (int i = 0; i < data_size; i++) { | ||
| data_f[i] = data_jfloat[i]; | ||
| } | ||
| data->getRegion(0, data_size, data_f.data()); | ||
| llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; |
| auto data_size = data->size(); | ||
| if (data_size == 0) { | ||
| return 0; | ||
| size_t expected = | ||
| static_cast<size_t>(batch_size) * static_cast<size_t>(n_channels) * | ||
| static_cast<size_t>(n_samples); | ||
| if (static_cast<size_t>(data_size) < expected) { | ||
| return static_cast<jint>(Error::InvalidArgument); | ||
| } | ||
| std::vector<jbyte> data_jbyte(data_size); | ||
| std::vector<uint8_t> data_u8(data_size); | ||
| data->getRegion(0, data_size, data_jbyte.data()); | ||
| for (int i = 0; i < data_size; i++) { | ||
| data_u8[i] = data_jbyte[i]; | ||
| } | ||
| data->getRegion( | ||
| 0, data_size, reinterpret_cast<jbyte*>(data_u8.data())); | ||
| llm::RawAudio audio{std::move(data_u8), batch_size, n_channels, n_samples}; |
| * @param batchSize Input batch size | ||
| * @param nBins Input number of bins | ||
| * @param nFrames Input number of frames | ||
| * @param audio Input audio as a byte array (uint8 values) |
| * @param batchSize Input batch size | ||
| * @param nBins Input number of bins | ||
| * @param nFrames Input number of frames | ||
| * @param audio Input audio as a float array (float32 values) |
| * Prefill the KV cache with the given raw audio input. | ||
| * Prefill the KV cache with raw audio data. | ||
| * | ||
| * @param audio Input raw audio as a byte array |
| require(batchSize > 0 && nBins > 0 && nFrames > 0) { | ||
| "batchSize, nBins, and nFrames must all be positive" | ||
| } |
|
@claude review |
|
Claude finished @kirklandsign's task in 1m 41s —— View job Code Review: Harden existing audio prefill APIs
Overall this is a solid hardening PR — the validation additions, Bug: Vector size mismatch with
|
Summary: Improve the existing audio prefill methods (`prefillAudio`, `prefillRawAudio`) with input validation and code cleanup. No API signature changes. Kotlin: - Add `require` checks for positive dimensions and array size >= expected element count - Use `Math.multiplyExact` to detect overflow in dimension multiplication (consistent with ByteBuffer variants) - Improve docstrings: clarify data types (uint8/float32), document `throws IllegalArgumentException` JNI (`jni_layer_llama.cpp`): - Add dimension validation (`batch_size > 0`, `n_bins > 0`, etc.) - Add array size consistency check against `batchSize * nBins * nFrames` - Replace double-allocation + per-element copy with single allocation + `reinterpret_cast` / direct `getRegion` - Fix `emplace_back` lint (`modernize-use-emplace`) - Remove silent success on empty data (now returns `InvalidArgument`) - Use `size_t` casts to prevent integer overflow in size calculations Differential Revision: D107929913
12ab659 to
d165042
Compare
Summary: Improve the existing audio prefill methods (`prefillAudio`, `prefillRawAudio`) with input validation and code cleanup. No API signature changes. Kotlin: - Add `require` checks for positive dimensions and array size >= expected element count - Use `Math.multiplyExact` to detect overflow in dimension multiplication (consistent with ByteBuffer variants) - Improve docstrings: clarify data types (uint8/float32), document `throws IllegalArgumentException` JNI (`jni_layer_llama.cpp`): - Add dimension validation (`batch_size > 0`, `n_bins > 0`, etc.) - Add array size consistency check against `batchSize * nBins * nFrames` - Replace double-allocation + per-element copy with single allocation + `reinterpret_cast` / direct `getRegion` - Fix `emplace_back` lint (`modernize-use-emplace`) - Remove silent success on empty data (now returns `InvalidArgument`) - Use `size_t` casts to prevent integer overflow in size calculations Differential Revision: D107929913
d165042 to
1692652
Compare
Summary: Improve the existing audio prefill methods (`prefillAudio`, `prefillRawAudio`) with input validation and code cleanup. No API signature changes. Kotlin: - Add `require` checks for positive dimensions and array size >= expected element count - Use `Math.multiplyExact` to detect overflow in dimension multiplication (consistent with ByteBuffer variants) - Improve docstrings: clarify data types (uint8/float32), document `throws IllegalArgumentException` JNI (`jni_layer_llama.cpp`): - Add dimension validation (`batch_size > 0`, `n_bins > 0`, etc.) - Add array size consistency check against `batchSize * nBins * nFrames` - Replace double-allocation + per-element copy with single allocation + `reinterpret_cast` / direct `getRegion` - Fix `emplace_back` lint (`modernize-use-emplace`) - Remove silent success on empty data (now returns `InvalidArgument`) - Use `size_t` casts to prevent integer overflow in size calculations Differential Revision: D107929913
1692652 to
69dba93
Compare
Summary:
Improve the existing audio prefill methods (
prefillAudio,prefillRawAudio) with input validation and code cleanup. No API signature changes.Kotlin:
requirechecks for positive dimensions and array size >= expected element countMath.multiplyExactto detect overflow in dimension multiplication (consistent with ByteBuffer variants)throws IllegalArgumentExceptionJNI (
jni_layer_llama.cpp):batch_size > 0,n_bins > 0, etc.)batchSize * nBins * nFramesreinterpret_cast/ directgetRegionemplace_backlint (modernize-use-emplace)InvalidArgument)size_tcasts to prevent integer overflow in size calculationsDifferential Revision: D107929913