Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ const App = () => {
return (
<>
<Button onPress={handleTranscribe} title="Transcribe" />
<Text>{cactusSTT.response}</Text>
<Text>{cactusSTT.transcription}</Text>
</>
);
};
Expand Down Expand Up @@ -672,7 +672,7 @@ The `useCactusSTT` hook manages a `CactusSTT` instance with reactive state. When

#### State

- `response: string` - Current transcription text. Automatically accumulated during streaming. Cleared before each new transcription and when calling `reset()` or `destroy()`.
- `transcription: string` - Current transcription text. Automatically accumulated during streaming. Cleared before each new transcription and when calling `reset()` or `destroy()`.
- `isGenerating: boolean` - Whether the model is currently generating (transcription or embedding). Both operations share this flag.
- `isInitializing: boolean` - Whether the model is initializing.
- `isDownloaded: boolean` - Whether the model is downloaded locally. Automatically checked when the hook mounts or model changes.
Expand All @@ -684,11 +684,11 @@ The `useCactusSTT` hook manages a `CactusSTT` instance with reactive state. When

- `download(params?: CactusSTTDownloadParams): Promise<void>` - Downloads the model. Updates `isDownloading` and `downloadProgress` state during download. Sets `isDownloaded` to `true` on success.
- `init(): Promise<void>` - Initializes the model for inference. Sets `isInitializing` to `true` during initialization.
- `transcribe(params: CactusSTTTranscribeParams): Promise<CactusSTTTranscribeResult>` - Transcribes audio to text. Automatically accumulates tokens in the `response` state during streaming. Sets `isGenerating` to `true` while generating. Clears `response` before starting.
- `transcribe(params: CactusSTTTranscribeParams): Promise<CactusSTTTranscribeResult>` - Transcribes audio to text. Automatically accumulates tokens in the `transcription` state during streaming. Sets `isGenerating` to `true` while generating. Clears `transcription` before starting.
- `audioEmbed(params: CactusSTTAudioEmbedParams): Promise<CactusSTTAudioEmbedResult>` - Generates embeddings for the given audio. Sets `isGenerating` to `true` during operation.
- `stop(): Promise<void>` - Stops ongoing generation. Clears any errors.
- `reset(): Promise<void>` - Resets the model's internal state. Also clears the `response` state.
- `destroy(): Promise<void>` - Releases all resources associated with the model. Clears the `response` state. Automatically called when the component unmounts.
- `reset(): Promise<void>` - Resets the model's internal state. Also clears the `transcription` state.
- `destroy(): Promise<void>` - Releases all resources associated with the model. Clears the `transcription` state. Automatically called when the component unmounts.
- `getModels(): Promise<CactusModel[]>` - Fetches available models from the database and checks their download status. Results are cached in memory and reused on subsequent calls.

## Type Definitions
Expand Down
Binary file modified android/src/main/jniLibs/arm64-v8a/libcactus.a
Binary file not shown.
4 changes: 2 additions & 2 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
PODS:
- boost (1.84.0)
- Cactus (1.1.0):
- Cactus (1.2.0):
- boost
- DoubleConversion
- fast_float
Expand Down Expand Up @@ -2643,7 +2643,7 @@ EXTERNAL SOURCES:

