Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions kiro_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@
"sonnet": "claude-sonnet-4",
"haiku": "claude-haiku-4.5",
"opus": "claude-opus-4.5",
"opus-4.6": "claude-opus-4.6",
}

KIRO_MODELS = {"auto", "claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "claude-opus-4.5"}
KIRO_MODELS = {"auto", "claude-sonnet-4.5", "claude-sonnet-4", "claude-haiku-4.5", "claude-opus-4.5", "claude-opus-4.6"}

def map_model_name(model: str) -> str:
"""将外部模型名称映射到 Kiro 支持的名称"""
Expand All @@ -57,7 +58,7 @@ def map_model_name(model: str) -> str:
return model
model_lower = model.lower()
if "opus" in model_lower:
return "claude-opus-4.5"
return "claude-opus-4.5" if "4.5" in model_lower else "claude-opus-4.6"
if "haiku" in model_lower:
return "claude-haiku-4.5"
if "sonnet" in model_lower:
Expand Down
89 changes: 73 additions & 16 deletions kiro_proxy/handlers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,14 @@ async def generate():

if flow_id:
flow_monitor.fail_flow(flow_id, "rate_limit_error", "All accounts rate limited", 429)
yield f'data: {{"type":"error","error":{{"type":"rate_limit_error","message":"All accounts rate limited"}}}}\n\n'
yield f'event: error\ndata: {{"type":"error","error":{{"type":"rate_limit_error","message":"All accounts rate limited"}}}}\n\n'
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=429, duration_ms=duration, error="All accounts rate limited"
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=False, latency_ms=duration)
return

# 处理可重试的服务端错误
Expand All @@ -256,7 +263,14 @@ async def generate():
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "api_error", "Server error after retries", response.status_code)
yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Server error after retries"}}}}\n\n'
yield f'event: error\ndata: {{"type":"error","error":{{"type":"api_error","message":"Server error after retries"}}}}\n\n'
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=response.status_code, duration_ms=duration, error="Server error after retries"
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=False, latency_ms=duration)
return

if response.status_code != 200:
Expand Down Expand Up @@ -326,7 +340,14 @@ async def api_caller(prompt: str) -> str:

if flow_id:
flow_monitor.fail_flow(flow_id, error_type, error_msg, response.status_code, error_str)
yield f'data: {{"type":"error","error":{{"type":"{error_type}","message":"{error_msg}"}}}}\n\n'
yield f'event: error\ndata: {{"type":"error","error":{{"type":"{error_type}","message":"{error_msg}"}}}}\n\n'
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=response.status_code, duration_ms=duration, error=error_msg
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=False, latency_ms=duration)
return

# 标记开始流式传输
Expand All @@ -335,8 +356,9 @@ async def api_caller(prompt: str) -> str:

# 正常处理响应
msg_id = f"msg_{log_id}"
yield f'data: {{"type":"message_start","message":{{"id":"{msg_id}","type":"message","role":"assistant","content":[],"model":"{model}","stop_reason":null,"stop_sequence":null,"usage":{{"input_tokens":0,"output_tokens":0}}}}}}\n\n'
yield f'data: {{"type":"content_block_start","index":0,"content_block":{{"type":"text","text":""}}}}\n\n'
yield f'event: message_start\ndata: {{"type":"message_start","message":{{"id":"{msg_id}","type":"message","role":"assistant","content":[],"model":"{model}","stop_reason":null,"stop_sequence":null,"usage":{{"input_tokens":0,"output_tokens":0}}}}}}\n\n'
yield f'event: content_block_start\ndata: {{"type":"content_block_start","index":0,"content_block":{{"type":"text","text":""}}}}\n\n'
yield f'event: ping\ndata: {{"type":"ping"}}\n\n'

full_response = b""

Expand Down Expand Up @@ -367,7 +389,7 @@ async def api_caller(prompt: str) -> str:
full_content += content
if flow_id:
flow_monitor.add_chunk(flow_id, content)
yield f'data: {{"type":"content_block_delta","index":0,"delta":{{"type":"text_delta","text":{json.dumps(content)}}}}}\n\n'
yield f'event: content_block_delta\ndata: {{"type":"content_block_delta","index":0,"delta":{{"type":"text_delta","text":{json.dumps(content)}}}}}\n\n'
except Exception:
pass
pos += total_len
Expand All @@ -376,17 +398,17 @@ async def api_caller(prompt: str) -> str:

