Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions apisix/plugins/ai-drivers/openai-base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
--
local _M = {}

local claude_converter = require("apisix.plugins.ai-proxy.converter.claude_to_openai")

local mt = {
__index = _M
}
Expand Down Expand Up @@ -62,6 +64,14 @@ function _M.validate_request(ctx)
return nil, err
end

if ctx.ai_client_protocol == "claude" then
local converted, err = claude_converter.convert_request(request_table)
if not converted then
return nil, err
end
request_table = converted
end

return request_table, nil
end

Expand Down Expand Up @@ -147,7 +157,16 @@ local function read_response(conf, ctx, res, response_filter)
::CONTINUE::
end

if ctx.ai_client_protocol == "claude" then
local converted = claude_converter.convert_sse_events(ctx, chunk)
if converted then
chunk = converted
else
goto NEXT_CHUNK
end
end
plugin.lua_response_filter(ctx, res.headers, chunk)
::NEXT_CHUNK::
end
end

Expand Down Expand Up @@ -208,6 +227,14 @@ local function read_response(conf, ctx, res, response_filter)
ctx.var.llm_response_text = content_to_check
end
end
if ctx.ai_client_protocol == "claude" and res_body then
if res.status == 200 then
raw_res_body = core.json.encode(claude_converter.convert_response(res_body))
else
raw_res_body = core.json.encode(
claude_converter.convert_error_response(res.status, res_body))
end
end
plugin.lua_response_filter(ctx, headers, raw_res_body)
end

Expand Down
5 changes: 5 additions & 0 deletions apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ function _M.before_proxy(conf, ctx, on_error)
local ai_instance = ctx.picked_ai_instance
local ai_driver = require("apisix.plugins.ai-drivers." .. ai_instance.provider)

local is_claude = core.string.has_suffix(ctx.var.uri, "/v1/messages")
if is_claude then
ctx.ai_client_protocol = "claude"
end

local request_body, err = ai_driver.validate_request(ctx)
if not request_body then
return 400, err
Expand Down
Loading
Loading