Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
563 changes: 555 additions & 8 deletions README.md

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_CXX_STANDARD 20)

# Define C++ library and add all sources
add_library(${PACKAGE_NAME} SHARED
add_library(${PACKAGE_NAME} SHARED
src/main/cpp/cpp-adapter.cpp
../cpp/HybridCactus.cpp
../cpp/HybridCactusUtil.cpp
../cpp/HybridCactusIndex.cpp
)

add_library(libcactus STATIC IMPORTED)
set_target_properties(libcactus PROPERTIES
IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libcactus.a"
)

add_library(libcactus_util SHARED IMPORTED)
add_library(libcactus_util STATIC IMPORTED)
set_target_properties(libcactus_util PROPERTIES
IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libcactus_util.so"
IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libcactus_util.a"
)

# Add Nitrogen specs :)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class HybridCactusFileSystem : HybridCactusFileSystemSpec() {
modelFile.deleteRecursively()
}

override fun getIndexPath(name: String): Promise<String> = Promise.async { indexFile(name).absolutePath }

private fun cactusFile(): File {
val cactusDir = File(context.filesDir, "cactus")

Expand All @@ -221,6 +223,23 @@ class HybridCactusFileSystem : HybridCactusFileSystemSpec() {

private fun modelFile(model: String): File {
val cactusDir = cactusFile()
return File(cactusDir, "models/$model")
val modelsDir = File(cactusDir, "models")

if (!modelsDir.exists()) {
modelsDir.mkdirs()
}

return File(modelsDir, model)
}

private fun indexFile(name: String): File {
val cactusDir = cactusFile()
val finalDir = File(cactusDir, "indexes/$name")

if (!finalDir.exists()) {
finalDir.mkdirs()
}

return finalDir
}
}
Binary file modified android/src/main/jniLibs/arm64-v8a/libcactus.a
Binary file not shown.
Binary file not shown.
Binary file not shown.
131 changes: 112 additions & 19 deletions cpp/HybridCactus.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "HybridCactus.hpp"

namespace margelo::nitro::cactus {

HybridCactus::HybridCactus() : HybridObject(TAG) {}

std::shared_ptr<Promise<void>>
Expand All @@ -19,7 +20,8 @@ HybridCactus::init(const std::string &modelPath, double contextSize,
corpusDir ? corpusDir->c_str() : nullptr);

if (!model) {
throw std::runtime_error("Failed to initialize Cactus model");
throw std::runtime_error("Cactus init failed: " +
std::string(cactus_get_last_error()));
}

this->_model = model;
Expand Down Expand Up @@ -65,7 +67,8 @@ std::shared_ptr<Promise<std::string>> HybridCactus::complete(
cactusTokenCallback, &callbackCtx);

if (result < 0) {
throw std::runtime_error("Cactus completion failed");
throw std::runtime_error("Cactus complete failed: " +
std::string(cactus_get_last_error()));
}

// Remove null terminator
Expand All @@ -75,12 +78,78 @@ std::shared_ptr<Promise<std::string>> HybridCactus::complete(
});
}

std::shared_ptr<Promise<std::vector<double>>>
HybridCactus::tokenize(const std::string &text) {
return Promise<std::vector<double>>::async([this,
text]() -> std::vector<double> {
std::lock_guard<std::mutex> lock(this->_modelMutex);

if (!this->_model) {
throw std::runtime_error("Cactus model is not initialized");
}

std::vector<uint32_t> tokenBuffer(text.length() * 2 + 16);
size_t outTokenLen = 0;

int result = cactus_tokenize(this->_model, text.c_str(), tokenBuffer.data(),
tokenBuffer.size(), &outTokenLen);

if (result < 0) {
throw std::runtime_error("Cactus tokenize failed: " +
std::string(cactus_get_last_error()));
}

tokenBuffer.resize(outTokenLen);

return std::vector<double>(tokenBuffer.begin(), tokenBuffer.end());
});
}

std::shared_ptr<Promise<std::string>>
HybridCactus::scoreWindow(const std::vector<double> &tokens, double start,
double end, double context) {
return Promise<std::string>::async(
[this, tokens, start, end, context]() -> std::string {
std::lock_guard<std::mutex> lock(this->_modelMutex);

if (!this->_model) {
throw std::runtime_error("Cactus model is not initialized");
}

std::vector<uint32_t> tokenBuffer;
tokenBuffer.reserve(tokens.size());
for (double d : tokens) {
tokenBuffer.emplace_back(static_cast<uint32_t>(d));
}

std::string responseBuffer;
responseBuffer.resize(1024);

int result = cactus_score_window(
this->_model, tokenBuffer.data(), tokenBuffer.size(),
static_cast<size_t>(start), static_cast<size_t>(end),
static_cast<size_t>(context), responseBuffer.data(),
responseBuffer.size());

if (result < 0) {
throw std::runtime_error("Cactus score window failed: " +
std::string(cactus_get_last_error()));
}

// Remove null terminator
responseBuffer.resize(strlen(responseBuffer.c_str()));

return responseBuffer;
});
}