result = parse_event_stream_full(full_response)

yield f'data: {{"type":"content_block_stop","index":0}}\n\n'
yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":0}}\n\n'

if result["tool_uses"]:
for i, tool_use in enumerate(result["tool_uses"], 1):
yield f'data: {{"type":"content_block_start","index":{i},"content_block":{{"type":"tool_use","id":"{tool_use["id"]}","name":"{tool_use["name"]}","input":{{}}}}}}\n\n'
yield f'data: {{"type":"content_block_delta","index":{i},"delta":{{"type":"input_json_delta","partial_json":{json.dumps(json.dumps(tool_use["input"]))}}}}}\n\n'
yield f'data: {{"type":"content_block_stop","index":{i}}}\n\n'
yield f'event: content_block_start\ndata: {{"type":"content_block_start","index":{i},"content_block":{{"type":"tool_use","id":"{tool_use["id"]}","name":"{tool_use["name"]}","input":{{}}}}}}\n\n'
yield f'event: content_block_delta\ndata: {{"type":"content_block_delta","index":{i},"delta":{{"type":"input_json_delta","partial_json":{json.dumps(json.dumps(tool_use["input"]))}}}}}\n\n'
yield f'event: content_block_stop\ndata: {{"type":"content_block_stop","index":{i}}}\n\n'

stop_reason = result["stop_reason"]
yield f'data: {{"type":"message_delta","delta":{{"stop_reason":"{stop_reason}","stop_sequence":null}},"usage":{{"output_tokens":100}}}}\n\n'
yield f'data: {{"type":"message_stop"}}\n\n'
yield f'event: message_delta\ndata: {{"type":"message_delta","delta":{{"stop_reason":"{stop_reason}","stop_sequence":null}},"usage":{{"output_tokens":100}}}}\n\n'
yield f'event: message_stop\ndata: {{"type":"message_stop"}}\n\n'

# 完成 Flow
if flow_id:
Expand All @@ -405,6 +427,13 @@ async def api_caller(prompt: str) -> str:
current_account.request_count += 1
current_account.last_used = time.time()
get_rate_limiter().record_request(current_account.id)
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=200, duration_ms=duration, error=None
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=True, latency_ms=duration)
return

except httpx.TimeoutException:
Expand All @@ -416,7 +445,14 @@ async def api_caller(prompt: str) -> str:
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "timeout_error", "Request timeout after retries", 408)
yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Request timeout after retries"}}}}\n\n'
yield f'event: error\ndata: {{"type":"error","error":{{"type":"api_error","message":"Request timeout after retries"}}}}\n\n'
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=408, duration_ms=duration, error="Request timeout after retries"
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=False, latency_ms=duration)
return
except httpx.ConnectError:
if retry_count < max_retries:
Expand All @@ -427,7 +463,14 @@ async def api_caller(prompt: str) -> str:
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "connection_error", "Connection error after retries", 502)
yield f'data: {{"type":"error","error":{{"type":"api_error","message":"Connection error after retries"}}}}\n\n'
yield f'event: error\ndata: {{"type":"error","error":{{"type":"api_error","message":"Connection error after retries"}}}}\n\n'
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=502, duration_ms=duration, error="Connection error after retries"
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=False, latency_ms=duration)
return
except Exception as e:
# 检查是否为可重试的网络错误
Expand All @@ -439,7 +482,14 @@ async def api_caller(prompt: str) -> str:
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "api_error", str(e), 500)
yield f'data: {{"type":"error","error":{{"type":"api_error","message":"{str(e)}"}}}}\n\n'
yield f'event: error\ndata: {{"type":"error","error":{{"type":"api_error","message":"{str(e)}"}}}}\n\n'
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id, timestamp=time.time(), method="POST", path="/v1/messages",
model=model, account_id=current_account.id if current_account else None,
status=500, duration_ms=duration, error=str(e)
))
stats_manager.record_request(account_id=current_account.id if current_account else "unknown", model=model, success=False, latency_ms=duration)
return

