From 64f30d14967d2937c0926821dc91f1c3dae7fab2 Mon Sep 17 00:00:00 2001 From: Andrea Donetti Date: Thu, 14 May 2026 12:56:23 -0600 Subject: [PATCH 1/2] fix(parser): skip whitespace-only chunks Prevent empty semantic chunks from reaching embedding providers, where they can produce invalid zero-dimensional results and pollute the vault or cache. The check lives in the parser callback path so all embedding providers and SQLite entry points share the same filtering behavior before provider-specific code runs. --- src/dbmem-parser.c | 8 ++++++-- test/unittest.c | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/dbmem-parser.c b/src/dbmem-parser.c index 03c3476..0258baa 100644 --- a/src/dbmem-parser.c +++ b/src/dbmem-parser.c @@ -1110,8 +1110,12 @@ int dbmem_parse (const char *md, size_t md_len, dbmem_parse_settings *settings) src_len = src_end - src_off; } - // Invoke callback - if (settings->callback) { + // Invoke callback (skip whitespace-only chunks) + bool has_text = false; + for (size_t k = 0; k < chunk_len; k++) { + if (!isspace((unsigned char)chunk_text[k])) { has_text = true; break; } + } + if (has_text && settings->callback) { rc = settings->callback(chunk_text, chunk_len, src_off, src_len, settings->xdata, i); if (rc != 0) break; } diff --git a/test/unittest.c b/test/unittest.c index e05d600..00fe619 100644 --- a/test/unittest.c +++ b/test/unittest.c @@ -2564,6 +2564,8 @@ typedef struct { char api_key[256]; } dummy_engine_t; +static int dummy_compute_calls = 0; + static void *dummy_init(const char *model, const char *api_key, void *xdata, char err_msg[1024]) { UNUSED_PARAM(model); UNUSED_PARAM(xdata); @@ -2584,6 +2586,7 @@ static int dummy_compute(void *engine, const char *text, int text_len, void *xda UNUSED_PARAM(xdata); dummy_engine_t *e = (dummy_engine_t *)engine; e->compute_count++; + dummy_compute_calls++; result->n_tokens = text_len / 4; result->truncated = false; result->n_embd = e->dimension; @@ -2779,6 +2782,35 @@ TEST(sqlite_custom_provider_add_text) { sqlite3_close(db); } +TEST(sqlite_custom_provider_skips_whitespace_only_text) { + sqlite3 *db = open_test_db(); + ASSERT(db != NULL); + + dbmem_provider_t prov = { .init = dummy_init, .compute = dummy_compute, .free = dummy_free }; + int rc = sqlite3_memory_register_provider(db, "dummy", &prov); + ASSERT_EQ(rc, SQLITE_OK); + + sqlite3_int64 result = 0; + rc = exec_get_int(db, "SELECT memory_set_model('dummy', 'test-model');", &result); + ASSERT_EQ(rc, SQLITE_OK); + + dummy_compute_calls = 0; + rc = exec_get_int(db, "SELECT memory_add_text(' \n\n \n');", &result); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(result, 1); + ASSERT_EQ(dummy_compute_calls, 0); + + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault;", &result); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(result, 0); + + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_cache;", &result); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(result, 0); + + sqlite3_close(db); +} + TEST(sqlite_custom_provider_persists_truncated_metadata) { sqlite3 *db = open_test_db(); ASSERT(db != NULL); @@ -3190,6 +3222,7 @@ int main(int argc, char *argv[]) { RUN_TEST(sqlite_custom_provider_register); RUN_TEST(sqlite_custom_provider_set_model); RUN_TEST(sqlite_custom_provider_add_text); + RUN_TEST(sqlite_custom_provider_skips_whitespace_only_text); RUN_TEST(sqlite_custom_provider_persists_truncated_metadata); RUN_TEST(sqlite_mdx_preprocessing_applies_only_to_mdx_files); RUN_TEST(sqlite_custom_provider_null_callbacks); From 850f439e4053606a107fbea241a0dd56eb2ef24b Mon Sep 17 00:00:00 2001 From: Andrea Donetti Date: Thu, 14 May 2026 13:57:23 -0600 Subject: [PATCH 2/2] fix(local): harden embedding context rebuilds Configure local llama contexts from max_tokens plus overlay_tokens so encoder inputs fit the chunk sizes produced by the parser. The local engine now sizes n_ctx, n_batch, n_ubatch, and token buffers together, caps the context to bounded/model-supported values, prepares reusable batches with sequence metadata, and truncates over-capacity tokenization explicitly instead of relying on llama.cpp assertions. Move llama diagnostics out of per-engine logger user_data because llama_log_set installs a process-global callback. Thread-local diagnostic capture keeps load/context errors useful without writing through stale engine pointers after another connection replaces or frees an engine. Rebuild the local engine when max_tokens or overlay_tokens changes because those options alter chunk sizes after memory_set_model. The option update now runs under a savepoint, rolls back in-memory and persisted settings on rebuild failure, and keeps the previous engine alive unless the replacement is fully ready. Invalidate cached embeddings for the active local provider/model after a successful context rebuild. The cache key does not include local context sizing, so clearing that provider/model avoids reusing stale embeddings, token counts, or truncation metadata generated under the previous context window. Add a logger regression test for stale global logger user_data and bump the extension version to 1.2.1. --- src/dbmem-embed.h | 2 +- src/dbmem-lembed.c | 211 ++++++++++++++++++++++++++++++++------------ src/sqlite-memory.c | 110 ++++++++++++++++++++--- src/sqlite-memory.h | 2 +- test/unittest.c | 13 +++ 5 files changed, 265 insertions(+), 73 deletions(-) diff --git a/src/dbmem-embed.h b/src/dbmem-embed.h index 15c1c1a..540a0ed 100644 --- a/src/dbmem-embed.h +++ b/src/dbmem-embed.h @@ -22,7 +22,7 @@ typedef struct { float *embedding; // Pointer to embedding (points to engine's buffer, do not free) } embedding_result_t; -dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, char err_msg[DBMEM_ERRBUF_SIZE]); +dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, int max_context_tokens, char err_msg[DBMEM_ERRBUF_SIZE]); int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *text, int text_len, embedding_result_t *result); bool dbmem_local_engine_warmup (dbmem_local_engine_t *engine); void dbmem_local_engine_free (dbmem_local_engine_t *engine); diff --git a/src/dbmem-lembed.c b/src/dbmem-lembed.c index e3c842f..2d6fbd0 100644 --- a/src/dbmem-lembed.c +++ b/src/dbmem-lembed.c @@ -13,6 +13,21 @@ #include #include +#define DBMEM_LOCAL_MIN_CONTEXT_TOKENS 128 +#define DBMEM_LOCAL_MAX_CONTEXT_TOKENS 8192 + +#if defined(_MSC_VER) +#define DBMEM_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L +#define DBMEM_THREAD_LOCAL _Thread_local +#else +#define DBMEM_THREAD_LOCAL __thread +#endif + +static DBMEM_THREAD_LOCAL bool dbmem_llama_diag_enabled = false; +static DBMEM_THREAD_LOCAL char dbmem_llama_diag[DBMEM_ERRBUF_SIZE]; +static DBMEM_THREAD_LOCAL size_t dbmem_llama_diag_len = 0; + struct dbmem_local_engine_t { dbmem_context *context; @@ -26,6 +41,7 @@ struct dbmem_local_engine_t { // Model info int n_embd; // Embedding dimension (e.g., 768 for nomic-embed) int n_ctx; // Maximum context length in tokens + int n_ubatch; // Maximum physical batch size for encoder input bool is_encoder_only; // True for BERT-style models, false for GPT-style // Settings @@ -33,7 +49,9 @@ struct dbmem_local_engine_t { // Reusable buffers (avoid repeated allocations) llama_token *tokens; // Pre-allocated buffer for tokenized input - int tokens_capacity; // Size of tokens buffer (equals n_ctx) + int tokens_capacity; // Size of tokens buffer, capped by n_ubatch + struct llama_batch batch; // Pre-allocated llama.cpp batch with sequence metadata + bool batch_initialized; // True when batch must be freed float *embedding; // Pre-allocated buffer for output embedding (n_embd floats) // Statistics @@ -75,75 +93,130 @@ static void dbmem_embedding_normalize (float *vec, int n) { } void dbmem_logger (enum ggml_log_level level, const char *text, void *user_data) { - dbmem_local_engine_t *engine = (dbmem_local_engine_t *)user_data; - //if (ai->db == NULL) return; - //if ((level == GGML_LOG_LEVEL_INFO) && (ai->options.log_info == false)) return; - - const char *type = NULL; - switch (level) { - case GGML_LOG_LEVEL_NONE: type = "NONE"; break; - case GGML_LOG_LEVEL_DEBUG: type = "DEBUG"; break; - case GGML_LOG_LEVEL_INFO: type = "INFO"; break; - case GGML_LOG_LEVEL_WARN: type = "WARNING"; break; - case GGML_LOG_LEVEL_ERROR: type = "ERROR"; break; - case GGML_LOG_LEVEL_CONT: type = NULL; break; + UNUSED_PARAM(user_data); + if (!dbmem_llama_diag_enabled || !text) return; + + if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_CONT) { + size_t remaining = sizeof(dbmem_llama_diag) - dbmem_llama_diag_len; + if (remaining > 1) { + int written = snprintf(dbmem_llama_diag + dbmem_llama_diag_len, remaining, "%s", text); + if (written > 0) { + size_t used = (size_t)written; + if (used >= remaining) { + dbmem_llama_diag_len = sizeof(dbmem_llama_diag) - 1; + } else { + dbmem_llama_diag_len += used; + } + } + } } - - // DEBUG - // printf("%s %s\n", type, text); - - //const char *values[] = {type, text}; - //int types[] = {(type == NULL) ? SQLITE_NULL : SQLITE_TEXT, SQLITE_TEXT}; - //int lens[] = {-1, -1}; - //sqlite_db_write(NULL, ai->db, LOG_TABLE_INSERT_STMT, values, types, lens, 2); } // MARK: - +static void dbmem_llama_diag_begin(void) { + dbmem_llama_diag[0] = 0; + dbmem_llama_diag_len = 0; + dbmem_llama_diag_enabled = true; +} + +static void dbmem_llama_diag_end(void) { + dbmem_llama_diag_enabled = false; +} + +static const char *dbmem_llama_diag_message(void) { + return dbmem_llama_diag[0] ? dbmem_llama_diag : NULL; +} + static void dbmem_local_set_error(dbmem_local_engine_t *engine, const char *message) { if (!engine || !engine->context) return; dbmem_context_set_error(engine->context, message); } -dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, char err_msg[DBMEM_ERRBUF_SIZE]) { +static bool dbmem_local_batch_prepare(dbmem_local_engine_t *engine, int n_tokens) { + if (!engine || !engine->batch_initialized) return false; + if (n_tokens <= 0 || n_tokens > engine->tokens_capacity) return false; + + engine->batch.n_tokens = 0; + for (int i = 0; i < n_tokens; i++) { + engine->batch.token[i] = engine->tokens[i]; + engine->batch.pos[i] = i; + engine->batch.n_seq_id[i] = 1; + engine->batch.seq_id[i][0] = 0; + engine->batch.logits[i] = 1; + engine->batch.n_tokens++; + } + + return true; +} + +static int dbmem_local_process_batch(dbmem_local_engine_t *engine) { + if (engine->is_encoder_only) { + return llama_encode(engine->ctx, engine->batch); + } + return llama_decode(engine->ctx, engine->batch); +} + +dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, int max_context_tokens, char err_msg[DBMEM_ERRBUF_SIZE]) { dbmem_local_engine_t *engine = (dbmem_local_engine_t *)dbmemory_zeroalloc(sizeof(dbmem_local_engine_t)); if (!engine) return NULL; engine->context = (dbmem_context *)ctx; // set logger - llama_log_set(dbmem_logger, engine); + llama_log_set(dbmem_logger, NULL); + dbmem_llama_diag_begin(); // Initialize backend llama_backend_init(); // Load model struct llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = 0; + model_params.split_mode = LLAMA_SPLIT_MODE_NONE; + model_params.main_gpu = -1; engine->model = llama_model_load_from_file(model_path, model_params); if (!engine->model) { - snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to load model: %s", model_path); + const char *diag = dbmem_llama_diag_message(); + if (diag) { + snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to load model: %s: %s", model_path, diag); + } else { + snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to load model: %s", model_path); + } goto cleanup; } // Get model's native context length int n_ctx_train = llama_model_n_ctx_train(engine->model); + int n_ctx = max_context_tokens * 4; + if (n_ctx < DBMEM_LOCAL_MIN_CONTEXT_TOKENS) n_ctx = DBMEM_LOCAL_MIN_CONTEXT_TOKENS; + if (n_ctx > DBMEM_LOCAL_MAX_CONTEXT_TOKENS) n_ctx = DBMEM_LOCAL_MAX_CONTEXT_TOKENS; + if (n_ctx_train > 0 && n_ctx > n_ctx_train) n_ctx = n_ctx_train; // Create context struct llama_context_params ctx_params = llama_context_default_params(); ctx_params.embeddings = true; - ctx_params.n_ctx = n_ctx_train; - ctx_params.n_batch = n_ctx_train; - ctx_params.n_ubatch = n_ctx_train; + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + ctx_params.n_ubatch = n_ctx; + ctx_params.offload_kqv = false; + ctx_params.op_offload = false; engine->ctx = llama_init_from_model(engine->model, ctx_params); if (!engine->ctx) { - snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to create context"); + const char *diag = dbmem_llama_diag_message(); + if (diag) { + snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to create context: %s", diag); + } else { + snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to create context"); + } goto cleanup; } // Get model info engine->vocab = llama_model_get_vocab(engine->model); - engine->n_embd = llama_model_n_embd(engine->model); + engine->n_embd = llama_model_n_embd_out(engine->model); engine->n_ctx = llama_n_ctx(engine->ctx); + engine->n_ubatch = llama_n_ubatch(engine->ctx); engine->pooling = llama_pooling_type(engine->ctx); engine->mem = llama_get_memory(engine->ctx); @@ -159,12 +232,22 @@ dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path // Allocate token buffer engine->tokens_capacity = engine->n_ctx; + if (engine->n_ubatch > 0 && engine->tokens_capacity > engine->n_ubatch) { + engine->tokens_capacity = engine->n_ubatch; + } engine->tokens = (llama_token *)dbmemory_alloc(sizeof(llama_token) * engine->tokens_capacity); if (!engine->tokens) { snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate token buffer"); goto cleanup; } + engine->batch = llama_batch_init(engine->tokens_capacity, 0, 1); + engine->batch_initialized = true; + if (!engine->batch.token || !engine->batch.pos || !engine->batch.n_seq_id || !engine->batch.seq_id || !engine->batch.logits) { + snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate llama batch"); + goto cleanup; + } + // Allocate single embedding buffer engine->embedding = (float *)dbmemory_alloc(sizeof(float) * engine->n_embd); if (!engine->embedding) { @@ -177,9 +260,11 @@ dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path engine->total_tokens_processed = 0; engine->total_embeddings_generated = 0; + dbmem_llama_diag_end(); return engine; cleanup: + dbmem_llama_diag_end(); dbmem_local_engine_free(engine); return NULL; } @@ -190,17 +275,8 @@ bool dbmem_local_engine_warmup (dbmem_local_engine_t *engine) { const char *warmup_text = "Warmup"; int warmup_tokens = llama_tokenize(engine->vocab, warmup_text, (int32_t)strlen(warmup_text), engine->tokens, engine->tokens_capacity, true, true); - if (warmup_tokens > 0) { - struct llama_batch batch = { - .n_tokens = warmup_tokens, - .token = engine->tokens, - .embd = NULL, - .pos = NULL, - .n_seq_id = NULL, - .seq_id = NULL, - .logits = NULL, - }; - llama_encode(engine->ctx, batch); + if (warmup_tokens > 0 && dbmem_local_batch_prepare(engine, warmup_tokens)) { + dbmem_local_process_batch(engine); if (engine->mem != NULL) { llama_memory_clear(engine->mem, true); @@ -215,30 +291,46 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex if (text_len == -1) text_len = (int)strlen(text); if (text_len == 0) return 0; + bool truncated = false; + // Tokenize int n_tokens = llama_tokenize(engine->vocab, text, text_len, engine->tokens, engine->tokens_capacity, true, true); if (n_tokens < 0) { - dbmem_local_set_error(engine, "Tokenization failed (text too long?)"); - return -1; + int needed = -n_tokens; + if (needed <= 0) { + dbmem_local_set_error(engine, "Tokenization failed"); + return -1; + } + + llama_token *all_tokens = (llama_token *)dbmemory_alloc(sizeof(llama_token) * needed); + if (!all_tokens) { + dbmem_local_set_error(engine, "Failed to allocate token overflow buffer"); + return -1; + } + + int full_tokens = llama_tokenize(engine->vocab, text, text_len, all_tokens, needed, true, true); + if (full_tokens < 0) { + dbmemory_free(all_tokens); + dbmem_local_set_error(engine, "Tokenization failed"); + return -1; + } + + n_tokens = engine->tokens_capacity; + memcpy(engine->tokens, all_tokens, sizeof(llama_token) * n_tokens); + dbmemory_free(all_tokens); + truncated = true; } // Handle token overflow: truncate to max context size - bool truncated = false; - if (n_tokens > engine->n_ctx) { + if (n_tokens > engine->tokens_capacity) { truncated = true; - n_tokens = engine->n_ctx; + n_tokens = engine->tokens_capacity; } - // Create batch - struct llama_batch batch = { - .n_tokens = n_tokens, - .token = engine->tokens, - .embd = NULL, - .pos = NULL, - .n_seq_id = NULL, - .seq_id = NULL, - .logits = NULL, - }; + if (!dbmem_local_batch_prepare(engine, n_tokens)) { + dbmem_local_set_error(engine, "Failed to prepare llama batch"); + return -1; + } // Clear memory if (engine->mem != NULL) { @@ -246,9 +338,9 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex } // Encode - int ret = llama_encode(engine->ctx, batch); + int ret = dbmem_local_process_batch(engine); if (ret != 0) { - dbmem_local_set_error(engine, "Llama_encode failed"); + dbmem_local_set_error(engine, "llama batch processing failed"); return -1; } @@ -297,6 +389,11 @@ void dbmem_local_engine_free (dbmem_local_engine_t *engine) { dbmemory_free(engine->tokens); engine->tokens = NULL; } + if (engine->batch_initialized) { + llama_batch_free(engine->batch); + memset(&engine->batch, 0, sizeof(engine->batch)); + engine->batch_initialized = false; + } if (engine->ctx) { llama_free(engine->ctx); engine->ctx = NULL; diff --git a/src/sqlite-memory.c b/src/sqlite-memory.c index 5663b97..4356404 100644 --- a/src/sqlite-memory.c +++ b/src/sqlite-memory.c @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -999,6 +998,28 @@ static void dbmem_clear (sqlite3_context *context, int argc, sqlite3_value **arg // MARK: - Cache Clear - +static int dbmem_cache_clear_provider_model(sqlite3 *db, const char *provider, const char *model) { + static const char *sql = "DELETE FROM dbmem_cache WHERE provider=?1 AND model=?2;"; + if (!provider || !model) return SQLITE_OK; + + sqlite3_stmt *vm = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL); + if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_bind_text(vm, 1, provider, -1, SQLITE_STATIC); + if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_bind_text(vm, 2, model, -1, SQLITE_STATIC); + if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_step(vm); + if (rc == SQLITE_DONE) rc = SQLITE_OK; + +cleanup: + if (vm) sqlite3_finalize(vm); + return rc; +} + static void dbmem_cache_clear (sqlite3_context *context, int argc, sqlite3_value **argv) { sqlite3 *db = sqlite3_context_db_handle(context); int rc; @@ -1013,15 +1034,7 @@ static void dbmem_cache_clear (sqlite3_context *context, int argc, sqlite3_value const char *provider = (const char *)sqlite3_value_text(argv[0]); const char *model = (const char *)sqlite3_value_text(argv[1]); - sqlite3_stmt *vm = NULL; - rc = sqlite3_prepare_v2(db, "DELETE FROM dbmem_cache WHERE provider=?1 AND model=?2;", -1, &vm, NULL); - if (rc == SQLITE_OK) { - sqlite3_bind_text(vm, 1, provider, -1, SQLITE_STATIC); - sqlite3_bind_text(vm, 2, model, -1, SQLITE_STATIC); - rc = sqlite3_step(vm); - if (rc == SQLITE_DONE) rc = SQLITE_OK; - } - if (vm) sqlite3_finalize(vm); + rc = dbmem_cache_clear_provider_model(db, provider, model); } else { sqlite3_result_error(context, "The function memory_cache_clear expects 0 or 2 arguments", SQLITE_ERROR); return; @@ -1141,7 +1154,8 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value * return; } - new_l_engine = dbmem_local_engine_init(ctx, model, ctx->error_msg); + int max_context_tokens = (int)(ctx->max_tokens + ctx->overlay_tokens); + new_l_engine = dbmem_local_engine_init(ctx, model, max_context_tokens, ctx->error_msg); if (new_l_engine == NULL) { dbmemory_free(new_provider); dbmemory_free(new_model); @@ -1299,26 +1313,94 @@ static void dbmem_set_apikey (sqlite3_context *context, int argc, sqlite3_value // MARK: - +static bool dbmem_is_local_context_option(const char *key) { + return (strcasecmp(key, DBMEM_SETTINGS_KEY_MAX_TOKENS) == 0 || + strcasecmp(key, DBMEM_SETTINGS_KEY_OVERLAY_TOKENS) == 0); +} + +static int dbmem_rebuild_local_engine_for_context_options(dbmem_context *ctx) { + #ifndef DBMEM_OMIT_LOCAL_ENGINE + if (!ctx || !ctx->is_local || ctx->is_custom || !ctx->model || !ctx->l_engine) { + return SQLITE_OK; + } + + int max_context_tokens = (int)(ctx->max_tokens + ctx->overlay_tokens); + dbmem_local_engine_t *new_l_engine = dbmem_local_engine_init(ctx, ctx->model, max_context_tokens, ctx->error_msg); + if (new_l_engine == NULL) { + return SQLITE_ERROR; + } + + if (ctx->engine_warmup) { + dbmem_local_engine_warmup(new_l_engine); + } + + int rc = dbmem_cache_clear_provider_model(ctx->db, ctx->provider, ctx->model); + if (rc != SQLITE_OK) { + dbmem_local_engine_free(new_l_engine); + return rc; + } + + dbmem_local_engine_free(ctx->l_engine); + ctx->l_engine = new_l_engine; + #else + UNUSED_PARAM(ctx); + #endif + + return SQLITE_OK; +} + static void dbmem_set_option (sqlite3_context *context, int argc, sqlite3_value **argv) { + UNUSED_PARAM(argc); + // sanity check type if (sqlite3_value_type(argv[0]) != SQLITE_TEXT) { sqlite3_result_error(context, "The function memory_set_option expects the key argument to be of type TEXT", SQLITE_ERROR); return; } - // update settings sqlite3 *db = sqlite3_context_db_handle(context); const char *key = (const char *)sqlite3_value_text(argv[0]); - int rc = dbmem_settings_write_value(db, key, argv[1]); // retrieve context dbmem_context *ctx = (dbmem_context *)sqlite3_user_data(context); + ctx->error_msg[0] = 0; + + bool context_option = dbmem_is_local_context_option(key); + size_t old_max_tokens = ctx->max_tokens; + size_t old_overlay_tokens = ctx->overlay_tokens; + + int rc = sqlite3_exec(db, "SAVEPOINT dbmem_set_option;", NULL, NULL, NULL); + bool savepoint_started = (rc == SQLITE_OK); + + if (rc == SQLITE_OK) { + rc = dbmem_settings_write_value(db, key, argv[1]); + } if (rc == SQLITE_OK) { dbmem_settings_sync(ctx, key, argv[1]); } + + if (rc == SQLITE_OK && context_option && + (old_max_tokens != ctx->max_tokens || old_overlay_tokens != ctx->overlay_tokens)) { + rc = dbmem_rebuild_local_engine_for_context_options(ctx); + } + + if (rc == SQLITE_OK && savepoint_started) { + rc = sqlite3_exec(db, "RELEASE dbmem_set_option;", NULL, NULL, NULL); + savepoint_started = false; + } + + if (rc != SQLITE_OK) { + if (savepoint_started) { + sqlite3_exec(db, "ROLLBACK TO dbmem_set_option; RELEASE dbmem_set_option;", NULL, NULL, NULL); + } + if (context_option) { + ctx->max_tokens = old_max_tokens; + ctx->overlay_tokens = old_overlay_tokens; + } + } - (rc == SQLITE_OK) ? sqlite3_result_int(context, 1) : sqlite3_result_error(context, sqlite3_errmsg(db), -1); + (rc == SQLITE_OK) ? sqlite3_result_int(context, 1) : sqlite3_result_error(context, ctx->error_msg[0] ? ctx->error_msg : sqlite3_errmsg(db), -1); } static void dbmem_get_option (sqlite3_context *context, int argc, sqlite3_value **argv) { diff --git a/src/sqlite-memory.h b/src/sqlite-memory.h index 77e17e9..0ee29bc 100644 --- a/src/sqlite-memory.h +++ b/src/sqlite-memory.h @@ -26,7 +26,7 @@ extern "C" { #endif -#define SQLITE_DBMEMORY_VERSION "1.2.0" +#define SQLITE_DBMEMORY_VERSION "1.2.1" // public API SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi); diff --git a/test/unittest.c b/test/unittest.c index 00fe619..6f17745 100644 --- a/test/unittest.c +++ b/test/unittest.c @@ -31,6 +31,10 @@ #ifdef TEST_SQLITE_EXTENSION #include "sqlite-memory.h" +#ifndef DBMEM_OMIT_LOCAL_ENGINE +#include "ggml.h" +void dbmem_logger(enum ggml_log_level level, const char *text, void *user_data); +#endif #endif // ============================================================================ @@ -3082,6 +3086,12 @@ TEST(sqlite_set_model_failed_remote_switch_keeps_custom_engine) { } #endif +#ifndef DBMEM_OMIT_LOCAL_ENGINE +TEST(sqlite_local_logger_ignores_stale_user_data) { + dbmem_logger(GGML_LOG_LEVEL_WARN, "ignored warning", (void *)1); +} +#endif + #endif // TEST_SQLITE_EXTENSION // ============================================================================ @@ -3234,6 +3244,9 @@ int main(int argc, char *argv[]) { #else RUN_TEST(sqlite_set_model_failed_remote_switch_keeps_custom_engine); #endif +#ifndef DBMEM_OMIT_LOCAL_ENGINE + RUN_TEST(sqlite_local_logger_ignores_stale_user_data); +#endif #endif printf("\n=== Results ===\n");