SPEC CHECKSUMS:
boost: 7e761d76ca2ce687f7cc98e698152abd03a18f90
Cactus: 2949301f1229677c0bbeba6856d3c78e5798aace
Cactus: 8853f351fa4c1ef40bad6d2b0152f503b713dc3e
DoubleConversion: cb417026b2400c8f53ae97020b2be961b59470cb
fast_float: b32c788ed9c6a8c584d114d0047beda9664e7cc6
FBLazyVector: b8f1312d48447cca7b4abc21ed155db14742bd03
Expand Down
4 changes: 2 additions & 2 deletions example/src/STTScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ const STTScreen = () => {
</TouchableOpacity>
</View>

{cactusSTT.response && (
{cactusSTT.transcription && (
<View style={styles.responseContainer}>
<Text style={styles.responseLabel}>Streaming:</Text>
<View style={styles.responseBox}>
<Text style={styles.responseText}>{cactusSTT.response}</Text>
<Text style={styles.responseText}>{cactusSTT.transcription}</Text>
</View>
</View>
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class ConvCache {
};

struct KVCache {
static constexpr size_t DEFAULT_WINDOW_SIZE = 512;
static constexpr size_t DEFAULT_WINDOW_SIZE = 1024;
static constexpr size_t DEFAULT_SINK_SIZE = 4;

struct LayerCache {
Expand Down Expand Up @@ -387,7 +387,7 @@ class Model {
const std::string& system_prompt = "", bool do_warmup = true);

virtual uint32_t generate(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
size_t top_k = 0, const std::string& profile_file = "");
size_t top_k = 0, const std::string& profile_file = "", bool prefill_only = false);

virtual uint32_t generate_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
float temperature = -1.0f, float top_p = -1.0f,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,44 @@ struct TensorConfig {
struct BroadcastInfo {
std::vector<size_t> output_shape;
bool needs_broadcasting;

static BroadcastInfo compute(const std::vector<size_t>& lhs, const std::vector<size_t>& rhs);
};

class BufferPool;

struct BufferDesc {
std::vector<size_t> shape;
size_t total_size;
size_t byte_size;
std::unique_ptr<char[]> data;
void* external_data;
char* pooled_data;
Precision precision;
float quantization_scale;

BufferDesc();
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);

~BufferDesc();

BufferDesc(BufferDesc&& other) noexcept;
BufferDesc& operator=(BufferDesc&& other) noexcept;

BufferDesc(const BufferDesc&) = delete;
BufferDesc& operator=(const BufferDesc&) = delete;

void* get_data();
const void* get_data() const;

template<typename T>
T* data_as() { return static_cast<T*>(get_data()); }

template<typename T>
const T* data_as() const { return static_cast<const T*>(get_data()); }

void allocate();
void allocate_from_pool(BufferPool& pool);
void release_to_pool(BufferPool& pool);
void set_external(void* ptr);
};

Expand Down Expand Up @@ -181,6 +193,33 @@ void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphN
void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);

void shrink_thread_local_buffers();

class BufferPool {
public:
BufferPool() = default;
~BufferPool() = default;

BufferPool(const BufferPool&) = delete;
BufferPool& operator=(const BufferPool&) = delete;

char* acquire(size_t byte_size);
void release(char* ptr, size_t byte_size);
void clear();

size_t active_bytes() const { return active_bytes_; }
size_t pool_bytes() const { return pool_bytes_; }
size_t peak_bytes() const { return peak_bytes_; }

private:
std::unordered_map<size_t, std::vector<std::unique_ptr<char[]>>> free_buffers_;
size_t active_bytes_ = 0;
size_t pool_bytes_ = 0;
size_t peak_bytes_ = 0;

size_t round_up_size(size_t size) const;
};

namespace ValidationUtils {
void validate_tensor_dims(const std::vector<size_t>& shape, size_t required_dims, const std::string& op_name);
void validate_precision(Precision actual, Precision required, const std::string& op_name);
Expand Down Expand Up @@ -286,6 +325,7 @@ class CactusGraph {
std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
std::unordered_map<std::string, size_t> weight_cache_;
std::vector<DebugNodeEntry> debug_nodes_;
BufferPool buffer_pool_;
};


Expand Down
Binary file modified ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class ConvCache {
};

struct KVCache {
static constexpr size_t DEFAULT_WINDOW_SIZE = 512;
static constexpr size_t DEFAULT_WINDOW_SIZE = 1024;
static constexpr size_t DEFAULT_SINK_SIZE = 4;

struct LayerCache {
Expand Down Expand Up @@ -387,7 +387,7 @@ class Model {
const std::string& system_prompt = "", bool do_warmup = true);

virtual uint32_t generate(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
size_t top_k = 0, const std::string& profile_file = "");
size_t top_k = 0, const std::string& profile_file = "", bool prefill_only = false);

virtual uint32_t generate_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
float temperature = -1.0f, float top_p = -1.0f,
Expand Down
52 changes: 46 additions & 6 deletions ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,44 @@ struct TensorConfig {
struct BroadcastInfo {
std::vector<size_t> output_shape;
bool needs_broadcasting;

static BroadcastInfo compute(const std::vector<size_t>& lhs, const std::vector<size_t>& rhs);
};

class BufferPool;

struct BufferDesc {
std::vector<size_t> shape;
size_t total_size;
size_t byte_size;
std::unique_ptr<char[]> data;
void* external_data;
char* pooled_data;
Precision precision;
float quantization_scale;

BufferDesc();
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);

~BufferDesc();

BufferDesc(BufferDesc&& other) noexcept;
BufferDesc& operator=(BufferDesc&& other) noexcept;

BufferDesc(const BufferDesc&) = delete;
BufferDesc& operator=(const BufferDesc&) = delete;

void* get_data();
const void* get_data() const;

template<typename T>
T* data_as() { return static_cast<T*>(get_data()); }

template<typename T>
const T* data_as() const { return static_cast<const T*>(get_data()); }

void allocate();
void allocate_from_pool(BufferPool& pool);
void release_to_pool(BufferPool& pool);
void set_external(void* ptr);
};

Expand Down Expand Up @@ -181,6 +193,33 @@ void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphN
void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);

void shrink_thread_local_buffers();

class BufferPool {
public:
BufferPool() = default;
~BufferPool() = default;

BufferPool(const BufferPool&) = delete;
BufferPool& operator=(const BufferPool&) = delete;

char* acquire(size_t byte_size);
void release(char* ptr, size_t byte_size);
void clear();

size_t active_bytes() const { return active_bytes_; }
size_t pool_bytes() const { return pool_bytes_; }
size_t peak_bytes() const { return peak_bytes_; }

private:
std::unordered_map<size_t, std::vector<std::unique_ptr<char[]>>> free_buffers_;
size_t active_bytes_ = 0;
size_t pool_bytes_ = 0;
size_t peak_bytes_ = 0;

size_t round_up_size(size_t size) const;
};

namespace ValidationUtils {
void validate_tensor_dims(const std::vector<size_t>& shape, size_t required_dims, const std::string& op_name);
void validate_precision(Precision actual, Precision required, const std::string& op_name);
Expand Down Expand Up @@ -286,6 +325,7 @@ class CactusGraph {
std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
std::unordered_map<std::string, size_t> weight_cache_;
std::vector<DebugNodeEntry> debug_nodes_;
BufferPool buffer_pool_;
};


Expand Down
Binary file modified ios/cactus.xcframework/ios-arm64/cactus.framework/cactus
Binary file not shown.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "cactus-react-native",
"version": "1.1.0",
"version": "1.2.0",
"description": "Run AI models locally on mobile devices",
"main": "./lib/module/index.js",
"types": "./lib/typescript/src/index.d.ts",
Expand Down
12 changes: 9 additions & 3 deletions src/classes/CactusSTT.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ export class CactusSTT {

private static readonly defaultModel = 'whisper-small';
private static readonly defaultContextSize = 2048;
private static readonly defaultPrompt =
'<|startoftranscript|><|en|><|transcribe|><|notimestamps|>';
private static readonly defaultTranscribeOptions = {
maxTokens: 512,
};
private static readonly defaultEmbedBufferSize = 32768;
private static readonly defaultEmbedBufferSize = 4096;

private static cactusModelsCache: CactusModel[] | null = null;

Expand Down Expand Up @@ -86,7 +88,7 @@ export class CactusSTT {

public async transcribe({
audioFilePath,
prompt = '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>',
prompt,
options,
onToken,
}: CactusSTTTranscribeParams): Promise<CactusSTTTranscribeResult> {
Expand All @@ -96,8 +98,12 @@ export class CactusSTT {

await this.init();

prompt = prompt ?? CactusSTT.defaultPrompt;
options = { ...CactusSTT.defaultTranscribeOptions, ...options };
const responseBufferSize = 32768;

const responseBufferSize =
8 * (options.maxTokens ?? CactusSTT.defaultTranscribeOptions.maxTokens) +
256;

this.isGenerating = true;
try {
Expand Down
2 changes: 1 addition & 1 deletion src/constants/packageVersion.ts
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export const packageVersion = '1.1.0';
export const packageVersion = '1.2.0';
14 changes: 7 additions & 7 deletions src/hooks/useCactusSTT.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export const useCactusSTT = ({
);

// State
const [response, setResponse] = useState('');
const [transcription, setTranscription] = useState('');
const [isGenerating, setIsGenerating] = useState(false);
const [isInitializing, setIsInitializing] = useState(false);
const [isDownloaded, setIsDownloaded] = useState(false);
Expand All @@ -39,7 +39,7 @@ export const useCactusSTT = ({
useEffect(() => {
setCactusSTT(new CactusSTT({ model, contextSize }));

setResponse('');
setTranscription('');
setIsGenerating(false);
setIsInitializing(false);
setIsDownloaded(false);
Expand Down Expand Up @@ -174,15 +174,15 @@ export const useCactusSTT = ({
}

setError(null);
setResponse('');
setTranscription('');
setIsGenerating(true);
try {
return await cactusSTT.transcribe({
audioFilePath,
prompt,
options,
onToken: (token) => {
setResponse((prev) => prev + token);
setTranscription((prev) => prev + token);
onToken?.(token);
},
});
Expand Down Expand Up @@ -238,7 +238,7 @@ export const useCactusSTT = ({
setError(getErrorMessage(e));
throw e;
} finally {
setResponse('');
setTranscription('');
}
}, [cactusSTT]);

Expand All @@ -250,7 +250,7 @@ export const useCactusSTT = ({
setError(getErrorMessage(e));
throw e;
} finally {
setResponse('');
setTranscription('');
}
}, [cactusSTT]);

Expand All @@ -265,7 +265,7 @@ export const useCactusSTT = ({
}, [cactusSTT]);

return {
response,
transcription,
isGenerating,
isInitializing,
isDownloaded,
Expand Down
Loading