-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat: ai-cache plugin #13308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: ai-cache plugin #13308
Changes from all commits
69fed98
aea4028
f323e04
eafce29
09e4692
a899f6a
8d64a39
6f72c37
4cf5c1c
62850aa
f16c7e3
0873560
9be2ecc
d691ea2
61ea49c
bf31bc8
803f741
0c26870
c345fe0
e665d2a
6b996cd
57687ca
1e42c86
3349acc
46ad8a6
84c6b56
8f9fbfd
d2eca3d
f692322
439b9f0
61530f8
a1de751
7df710e
bf250ac
4796a1d
69e8c8e
a8843cc
a1f0128
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,296 @@ | ||
| -- | ||
| -- Licensed to the Apache Software Foundation (ASF) under one or more | ||
| -- contributor license agreements. See the NOTICE file distributed with | ||
| -- this work for additional information regarding copyright ownership. | ||
| -- The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| -- (the "License"); you may not use this file except in compliance with | ||
| -- the License. You may obtain a copy of the License at | ||
| -- | ||
| -- http://www.apache.org/licenses/LICENSE-2.0 | ||
| -- | ||
| -- Unless required by applicable law or agreed to in writing, software | ||
| -- distributed under the License is distributed on an "AS IS" BASIS, | ||
| -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| -- See the License for the specific language governing permissions and | ||
| -- limitations under the License. | ||
| -- | ||
|
|
||
| local core = require("apisix.core") | ||
| local schema = require("apisix.plugins.ai-cache.schema") | ||
| local exact = require("apisix.plugins.ai-cache.exact") | ||
| local semantic = require("apisix.plugins.ai-cache.semantic") | ||
| local protocols = require("apisix.plugins.ai-protocols") | ||
| local http = require("resty.http") | ||
| local ngx = ngx | ||
| local ngx_time = ngx.time | ||
| local ngx_now = ngx.now | ||
| local ipairs = ipairs | ||
| local require = require | ||
| local tostring = tostring | ||
| local table_concat = table.concat | ||
|
|
||
| local plugin_name = "ai-cache" | ||
|
|
||
| local _M = { | ||
| version = 0.1, | ||
| priority = 1065, | ||
| name = plugin_name, | ||
| schema = schema.schema | ||
| } | ||
|
|
||
|
|
||
| local function layer_enabled(conf, name) | ||
| local layers = conf.layers or { "exact", "semantic" } | ||
| for _, l in ipairs(layers) do | ||
| if l == name then return true end | ||
| end | ||
| return false | ||
| end | ||
|
|
||
|
|
||
| local function populate_ai_ctx_on_hit(ctx, protocol_name, body_tab, is_stream, cached_text) | ||
| ctx.ai_client_protocol = protocol_name | ||
| ctx.var.request_type = is_stream and "ai_stream" or "ai_chat" | ||
| if body_tab.model then | ||
| ctx.var.request_llm_model = body_tab.model | ||
| ctx.var.llm_model = body_tab.model | ||
| end | ||
| ctx.var.llm_response_text = cached_text | ||
| end | ||
|
|
||
|
|
||
| function _M.check_schema(conf) | ||
| local ok, err = core.schema.check(schema.schema, conf) | ||
| if not ok then | ||
| return false, err | ||
| end | ||
|
|
||
| if layer_enabled(conf, "semantic") then | ||
| if not (conf.semantic and conf.semantic.embedding) then | ||
| return false, "semantic layer requires semantic.embedding to be configured" | ||
| end | ||
| end | ||
|
|
||
| core.utils.check_https({ "semantic.embedding.endpoint" }, conf, plugin_name) | ||
|
|
||
| return true | ||
| end | ||
|
|
||
|
|
||
| function _M.access(conf, ctx) | ||
| -- Check bypass_on conditions | ||
| if conf.bypass_on then | ||
| local req_headers = ngx.req.get_headers() | ||
| for _, rule in ipairs(conf.bypass_on) do | ||
| if req_headers[rule.header] == rule.equals then | ||
| ctx.ai_cache_status = "BYPASS" | ||
| return | ||
| end | ||
| end | ||
| end | ||
|
|
||
| local body_tab, err = core.request.get_json_request_body_table() | ||
| if not body_tab then | ||
| core.log.warn("ai-cache: failed to read request body: ", err or "unknown error") | ||
| ctx.ai_cache_status = "MISS" | ||
| return | ||
| end | ||
|
|
||
| local protocol_name = protocols.detect(body_tab, ctx) | ||
| if not protocol_name then | ||
| core.log.warn("ai-cache: could not detect AI protocol, skipping cache") | ||
| ctx.ai_cache_status = "MISS" | ||
| return | ||
| end | ||
|
|
||
| local proto = protocols.get(protocol_name) | ||
| local contents = proto.extract_request_content(body_tab) | ||
| if not contents or #contents == 0 then | ||
| ctx.ai_cache_status = "MISS" | ||
| return | ||
| end | ||
|
|
||
| local prompt_text = table_concat(contents, " ") | ||
| local scope_hash = exact.compute_scope_hash(conf, ctx) | ||
| local prompt_hash = exact.compute_prompt_hash(prompt_text) | ||
|
|
||
| local is_stream = body_tab.stream == true | ||
|
|
||
| -- L1 exact lookup | ||
| if layer_enabled(conf, "exact") then | ||
| local cached_text, written_at, lookup_err = exact.get(conf, scope_hash, prompt_hash) | ||
| if lookup_err then | ||
| core.log.warn("ai-cache: L1 lookup error: ", lookup_err) | ||
| elseif cached_text then | ||
| core.log.info("ai-cache: L1 hit for key: ", prompt_hash) | ||
| ctx.ai_cache_status = "HIT-L1" | ||
| ctx.ai_cache_written_at = written_at | ||
| if is_stream then | ||
| core.response.set_header("Content-Type", "text/event-stream") | ||
| else | ||
| core.response.set_header("Content-Type", "application/json") | ||
| end | ||
| populate_ai_ctx_on_hit(ctx, protocol_name, body_tab, is_stream, cached_text) | ||
| -- TODO: rename build_deny_response to build_response_from_text in a | ||
| -- follow-up. We use it here to wrap cached text in the protocol's | ||
| -- response shape, not for policy denial. | ||
| return 200, proto.build_deny_response({ | ||
| stream = is_stream, | ||
| text = cached_text, | ||
| }) | ||
| end | ||
|
Comment on lines
+128
to
+141
|
||
| end | ||
|
|
||
| -- L2 semantic lookup | ||
| if layer_enabled(conf, "semantic") then | ||
| local emb_conf = conf.semantic.embedding | ||
| local emb_driver = require("apisix.plugins.ai-cache.embeddings." .. emb_conf.provider) | ||
| local httpc = http.new() | ||
|
|
||
| local t0 = ngx_now() | ||
| local embedding, _, emb_err = emb_driver.get_embeddings( | ||
| emb_conf, prompt_text, httpc, emb_conf.ssl_verify | ||
| ) | ||
| if not embedding then | ||
| core.log.warn("ai-cache: embedding fetch failed (degrading to MISS): ", emb_err) | ||
| ctx.ai_cache_embedding_failed = true | ||
| else | ||
| ctx.ai_cache_embedding_latency_ms = (ngx_now() - t0) * 1000 | ||
| ctx.ai_cache_embedding_provider = emb_conf.provider | ||
| ctx.ai_cache_embedding = embedding | ||
|
|
||
| local threshold = conf.semantic.similarity_threshold or 0.95 | ||
| local cached_text, similarity, search_err = semantic.search( | ||
| conf, scope_hash, embedding, threshold | ||
| ) | ||
|
|
||
| if search_err then | ||
| core.log.warn("ai-cache: L2 search error (degrading to MISS): ", search_err) | ||
| elseif cached_text then | ||
| core.log.info("ai-cache: L2 hit, similarity=", similarity) | ||
|
|
||
| if layer_enabled(conf, "exact") then | ||
| local l1_ttl = (conf.exact and conf.exact.ttl) or 3600 | ||
| local l1_err = exact.set( | ||
| conf, scope_hash, prompt_hash, cached_text, l1_ttl | ||
| ) | ||
| if l1_err then | ||
| core.log.warn("ai-cache: L2->L1 backfill failed: ", l1_err) | ||
| end | ||
| end | ||
|
|
||
| ctx.ai_cache_status = "HIT-L2" | ||
| ctx.ai_cache_similarity = similarity | ||
| if is_stream then | ||
| core.response.set_header("Content-Type", "text/event-stream") | ||
| else | ||
| core.response.set_header("Content-Type", "application/json") | ||
| end | ||
| populate_ai_ctx_on_hit(ctx, protocol_name, body_tab, is_stream, cached_text) | ||
| return 200, proto.build_deny_response({ | ||
| stream = is_stream, | ||
| text = cached_text, | ||
| }) | ||
|
janiussyafiq marked this conversation as resolved.
|
||
| end | ||
| end | ||
| end | ||
|
|
||
| ctx.ai_cache_status = "MISS" | ||
| ctx.ai_cache_scope_hash = scope_hash | ||
| ctx.ai_cache_prompt_hash = prompt_hash | ||
| ctx.ai_cache_prompt_text = prompt_text | ||
| end | ||
|
|
||
|
|
||
| function _M.header_filter(conf, ctx) | ||
| if not ctx.ai_cache_status then | ||
| return | ||
| end | ||
|
|
||
| local status_header = (conf.headers and conf.headers.cache_status) | ||
| or "X-AI-Cache-Status" | ||
| ngx.header[status_header] = ctx.ai_cache_status | ||
|
|
||
| if ctx.ai_cache_status == "HIT-L1" and ctx.ai_cache_written_at then | ||
| local age_header = (conf.headers and conf.headers.cache_age) | ||
| or "X-AI-Cache-Age" | ||
| ngx.header[age_header] = tostring(ngx_time() - ctx.ai_cache_written_at) | ||
| end | ||
|
|
||
| if ctx.ai_cache_status == "HIT-L2" and ctx.ai_cache_similarity then | ||
| local sim_header = (conf.headers and conf.headers.cache_similarity) | ||
| or "X-AI-Cache-Similarity" | ||
| ngx.header[sim_header] = tostring(ctx.ai_cache_similarity) | ||
| end | ||
| end | ||
|
|
||
|
|
||
| function _M.log(conf, ctx) | ||
| if ctx.ai_cache_status ~= "MISS" then | ||
| return | ||
| end | ||
|
|
||
| -- Early-MISS paths (body parse / protocol detect / empty content) skip | ||
| -- key computation, so bail out if cache key fields are absent. | ||
| if not ctx.ai_cache_prompt_hash or not ctx.ai_cache_prompt_text then | ||
| return | ||
| end | ||
|
|
||
| local upstream_status = core.response.get_upstream_status(ctx) or ngx.status | ||
| if not upstream_status or upstream_status < 200 or upstream_status >= 300 then | ||
| return | ||
| end | ||
|
|
||
| local response_text = ctx.var.llm_response_text | ||
| if not response_text or response_text == "" then | ||
| return | ||
| end | ||
|
|
||
| local max_size = conf.max_cache_body_size or 1048576 | ||
| if #response_text > max_size then | ||
| core.log.warn("ai-cache: response size ", #response_text, | ||
| " exceeds max_cache_body_size ", max_size, | ||
| ", skipping cache write") | ||
| return | ||
| end | ||
|
|
||
| local exact_enabled = layer_enabled(conf, "exact") | ||
| local semantic_enabled = layer_enabled(conf, "semantic") | ||
| local ttl_exact = (conf.exact and conf.exact.ttl) or 3600 | ||
| local scope_hash = ctx.ai_cache_scope_hash | ||
| local prompt_hash = ctx.ai_cache_prompt_hash | ||
| local embedding = ctx.ai_cache_embedding | ||
|
|
||
|
janiussyafiq marked this conversation as resolved.
|
||
| local ok, timer_err = ngx.timer.at(0, function(premature) | ||
| if premature then | ||
| return | ||
| end | ||
|
|
||
| if exact_enabled then | ||
| local err = exact.set(conf, scope_hash, prompt_hash, response_text, ttl_exact) | ||
| if err then | ||
| ngx.log(ngx.WARN, "ai-cache: failed to write L1 cache: ", err) | ||
| end | ||
| end | ||
|
|
||
| if semantic_enabled then | ||
| if not embedding then | ||
| return | ||
| end | ||
|
|
||
| local ttl_semantic = (conf.semantic and conf.semantic.ttl) or 86400 | ||
| local store_err = semantic.store( | ||
| conf, scope_hash, embedding, response_text, ttl_semantic | ||
| ) | ||
| if store_err then | ||
| ngx.log(ngx.WARN, "ai-cache: failed to write L2 cache: ", store_err) | ||
| end | ||
| end | ||
| end) | ||
| if not ok then | ||
| core.log.warn("ai-cache: failed to schedule cache write: ", timer_err) | ||
| end | ||
| end | ||
|
|
||
|
|
||
| return _M | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| -- | ||
| -- Licensed to the Apache Software Foundation (ASF) under one or more | ||
| -- contributor license agreements. See the NOTICE file distributed with | ||
| -- this work for additional information regarding copyright ownership. | ||
| -- The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| -- (the "License"); you may not use this file except in compliance with | ||
| -- the License. You may obtain a copy of the License at | ||
| -- | ||
| -- http://www.apache.org/licenses/LICENSE-2.0 | ||
| -- | ||
| -- Unless required by applicable law or agreed to in writing, software | ||
| -- distributed under the License is distributed on an "AS IS" BASIS, | ||
| -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| -- See the License for the specific language governing permissions and | ||
| -- limitations under the License. | ||
| -- | ||
|
|
||
| local core = require("apisix.core") | ||
| local type = type | ||
|
|
||
| local ngx = ngx | ||
| local HTTP_OK = ngx.HTTP_OK | ||
| local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR | ||
|
|
||
| local _M = {} | ||
|
|
||
|
|
||
| function _M.get_embeddings(conf, text, httpc, ssl_verify) | ||
| local body, err = core.json.encode({ input = text }) | ||
| if not body then | ||
| return nil, HTTP_INTERNAL_SERVER_ERROR, err | ||
| end | ||
|
|
||
| httpc:set_timeout(conf.timeout) | ||
|
|
||
| local res, err = httpc:request_uri(conf.endpoint, { | ||
| method = "POST", | ||
| headers = { | ||
| ["Content-Type"] = "application/json", | ||
| ["api-key"] = conf.api_key, | ||
| }, | ||
| body = body, | ||
| ssl_verify = ssl_verify, | ||
| keepalive = true, | ||
| }) | ||
|
|
||
| if not res or not res.body then | ||
| return nil, HTTP_INTERNAL_SERVER_ERROR, err or "no response from embeddings API" | ||
| end | ||
|
|
||
| if res.status ~= HTTP_OK then | ||
| return nil, res.status, res.body | ||
| end | ||
|
|
||
| local res_tab, err = core.json.decode(res.body) | ||
| if not res_tab then | ||
| return nil, HTTP_INTERNAL_SERVER_ERROR, err | ||
| end | ||
|
|
||
| if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then | ||
| return nil, HTTP_INTERNAL_SERVER_ERROR, "unexpected embedding response: " .. res.body | ||
| end | ||
|
|
||
| local embedding = res_tab.data[1].embedding | ||
| if type(embedding) ~= "table" then | ||
| return nil, HTTP_INTERNAL_SERVER_ERROR, "missing embedding field in response" | ||
| end | ||
| if #embedding == 0 then | ||
| return nil, HTTP_INTERNAL_SERVER_ERROR, "embedding vector is empty" | ||
| end | ||
|
|
||
| return embedding, nil, nil | ||
| end | ||
|
|
||
|
|
||
| return _M |
Uh oh!
There was an error while loading. Please reload this page.