return StreamingResponse(generate(), media_type="text/event-stream")
Expand All @@ -452,8 +502,10 @@ async def _handle_non_stream(kiro_request, headers, account, model, log_id, star
current_account = account
max_retries = 2
retry_ctx = RetryableRequest(max_retries=2)
should_log = False

for retry in range(max_retries + 1):
should_log = False
try:
async with httpx.AsyncClient(verify=False, timeout=300) as client:
response = await client.post(KIRO_API_URL, json=kiro_request, headers=headers)
Expand Down Expand Up @@ -547,9 +599,11 @@ async def api_caller(prompt: str) -> str:
),
)

should_log = True
return convert_kiro_response_to_anthropic(result, model, f"msg_{log_id}")

except HTTPException:
should_log = True
raise
except httpx.TimeoutException as e:
error_msg = f"Request timeout: {e}"
Expand All @@ -560,6 +614,7 @@ async def api_caller(prompt: str) -> str:
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "timeout_error", "Request timeout after retries", 408)
should_log = True
raise HTTPException(408, "Request timeout after retries")
except httpx.ConnectError as e:
error_msg = f"Connection error: {e}"
Expand All @@ -570,6 +625,7 @@ async def api_caller(prompt: str) -> str:
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "connection_error", "Connection error after retries", 502)
should_log = True
raise HTTPException(502, "Connection error after retries")
except Exception as e:
error_msg = str(e)
Expand All @@ -581,9 +637,10 @@ async def api_caller(prompt: str) -> str:
continue
if flow_id:
flow_monitor.fail_flow(flow_id, "api_error", str(e), 500)
should_log = True
raise HTTPException(500, str(e))
finally:
if retry == max_retries or status_code == 200:
if should_log:
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id,
Expand Down
33 changes: 17 additions & 16 deletions kiro_proxy/handlers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ async def call_summary(prompt: str) -> str:
content = ""
current_account = account
max_retries = 2

for retry in range(max_retries + 1):

try:
for retry in range(max_retries + 1):
try:
async with httpx.AsyncClient(verify=False, timeout=120) as client:
resp = await client.post(KIRO_API_URL, json=kiro_request, headers=headers)
Expand Down Expand Up @@ -239,20 +240,20 @@ async def call_summary(prompt: str) -> str:
await asyncio.sleep(0.5 * (2 ** retry))
continue
raise HTTPException(500, str(e))

# 记录日志
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id,
timestamp=time.time(),
method="POST",
path=f"/v1/models/{model_name}:generateContent",
model=model,
account_id=current_account.id if current_account else None,
status=status_code,
duration_ms=duration,
error=error_msg
))
finally:
# 记录日志
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id,
timestamp=time.time(),
method="POST",
path=f"/v1/models/{model_name}:generateContent",
model=model,
account_id=current_account.id if current_account else None,
status=status_code,
duration_ms=duration,
error=error_msg
))

# 使用转换函数生成 Gemini 格式响应
return convert_kiro_response_to_gemini(result, model)
49 changes: 25 additions & 24 deletions kiro_proxy/handlers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ async def call_summary(prompt: str) -> str:
content = ""
current_account = account
max_retries = 2

for retry in range(max_retries + 1):

try:
for retry in range(max_retries + 1):
try:
async with httpx.AsyncClient(verify=False, timeout=120) as client:
resp = await client.post(KIRO_API_URL, json=kiro_request, headers=headers)
Expand Down Expand Up @@ -239,28 +240,28 @@ async def call_summary(prompt: str) -> str:
await asyncio.sleep(0.5 * (2 ** retry))
continue
raise HTTPException(500, str(e))

# 记录日志
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id,
timestamp=time.time(),
method="POST",
path="/v1/chat/completions",
model=model,
account_id=current_account.id if current_account else None,
status=status_code,
duration_ms=duration,
error=error_msg
))
# 记录统计
stats_manager.record_request(
account_id=current_account.id if current_account else "unknown",
model=model,
success=status_code == 200,
latency_ms=duration
)
finally:
# 记录日志
duration = (time.time() - start_time) * 1000
state.add_log(RequestLog(
id=log_id,
timestamp=time.time(),
method="POST",
path="/v1/chat/completions",
model=model,
account_id=current_account.id if current_account else None,
status=status_code,
duration_ms=duration,
error=error_msg
))

# 记录统计
stats_manager.record_request(
account_id=current_account.id if current_account else "unknown",
model=model,
success=status_code == 200,
latency_ms=duration
)

if stream:
async def generate():
Expand Down
Loading