diff --git a/README.md b/README.md
index 604b971..1b83d52 100644
--- a/README.md
+++ b/README.md
@@ -460,7 +460,7 @@ const App = () => {
return (
<>
- {cactusSTT.response}
+ {cactusSTT.transcription}
>
);
};
@@ -672,7 +672,7 @@ The `useCactusSTT` hook manages a `CactusSTT` instance with reactive state. When
#### State
-- `response: string` - Current transcription text. Automatically accumulated during streaming. Cleared before each new transcription and when calling `reset()` or `destroy()`.
+- `transcription: string` - Current transcription text. Automatically accumulated during streaming. Cleared before each new transcription and when calling `reset()` or `destroy()`.
- `isGenerating: boolean` - Whether the model is currently generating (transcription or embedding). Both operations share this flag.
- `isInitializing: boolean` - Whether the model is initializing.
- `isDownloaded: boolean` - Whether the model is downloaded locally. Automatically checked when the hook mounts or model changes.
@@ -684,11 +684,11 @@ The `useCactusSTT` hook manages a `CactusSTT` instance with reactive state. When
- `download(params?: CactusSTTDownloadParams): Promise` - Downloads the model. Updates `isDownloading` and `downloadProgress` state during download. Sets `isDownloaded` to `true` on success.
- `init(): Promise` - Initializes the model for inference. Sets `isInitializing` to `true` during initialization.
-- `transcribe(params: CactusSTTTranscribeParams): Promise` - Transcribes audio to text. Automatically accumulates tokens in the `response` state during streaming. Sets `isGenerating` to `true` while generating. Clears `response` before starting.
+- `transcribe(params: CactusSTTTranscribeParams): Promise` - Transcribes audio to text. Automatically accumulates tokens in the `transcription` state during streaming. Sets `isGenerating` to `true` while generating. Clears `transcription` before starting.
- `audioEmbed(params: CactusSTTAudioEmbedParams): Promise` - Generates embeddings for the given audio. Sets `isGenerating` to `true` during operation.
- `stop(): Promise` - Stops ongoing generation. Clears any errors.
-- `reset(): Promise` - Resets the model's internal state. Also clears the `response` state.
-- `destroy(): Promise` - Releases all resources associated with the model. Clears the `response` state. Automatically called when the component unmounts.
+- `reset(): Promise` - Resets the model's internal state. Also clears the `transcription` state.
+- `destroy(): Promise` - Releases all resources associated with the model. Clears the `transcription` state. Automatically called when the component unmounts.
- `getModels(): Promise` - Fetches available models from the database and checks their download status. Results are cached in memory and reused on subsequent calls.
## Type Definitions
diff --git a/android/src/main/jniLibs/arm64-v8a/libcactus.a b/android/src/main/jniLibs/arm64-v8a/libcactus.a
index 7fb0617..60bf48f 100644
Binary files a/android/src/main/jniLibs/arm64-v8a/libcactus.a and b/android/src/main/jniLibs/arm64-v8a/libcactus.a differ
diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock
index fcb4602..0818af0 100644
--- a/example/ios/Podfile.lock
+++ b/example/ios/Podfile.lock
@@ -1,6 +1,6 @@
PODS:
- boost (1.84.0)
- - Cactus (1.1.0):
+ - Cactus (1.2.0):
- boost
- DoubleConversion
- fast_float
@@ -2643,7 +2643,7 @@ EXTERNAL SOURCES:
SPEC CHECKSUMS:
boost: 7e761d76ca2ce687f7cc98e698152abd03a18f90
- Cactus: 2949301f1229677c0bbeba6856d3c78e5798aace
+ Cactus: 8853f351fa4c1ef40bad6d2b0152f503b713dc3e
DoubleConversion: cb417026b2400c8f53ae97020b2be961b59470cb
fast_float: b32c788ed9c6a8c584d114d0047beda9664e7cc6
FBLazyVector: b8f1312d48447cca7b4abc21ed155db14742bd03
diff --git a/example/src/STTScreen.tsx b/example/src/STTScreen.tsx
index f58fbbe..898c2e5 100644
--- a/example/src/STTScreen.tsx
+++ b/example/src/STTScreen.tsx
@@ -136,11 +136,11 @@ const STTScreen = () => {
- {cactusSTT.response && (
+ {cactusSTT.transcription && (
Streaming:
- {cactusSTT.response}
+ {cactusSTT.transcription}
)}
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h
index e233cd2..77594ed 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h
@@ -319,7 +319,7 @@ class ConvCache {
};
struct KVCache {
- static constexpr size_t DEFAULT_WINDOW_SIZE = 512;
+ static constexpr size_t DEFAULT_WINDOW_SIZE = 1024;
static constexpr size_t DEFAULT_SINK_SIZE = 4;
struct LayerCache {
@@ -387,7 +387,7 @@ class Model {
const std::string& system_prompt = "", bool do_warmup = true);
virtual uint32_t generate(const std::vector& tokens, float temperature = -1.0f, float top_p = -1.0f,
- size_t top_k = 0, const std::string& profile_file = "");
+ size_t top_k = 0, const std::string& profile_file = "", bool prefill_only = false);
virtual uint32_t generate_with_images(const std::vector& tokens, const std::vector& image_paths,
float temperature = -1.0f, float top_p = -1.0f,
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h
index bf4bc19..bc43d1c 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h
@@ -92,32 +92,44 @@ struct TensorConfig {
struct BroadcastInfo {
std::vector output_shape;
bool needs_broadcasting;
-
+
static BroadcastInfo compute(const std::vector& lhs, const std::vector& rhs);
};
+class BufferPool;
+
struct BufferDesc {
std::vector shape;
size_t total_size;
size_t byte_size;
std::unique_ptr data;
void* external_data;
+ char* pooled_data;
Precision precision;
float quantization_scale;
-
+
BufferDesc();
BufferDesc(const std::vector& s, Precision prec = Precision::INT8, float scale = 1.0f);
-
+ ~BufferDesc();
+
+ BufferDesc(BufferDesc&& other) noexcept;
+ BufferDesc& operator=(BufferDesc&& other) noexcept;
+
+ BufferDesc(const BufferDesc&) = delete;
+ BufferDesc& operator=(const BufferDesc&) = delete;
+
void* get_data();
const void* get_data() const;
-
+
template
T* data_as() { return static_cast(get_data()); }
-
+
template
const T* data_as() const { return static_cast(get_data()); }
-
+
void allocate();
+ void allocate_from_pool(BufferPool& pool);
+ void release_to_pool(BufferPool& pool);
void set_external(void* ptr);
};
@@ -181,6 +193,33 @@ void compute_topk_node(GraphNode& node, const std::vector>& nodes, const std::unordered_map& node_index_map);
void compute_index_node(GraphNode& node, const std::vector>& nodes, const std::unordered_map& node_index_map);
+void shrink_thread_local_buffers();
+
+class BufferPool {
+public:
+ BufferPool() = default;
+ ~BufferPool() = default;
+
+ BufferPool(const BufferPool&) = delete;
+ BufferPool& operator=(const BufferPool&) = delete;
+
+ char* acquire(size_t byte_size);
+ void release(char* ptr, size_t byte_size);
+ void clear();
+
+ size_t active_bytes() const { return active_bytes_; }
+ size_t pool_bytes() const { return pool_bytes_; }
+ size_t peak_bytes() const { return peak_bytes_; }
+
+private:
+ std::unordered_map>> free_buffers_;
+ size_t active_bytes_ = 0;
+ size_t pool_bytes_ = 0;
+ size_t peak_bytes_ = 0;
+
+ size_t round_up_size(size_t size) const;
+};
+
namespace ValidationUtils {
void validate_tensor_dims(const std::vector& shape, size_t required_dims, const std::string& op_name);
void validate_precision(Precision actual, Precision required, const std::string& op_name);
@@ -286,6 +325,7 @@ class CactusGraph {
std::vector> mapped_files_;
std::unordered_map weight_cache_;
std::vector debug_nodes_;
+ BufferPool buffer_pool_;
};
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus
index 449c8ac..9ab26fc 100755
Binary files a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus and b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus differ
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h
index e233cd2..77594ed 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h
@@ -319,7 +319,7 @@ class ConvCache {
};
struct KVCache {
- static constexpr size_t DEFAULT_WINDOW_SIZE = 512;
+ static constexpr size_t DEFAULT_WINDOW_SIZE = 1024;
static constexpr size_t DEFAULT_SINK_SIZE = 4;
struct LayerCache {
@@ -387,7 +387,7 @@ class Model {
const std::string& system_prompt = "", bool do_warmup = true);
virtual uint32_t generate(const std::vector& tokens, float temperature = -1.0f, float top_p = -1.0f,
- size_t top_k = 0, const std::string& profile_file = "");
+ size_t top_k = 0, const std::string& profile_file = "", bool prefill_only = false);
virtual uint32_t generate_with_images(const std::vector& tokens, const std::vector& image_paths,
float temperature = -1.0f, float top_p = -1.0f,
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
index bf4bc19..bc43d1c 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
@@ -92,32 +92,44 @@ struct TensorConfig {
struct BroadcastInfo {
std::vector output_shape;
bool needs_broadcasting;
-
+
static BroadcastInfo compute(const std::vector& lhs, const std::vector& rhs);
};
+class BufferPool;
+
struct BufferDesc {
std::vector shape;
size_t total_size;
size_t byte_size;
std::unique_ptr data;
void* external_data;
+ char* pooled_data;
Precision precision;
float quantization_scale;
-
+
BufferDesc();
BufferDesc(const std::vector& s, Precision prec = Precision::INT8, float scale = 1.0f);
-
+ ~BufferDesc();
+
+ BufferDesc(BufferDesc&& other) noexcept;
+ BufferDesc& operator=(BufferDesc&& other) noexcept;
+
+ BufferDesc(const BufferDesc&) = delete;
+ BufferDesc& operator=(const BufferDesc&) = delete;
+
void* get_data();
const void* get_data() const;
-
+
template
T* data_as() { return static_cast(get_data()); }
-
+
template
const T* data_as() const { return static_cast(get_data()); }
-
+
void allocate();
+ void allocate_from_pool(BufferPool& pool);
+ void release_to_pool(BufferPool& pool);
void set_external(void* ptr);
};
@@ -181,6 +193,33 @@ void compute_topk_node(GraphNode& node, const std::vector>& nodes, const std::unordered_map& node_index_map);
void compute_index_node(GraphNode& node, const std::vector>& nodes, const std::unordered_map& node_index_map);
+void shrink_thread_local_buffers();
+
+class BufferPool {
+public:
+ BufferPool() = default;
+ ~BufferPool() = default;
+
+ BufferPool(const BufferPool&) = delete;
+ BufferPool& operator=(const BufferPool&) = delete;
+
+ char* acquire(size_t byte_size);
+ void release(char* ptr, size_t byte_size);
+ void clear();
+
+ size_t active_bytes() const { return active_bytes_; }
+ size_t pool_bytes() const { return pool_bytes_; }
+ size_t peak_bytes() const { return peak_bytes_; }
+
+private:
+ std::unordered_map>> free_buffers_;
+ size_t active_bytes_ = 0;
+ size_t pool_bytes_ = 0;
+ size_t peak_bytes_ = 0;
+
+ size_t round_up_size(size_t size) const;
+};
+
namespace ValidationUtils {
void validate_tensor_dims(const std::vector& shape, size_t required_dims, const std::string& op_name);
void validate_precision(Precision actual, Precision required, const std::string& op_name);
@@ -286,6 +325,7 @@ class CactusGraph {
std::vector> mapped_files_;
std::unordered_map weight_cache_;
std::vector debug_nodes_;
+ BufferPool buffer_pool_;
};
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus b/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus
index c71c4eb..1335f38 100755
Binary files a/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus and b/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus differ
diff --git a/package.json b/package.json
index 45e4055..6d42725 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "cactus-react-native",
- "version": "1.1.0",
+ "version": "1.2.0",
"description": "Run AI models locally on mobile devices",
"main": "./lib/module/index.js",
"types": "./lib/typescript/src/index.d.ts",
diff --git a/src/classes/CactusSTT.ts b/src/classes/CactusSTT.ts
index 2747217..243a4ab 100644
--- a/src/classes/CactusSTT.ts
+++ b/src/classes/CactusSTT.ts
@@ -25,10 +25,12 @@ export class CactusSTT {
private static readonly defaultModel = 'whisper-small';
private static readonly defaultContextSize = 2048;
+ private static readonly defaultPrompt =
+ '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>';
private static readonly defaultTranscribeOptions = {
maxTokens: 512,
};
- private static readonly defaultEmbedBufferSize = 32768;
+ private static readonly defaultEmbedBufferSize = 4096;
private static cactusModelsCache: CactusModel[] | null = null;
@@ -86,7 +88,7 @@ export class CactusSTT {
public async transcribe({
audioFilePath,
- prompt = '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>',
+ prompt,
options,
onToken,
}: CactusSTTTranscribeParams): Promise {
@@ -96,8 +98,12 @@ export class CactusSTT {
await this.init();
+ prompt = prompt ?? CactusSTT.defaultPrompt;
options = { ...CactusSTT.defaultTranscribeOptions, ...options };
- const responseBufferSize = 32768;
+
+ const responseBufferSize =
+ 8 * (options.maxTokens ?? CactusSTT.defaultTranscribeOptions.maxTokens) +
+ 256;
this.isGenerating = true;
try {
diff --git a/src/constants/packageVersion.ts b/src/constants/packageVersion.ts
index 3009f7a..9de35a5 100644
--- a/src/constants/packageVersion.ts
+++ b/src/constants/packageVersion.ts
@@ -1 +1 @@
-export const packageVersion = '1.1.0';
+export const packageVersion = '1.2.0';
diff --git a/src/hooks/useCactusSTT.ts b/src/hooks/useCactusSTT.ts
index 3446df8..aadf19b 100644
--- a/src/hooks/useCactusSTT.ts
+++ b/src/hooks/useCactusSTT.ts
@@ -21,7 +21,7 @@ export const useCactusSTT = ({
);
// State
- const [response, setResponse] = useState('');
+ const [transcription, setTranscription] = useState('');
const [isGenerating, setIsGenerating] = useState(false);
const [isInitializing, setIsInitializing] = useState(false);
const [isDownloaded, setIsDownloaded] = useState(false);
@@ -39,7 +39,7 @@ export const useCactusSTT = ({
useEffect(() => {
setCactusSTT(new CactusSTT({ model, contextSize }));
- setResponse('');
+ setTranscription('');
setIsGenerating(false);
setIsInitializing(false);
setIsDownloaded(false);
@@ -174,7 +174,7 @@ export const useCactusSTT = ({
}
setError(null);
- setResponse('');
+ setTranscription('');
setIsGenerating(true);
try {
return await cactusSTT.transcribe({
@@ -182,7 +182,7 @@ export const useCactusSTT = ({
prompt,
options,
onToken: (token) => {
- setResponse((prev) => prev + token);
+ setTranscription((prev) => prev + token);
onToken?.(token);
},
});
@@ -238,7 +238,7 @@ export const useCactusSTT = ({
setError(getErrorMessage(e));
throw e;
} finally {
- setResponse('');
+ setTranscription('');
}
}, [cactusSTT]);
@@ -250,7 +250,7 @@ export const useCactusSTT = ({
setError(getErrorMessage(e));
throw e;
} finally {
- setResponse('');
+ setTranscription('');
}
}, [cactusSTT]);
@@ -265,7 +265,7 @@ export const useCactusSTT = ({
}, [cactusSTT]);
return {
- response,
+ transcription,
isGenerating,
isInitializing,
isDownloaded,