From dd0e9a2400d890fa616ace716b4e6be3f7bc2412 Mon Sep 17 00:00:00 2001 From: zqbxdev Date: Sun, 31 May 2026 08:00:26 +0800 Subject: [PATCH 1/6] fix: preserve gemini multimodal inputs --- services/protocol/gemini_native.py | 97 ++++++++++++++++++++++--- services/protocol/openai_v1_response.py | 23 ++++-- utils/helper.py | 13 ++++ 3 files changed, 116 insertions(+), 17 deletions(-) diff --git a/services/protocol/gemini_native.py b/services/protocol/gemini_native.py index 4d26402..89212e7 100644 --- a/services/protocol/gemini_native.py +++ b/services/protocol/gemini_native.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import binascii import json import re import time @@ -28,6 +30,8 @@ class ToolConfig: CompletionFunc = Callable[[dict[str, Any], ModelSpec, list[dict[str, Any]]], gemini.GeminiCompletion] +_INLINE_MEDIA_MIME_TYPES = {"image/png", "image/jpeg", "image/jpg", "image/webp", "image/gif"} +_MAX_INLINE_MEDIA_BYTES = 10 * 1024 * 1024 def list_models() -> dict[str, Any]: @@ -78,7 +82,8 @@ def generate_content(model: str, body: dict[str, Any], completion_func: Completi parsed = parse_native_tool_response(completion.content, tools, tool_config) if tools else [] if parsed: return gemini_response(model_id, function_call_parts(parsed), "STOP", text) - return gemini_response(model_id, [{"text": tool_calls.strip_tool_markup(completion.content)}], "STOP", text) + content = native_text_response(completion.content) if tools else completion.content + return gemini_response(model_id, [{"text": tool_calls.strip_tool_markup(content)}], "STOP", text) def stream_generate_content(model: str, body: dict[str, Any], completion_func: CompletionFunc | None = None, first_event: dict[str, Any] | None = None) -> Iterator[dict[str, Any]]: @@ -102,6 +107,30 @@ def stream_generate_content(model: str, body: dict[str, Any], completion_func: C yield gemini_response(_model_id(model), [], "STOP", text) +def _inline_media_part(part: dict[str, Any]) -> tuple[dict[str, Any], str] | None: + inline = _dict_value(part, "inline_data", "inlineData") + if not isinstance(inline, dict): + return None + data = inline.get("data") + if not isinstance(data, str) or not data: + return None + mime_type = str(inline.get("mime_type") or inline.get("mimeType") or "application/octet-stream").strip().lower() or "application/octet-stream" + if mime_type == "image/jpg": + mime_type = "image/jpeg" + if mime_type not in _INLINE_MEDIA_MIME_TYPES: + raise HTTPException(status_code=400, detail={"error": "unsupported inline media mime type"}) + try: + decoded = base64.b64decode(data, validate=True) + except (binascii.Error, ValueError) as exc: + raise HTTPException(status_code=400, detail={"error": "invalid inline media data"}) from exc + if not decoded: + raise HTTPException(status_code=400, detail={"error": "inline media is empty"}) + if len(decoded) > _MAX_INLINE_MEDIA_BYTES: + raise HTTPException(status_code=400, detail={"error": "inline media is too large"}) + content_part = {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{data}"}} + return content_part, f"[image:{mime_type}]" + + def messages_from_contents(body: dict[str, Any]) -> tuple[list[dict[str, Any]], str]: contents = body.get("contents") if not isinstance(contents, list) or not contents: @@ -109,7 +138,7 @@ def messages_from_contents(body: dict[str, Any]) -> tuple[list[dict[str, Any]], messages: list[dict[str, Any]] = [] text_parts: list[str] = [] non_text_parts: list[str] = [] - saw_media = False + media_previews: list[str] = [] for content in contents: if not isinstance(content, dict): continue @@ -117,38 +146,45 @@ def messages_from_contents(body: dict[str, Any]) -> tuple[list[dict[str, Any]], parts = content.get("parts") if not isinstance(parts, list): continue - message_parts: list[str] = [] + message_parts: list[dict[str, Any]] = [] for part in parts: if not isinstance(part, dict): continue if isinstance(part.get("text"), str): text = str(part.get("text") or "") if text: - message_parts.append(text) + message_parts.append({"type": "text", "text": text}) text_parts.append(text) continue - if part.get("inline_data") is not None or part.get("inlineData") is not None: - saw_media = True + inline_media = _inline_media_part(part) + if inline_media is not None: + media_part, preview = inline_media + message_parts.append(media_part) + media_previews.append(preview) continue call = _dict_value(part, "function_call", "functionCall") if call: serialized = "Function call: " + json.dumps(call, ensure_ascii=False, separators=(",", ":")) - message_parts.append(serialized) + message_parts.append({"type": "text", "text": serialized}) non_text_parts.append(serialized) continue response = _dict_value(part, "function_response", "functionResponse") if response: serialized = "Function response: " + json.dumps(response, ensure_ascii=False, separators=(",", ":")) - message_parts.append(serialized) + message_parts.append({"type": "text", "text": serialized}) non_text_parts.append(serialized) - if message_parts: - messages.append({"role": role, "content": "\n".join(message_parts)}) + if not message_parts: + continue + if all(part.get("type") == "text" for part in message_parts): + messages.append({"role": role, "content": "\n".join(str(part.get("text") or "") for part in message_parts)}) + else: + messages.append({"role": role, "content": message_parts}) request_text = "\n".join(text_parts).strip() if not request_text and non_text_parts: request_text = "\n".join(non_text_parts).strip() + if not request_text and media_previews: + request_text = "\n".join(media_previews).strip() if not request_text: - if saw_media: - raise HTTPException(status_code=400, detail={"error": "Gemini native inline media is not supported by this provider"}) raise HTTPException(status_code=400, detail={"error": "Gemini generateContent requires at least one text part"}) if not messages: raise HTTPException(status_code=400, detail={"error": "Gemini generateContent requires at least one text part"}) @@ -221,6 +257,31 @@ def inject_native_tool_prompt(messages: list[dict[str, Any]], tools: list[Native return [{"role": "system", "content": prompt}, *messages] +def _extract_json_object(text: str) -> dict[str, Any]: + decoder = json.JSONDecoder() + start = text.find("{") + while start != -1: + try: + value, _ = decoder.raw_decode(text, start) + except json.JSONDecodeError: + start = text.find("{", start + 1) + continue + return value if isinstance(value, dict) else {} + return {} + + +def native_text_response(text: str) -> str: + stripped = _strip_fences(text or "").strip() + obj = _extract_json_object(stripped) + if not obj: + return text + status = str(obj.get("status") or "").strip().lower() + if status != "text": + return text + content = obj.get("content") + return str(content).strip() if content is not None else "" + + def parse_native_tool_response(text: str, tools: list[NativeTool], config: ToolConfig | None = None) -> list[tool_calls.ParsedToolCall]: config = config or ToolConfig() available = [item.name for item in tools] @@ -291,6 +352,18 @@ def _native_role_to_openai(role: str) -> str: return "user" +def _strip_fences(text: str) -> str: + stripped = str(text or "").strip() + if not stripped.startswith("```"): + return stripped + lines = stripped.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + return "\n".join(lines).strip() + + def _dict_value(data: dict[str, Any], *keys: str) -> dict[str, Any]: for key in keys: value = data.get(key) diff --git a/services/protocol/openai_v1_response.py b/services/protocol/openai_v1_response.py index 56c99ce..a15c4ab 100644 --- a/services/protocol/openai_v1_response.py +++ b/services/protocol/openai_v1_response.py @@ -17,7 +17,7 @@ stream_text_deltas, text_backend, ) -from utils.helper import extract_image_from_message_content, extract_response_prompt, has_response_image_generation_tool +from utils.helper import extract_image_from_message_content, extract_response_prompt, has_image_message_content, has_response_image_generation_tool gpt_chat = chat_adapter("gpt") @@ -49,6 +49,12 @@ def extract_response_image(input_value: object) -> tuple[bytes, str] | None: return None +def _typed_response_content(items: list[dict[str, Any]]) -> str | list[dict[str, Any]]: + if any(has_image_message_content([item]) for item in items): + return [dict(item) for item in items] + return extract_response_prompt(items) + + def messages_from_input(input_value: object, instructions: object = None) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] system_text = str(instructions or "").strip() @@ -63,9 +69,9 @@ def messages_from_input(input_value: object, instructions: object = None) -> lis return messages if isinstance(input_value, list): if all(isinstance(item, dict) and item.get("type") for item in input_value): - text = extract_response_prompt(input_value) - if text: - messages.append({"role": "user", "content": text}) + content = _typed_response_content(input_value) + if content: + messages.append({"role": "user", "content": content}) return messages for item in input_value: if isinstance(item, dict): @@ -94,9 +100,16 @@ def _messages_from_response_item(item: dict[str, Any]) -> list[dict[str, Any]]: "tool_call_id": str(item.get("call_id") or item.get("tool_call_id") or ""), "content": str(item.get("output") or item.get("content") or ""), }] + content = item.get("content") + if has_image_message_content(content): + message_content = content + elif has_image_message_content([item]): + message_content = _typed_response_content([item]) + else: + message_content = extract_response_prompt([item]) or content or "" return [{ "role": str(item.get("role") or "user"), - "content": extract_response_prompt([item]) or item.get("content") or "", + "content": message_content, }] diff --git a/utils/helper.py b/utils/helper.py index 75fefa4..b69ffb9 100644 --- a/utils/helper.py +++ b/utils/helper.py @@ -97,6 +97,19 @@ def new_uuid() -> str: return str(uuid.uuid4()) +def has_image_message_content(content: object, _depth: int = 0) -> bool: + if _depth > 20: + return False + if isinstance(content, dict): + item_type = str(content.get("type") or "").strip() + if item_type in {"image", "image_url", "input_image"}: + return True + return any(has_image_message_content(value, _depth + 1) for value in content.values()) + if isinstance(content, list): + return any(has_image_message_content(item, _depth + 1) for item in content) + return False + + def is_image_chat_request(body: dict[str, object]) -> bool: model = str(body.get("model") or "").strip() modalities = body.get("modalities") From 516f73d59f8b22fb0ee22132e5ddd503a9484e58 Mon Sep 17 00:00:00 2001 From: zqbxdev Date: Sun, 31 May 2026 08:00:32 +0800 Subject: [PATCH 2/6] fix: reject unsupported gemini images --- services/protocol/openai_v1_chat_complete.py | 23 ++++++++++++++----- services/providers/gemini/client.py | 24 +++++++++++++++++++- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/services/protocol/openai_v1_chat_complete.py b/services/protocol/openai_v1_chat_complete.py index 0d2f623..38bae66 100644 --- a/services/protocol/openai_v1_chat_complete.py +++ b/services/protocol/openai_v1_chat_complete.py @@ -23,7 +23,7 @@ stream_text_deltas, text_backend, ) -from utils.helper import build_chat_image_markdown_content, extract_chat_image, extract_chat_prompt, is_image_chat_request, parse_image_count +from utils.helper import build_chat_image_markdown_content, extract_chat_image, extract_chat_prompt, has_image_message_content, is_image_chat_request, parse_image_count gpt_chat = chat_adapter("gpt") @@ -440,8 +440,15 @@ def image_result_content(result: dict[str, Any]) -> str: return str(result.get("message") or "Image generation completed.") +def has_chat_image_input(body: dict[str, Any]) -> bool: + messages = body.get("messages") + if not isinstance(messages, list): + return False + return any(isinstance(message, dict) and has_image_message_content(message.get("content")) for message in messages) + + def gemini_image_chat_unsupported() -> HTTPException: - return _unsupported_image_error(GEMINI_PROVIDER) + return HTTPException(status_code=400, detail={"error": "Gemini Web image input is not supported by this upstream adapter"}) def image_chat_response(body: dict[str, Any]) -> dict[str, Any]: @@ -512,10 +519,12 @@ def stream_image_chat_completion(image_outputs: Iterable[ImageOutput], model: st def handle(body: dict[str, Any]) -> dict[str, Any] | Iterator[dict[str, Any]]: if body.get("stream"): - if is_image_chat_request(body): - return image_chat_events(body) model, messages, _ = text_chat_parts(body) spec = resolve_model(model) + if spec.provider == GEMINI_PROVIDER and has_chat_image_input(body): + raise gemini_image_chat_unsupported() + if is_image_chat_request(body): + return image_chat_events(body) if tool_calls.has_function_tools(body): if spec.provider == GROK_PROVIDER: if grok_chat.is_app_chat_model(spec): @@ -553,10 +562,12 @@ def handle(body: dict[str, Any]) -> dict[str, Any] | Iterator[dict[str, Any]]: if spec.provider == GEMINI_PROVIDER: return stream_gemini_chat_completion(body, spec, messages, model) return stream_text_chat_completion(text_backend(), messages, model, stream_include_usage(body)) - if is_image_chat_request(body): - return image_chat_response(body) model, messages, original_messages = text_chat_parts(body) spec = resolve_model(model) + if spec.provider == GEMINI_PROVIDER and has_chat_image_input(body): + raise gemini_image_chat_unsupported() + if is_image_chat_request(body): + return image_chat_response(body) if spec.provider == GROK_PROVIDER: if grok_chat.is_app_chat_model(spec): response = grok_chat.chat_completion(body, spec, messages) diff --git a/services/providers/gemini/client.py b/services/providers/gemini/client.py index dda6d4f..41a8500 100644 --- a/services/providers/gemini/client.py +++ b/services/providers/gemini/client.py @@ -26,6 +26,9 @@ GEMINI_SENSITIVE_COOKIE_NAMES = ("__Secure-1PSID", "__Secure-1PSIDTS", "SNlM0e", "at", "session_token") GEMINI_NON_COOKIE_FIELDS = ("SNlM0e", "session_token", "at") GEMINI_WEB_RPC_ID = "assistant.lamda.BardFrontendService.StreamGenerate" +GEMINI_WEB_IMAGE_UNSUPPORTED_DETAIL = "Gemini Web image input is not supported by this upstream adapter" +GEMINI_IMAGE_PART_TYPES = {"image", "image_url", "input_image"} +GEMINI_IMAGE_PAYLOAD_KEYS = {"image_url", "inlineData", "inline_data"} @dataclass(frozen=True) @@ -232,6 +235,23 @@ def rotate_psidts_cookie(session: object, cookie_header: str, user_agent: str | return merged +def contains_image_content(value: object) -> bool: + if isinstance(value, dict): + block_type = str(value.get("type") or "").strip() + if block_type in GEMINI_IMAGE_PART_TYPES: + return True + if any(key in value for key in GEMINI_IMAGE_PAYLOAD_KEYS): + return True + return any(contains_image_content(item) for item in value.values()) + if isinstance(value, list): + return any(contains_image_content(item) for item in value) + return False + + +def raise_unsupported_image_input() -> None: + raise HTTPException(status_code=400, detail={"error": GEMINI_WEB_IMAGE_UNSUPPORTED_DETAIL}) + + def message_text(content: object) -> str: if isinstance(content, str): return content @@ -255,6 +275,8 @@ def message_text(content: object) -> str: def build_prompt(messages: list[dict[str, Any]]) -> str: parts: list[str] = [] for message in messages: + if contains_image_content(message.get("content")): + raise_unsupported_image_input() role = str(message.get("role") or "user").strip().lower() text = message_text(message.get("content")).strip() if not text: @@ -623,12 +645,12 @@ def extract_completion(payload: object) -> GeminiCompletion: def chat_completion(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> GeminiCompletion: from services.account_service import account_service + payload = build_web_payload(spec, body, messages) access_token = account_service.get_text_access_token(provider="gemini") if not access_token: raise HTTPException(status_code=503, detail={"error": "no available Gemini account"}) account = account_service.get_account(access_token) or {"access_token": access_token, "provider": "gemini"} cookie_header = account_cookie_header(account) - payload = build_web_payload(spec, body, messages) session_token = account_session_token(account) if session_token: payload["session_token"] = session_token From 69a668b4b59081459c109adacf1f75a6719cb1a1 Mon Sep 17 00:00:00 2001 From: zqbxdev Date: Sun, 31 May 2026 08:00:38 +0800 Subject: [PATCH 3/6] fix: thread gemini deep research options --- services/gemini_deep_research.py | 56 +++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/services/gemini_deep_research.py b/services/gemini_deep_research.py index f950317..0224eb8 100644 --- a/services/gemini_deep_research.py +++ b/services/gemini_deep_research.py @@ -43,27 +43,42 @@ def to_dict(self) -> dict[str, Any]: CompletionTextFunc = Callable[[str, str], str] +def _deep_research_language(body: dict[str, Any]) -> str: + return str(body.get("language") or body.get("lang") or "").strip() + + +def _deep_research_max_sources(body: dict[str, Any]) -> int: + raw = body.get("max_sources", body.get("maxSources", 0)) + try: + value = int(raw or 0) + except (TypeError, ValueError): + return 0 + return max(0, value) + + def run_deep_research(body: dict[str, Any], completion_func: CompletionTextFunc | None = None) -> dict[str, Any]: query = _query(body) model = str(body.get("model") or "gemini-2.5-pro") + language = _deep_research_language(body) + max_sources = _deep_research_max_sources(body) result = DeepResearchResult(id=f"dr_{uuid.uuid4().hex}", status="in_progress", query=query, model=model) started = time.time() complete = completion_func or gemini_native.complete_text try: - plan_text = complete(model, _plan_prompt(query)) + plan_text = complete(model, _plan_prompt(query, language, max_sources)) plan = _json_value(plan_text) questions = _questions(plan, query) result.steps.append({"type": "plan", "content": plan}) findings: list[dict[str, Any]] = [] for index, question in enumerate(questions, start=1): - research_text = complete(model, _research_prompt(query, question)) + research_text = complete(model, _research_prompt(query, question, language, max_sources)) research = _json_value(research_text) sources = _sources(research) result.sources.extend(sources) step = {"type": "research", "index": index, "question": question, "content": research} result.steps.append(step) findings.append({"question": question, "research": research}) - synthesis = complete(model, _synthesis_prompt(query, findings)) + synthesis = complete(model, _synthesis_prompt(query, findings, language, max_sources)) synthesis_data = _json_value(synthesis) result.summary = str(synthesis_data.get("summary") or synthesis).strip() result.status = "completed" @@ -83,19 +98,21 @@ def run_deep_research(body: dict[str, Any], completion_func: CompletionTextFunc def stream_deep_research(body: dict[str, Any], completion_func: CompletionTextFunc | None = None) -> Iterator[tuple[str, dict[str, Any]]]: query = _query(body) model = str(body.get("model") or "gemini-2.5-pro") + language = _deep_research_language(body) + max_sources = _deep_research_max_sources(body) result_id = f"dr_{uuid.uuid4().hex}" yield "progress", {"id": result_id, "status": "in_progress", "query": query, "model": model} complete = completion_func or gemini_native.complete_text started = time.time() try: - plan = _json_value(complete(model, _plan_prompt(query))) + plan = _json_value(complete(model, _plan_prompt(query, language, max_sources))) questions = _questions(plan, query) yield "step", {"id": result_id, "type": "plan", "content": plan} findings: list[dict[str, Any]] = [] sources: list[dict[str, Any]] = [] steps = [{"type": "plan", "content": plan}] for index, question in enumerate(questions, start=1): - research = _json_value(complete(model, _research_prompt(query, question))) + research = _json_value(complete(model, _research_prompt(query, question, language, max_sources))) step = {"type": "research", "index": index, "question": question, "content": research} steps.append(step) findings.append({"question": question, "research": research}) @@ -103,7 +120,7 @@ def stream_deep_research(body: dict[str, Any], completion_func: CompletionTextFu for source in _sources(research): sources.append(source) yield "source", {"id": result_id, "source": source} - synthesis = complete(model, _synthesis_prompt(query, findings)) + synthesis = complete(model, _synthesis_prompt(query, findings, language, max_sources)) synthesis_data = _json_value(synthesis) result = DeepResearchResult( id=result_id, @@ -194,16 +211,31 @@ def _query(body: dict[str, Any]) -> str: return query -def _plan_prompt(query: str) -> str: - return "Return JSON only with a questions array for researching this query: " + query +def _plan_prompt(query: str, language: str = "", max_sources: int = 0) -> str: + prompt = "Return JSON only with a questions array for researching this query: " + query + if language: + prompt += " Respond in language: " + language + if max_sources > 0: + prompt += " Use at most " + str(max_sources) + " sources overall." + return prompt -def _research_prompt(query: str, question: str) -> str: - return "Return JSON only with summary and sources array for query " + json.dumps(query) + " subquestion " + json.dumps(question) +def _research_prompt(query: str, question: str, language: str = "", max_sources: int = 0) -> str: + prompt = "Return JSON only with summary and sources array for query " + json.dumps(query) + " subquestion " + json.dumps(question) + if language: + prompt += " language " + json.dumps(language) + if max_sources > 0: + prompt += " max_sources " + str(max_sources) + return prompt -def _synthesis_prompt(query: str, findings: list[dict[str, Any]]) -> str: - return "Return JSON only with summary for query " + json.dumps(query) + " using findings " + json.dumps(findings, ensure_ascii=False) +def _synthesis_prompt(query: str, findings: list[dict[str, Any]], language: str = "", max_sources: int = 0) -> str: + prompt = "Return JSON only with summary for query " + json.dumps(query) + " using findings " + json.dumps(findings, ensure_ascii=False) + if language: + prompt += " language " + json.dumps(language) + if max_sources > 0: + prompt += " max_sources " + str(max_sources) + return prompt def _json_value(text: str) -> dict[str, Any]: From a8d22169719b5dd54f08338ac2371d2ba232fb98 Mon Sep 17 00:00:00 2001 From: zqbxdev Date: Sun, 31 May 2026 08:00:45 +0800 Subject: [PATCH 4/6] fix: redact sensitive error payloads --- services/protocol/error_response.py | 97 ++++++++++++++++++++++++++++- test/test_error_response.py | 88 ++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 test/test_error_response.py diff --git a/services/protocol/error_response.py b/services/protocol/error_response.py index c3df3ae..0ae59a2 100644 --- a/services/protocol/error_response.py +++ b/services/protocol/error_response.py @@ -1,10 +1,55 @@ from __future__ import annotations +import re from typing import Any from fastapi.responses import JSONResponse +SENSITIVE_ERROR_FALLBACK = "request failed" +MAX_PUBLIC_ERROR_MESSAGE_LENGTH = 500 + +_SENSITIVE_MARKERS = ( + "access_token", + "refresh_token", + "id_token", + "authorization", + "bearer ", + "set-cookie", + "cookie", + "session token", + "secret_key", + "oauth", + "sso", +) +_NOISY_MARKERS = ( + "traceback", + "upstreamhttperror", + "backend-api/", + "chatgpt.com", + "status=", + "body=", + "curl: (", + "tls connect error", + "openssl_internal", + "connection reset", + "read timed out", + "connect timeout", + "max retries exceeded", + "httpconnectionpool", + "httpsconnectionpool", + "clientconnectorerror", + "serverdisconnectederror", + "failed to establish a new connection", +) +_EMAIL_PATTERN = re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE) +_PYTHON_FRAME_PATTERN = re.compile(r'File "[^"]+", line \d+', re.IGNORECASE) +_EXCEPTION_REPR_PATTERN = re.compile(r"\b[A-Za-z_][A-Za-z0-9_]*(?:Error|Exception)\s*\(") +_SECRET_ASSIGNMENT_PATTERN = re.compile( + r"(?i)\b(?:access_token|refresh_token|id_token|authorization|cookie|session_token|secret_key|api_key)\b\s*[:=]" +) + + def _message_from_value(value: object) -> str: if isinstance(value, str): return value @@ -60,6 +105,52 @@ def _default_error_code(status_code: int) -> str: return "upstream_error" +_DATA_URL_PATTERN = re.compile(r"data:[-+./\w]+;base64,[A-Za-z0-9+/=]+", re.IGNORECASE) +_LONG_BASE64_PATTERN = re.compile(r"(? str: + value = str(message or "").strip() + if not value: + return fallback + normalized = " ".join(value.split()) + lowered = normalized.lower() + if ( + any(marker in lowered for marker in _SENSITIVE_MARKERS) + or any(marker in lowered for marker in _NOISY_MARKERS) + or _EMAIL_PATTERN.search(normalized) + or _PYTHON_FRAME_PATTERN.search(normalized) + or _EXCEPTION_REPR_PATTERN.search(normalized) + or _SECRET_ASSIGNMENT_PATTERN.search(normalized) + or _DATA_URL_PATTERN.search(normalized) + or _LONG_BASE64_PATTERN.search(normalized) + ): + return fallback + if len(normalized) > MAX_PUBLIC_ERROR_MESSAGE_LENGTH: + return normalized[: MAX_PUBLIC_ERROR_MESSAGE_LENGTH - 3] + "..." + return normalized + + +def _sanitize_payload_value(value: object, *, fallback: object | None = None) -> object: + if not isinstance(value, str): + return value + sanitized = sanitize_public_error_message(value, fallback=str(fallback or SENSITIVE_ERROR_FALLBACK)) + if sanitized == SENSITIVE_ERROR_FALLBACK and fallback is not None: + return fallback + return sanitized + + +def sanitize_openai_error_payload(payload: dict[str, Any]) -> dict[str, Any]: + error = payload.get("error") + if not isinstance(error, dict): + return payload + sanitized_error = dict(error) + sanitized_error["message"] = sanitize_public_error_message(error.get("message")) + sanitized_error["param"] = _sanitize_payload_value(error.get("param")) + sanitized_error["code"] = _sanitize_payload_value(error.get("code"), fallback="upstream_error") + return {**payload, "error": sanitized_error} + + def openai_error_payload( detail: object, status_code: int, @@ -70,7 +161,7 @@ def openai_error_payload( ) -> dict[str, Any]: error_detail = detail.get("error") if isinstance(detail, dict) else None if isinstance(error_detail, dict): - return { + payload = { "error": { "message": error_message_from_detail(error_detail) or "request failed", "type": str(error_detail.get("type") or error_type or _default_error_type(status_code)), @@ -78,7 +169,8 @@ def openai_error_payload( "code": error_detail.get("code", code if code is not None else _default_error_code(status_code)), } } - return { + return sanitize_openai_error_payload(payload) + payload = { "error": { "message": error_message_from_detail(detail) or "request failed", "type": error_type or _default_error_type(status_code), @@ -86,6 +178,7 @@ def openai_error_payload( "code": code if code is not None else _default_error_code(status_code), } } + return sanitize_openai_error_payload(payload) def openai_error_response( diff --git a/test/test_error_response.py b/test/test_error_response.py new file mode 100644 index 0000000..7df2d6c --- /dev/null +++ b/test/test_error_response.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import unittest + +from test.optional_stubs import install_fastapi_stubs + +install_fastapi_stubs() + +from services.protocol.error_response import ( + MAX_PUBLIC_ERROR_MESSAGE_LENGTH, + openai_error_payload, + sanitize_openai_error_payload, + sanitize_public_error_message, +) + + +class PublicErrorSanitizationTests(unittest.TestCase): + def test_preserves_clear_validation_messages(self) -> None: + self.assertEqual( + sanitize_public_error_message("image_url must be a data URL or http URL"), + "image_url must be a data URL or http URL", + ) + self.assertEqual( + openai_error_payload("file_id image references are not supported", 400)["error"]["message"], + "file_id image references are not supported", + ) + validation_payload = openai_error_payload( + [{"loc": ["body", "prompt"], "msg": "Field required"}, {"loc": ["body", "n"], "msg": "Input should be greater than or equal to 1"}], + 422, + ) + self.assertIn("prompt: Field required", validation_payload["error"]["message"]) + self.assertIn("n: Input should be greater than or equal to 1", validation_payload["error"]["message"]) + + def test_replaces_tracebacks_and_python_exception_reprs(self) -> None: + traceback_message = 'Traceback (most recent call last): File "/srv/app/provider.py", line 42, in call RuntimeError("boom")' + self.assertEqual(sanitize_public_error_message(traceback_message), "request failed") + self.assertEqual(sanitize_public_error_message('UpstreamHTTPError(status=500, body="secret")'), "request failed") + + def test_replaces_network_browser_and_raw_upstream_details(self) -> None: + messages = [ + "curl: (35) TLS connect error: OPENSSL_internal:WRONG_VERSION_NUMBER", + "HTTPSConnectionPool(host='chatgpt.com', port=443): Max retries exceeded with url: /backend-api/conversation", + "status=500 body={'error': 'internal'}", + ] + for message in messages: + with self.subTest(message=message): + self.assertEqual(sanitize_public_error_message(message), "request failed") + + def test_replaces_token_cookie_auth_email_and_inline_image_details(self) -> None: + messages = [ + "authorization: Bearer sk-secret", + "Set-Cookie: session=secret", + "refresh_token=rt-secret", + "account user@example.com failed upstream", + "bad payload data:image/png;base64,AA==", + "raw " + ("A" * 160), + ] + for message in messages: + with self.subTest(message=message): + self.assertEqual(sanitize_public_error_message(message), "request failed") + + def test_truncates_oversized_safe_messages(self) -> None: + message = "safe detail " * 100 + sanitized = sanitize_public_error_message(message) + self.assertEqual(len(sanitized), MAX_PUBLIC_ERROR_MESSAGE_LENGTH) + self.assertTrue(sanitized.endswith("...")) + self.assertIn("safe detail", sanitized) + + def test_sanitizes_openai_payload_message_param_and_code(self) -> None: + payload = sanitize_openai_error_payload( + { + "error": { + "message": "Traceback with access_token=secret", + "type": "server_error", + "param": "authorization", + "code": "cookie=secret", + } + } + ) + + self.assertEqual(payload["error"]["message"], "request failed") + self.assertEqual(payload["error"]["param"], "request failed") + self.assertEqual(payload["error"]["code"], "upstream_error") + self.assertEqual(payload["error"]["type"], "server_error") + + +if __name__ == "__main__": + unittest.main() From 3330ef55bb372f83be97b4bfc8eaecc4aa13edb3 Mon Sep 17 00:00:00 2001 From: zqbxdev Date: Sun, 31 May 2026 08:00:53 +0800 Subject: [PATCH 5/6] test: cover gemini multimodal parity --- test/test_gemini_native.py | 105 ++++++++++++++++++++++++-- test/test_gemini_provider.py | 138 +++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 5 deletions(-) diff --git a/test/test_gemini_native.py b/test/test_gemini_native.py index 8eec4d9..e65cd83 100644 --- a/test/test_gemini_native.py +++ b/test/test_gemini_native.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import json import sys import time @@ -25,7 +26,7 @@ from services import gemini_deep_research from services.providers import gemini as gemini_provider from services.providers.gemini import models as gemini_models -from services.protocol import gemini_native, openai_v1_chat_complete +from services.protocol import gemini_native, openai_v1_chat_complete, openai_v1_response AUTH_HEADERS = {"Authorization": "Bearer webchat2api"} @@ -129,16 +130,110 @@ def test_stream_generate_content_skips_empty_text_chunk(self) -> None: self.assertEqual(chunks[0]["candidates"][0]["content"]["parts"], []) self.assertEqual(chunks[0]["candidates"][0]["finishReason"], "STOP") - def test_inline_media_without_text_rejects(self) -> None: + def test_inline_media_without_text_is_preserved(self) -> None: + seen_messages: list[dict[str, Any]] = [] + + response = gemini_native.generate_content( + "gemini-2.5-pro", + {"contents": [{"role": "user", "parts": [{"inlineData": {"mimeType": "image/png", "data": "AA=="}}]}]}, + completion_func=lambda body, spec, messages: seen_messages.extend(messages) or gemini_provider.GeminiCompletion("described"), + ) + + self.assertEqual(response["candidates"][0]["content"]["parts"], [{"text": "described"}]) + self.assertEqual(seen_messages, [{ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,AA=="}}], + }]) + self.assertEqual(gemini_native.request_text_from_body({"contents": [{"role": "user", "parts": [{"inlineData": {"mimeType": "image/png", "data": "AA=="}}]}]}), "[image:image/png]") + + def test_inline_media_with_text_is_preserved(self) -> None: + seen_messages: list[dict[str, Any]] = [] + body = {"contents": [{"role": "user", "parts": [ + {"text": "Describe this"}, + {"inline_data": {"mime_type": "image/jpeg", "data": "AQI="}}, + ]}]} + + gemini_native.generate_content( + "gemini-2.5-pro", + body, + completion_func=lambda body, spec, messages: seen_messages.extend(messages) or gemini_provider.GeminiCompletion("ok"), + ) + + self.assertEqual(seen_messages, [{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,AQI="}}, + ], + }]) + self.assertEqual(gemini_native.request_text_from_body(body), "Describe this") + + def test_responses_gemini_input_image_is_preserved_until_provider_error(self) -> None: + with self.assertRaises(HTTPException) as raised: + openai_v1_response.handle({ + "model": "gemini-2.5-pro", + "input": [ + {"type": "input_text", "text": "Describe this"}, + {"type": "input_image", "image_url": "data:image/png;base64,AA=="}, + ], + }) + + detail = cast(dict[str, Any], getattr(raised.exception, "detail")) + self.assertEqual(getattr(raised.exception, "status_code"), 400) + self.assertEqual(detail["error"], "Gemini Web image input is not supported by this upstream adapter") + + + def test_inline_media_rejects_invalid_or_oversized_data(self) -> None: + invalid_cases = [ + ({"mimeType": "text/plain", "data": "AA=="}, "unsupported inline media mime type"), + ({"mimeType": "image/png", "data": "not-base64"}, "invalid inline media data"), + ({"mimeType": "image/png", "data": ""}, "Gemini generateContent requires at least one text part"), + ] + for inline, message in invalid_cases: + with self.subTest(message=message): + with self.assertRaises(HTTPException) as raised: + gemini_native.generate_content( + "gemini-2.5-pro", + {"contents": [{"role": "user", "parts": [{"inlineData": inline}]}]}, + completion_func=lambda body, spec, messages: gemini_provider.GeminiCompletion("ignored"), + ) + self.assertIn(message, str(getattr(raised.exception, "detail"))) + + oversized = base64.b64encode(b"0" * (10 * 1024 * 1024 + 1)).decode("ascii") with self.assertRaises(HTTPException) as raised: gemini_native.generate_content( "gemini-2.5-pro", - {"contents": [{"role": "user", "parts": [{"inlineData": {"mimeType": "image/png", "data": "AA=="}}]}]}, + {"contents": [{"role": "user", "parts": [{"inlineData": {"mimeType": "image/png", "data": oversized}}]}]}, completion_func=lambda body, spec, messages: gemini_provider.GeminiCompletion("ignored"), ) + self.assertIn("inline media is too large", str(getattr(raised.exception, "detail"))) - self.assertEqual(getattr(raised.exception, "status_code"), 400) - self.assertIn("inline media", str(getattr(raised.exception, "detail"))) + def test_native_tool_text_status_returns_content_text(self) -> None: + body = _native_body("weather") | {"tools": [_tool()]} + response = gemini_native.generate_content( + "gemini-2.5-pro", + body, + completion_func=lambda body, spec, messages: gemini_provider.GeminiCompletion('{"status":"text","content":"No tool needed"}'), + ) + + self.assertEqual(response["candidates"][0]["content"]["parts"], [{"text": "No tool needed"}]) + + def test_deepresearch_options_are_included_in_prompts(self) -> None: + prompts: list[str] = [] + outputs = iter([ + '{"questions":["What is A?"]}', + '{"summary":"A facts","sources":[]}', + '{"summary":"Final report"}', + ]) + + result = gemini_deep_research.run_deep_research( + {"query": "A", "language": "fr", "max_sources": 2}, + completion_func=lambda model, prompt: prompts.append(prompt) or next(outputs), + ) + + self.assertEqual(result["status"], "completed") + self.assertTrue(any("fr" in prompt for prompt in prompts)) + self.assertTrue(any("2" in prompt for prompt in prompts)) class GeminiDeepResearchTests(unittest.TestCase): diff --git a/test/test_gemini_provider.py b/test/test_gemini_provider.py index 0ea3e56..d180d43 100644 --- a/test/test_gemini_provider.py +++ b/test/test_gemini_provider.py @@ -39,6 +39,78 @@ def test_build_web_payload_converts_messages_to_prompt(self) -> None: self.assertEqual(payload["max_tokens"], 64) self.assertEqual(payload["prompt"], "System: Be brief.\n\nUser: Hello\n\nAssistant: Hi") + def test_build_web_payload_rejects_image_url_part(self) -> None: + with self.assertRaises(HTTPException) as raised: + gemini.build_web_payload( + resolve_model("gemini-2.5-pro"), + {}, + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AA=="}}, + ], + } + ], + ) + + self.assertEqual(getattr(raised.exception, "status_code"), 400) + detail = cast(dict[str, Any], getattr(raised.exception, "detail")) + self.assertEqual(detail["error"], "Gemini Web image input is not supported by this upstream adapter") + + def test_build_web_payload_rejects_responses_input_image_part(self) -> None: + with self.assertRaises(HTTPException) as raised: + gemini.build_web_payload( + resolve_model("gemini-2.5-pro"), + {}, + [ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "Describe it"}, + {"type": "input_image", "image_url": "data:image/png;base64,AA=="}, + ], + } + ], + ) + + self.assertEqual(getattr(raised.exception, "status_code"), 400) + self.assertIn("Gemini Web image input", str(getattr(raised.exception, "detail"))) + + def test_build_web_payload_rejects_native_inline_data_part(self) -> None: + with self.assertRaises(HTTPException) as raised: + gemini.build_web_payload( + resolve_model("gemini-2.5-pro"), + {}, + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe it"}, + {"inlineData": {"mimeType": "image/png", "data": "AA=="}}, + ], + } + ], + ) + + self.assertEqual(getattr(raised.exception, "status_code"), 400) + self.assertIn("Gemini Web image input", str(getattr(raised.exception, "detail"))) + + def test_chat_completion_rejects_image_message_before_account_lookup(self) -> None: + account_service = mock.Mock() + with mock.patch.dict(sys.modules, {"services.account_service": mock.Mock(account_service=account_service)}), \ + self.assertRaises(HTTPException) as raised: + gemini.chat_completion( + {}, + resolve_model("gemini-2.5-pro"), + [{"role": "user", "content": [{"type": "image", "data": b"image-bytes", "mime": "image/png"}]}], + ) + + self.assertEqual(getattr(raised.exception, "status_code"), 400) + self.assertIn("Gemini Web image input", str(getattr(raised.exception, "detail"))) + account_service.get_text_access_token.assert_not_called() + def test_account_cookie_header_requires_both_session_cookies(self) -> None: with self.assertRaises(HTTPException) as raised: gemini.account_cookie_header({"access_token": "__Secure-1PSID=psid"}) @@ -642,6 +714,34 @@ def test_chat_completion_success_marks_account_used_once(self) -> None: account_service.mark_text_used.assert_called_once_with("gemini-token") account_service.remove_invalid_token.assert_not_called() + def test_gemini_chat_image_without_modalities_is_not_silently_stripped(self) -> None: + with self.assertRaises(HTTPException) as raised: + openai_v1_chat_complete.handle({ + "model": "gemini-2.5-pro", + "messages": [{"role": "user", "content": [ + {"type": "text", "text": "Describe this"}, + {"type": "image_url", "image_url": {"url": "https://example.test/cat.png"}}, + ]}], + }) + + detail = cast(dict[str, Any], getattr(raised.exception, "detail")) + self.assertEqual(detail["error"], "Gemini Web image input is not supported by this upstream adapter") + + def test_gemini_stream_chat_image_without_modalities_is_not_silently_stripped(self) -> None: + with self.assertRaises(HTTPException) as raised: + result = openai_v1_chat_complete.handle({ + "model": "gemini-2.5-pro", + "stream": True, + "messages": [{"role": "user", "content": [ + {"type": "text", "text": "Describe this"}, + {"type": "image_url", "image_url": {"url": "https://example.test/cat.png"}}, + ]}], + }) + list(cast(Any, result)) + + detail = cast(dict[str, Any], getattr(raised.exception, "detail")) + self.assertEqual(detail["error"], "Gemini Web image input is not supported by this upstream adapter") + def test_gemini_image_chat_request_is_rejected(self) -> None: with self.assertRaises(HTTPException) as raised: openai_v1_chat_complete.handle({ @@ -667,6 +767,44 @@ def test_gemini_streaming_image_chat_request_is_rejected(self) -> None: self.assertEqual(getattr(raised.exception, "status_code"), 400) + def test_gemini_image_input_chat_routes_to_text_provider_error(self) -> None: + with self.assertRaises(HTTPException) as raised: + openai_v1_chat_complete.handle({ + "model": "gemini-2.5-pro", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AA=="}}, + ], + }], + "modalities": ["image"], + }) + + detail = cast(dict[str, Any], getattr(raised.exception, "detail")) + self.assertEqual(getattr(raised.exception, "status_code"), 400) + self.assertEqual(detail["error"], "Gemini Web image input is not supported by this upstream adapter") + + def test_gemini_streaming_image_input_chat_routes_to_text_provider_error(self) -> None: + with self.assertRaises(HTTPException) as raised: + result = openai_v1_chat_complete.handle({ + "model": "gemini-2.5-pro", + "stream": True, + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AA=="}}, + ], + }], + "modalities": ["image"], + }) + list(cast(Any, result)) + + detail = cast(dict[str, Any], getattr(raised.exception, "detail")) + self.assertEqual(getattr(raised.exception, "status_code"), 400) + self.assertEqual(detail["error"], "Gemini Web image input is not supported by this upstream adapter") + def test_non_streaming_chat_dispatch_uses_gemini_provider(self) -> None: with mock.patch.object(gemini, "chat_completion", return_value=gemini.GeminiCompletion("Hello from Gemini")) as chat_completion: response = openai_v1_chat_complete.handle({ From 7aabe963c97bdeefbabf616dacd1a152399e7036 Mon Sep 17 00:00:00 2001 From: zerogzy Date: Fri, 12 Jun 2026 18:06:55 +0000 Subject: [PATCH 6/6] feat: support per-account proxy settings --- api/accounts.py | 31 +++++++++---- services/account_service.py | 46 ++++++++++++------- services/network/client.py | 9 ++-- services/openai_backend_api.py | 10 ++-- services/providers/gemini/accounts.py | 2 +- services/providers/gemini/client.py | 9 ++-- services/providers/grok/client.py | 8 +++- services/proxy_service.py | 11 ++++- .../components/account-import-dialog.tsx | 40 ++++++++++++---- web/src/app/accounts/page.tsx | 32 +++++++++++-- web/src/lib/api.ts | 6 ++- 11 files changed, 151 insertions(+), 53 deletions(-) diff --git a/api/accounts.py b/api/accounts.py index 93acc85..1c4c5fe 100644 --- a/api/accounts.py +++ b/api/accounts.py @@ -81,11 +81,14 @@ class AccountExportRequest(BaseModel): class AccountUpdateRequest(BaseModel): access_token: str = "" + account_id: str | None = None + row_id: str | None = None type: str | None = None provider: str | None = None target_provider: Literal["gpt", "grok", "gemini"] | None = None status: str | None = None quota: int | None = None + proxy: str | None = None class CPAPoolCreateRequest(BaseModel): @@ -241,7 +244,7 @@ def _validate_gemini_import_payloads(tokens: list[str], payloads: list[dict[str, if not payloads: raise HTTPException(status_code=400, detail={"error": "Gemini 目前只支持 Cookie 双字段导入"}) top_level_gemini = normalize_provider(provider) == GEMINI_PROVIDER - allowed_keys = {"provider", "__Secure-1PSID", "__Secure-1PSIDTS"} + allowed_keys = {"provider", "__Secure-1PSID", "__Secure-1PSIDTS", "proxy"} for item in payloads: item_provider = str(item.get("provider") or "").strip() if top_level_gemini and item_provider and normalize_account_provider(item_provider) != GEMINI_PROVIDER: @@ -285,13 +288,20 @@ def _grok_import_requested(provider: str | None, payloads: list[dict[str, Any]]) def _validate_grok_import_payloads(tokens: list[str], payloads: list[dict[str, Any]], provider: str | None) -> list[str]: if not _grok_import_requested(provider, payloads): return tokens - if payloads: - raise HTTPException(status_code=400, detail={"error": "Grok 导入只接受裸 SSO 值,或每行一个 sso=<值>;不支持 sso-rw、完整 Cookie header、JSON、CPA、cookies 或 accounts 账号 payload"}) - if not tokens: + if not tokens and not payloads: raise HTTPException(status_code=400, detail={"error": "Grok 导入只接受裸 SSO 值,或每行一个 sso=<值>"}) + allowed_keys = {"provider", "sso", "proxy"} + for item in payloads: + item_provider = str(item.get("provider") or provider or "").strip() + if item_provider and normalize_account_provider(item_provider) != GROK_PROVIDER: + raise HTTPException(status_code=400, detail={"error": "Grok 导入不能混用其他供应商账号"}) + extra_keys = {key for key, value in item.items() if value is not None} - allowed_keys + sso = str(item.get("sso") or "").strip() + if extra_keys or not _normalize_grok_sso_import_token(sso): + raise HTTPException(status_code=400, detail={"error": "Grok 导入只接受裸 SSO 值,或每行一个 sso=<值>;不支持 sso-rw、完整 Cookie header、其他 Cookie 名称、JSON、CPA 或 cookies"}) normalized_tokens = [_normalize_grok_sso_import_token(token) for token in tokens] if not all(normalized_tokens): - raise HTTPException(status_code=400, detail={"error": "Grok 导入只接受裸 SSO 值,或每行一个 sso=<值>;不支持 sso-rw、完整 Cookie header、其他 Cookie 名称、JSON、CPA、cookies 或 accounts 账号 payload"}) + raise HTTPException(status_code=400, detail={"error": "Grok 导入只接受裸 SSO 值,或每行一个 sso=<值>;不支持 sso-rw、完整 Cookie header、其他 Cookie 名称、JSON、CPA 或 cookies"}) return normalized_tokens @@ -480,8 +490,9 @@ async def export_accounts(body: AccountExportRequest, authorization: str | None async def update_account(body: AccountUpdateRequest, authorization: str | None = Header(default=None)): require_admin(authorization) access_token = str(body.access_token or "").strip() - if not access_token: - raise HTTPException(status_code=400, detail={"error": "access_token is required"}) + identifier = _delete_identifiers([{"account_id": body.account_id, "row_id": body.row_id}]) + if not access_token and not identifier: + raise HTTPException(status_code=400, detail={"error": "access_token or identifier is required"}) updates = { key: value for key, value in { @@ -489,12 +500,16 @@ async def update_account(body: AccountUpdateRequest, authorization: str | None = "provider": body.provider, "status": body.status, "quota": body.quota, + "proxy": body.proxy, }.items() if value is not None } if not updates: raise HTTPException(status_code=400, detail={"error": "还没有检测到改动,请修改后再保存"}) - account = account_service.update_account(access_token, updates, provider=body.target_provider) + if access_token: + account = account_service.update_account(access_token, updates, provider=body.target_provider) + else: + account = account_service.update_account_by_identifier(identifier[0], updates, provider=body.target_provider) if account is None: raise HTTPException(status_code=404, detail={"error": "account not found"}) return {"item": sanitize_account(account), "items": sanitize_accounts(account_service.list_accounts(provider=body.target_provider))} diff --git a/services/account_service.py b/services/account_service.py index 9a24905..022e821 100644 --- a/services/account_service.py +++ b/services/account_service.py @@ -387,6 +387,7 @@ def _normalize_account(self, item: dict) -> dict | None: normalized["success"] = int(normalized.get("success") or 0) normalized["fail"] = int(normalized.get("fail") or 0) normalized["last_used_at"] = normalized.get("last_used_at") + normalized["proxy"] = _clean_string(normalized.get("proxy")) return normalized def list_tokens(self, provider: str | None = None) -> list[str]: @@ -824,27 +825,40 @@ def delete_limited_accounts(self, provider: str | None = None) -> dict: items = self._list_account_items_locked(provider) return {"removed": removed, "items": items} + def _update_account_locked(self, access_token: str, updates: dict, provider: str | None = None) -> dict | None: + current = self._get_account_locked(access_token, provider) + if current is None: + return None + account = self._normalize_account({**current, **updates, "access_token": access_token}) + if account is None: + return None + if account.get("status") == "限流" and config.auto_remove_rate_limited_accounts: + self._pop_account_locked(access_token, provider) + self._save_accounts() + log_service.add(LOG_TYPE_ACCOUNT, "自动移除限流账号", {"token": anonymize_token(access_token)}) + return None + self._set_account_locked(account) + self._save_accounts() + log_service.add(LOG_TYPE_ACCOUNT, "更新账号", {"token": anonymize_token(access_token), "status": account.get("status")}) + return dict(account) + def update_account(self, access_token: str, updates: dict, provider: str | None = None) -> dict | None: if not access_token: return None with self._lock: - current = self._get_account_locked(access_token, provider) - if current is None: - return None - account = self._normalize_account({**current, **updates, "access_token": access_token}) - if account is None: - return None - if account.get("status") == "限流" and config.auto_remove_rate_limited_accounts: - self._pop_account_locked(access_token, provider) - self._save_accounts() - log_service.add(LOG_TYPE_ACCOUNT, "自动移除限流账号", {"token": anonymize_token(access_token)}) + return self._update_account_locked(access_token, updates, provider) + + def update_account_by_identifier(self, identifier: dict[str, str], updates: dict, provider: str | None = None) -> dict | None: + target_provider = self._provider_filter(provider) + if target_provider is None: + return None + with self._lock: + provider_accounts = self._accounts.get(target_provider, {}) + matched_tokens = _matched_account_tokens_by_identifiers([identifier], provider_accounts, target_provider) + if len(matched_tokens) != 1: return None - self._set_account_locked(account) - self._save_accounts() - log_service.add(LOG_TYPE_ACCOUNT, "更新账号", - {"token": anonymize_token(access_token), "status": account.get("status")}) - return dict(account) - return None + access_token = next(iter(matched_tokens)) + return self._update_account_locked(access_token, updates, target_provider) def mark_image_result(self, access_token: str, success: bool) -> dict | None: if not access_token: diff --git a/services/network/client.py b/services/network/client.py index 1364453..1c09d34 100644 --- a/services/network/client.py +++ b/services/network/client.py @@ -2,16 +2,17 @@ from typing import Any -def build_session_kwargs(*, impersonate: str | None = None, verify: bool = True, **session_kwargs: Any) -> dict[str, object]: + +def build_session_kwargs(*, impersonate: str | None = None, verify: bool = True, account: dict | None = None, **session_kwargs: Any) -> dict[str, object]: if impersonate: session_kwargs["impersonate"] = impersonate session_kwargs["verify"] = verify from services.proxy_service import proxy_settings - return proxy_settings.build_session_kwargs(**session_kwargs) + return proxy_settings.build_session_kwargs(account=account, **session_kwargs) -def create_session(*, impersonate: str | None = None, verify: bool = True, **session_kwargs: Any): +def create_session(*, impersonate: str | None = None, verify: bool = True, account: dict | None = None, **session_kwargs: Any): from curl_cffi import requests - return requests.Session(**build_session_kwargs(impersonate=impersonate, verify=verify, **session_kwargs)) + return requests.Session(**build_session_kwargs(impersonate=impersonate, verify=verify, account=account, **session_kwargs)) diff --git a/services/openai_backend_api.py b/services/openai_backend_api.py index 72f2413..141dc9c 100644 --- a/services/openai_backend_api.py +++ b/services/openai_backend_api.py @@ -79,6 +79,7 @@ def __init__(self, access_token: str = "") -> None: self.client_version = DEFAULT_CLIENT_VERSION self.client_build_number = DEFAULT_CLIENT_BUILD_NUMBER self.access_token = access_token + self.account = self._load_account() self.network_profile = self._build_network_profile() self.fp = self.network_profile.as_fingerprint() self.user_agent = self.fp["user-agent"] @@ -86,7 +87,7 @@ def __init__(self, access_token: str = "") -> None: self.session_id = self.fp["oai-session-id"] self.pow_script_sources: list[str] = [] self.pow_data_build = "" - self.session = create_session(impersonate=self.network_profile.impersonate, verify=self.network_profile.verify) + self.session = create_session(impersonate=self.network_profile.impersonate, verify=self.network_profile.verify, account=self.account) self.session.headers.update(build_chatgpt_web_headers( self.network_profile, base_url=self.base_url, @@ -105,9 +106,12 @@ def __enter__(self) -> "OpenAIBackendAPI": def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: self.close() - def _build_network_profile(self): + def _load_account(self) -> dict: account = account_service.get_account(self.access_token) if self.access_token else {} - account = account if isinstance(account, dict) else {} + return account if isinstance(account, dict) else {} + + def _build_network_profile(self): + account = self.account global_fp = config.data.get("chatgpt_fingerprint") global_fp = global_fp if isinstance(global_fp, dict) else {} return build_chatgpt_web_profile(account, global_fp) diff --git a/services/providers/gemini/accounts.py b/services/providers/gemini/accounts.py index 947d756..cd97ecf 100644 --- a/services/providers/gemini/accounts.py +++ b/services/providers/gemini/accounts.py @@ -289,7 +289,7 @@ def validate_remote_info(access_token: str, account: dict[str, Any] | None = Non if access_token: source.setdefault("access_token", access_token) cookie_header_value = account_cookie_header(source) - with GeminiWebClient(cookie_header_value, source.get("user_agent")) as client: + with GeminiWebClient(cookie_header_value, source.get("user_agent"), account=source) as client: client.rotate_psidts() session_token = client.bootstrap_session_token() return gemini_session_writeback(source, client.cookie_header, session_token) diff --git a/services/providers/gemini/client.py b/services/providers/gemini/client.py index 29a54e4..3ada1d0 100644 --- a/services/providers/gemini/client.py +++ b/services/providers/gemini/client.py @@ -476,11 +476,12 @@ def parse_web_response_text(raw_text: str) -> object: class GeminiWebClient: - def __init__(self, cookie_header: str, user_agent: str | None = None) -> None: + def __init__(self, cookie_header: str, user_agent: str | None = None, account: dict[str, Any] | None = None) -> None: self.cookie_header = cookie_header self.user_agent = user_agent or GEMINI_BROWSER_USER_AGENT self.session_token = "" - self.session = create_session() + self.account = account if isinstance(account, dict) else None + self.session = create_session(account=self.account) def __enter__(self) -> "GeminiWebClient": return self @@ -604,7 +605,7 @@ def fetch_authenticated_init_body() -> str: return "" account = account_service.get_account(access_token) or {"access_token": access_token, "provider": "gemini"} cookie_header = account_cookie_header(account) - with GeminiWebClient(cookie_header, account.get("user_agent")) as client: + with GeminiWebClient(cookie_header, account.get("user_agent"), account=account) as client: init_body = client.fetch_init_body() persist_gemini_session(account_service, access_token, account, client.cookie_header) return init_body @@ -631,7 +632,7 @@ def chat_completion(body: dict[str, Any], spec: ModelSpec, messages: list[dict[s if session_token: payload["session_token"] = session_token try: - with GeminiWebClient(cookie_header, account.get("user_agent")) as client: + with GeminiWebClient(cookie_header, account.get("user_agent"), account=account) as client: response_payload = client.generate(payload) persist_gemini_session(account_service, access_token, account, client.cookie_header, client.session_token) except GeminiWebError as exc: diff --git a/services/providers/grok/client.py b/services/providers/grok/client.py index 07b5f9c..f311155 100644 --- a/services/providers/grok/client.py +++ b/services/providers/grok/client.py @@ -668,7 +668,11 @@ class GrokConsoleClient: def __init__(self, access_token: str) -> None: self.access_token = access_token self.network_profile = _grok_console_profile() - self.session = create_session(impersonate=self.network_profile.impersonate, verify=self.network_profile.verify) + from services.account_service import account_service + + account = account_service.get_account(access_token, provider=GROK_PROVIDER) + account = account if isinstance(account, dict) else None + self.session = create_session(impersonate=self.network_profile.impersonate, verify=self.network_profile.verify, account=account) def close(self) -> None: self.session.close() @@ -1840,7 +1844,7 @@ def __init__(self, access_token: str, account: dict[str, Any] | None = None) -> self.account = account if isinstance(account, dict) else None self.network_profile = _grok_app_chat_profile() impersonate = _app_chat_impersonate(self.network_profile, self.account) - self.session = create_session(impersonate=impersonate, verify=self.network_profile.verify) + self.session = create_session(impersonate=impersonate, verify=self.network_profile.verify, account=self.account) def close(self) -> None: self.session.close() diff --git a/services/proxy_service.py b/services/proxy_service.py index c3587af..abf67b6 100644 --- a/services/proxy_service.py +++ b/services/proxy_service.py @@ -11,8 +11,15 @@ class ProxySettingsStore: - def build_session_kwargs(self, **session_kwargs) -> dict[str, object]: - proxy = config.get_proxy_settings() + def resolve_proxy(self, account: dict | None = None) -> str: + if isinstance(account, dict): + proxy = _clean(account.get("proxy")) + if proxy: + return proxy + return config.get_proxy_settings() + + def build_session_kwargs(self, *, account: dict | None = None, **session_kwargs) -> dict[str, object]: + proxy = self.resolve_proxy(account) if proxy: session_kwargs["proxy"] = proxy return session_kwargs diff --git a/web/src/app/accounts/components/account-import-dialog.tsx b/web/src/app/accounts/components/account-import-dialog.tsx index 12172cc..43ba79d 100644 --- a/web/src/app/accounts/components/account-import-dialog.tsx +++ b/web/src/app/accounts/components/account-import-dialog.tsx @@ -259,6 +259,7 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo const [tokenInput, setTokenInput] = useState(""); const [importProvider, setImportProvider] = useState("gpt"); const [sessionInput, setSessionInput] = useState(""); + const [importProxy, setImportProxy] = useState(""); const [geminiSecure1Psid, setGeminiSecure1Psid] = useState(""); const [geminiSecure1Psidts, setGeminiSecure1Psidts] = useState(""); const [isSubmitting, setIsSubmitting] = useState(false); @@ -276,6 +277,7 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo setTokenInput(""); setImportProvider("gpt"); setSessionInput(""); + setImportProxy(""); setGeminiSecure1Psid(""); setGeminiSecure1Psidts(""); setPendingCpaImport(null); @@ -321,12 +323,18 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo } }; - const buildTokenPayloads = (tokens: string[]): AccountImportPayload[] => { - if (importProvider === "grok") { - return []; - } + const withImportProxy = (payload: AccountImportPayload): AccountImportPayload => { + const proxy = importProxy.trim(); + return proxy ? { ...payload, proxy } : payload; + }; - return tokens.map((token) => ({ access_token: token, provider: importProvider })); + const buildTokenPayloads = (tokens: string[]): AccountImportPayload[] => { + return tokens.map((token) => { + const payload: AccountImportPayload = importProvider === "grok" + ? { sso: token, provider: importProvider } + : { access_token: token, provider: importProvider }; + return withImportProxy(payload); + }); }; const normalizeImportTokens = (tokens: string[]) => { @@ -344,12 +352,25 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo toast.error("Grok 每行仅支持裸 SSO 值或单个 sso=完整值;不支持 sso-rw、完整 Cookie header 或其他 name=value。"); }; + const renderImportProxyField = () => ( +
+ + setImportProxy(event.target.value)} + placeholder="留空使用全局代理,例如 http://127.0.0.1:7890" + className="h-11 rounded-xl border-stone-200 bg-white" + /> +

本次导入的账号会保存该代理;留空则使用系统全局代理。

+
+ ); + const buildGeminiSessionPayload = (secure1Psid: string, secure1Psidts: string): AccountImportPayload => { - return { + return withImportProxy({ provider: "gemini", "__Secure-1PSID": secure1Psid, "__Secure-1PSIDTS": secure1Psidts, - }; + }); }; const handleImportTokenText = async () => { @@ -469,7 +490,7 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo }), ); - const accounts = results.flatMap((item) => item.accounts); + const accounts = results.flatMap((item) => item.accounts).map(withImportProxy); const tokens = accounts.map((item) => item.access_token).filter((token): token is string => Boolean(token)); const parsedFileCount = results.filter((item) => item.accounts.length > 0).length; const errorCount = results.length - parsedFileCount; @@ -589,6 +610,7 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo className="min-h-48 resize-none rounded-xl border-stone-200" /> + {renderImportProxyField()}
@@ -637,6 +659,7 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo
风险提示
不要使用自己的大号,尽量使用不常用的小号进行导入,避免出现封号风险。本项目不承担任何封号风险责任。
+ {renderImportProxyField()}
{importProvider === "gemini" ? ( <> @@ -688,6 +711,7 @@ export function AccountImportDialog({ disabled, onImported }: AccountImportDialo 返回导入方式 + {renderImportProxyField()}
多选本地 JSON 文件
diff --git a/web/src/app/accounts/page.tsx b/web/src/app/accounts/page.tsx index cc30d6d..3dc3905 100644 --- a/web/src/app/accounts/page.tsx +++ b/web/src/app/accounts/page.tsx @@ -730,6 +730,7 @@ function AccountsPageContent() { const [pageSize] = useState("10"); const [editingAccount, setEditingAccount] = useState(null); const [editStatus, setEditStatus] = useState("正常"); + const [editProxy, setEditProxy] = useState(""); const [isLoading, setIsLoading] = useState(true); const [isRefreshing, setIsRefreshing] = useState(false); const [isValidating, setIsValidating] = useState(false); @@ -940,6 +941,7 @@ function AccountsPageContent() { const openEditDialog = (account: Account) => { setEditingAccount(account); setEditStatus(account.status); + setEditProxy(String(account.proxy ?? "")); }; const handleUpdateAccount = async () => { @@ -949,17 +951,29 @@ function AccountsPageContent() { const provider = accountProviderId(editingAccount); const token = accountToken(editingAccount); - if (!token) { + const proxy = editProxy.trim(); + const statusChanged = editStatus !== editingAccount.status; + const proxyChanged = proxy !== String(editingAccount.proxy ?? "").trim(); + if (!token && statusChanged) { toast.error("脱敏账号不能在列表中直接编辑,请重新导入或通过后端管理接口处理"); return; } + if (!token && !proxyChanged) { + toast.error("还没有检测到改动,请修改后再保存"); + return; + } setIsUpdating(true); try { - const data = await updateAccount(token, { status: editStatus }, provider); + const data = await updateAccount(token, { + ...(token ? { status: editStatus } : {}), + proxy, + account_id: editingAccount.account_id, + row_id: editingAccount.row_id, + }, provider); handleProviderMutationResult(provider, data.items); setEditingAccount(null); - toast.success("账号状态已更新"); + toast.success("账号设置已更新"); } catch (error) { const message = error instanceof Error ? error.message : "更新账号失败"; toast.error(message); @@ -1035,7 +1049,7 @@ function AccountsPageContent() { (!open ? setEditingAccount(null) : null)}> - 编辑账户状态 + 编辑账户设置 当前账号归属 {editingAccount ? getAccountProviderLabel(editingAccount.provider) : ""};更新请求会按该服务商定位账号。 @@ -1057,6 +1071,16 @@ function AccountsPageContent() {
+
+ + setEditProxy(event.target.value)} + placeholder="留空使用全局代理,例如 http://127.0.0.1:7890" + className="h-11 rounded-xl border-stone-200 bg-white" + /> +

设置后该账号请求优先使用此代理;留空则使用系统全局代理。

+