std::shared_ptr<Promise<std::string>> HybridCactus::transcribe(
const std::string &audioFilePath, const std::string &prompt,
double responseBufferSize, const std::optional<std::string> &optionsJson,
const std::variant<std::vector<double>, std::string> &audio,
const std::string &prompt, double responseBufferSize,
const std::optional<std::string> &optionsJson,
const std::optional<std::function<void(const std::string & /* token */,
double /* tokenId */)>> &callback) {
return Promise<std::string>::async([this, audioFilePath, prompt, optionsJson,
return Promise<std::string>::async([this, audio, prompt, optionsJson,
callback,
responseBufferSize]() -> std::string {
std::lock_guard<std::mutex> lock(this->_modelMutex);
Expand All @@ -105,14 +174,34 @@ std::shared_ptr<Promise<std::string>> HybridCactus::transcribe(
std::string responseBuffer;
responseBuffer.resize(responseBufferSize);

int result =
cactus_transcribe(this->_model, audioFilePath.c_str(), prompt.c_str(),
responseBuffer.data(), responseBufferSize,
optionsJson ? optionsJson->c_str() : nullptr,
cactusTokenCallback, &callbackCtx);
int result;
if (std::holds_alternative<std::string>(audio)) {
result = cactus_transcribe(
this->_model, std::get<std::string>(audio).c_str(), prompt.c_str(),
responseBuffer.data(), responseBufferSize,
optionsJson ? optionsJson->c_str() : nullptr, cactusTokenCallback,
&callbackCtx, nullptr, 0);
} else {
const auto &audioDoubles = std::get<std::vector<double>>(audio);

std::vector<uint8_t> audioBytes;
audioBytes.reserve(audioDoubles.size());

for (double d : audioDoubles) {
d = std::clamp(d, 0.0, 255.0);
audioBytes.emplace_back(static_cast<uint8_t>(d));
}

result = cactus_transcribe(this->_model, nullptr, prompt.c_str(),
responseBuffer.data(), responseBufferSize,
optionsJson ? optionsJson->c_str() : nullptr,
cactusTokenCallback, &callbackCtx,
audioBytes.data(), audioBytes.size());
}

if (result < 0) {
throw std::runtime_error("Cactus transcription failed");
throw std::runtime_error("Cactus transcribe failed: " +
std::string(cactus_get_last_error()));
}

// Remove null terminator
Expand All @@ -123,9 +212,10 @@ std::shared_ptr<Promise<std::string>> HybridCactus::transcribe(
}

std::shared_ptr<Promise<std::vector<double>>>
HybridCactus::embed(const std::string &text, double embeddingBufferSize) {
HybridCactus::embed(const std::string &text, double embeddingBufferSize,
bool normalize) {
return Promise<std::vector<double>>::async(
[this, text, embeddingBufferSize]() -> std::vector<double> {
[this, text, embeddingBufferSize, normalize]() -> std::vector<double> {
std::lock_guard<std::mutex> lock(this->_modelMutex);

if (!this->_model) {
Expand All @@ -135,12 +225,13 @@ HybridCactus::embed(const std::string &text, double embeddingBufferSize) {
std::vector<float> embeddingBuffer(embeddingBufferSize);
size_t embeddingDim;

int result =
cactus_embed(this->_model, text.c_str(), embeddingBuffer.data(),
embeddingBufferSize * sizeof(float), &embeddingDim);
int result = cactus_embed(
this->_model, text.c_str(), embeddingBuffer.data(),
embeddingBufferSize * sizeof(float), &embeddingDim, normalize);

if (result < 0) {
throw std::runtime_error("Cactus embedding failed");
throw std::runtime_error("Cactus embed failed: " +
std::string(cactus_get_last_error()));
}

embeddingBuffer.resize(embeddingDim);
Expand Down Expand Up @@ -169,7 +260,8 @@ HybridCactus::imageEmbed(const std::string &imagePath,
embeddingBufferSize * sizeof(float), &embeddingDim);

if (result < 0) {
throw std::runtime_error("Cactus image embedding failed");
throw std::runtime_error("Cactus image embed failed: " +
std::string(cactus_get_last_error()));
}

embeddingBuffer.resize(embeddingDim);
Expand Down Expand Up @@ -198,7 +290,8 @@ HybridCactus::audioEmbed(const std::string &audioPath,
embeddingBufferSize * sizeof(float), &embeddingDim);

if (result < 0) {
throw std::runtime_error("Cactus audio embedding failed");
throw std::runtime_error("Cactus audio embed failed: " +
std::string(cactus_get_last_error()));
}

embeddingBuffer.resize(embeddingDim);
Expand Down
15 changes: 12 additions & 3 deletions cpp/HybridCactus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,24 @@ class HybridCactus : public HybridCactusSpec {
double /* tokenId */)>> &callback)
override;

std::shared_ptr<Promise<std::vector<double>>>
tokenize(const std::string &text) override;

std::shared_ptr<Promise<std::string>>
scoreWindow(const std::vector<double> &tokens, double start, double end,
double context) override;

std::shared_ptr<Promise<std::string>> transcribe(
const std::string &audioFilePath, const std::string &prompt,
double responseBufferSize, const std::optional<std::string> &optionsJson,
const std::variant<std::vector<double>, std::string> &audio,
const std::string &prompt, double responseBufferSize,
const std::optional<std::string> &optionsJson,
const std::optional<std::function<void(const std::string & /* token */,
double /* tokenId */)>> &callback)
override;

std::shared_ptr<Promise<std::vector<double>>>
embed(const std::string &text, double embeddingBufferSize) override;
embed(const std::string &text, double embeddingBufferSize,
bool normalize) override;

std::shared_ptr<Promise<std::vector<double>>>
imageEmbed(const std::string &imagePath, double embeddingBufferSize) override;
Expand Down
Loading
Loading