Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -711,20 +711,34 @@ private constructor(
): Int

/**
* Prefill the KV cache with the given preprocessed audio input.
* Prefill the KV cache with pre-processed audio data (uint8 mel spectrogram).
*
* @param audio Input preprocessed audio as a byte array
* @param batchSize Input batch size
* @param nBins Input number of bins
* @param nFrames Input number of frames
* @param audio Input audio as a byte array (uint8 values)
* @param batchSize Input batch size (must be positive)
* @param nBins Input number of frequency bins (must be positive)
* @param nFrames Input number of time frames (must be positive)
* @throws ExecutorchRuntimeException if the prefill failed
* @throws IllegalArgumentException if dimensions are non-positive or array is too small
*/
@Experimental
fun prefillAudio(audio: ByteArray, batchSize: Int, nBins: Int, nFrames: Int) {
mLock.lock()
try {
checkNotReentrant()
checkNotDestroyed()
require(batchSize > 0 && nBins > 0 && nFrames > 0) {
"batchSize, nBins, and nFrames must all be positive"
}
Comment on lines +729 to +731
val expected: Long
try {
val partial = Math.multiplyExact(batchSize.toLong(), nBins.toLong())
expected = Math.multiplyExact(partial, nFrames.toLong())
} catch (ex: ArithmeticException) {
throw IllegalArgumentException("batchSize*nBins*nFrames overflows", ex)
}
require(audio.size.toLong() == expected) {
"audio.size (${audio.size}) must equal batchSize*nBins*nFrames ($expected)"
}
val nativeResult = prefillAudioInput(audio, batchSize, nBins, nFrames)
if (nativeResult != 0) {
throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed")
Expand All @@ -742,20 +756,34 @@ private constructor(
): Int

/**
* Prefill the KV cache with the given preprocessed audio input.
* Prefill the KV cache with pre-processed audio data (float32 mel spectrogram).
*
* @param audio Input preprocessed audio as a float array
* @param batchSize Input batch size
* @param nBins Input number of bins
* @param nFrames Input number of frames
* @param audio Input audio as a float array (float32 values)
* @param batchSize Input batch size (must be positive)
* @param nBins Input number of frequency bins (must be positive)
* @param nFrames Input number of time frames (must be positive)
* @throws ExecutorchRuntimeException if the prefill failed
* @throws IllegalArgumentException if dimensions are non-positive or array is too small
*/
@Experimental
fun prefillAudio(audio: FloatArray, batchSize: Int, nBins: Int, nFrames: Int) {
mLock.lock()
try {
checkNotReentrant()
checkNotDestroyed()
require(batchSize > 0 && nBins > 0 && nFrames > 0) {
"batchSize, nBins, and nFrames must all be positive"
}
val expected: Long
try {
val partial = Math.multiplyExact(batchSize.toLong(), nBins.toLong())
expected = Math.multiplyExact(partial, nFrames.toLong())
} catch (ex: ArithmeticException) {
throw IllegalArgumentException("batchSize*nBins*nFrames overflows", ex)
}
require(audio.size.toLong() == expected) {
"audio.size (${audio.size}) must equal batchSize*nBins*nFrames ($expected)"
}
val nativeResult = prefillAudioInputFloat(audio, batchSize, nBins, nFrames)
if (nativeResult != 0) {
throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed")
Expand All @@ -773,20 +801,34 @@ private constructor(
): Int

/**
* Prefill the KV cache with the given raw audio input.
* Prefill the KV cache with raw audio data.
*
* @param audio Input raw audio as a byte array
* @param batchSize Input batch size
* @param nChannels Input number of channels
* @param nSamples Input number of samples
* @param batchSize Input batch size (must be positive)
* @param nChannels Input number of channels (must be positive)
* @param nSamples Input number of samples (must be positive)
* @throws ExecutorchRuntimeException if the prefill failed
* @throws IllegalArgumentException if dimensions are non-positive or array is too small
*/
@Experimental
fun prefillRawAudio(audio: ByteArray, batchSize: Int, nChannels: Int, nSamples: Int) {
mLock.lock()
try {
checkNotReentrant()
checkNotDestroyed()
require(batchSize > 0 && nChannels > 0 && nSamples > 0) {
"batchSize, nChannels, and nSamples must all be positive"
}
val expected: Long
try {
val partial = Math.multiplyExact(batchSize.toLong(), nChannels.toLong())
expected = Math.multiplyExact(partial, nSamples.toLong())
} catch (ex: ArithmeticException) {
throw IllegalArgumentException("batchSize*nChannels*nSamples overflows", ex)
}
require(audio.size.toLong() == expected) {
"audio.size (${audio.size}) must equal batchSize*nChannels*nSamples ($expected)"
}
val nativeResult = prefillRawAudioInput(audio, batchSize, nChannels, nSamples)
if (nativeResult != 0) {
throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed")
Expand Down
51 changes: 22 additions & 29 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,22 +474,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
if (!runner_) {
return static_cast<jint>(Error::InvalidState);
}
if (data == nullptr) {
if (data == nullptr || batch_size <= 0 || n_bins <= 0 || n_frames <= 0) {
return static_cast<jint>(Error::InvalidArgument);
}
auto data_size = data->size();
if (data_size == 0) {
return 0;
size_t expected = static_cast<size_t>(batch_size) *
static_cast<size_t>(n_bins) * static_cast<size_t>(n_frames);
if (static_cast<size_t>(data_size) != expected) {
return static_cast<jint>(Error::InvalidArgument);
}
std::vector<jbyte> data_jbyte(data_size);
std::vector<uint8_t> data_u8(data_size);
data->getRegion(0, data_size, data_jbyte.data());
for (int i = 0; i < data_size; i++) {
data_u8[i] = data_jbyte[i];
}
data->getRegion(0, data_size, reinterpret_cast<jbyte*>(data_u8.data()));
llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames};
Comment on lines 480 to 488
std::vector<llm::MultimodalInput> inputs;
inputs.emplace_back(llm::MultimodalInput{std::move(audio)});
inputs.emplace_back(std::move(audio));
int32_t bos = needs_bos_ ? num_bos_ : 0;
auto result = runner_->prefill(inputs, bos, /*num_eos=*/0);
if (!result.ok()) {
Expand All @@ -499,7 +497,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

// Returns status_code
jint prefill_audio_input_float(
facebook::jni::alias_ref<jfloatArray> data,
jint batch_size,
Expand All @@ -508,22 +505,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
if (!runner_) {
return static_cast<jint>(Error::InvalidState);
}
if (data == nullptr) {
if (data == nullptr || batch_size <= 0 || n_bins <= 0 || n_frames <= 0) {
return static_cast<jint>(Error::InvalidArgument);
}
auto data_size = data->size();
if (data_size == 0) {
return 0;
size_t expected = static_cast<size_t>(batch_size) *
static_cast<size_t>(n_bins) * static_cast<size_t>(n_frames);
if (static_cast<size_t>(data_size) != expected) {
return static_cast<jint>(Error::InvalidArgument);
}
std::vector<jfloat> data_jfloat(data_size);
std::vector<float> data_f(data_size);
data->getRegion(0, data_size, data_jfloat.data());
for (int i = 0; i < data_size; i++) {
data_f[i] = data_jfloat[i];
}
data->getRegion(0, data_size, data_f.data());
llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames};
Comment on lines 511 to 519
std::vector<llm::MultimodalInput> inputs;
inputs.emplace_back(llm::MultimodalInput{std::move(audio)});
inputs.emplace_back(std::move(audio));
int32_t bos = needs_bos_ ? num_bos_ : 0;
auto result = runner_->prefill(inputs, bos, /*num_eos=*/0);
if (!result.ok()) {
Expand All @@ -533,7 +528,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

// Returns status_code
jint prefill_raw_audio_input(
facebook::jni::alias_ref<jbyteArray> data,
jint batch_size,
Expand All @@ -542,22 +536,21 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
if (!runner_) {
return static_cast<jint>(Error::InvalidState);
}
if (data == nullptr) {
if (data == nullptr || batch_size <= 0 || n_channels <= 0 ||
n_samples <= 0) {
return static_cast<jint>(Error::InvalidArgument);
}
auto data_size = data->size();
if (data_size == 0) {
return 0;
size_t expected = static_cast<size_t>(batch_size) *
static_cast<size_t>(n_channels) * static_cast<size_t>(n_samples);
if (static_cast<size_t>(data_size) != expected) {
return static_cast<jint>(Error::InvalidArgument);
}
std::vector<jbyte> data_jbyte(data_size);
std::vector<uint8_t> data_u8(data_size);
data->getRegion(0, data_size, data_jbyte.data());
for (int i = 0; i < data_size; i++) {
data_u8[i] = data_jbyte[i];
}
data->getRegion(0, data_size, reinterpret_cast<jbyte*>(data_u8.data()));
llm::RawAudio audio{std::move(data_u8), batch_size, n_channels, n_samples};
Comment on lines 543 to 551
std::vector<llm::MultimodalInput> inputs;
inputs.emplace_back(llm::MultimodalInput{std::move(audio)});
inputs.emplace_back(std::move(audio));
int32_t bos = needs_bos_ ? num_bos_ : 0;
auto result = runner_->prefill(inputs, bos, /*num_eos=*/0);
if (!result.ok()) {
Expand Down
Loading