diff --git a/docs/docs/03-hooks/01-natural-language-processing/useLLM.md b/docs/docs/03-hooks/01-natural-language-processing/useLLM.md index 56deff82a..98dd26aed 100644 --- a/docs/docs/03-hooks/01-natural-language-processing/useLLM.md +++ b/docs/docs/03-hooks/01-natural-language-processing/useLLM.md @@ -75,6 +75,8 @@ For more information on loading resources, take a look at [loading models](../.. | `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion `messageHistory` will be updated. | | `messageHistory` | `Message[]` | History containing all messages in conversation. This field is updated after model responds to `sendMessage`. | | `getGeneratedTokenCount` | `() => number` | Returns the number of tokens generated in the last response. | +| `getPromptTokenCount` | `() => number` | Returns the number of tokens in the prompt (input tokens) for the current session. | +| `getTotalTokenCount` | `() => number` | Returns the total number of tokens processed (sum of prompt tokens and generated tokens). |
Type definitions @@ -110,6 +112,8 @@ interface LLMType { generationConfig?: GenerationConfig; }) => void; getGeneratedTokenCount: () => number; + getPromptTokenCount: () => number; + getTotalTokenCount: () => number; generate: (messages: Message[], tools?: LLMTool[]) => Promise; sendMessage: (message: string) => Promise; deleteMessage: (index: number) => void; diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md index a2d3f4e8c..be92bc4cd 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md @@ -44,6 +44,8 @@ llm.delete(); | `delete` | `() => void` | Method to delete the model from memory. Note you cannot delete model while it's generating. You need to interrupt it first and make sure model stopped generation. | | `interrupt` | `() => void` | Interrupts model generation. It may return one more token after interrupt. | | `getGeneratedTokenCount` | `() => number` | Returns the number of tokens generated in the last response. | +| `getPromptTokenCount` | `() => number` | Returns the number of tokens in the prompt (input tokens) for the current session. | +| `getTotalTokenCount` | `() => number` | Returns the total number of tokens processed (sum of prompt tokens and generated tokens). |
Type definitions diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 9f7cb08ea..a1ce8e8e8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -107,6 +107,11 @@ template class ModelHostObject : public JsiHostObject { synchronousHostFunction<&Model::getGeneratedTokenCount>, "getGeneratedTokenCount")); + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + synchronousHostFunction<&Model::getPromptTokenCount>, + "getPromptTokenCount")); + addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, synchronousHostFunction<&Model::setCountInterval>, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 9d4f9feb0..c2b4c831f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -71,6 +71,13 @@ size_t LLM::getGeneratedTokenCount() const noexcept { return runner->stats_.num_generated_tokens; } +size_t LLM::getPromptTokenCount() const noexcept { + if (!runner || !runner->is_loaded()) { + return 0; + } + return runner->stats_.num_prompt_tokens; +} + size_t LLM::getMemoryLowerBound() const noexcept { return memorySizeLowerBound; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 6b778f56c..b23c9677a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -23,6 +23,7 @@ class LLM : public BaseModel { void interrupt(); void unload() noexcept; size_t getGeneratedTokenCount() const noexcept; + size_t getPromptTokenCount() const noexcept; size_t getMemoryLowerBound() const noexcept; void setCountInterval(size_t countInterval); void setTemperature(float temperature); diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index 5c768aab4..06ae59100 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -255,6 +255,20 @@ export class LLMController { return this.nativeModule.getGeneratedTokenCount(); } + public getPromptTokenCount(): number { + if (!this.nativeModule) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + "Cannot get prompt token count for a model that's not loaded." + ); + } + return this.nativeModule.getPromptTokenCount(); + } + + public getTotalTokenCount(): number { + return this.getGeneratedTokenCount() + this.getPromptTokenCount(); + } + public async generate( messages: Message[], tools?: LLMTool[] diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 852441556..d8f3d80ca 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -130,6 +130,16 @@ export const useLLM = ({ [controllerInstance] ); + const getPromptTokenCount = useCallback( + () => controllerInstance.getPromptTokenCount(), + [controllerInstance] + ); + + const getTotalTokenCount = useCallback( + () => controllerInstance.getTotalTokenCount(), + [controllerInstance] + ); + return { messageHistory, response, @@ -139,6 +149,8 @@ export const useLLM = ({ downloadProgress, error, getGeneratedTokenCount: getGeneratedTokenCount, + getPromptTokenCount: getPromptTokenCount, + getTotalTokenCount: getTotalTokenCount, configure: configure, generate: generate, sendMessage: sendMessage, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts index b00aac29c..bbb4c6d36 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts @@ -85,6 +85,14 @@ export class LLMModule { return this.controller.getGeneratedTokenCount(); } + getPromptTokensCount() { + return this.controller.getPromptTokenCount(); + } + + getTotalTokensCount() { + return this.controller.getTotalTokenCount(); + } + delete() { this.controller.delete(); } diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index c2ecf69aa..9d7d39cc2 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -18,6 +18,8 @@ export interface LLMType { generationConfig?: GenerationConfig; }) => void; getGeneratedTokenCount: () => number; + getTotalTokenCount: () => number; + getPromptTokenCount: () => number; generate: (messages: Message[], tools?: LLMTool[]) => Promise; sendMessage: (message: string) => Promise; deleteMessage: (index: number) => void;