diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml deleted file mode 100644 index 7609a68b9b..0000000000 --- a/.github/workflows/docker-image.yml +++ /dev/null @@ -1,140 +0,0 @@ -name: docker-image - -on: - workflow_dispatch: - push: - tags: - - v* - -env: - APP_NAME: CLIProxyAPI - DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/cli-proxy-api-plus - -jobs: - docker_amd64: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Login to DockerHub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Generate Build Metadata - run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV - echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV - echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - - name: Build and push (amd64) - uses: docker/build-push-action@v6 - with: - context: . - platforms: linux/amd64 - push: true - build-args: | - VERSION=${{ env.VERSION }} - COMMIT=${{ env.COMMIT }} - BUILD_DATE=${{ env.BUILD_DATE }} - tags: | - ${{ env.DOCKERHUB_REPO }}:latest-amd64 - ${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-amd64 - - docker_arm64: - runs-on: ubuntu-24.04-arm - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Login to DockerHub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Generate Build Metadata - run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV - echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV - echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - - name: Build and push (arm64) - uses: docker/build-push-action@v6 - with: - context: . - platforms: linux/arm64 - push: true - build-args: | - VERSION=${{ env.VERSION }} - COMMIT=${{ env.COMMIT }} - BUILD_DATE=${{ env.BUILD_DATE }} - tags: | - ${{ env.DOCKERHUB_REPO }}:latest-arm64 - ${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-arm64 - - docker_manifest: - runs-on: ubuntu-latest - needs: - - docker_amd64 - - docker_arm64 - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Login to DockerHub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Generate Build Metadata - run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV - echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV - echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - - name: Create and push multi-arch manifests - run: | - docker buildx imagetools create \ - --tag "${DOCKERHUB_REPO}:latest" \ - "${DOCKERHUB_REPO}:latest-amd64" \ - "${DOCKERHUB_REPO}:latest-arm64" - docker buildx imagetools create \ - --tag "${DOCKERHUB_REPO}:${VERSION}" \ - "${DOCKERHUB_REPO}:${VERSION}-amd64" \ - "${DOCKERHUB_REPO}:${VERSION}-arm64" - - name: Cleanup temporary tags - continue-on-error: true - env: - DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} - DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} - run: | - set -euo pipefail - namespace="${DOCKERHUB_REPO%%/*}" - repo_name="${DOCKERHUB_REPO#*/}" - - token="$( - curl -fsSL \ - -H 'Content-Type: application/json' \ - -d "{\"username\":\"${DOCKERHUB_USERNAME}\",\"password\":\"${DOCKERHUB_TOKEN}\"}" \ - 'https://hub.docker.com/v2/users/login/' \ - | python3 -c 'import json,sys; print(json.load(sys.stdin)["token"])' - )" - - delete_tag() { - local tag="$1" - local url="https://hub.docker.com/v2/repositories/${namespace}/${repo_name}/tags/${tag}/" - local http_code - http_code="$(curl -sS -o /dev/null -w "%{http_code}" -X DELETE -H "Authorization: JWT ${token}" "${url}" || true)" - if [ "${http_code}" = "204" ] || [ "${http_code}" = "404" ]; then - echo "Docker Hub tag removed (or missing): ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})" - return 0 - fi - echo "Docker Hub tag delete failed: ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})" - return 0 - } - - delete_tag "latest-amd64" - delete_tag "latest-arm64" - delete_tag "${VERSION}-amd64" - delete_tag "${VERSION}-arm64" diff --git a/.github/workflows/pr-path-guard.yml b/.github/workflows/pr-path-guard.yml index 4fe3d93881..fc143c1614 100644 --- a/.github/workflows/pr-path-guard.yml +++ b/.github/workflows/pr-path-guard.yml @@ -20,6 +20,10 @@ jobs: with: files: | internal/translator/** + !internal/translator/kiro/** + !internal/translator/antigravity/** + !internal/translator/codex/** + !internal/translator/gemini-cli/** - name: Fail when restricted paths change if: steps.changed-files.outputs.any_changed == 'true' run: | diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 04ec21a9a5..a0c99875c6 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -37,3 +37,9 @@ jobs: VERSION: ${{ env.VERSION }} COMMIT: ${{ env.COMMIT }} BUILD_DATE: ${{ env.BUILD_DATE }} + - name: Discord Notification + if: success() + run: | + curl -X POST ${{ secrets.DISCORD_WEBHOOK_URL }} \ + -H "Content-Type: application/json" \ + -d '{"content": "โœ… **CLIProxyAPIPlus** Build Complete!\n\n**Version:** '${{ env.VERSION }}'\n**Commit:** '${{ env.COMMIT }}'\n**Build Date:** '${{ env.BUILD_DATE }}'\n\n๐Ÿ”— [Release](https://github.com/'${{ github.repository }}'/releases/tag/'${{ env.VERSION }}')"}' diff --git a/.gitignore b/.gitignore index e6e6ab0aaa..1fd21b4f2a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ # Binaries cli-proxy-api -cliproxy +/cliproxy *.exe @@ -54,3 +54,8 @@ _bmad-output/* .DS_Store ._* *.bak +*.json +.cli-proxy-api/ +.sisyphus/ +.tldr/ +server diff --git a/.tldrignore b/.tldrignore new file mode 100644 index 0000000000..e01df83cb2 --- /dev/null +++ b/.tldrignore @@ -0,0 +1,84 @@ +# TLDR ignore patterns (gitignore syntax) +# Auto-generated - review and customize for your project +# Docs: https://git-scm.com/docs/gitignore + +# =================== +# Dependencies +# =================== +node_modules/ +.venv/ +venv/ +env/ +__pycache__/ +.tox/ +.nox/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +vendor/ +Pods/ + +# =================== +# Build outputs +# =================== +dist/ +build/ +out/ +target/ +*.egg-info/ +*.whl +*.pyc +*.pyo + +# =================== +# Binary/large files +# =================== +*.so +*.dylib +*.dll +*.exe +*.bin +*.o +*.a +*.lib + +# =================== +# IDE/editors +# =================== +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# =================== +# Security (always exclude) +# =================== +.env +.env.* +*.pem +*.key +*.p12 +*.pfx +credentials.* +secrets.* + +# =================== +# Version control +# =================== +.git/ +.hg/ +.svn/ + +# =================== +# OS files +# =================== +.DS_Store +Thumbs.db + +# =================== +# Project-specific +# Add your custom patterns below +# =================== +# large_test_fixtures/ +# data/ diff --git a/assets/cubence.png b/assets/cubence.png new file mode 100644 index 0000000000..c61f12f61e Binary files /dev/null and b/assets/cubence.png differ diff --git a/cmd/server/main.go b/cmd/server/main.go index 9a204ebb73..942239a04b 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -92,6 +92,8 @@ func main() { var kiroIDCRegion string var kiroIDCFlow string var githubCopilotLogin bool + var kilocodeLogin bool + var clineLogin bool var projectID string var vertexImport string var configPath string @@ -126,6 +128,8 @@ func main() { flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)") flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device") flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow") + flag.BoolVar(&kilocodeLogin, "kilocode-login", false, "Login to Kilocode using device flow") + flag.BoolVar(&clineLogin, "cline-login", false, "Login to Cline using OAuth") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -509,6 +513,9 @@ func main() { } else if githubCopilotLogin { // Handle GitHub Copilot login cmd.DoGitHubCopilotLogin(cfg, options) + } else if kilocodeLogin { + // Handle Kilocode login + cmd.DoKilocodeLogin(cfg, options) } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) @@ -528,6 +535,8 @@ func main() { cmd.DoIFlowCookieAuth(cfg, options) } else if kimiLogin { cmd.DoKimiLogin(cfg, options) + } else if clineLogin { + cmd.DoClineLogin(cfg, options) } else if kiroLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito diff --git a/config.example.yaml b/config.example.yaml index 9c4313b35d..cca0e98f77 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -94,7 +94,15 @@ quota-exceeded: # Routing strategy for selecting credentials when multiple match. routing: - strategy: 'round-robin' # round-robin (default), fill-first + strategy: "round-robin" # round-robin (default), fill-first + mode: "" # "" (default): rotate per provider:model, "key-based": rotate per model only (ignores provider) + # fallback-models: # (optional) auto-fallback on 429/401/5xx errors (chat/completion only) + # gpt-4o: claude-sonnet-4-20250514 + # opus: sonnet + # fallback-chain: # (optional) general fallback chain for models not in fallback-models + # - glm-4.7 + # - grok-code-fast-1 + # fallback-max-depth: 3 # (optional) maximum fallback depth (default: 3) # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false diff --git a/go.mod b/go.mod index 461d5517d7..e80036fed3 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.26.0 require ( github.com/andybalholm/brotli v1.0.6 + github.com/denisbrodbeck/machineid v1.0.1 github.com/atotto/clipboard v0.1.4 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 @@ -91,8 +92,8 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect - github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/x448/float16 v0.8.4 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect diff --git a/go.sum b/go.sum index 8a4a967d9a..2d7f39b859 100644 --- a/go.sum +++ b/go.sum @@ -201,10 +201,10 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= -github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= diff --git a/internal/api/.tldrignore b/internal/api/.tldrignore new file mode 100644 index 0000000000..e01df83cb2 --- /dev/null +++ b/internal/api/.tldrignore @@ -0,0 +1,84 @@ +# TLDR ignore patterns (gitignore syntax) +# Auto-generated - review and customize for your project +# Docs: https://git-scm.com/docs/gitignore + +# =================== +# Dependencies +# =================== +node_modules/ +.venv/ +venv/ +env/ +__pycache__/ +.tox/ +.nox/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +vendor/ +Pods/ + +# =================== +# Build outputs +# =================== +dist/ +build/ +out/ +target/ +*.egg-info/ +*.whl +*.pyc +*.pyo + +# =================== +# Binary/large files +# =================== +*.so +*.dylib +*.dll +*.exe +*.bin +*.o +*.a +*.lib + +# =================== +# IDE/editors +# =================== +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# =================== +# Security (always exclude) +# =================== +.env +.env.* +*.pem +*.key +*.p12 +*.pfx +credentials.* +secrets.* + +# =================== +# Version control +# =================== +.git/ +.hg/ +.svn/ + +# =================== +# OS files +# =================== +.DS_Store +Thumbs.db + +# =================== +# Project-specific +# Add your custom patterns below +# =================== +# large_test_fixtures/ +# data/ diff --git a/internal/api/handlers/management/.tldrignore b/internal/api/handlers/management/.tldrignore new file mode 100644 index 0000000000..e01df83cb2 --- /dev/null +++ b/internal/api/handlers/management/.tldrignore @@ -0,0 +1,84 @@ +# TLDR ignore patterns (gitignore syntax) +# Auto-generated - review and customize for your project +# Docs: https://git-scm.com/docs/gitignore + +# =================== +# Dependencies +# =================== +node_modules/ +.venv/ +venv/ +env/ +__pycache__/ +.tox/ +.nox/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +vendor/ +Pods/ + +# =================== +# Build outputs +# =================== +dist/ +build/ +out/ +target/ +*.egg-info/ +*.whl +*.pyc +*.pyo + +# =================== +# Binary/large files +# =================== +*.so +*.dylib +*.dll +*.exe +*.bin +*.o +*.a +*.lib + +# =================== +# IDE/editors +# =================== +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# =================== +# Security (always exclude) +# =================== +.env +.env.* +*.pem +*.key +*.p12 +*.pfx +credentials.* +secrets.* + +# =================== +# Version control +# =================== +.git/ +.hg/ +.svn/ + +# =================== +# OS files +# =================== +.DS_Store +Thumbs.db + +# =================== +# Project-specific +# Add your custom patterns below +# =================== +# large_test_fixtures/ +# data/ diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index 666ff24884..7781f04b58 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -476,6 +476,15 @@ func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth * return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token") } + // Preserve tier info before refresh + var tierID, tierName string + var tierIsPaid bool + if auth.Metadata != nil { + tierID, _ = auth.Metadata["tier_id"].(string) + tierName, _ = auth.Metadata["tier_name"].(string) + tierIsPaid, _ = auth.Metadata["tier_is_paid"].(bool) + } + if auth.Metadata == nil { auth.Metadata = make(map[string]any) } @@ -491,6 +500,17 @@ func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth * } auth.Metadata["type"] = "antigravity" + // Restore preserved tier info + if tierID != "" { + auth.Metadata["tier_id"] = tierID + } + if tierName != "" { + auth.Metadata["tier_name"] = tierName + } + if tierIsPaid { + auth.Metadata["tier_is_paid"] = tierIsPaid + } + if h != nil && h.authManager != nil { auth.LastRefreshedAt = now auth.UpdatedAt = now diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 992d13e6c0..5b5b6e15bd 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -26,6 +26,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cline" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" @@ -52,6 +53,7 @@ const ( anthropicCallbackPort = 54545 geminiCallbackPort = 8085 codexCallbackPort = 1455 + clineCallbackPort = 1456 geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" geminiCLIVersion = "v1internal" ) @@ -226,6 +228,14 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) { log.Infof("callback forwarder on port %d stopped", port) } +func sanitizeAntigravityFileName(email string) string { + if strings.TrimSpace(email) == "" { + return "antigravity.json" + } + replacer := strings.NewReplacer("@", "_", ".", "_") + return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email)) +} + func (h *Handler) managementCallbackURL(path string) (string, error) { if h == nil || h.cfg == nil || h.cfg.Port <= 0 { return "", fmt.Errorf("server port is not configured") @@ -421,9 +431,139 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { if claims := extractCodexIDTokenClaims(auth); claims != nil { entry["id_token"] = claims } + // Add Antigravity tier info (fetch if missing) + if auth.Provider == "antigravity" && auth.Metadata != nil { + tierID, _ := auth.Metadata["tier_id"].(string) + tierName, _ := auth.Metadata["tier_name"].(string) + + // If tier info missing, try to fetch it + if tierID == "" { + tierID, tierName = h.fetchAndCacheAntigravityTier(auth, false) + } + + if tierID != "" { + entry["tier"] = tierID + } + if tierName != "" { + entry["tier_name"] = tierName + } + } + entry["quota"] = gin.H{ + "exceeded": auth.Quota.Exceeded, + "reason": auth.Quota.Reason, + "next_recover_at": auth.Quota.NextRecoverAt, + "backoff_level": auth.Quota.BackoffLevel, + } + if auth.LastError != nil { + entry["last_error"] = gin.H{ + "code": auth.LastError.Code, + "message": auth.LastError.Message, + "retryable": auth.LastError.Retryable, + "http_status": auth.LastError.HTTPStatus, + } + } + if !auth.NextRetryAfter.IsZero() { + entry["next_retry_after"] = auth.NextRetryAfter + } return entry } +// fetchAndCacheAntigravityTier fetches tier info for an antigravity auth and caches it in metadata. +// Returns tierID, tierName. On error, returns empty strings. +// If forceRefresh is true, it will fetch the tier info even if it's already cached. +func (h *Handler) fetchAndCacheAntigravityTier(auth *coreauth.Auth, forceRefresh bool) (string, string) { + if auth == nil || auth.Provider != "antigravity" || auth.Metadata == nil { + return "", "" + } + + // Check if already has tier info (skip if forceRefresh) + if !forceRefresh { + if tierID, ok := auth.Metadata["tier_id"].(string); ok && tierID != "" { + tierName, _ := auth.Metadata["tier_name"].(string) + return tierID, tierName + } + } + + // Get access token + accessToken, ok := auth.Metadata["access_token"].(string) + if !ok || strings.TrimSpace(accessToken) == "" { + return "", "" + } + + // Fetch tier info + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) + projectInfo, err := sdkAuth.FetchAntigravityProjectInfo(ctx, accessToken, httpClient) + if err != nil { + log.Debugf("antigravity: failed to fetch tier for %s: %v", auth.ID, err) + return "", "" + } + + // Cache in metadata + auth.Metadata["tier_id"] = projectInfo.TierID + auth.Metadata["tier_name"] = projectInfo.TierName + auth.Metadata["tier_is_paid"] = projectInfo.IsPaid + + // Try to persist to disk if authManager is available + if h.authManager != nil { + if _, err := h.authManager.Update(ctx, auth); err != nil { + log.Debugf("antigravity: failed to persist tier for %s: %v", auth.ID, err) + } + } + + log.Infof("antigravity: fetched tier %s for existing auth %s", projectInfo.TierID, auth.ID) + return projectInfo.TierID, projectInfo.TierName +} + +func (h *Handler) RefreshTier(c *gin.Context) { + if h.authManager == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"}) + return + } + + authID := strings.TrimSpace(c.Param("id")) + if authID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "auth id is required"}) + return + } + + auth, ok := h.authManager.GetByID(authID) + if !ok { + auths := h.authManager.List() + for _, a := range auths { + if a.FileName == authID || a.ID == authID { + auth = a + ok = true + break + } + } + } + + if !ok || auth == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "auth not found"}) + return + } + + if auth.Provider != "antigravity" { + c.JSON(http.StatusBadRequest, gin.H{"error": "tier refresh only supported for antigravity provider"}) + return + } + + tierID, tierName := h.fetchAndCacheAntigravityTier(auth, true) + if tierID == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch tier info"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "ok", + "tier": tierID, + "tier_name": tierName, + }) +} + func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H { if auth == nil || auth.Metadata == nil { return nil @@ -1107,14 +1247,67 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { rawCode := resultMap["code"] code := strings.Split(rawCode, "#")[0] - // Exchange code for tokens using internal auth service - bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes) - if errExchange != nil { - authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange) + // Exchange code for tokens (replicate logic using updated redirect_uri) + // Extract client_id from the modified auth URL + clientID := "" + if u2, errP := url.Parse(authURL); errP == nil { + clientID = u2.Query().Get("client_id") + } + // Build request + bodyMap := map[string]any{ + "code": code, + "state": state, + "grant_type": "authorization_code", + "client_id": clientID, + "redirect_uri": "http://localhost:54545/callback", + "code_verifier": pkceCodes.CodeVerifier, + } + bodyJSON, _ := json.Marshal(bodyMap) + + httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) + req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON))) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + resp, errDo := httpClient.Do(req) + if errDo != nil { + authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") return } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("failed to close response body: %v", errClose) + } + }() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) + SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) + return + } + var tResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + Account struct { + EmailAddress string `json:"email_address"` + } `json:"account"` + } + if errU := json.Unmarshal(respBody, &tResp); errU != nil { + log.Errorf("failed to parse token response: %v", errU) + SetOAuthSessionError(state, "Failed to parse token response") + return + } + bundle := &claude.ClaudeAuthBundle{ + TokenData: claude.ClaudeTokenData{ + AccessToken: tResp.AccessToken, + RefreshToken: tResp.RefreshToken, + Email: tResp.Account.EmailAddress, + Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339), + }, + LastRefresh: time.Now().Format(time.RFC3339), + } // Create token storage tokenStorage := anthropicAuth.CreateTokenStorage(bundle) @@ -1155,13 +1348,17 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Println("Initializing Google authentication...") - // OAuth2 configuration using exported constants from internal/auth/gemini + // OAuth2 configuration (mirrors internal/auth/gemini) conf := &oauth2.Config{ ClientID: geminiAuth.ClientID, ClientSecret: geminiAuth.ClientSecret, - RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort), - Scopes: geminiAuth.Scopes, - Endpoint: google.Endpoint, + RedirectURL: "http://localhost:8085/oauth2callback", + Scopes: []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + }, + Endpoint: google.Endpoint, } // Build authorization URL and return it immediately @@ -1285,7 +1482,11 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { ifToken["token_uri"] = "https://oauth2.googleapis.com/token" ifToken["client_id"] = geminiAuth.ClientID ifToken["client_secret"] = geminiAuth.ClientSecret - ifToken["scopes"] = geminiAuth.Scopes + ifToken["scopes"] = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + } ifToken["universe_domain"] = "googleapis.com" ts := geminiAuth.GeminiTokenStorage{ @@ -1497,25 +1698,73 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } log.Debug("Authorization code received, exchanging for tokens...") - // Exchange code for tokens using internal auth service - bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes) - if errExchange != nil { - authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange) + // Extract client_id from authURL + clientID := "" + if u2, errP := url.Parse(authURL); errP == nil { + clientID = u2.Query().Get("client_id") + } + // Exchange code for tokens with redirect equal to mgmtRedirect + form := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {clientID}, + "code": {code}, + "redirect_uri": {"http://localhost:1455/auth/callback"}, + "code_verifier": {pkceCodes.CodeVerifier}, + } + httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) + req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + resp, errDo := httpClient.Do(req) + if errDo != nil { + authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } - - // Extract additional info for filename generation - claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken) + defer func() { _ = resp.Body.Close() }() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) + log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) + return + } + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + ExpiresIn int `json:"expires_in"` + } + if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { + SetOAuthSessionError(state, "Failed to parse token response") + log.Errorf("failed to parse token response: %v", errU) + return + } + claims, _ := codex.ParseJWTToken(tokenResp.IDToken) + email := "" + accountID := "" planType := "" - hashAccountID := "" if claims != nil { + email = claims.GetUserEmail() + accountID = claims.GetAccountID() planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) - if accountID := claims.GetAccountID(); accountID != "" { - digest := sha256.Sum256([]byte(accountID)) - hashAccountID = hex.EncodeToString(digest[:])[:8] - } + } + hashAccountID := "" + if accountID != "" { + digest := sha256.Sum256([]byte(accountID)) + hashAccountID = hex.EncodeToString(digest[:])[:8] + } + // Build bundle compatible with existing storage + bundle := &codex.CodexAuthBundle{ + TokenData: codex.CodexTokenData{ + IDToken: tokenResp.IDToken, + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + AccountID: accountID, + Email: email, + Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), + }, + LastRefresh: time.Now().Format(time.RFC3339), } // Create token storage and persist @@ -1549,14 +1798,143 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } +func (h *Handler) RequestClineToken(c *gin.Context) { + ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) + + fmt.Println("Initializing Cline authentication...") + + state, errState := misc.GenerateRandomState() + if errState != nil { + log.Errorf("Failed to generate state parameter: %v", errState) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"}) + return + } + + redirectURL := fmt.Sprintf("http://localhost:%d/callback", clineCallbackPort) + clineAuth := cline.NewClineAuth(h.cfg) + authURL := clineAuth.GenerateAuthURL(state, redirectURL) + + RegisterOAuthSession(state, "cline") + + isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/cline/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute cline callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + var errStart error + if forwarder, errStart = startCallbackForwarder(clineCallbackPort, "cline", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start cline callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + + go func() { + if isWebUI { + defer stopCallbackForwarderInstance(clineCallbackPort, forwarder) + } + + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-cline-%s.oauth", state)) + deadline := time.Now().Add(cline.AuthTimeout) + var authCode string + for { + if !IsOAuthSessionPending(state, "cline") { + return + } + if time.Now().After(deadline) { + log.Error("oauth flow timed out") + SetOAuthSessionError(state, "OAuth flow timed out") + return + } + if data, errRead := os.ReadFile(waitFile); errRead == nil { + var payload map[string]string + _ = json.Unmarshal(data, &payload) + _ = os.Remove(waitFile) + + if errStr := strings.TrimSpace(payload["error"]); errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + SetOAuthSessionError(state, "Authentication failed") + return + } + if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { + log.Errorf("Authentication failed: state mismatch") + SetOAuthSessionError(state, "Authentication failed: state mismatch") + return + } + authCode = strings.TrimSpace(payload["code"]) + if authCode == "" { + log.Error("Authentication failed: code not found") + SetOAuthSessionError(state, "Authentication failed: code not found") + return + } + break + } + time.Sleep(500 * time.Millisecond) + } + + tokenResp, errExchange := clineAuth.ExchangeCode(ctx, authCode, redirectURL) + if errExchange != nil { + log.Errorf("Failed to exchange token: %v", errExchange) + SetOAuthSessionError(state, "Failed to exchange token") + return + } + + // Parse expiresAt from string to int64 + var expiresAtInt int64 + if tokenResp.ExpiresAt != "" { + if t, err := time.Parse(time.RFC3339Nano, tokenResp.ExpiresAt); err == nil { + expiresAtInt = t.Unix() + } else if t, err := time.Parse(time.RFC3339, tokenResp.ExpiresAt); err == nil { + expiresAtInt = t.Unix() + } + } + + tokenStorage := &cline.ClineTokenStorage{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAtInt, + Email: tokenResp.Email, + Type: "cline", + } + + fileName := cline.CredentialFileName(tokenStorage.Email) + record := &coreauth.Auth{ + ID: fileName, + Provider: "cline", + FileName: fileName, + Storage: tokenStorage, + Metadata: map[string]any{ + "email": tokenStorage.Email, + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + SetOAuthSessionError(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + fmt.Println("You can now use Cline services through this CLI") + CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("cline") + }() + + c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) +} + func (h *Handler) RequestAntigravityToken(c *gin.Context) { ctx := context.Background() ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Antigravity authentication...") - authSvc := antigravity.NewAntigravityAuth(h.cfg, nil) - state, errState := misc.GenerateRandomState() if errState != nil { log.Errorf("Failed to generate state parameter: %v", errState) @@ -1565,7 +1943,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort) - authURL := authSvc.BuildAuthURL(state, redirectURI) + + params := url.Values{} + params.Set("access_type", "offline") + params.Set("client_id", antigravity.ClientID) + params.Set("prompt", "consent") + params.Set("redirect_uri", redirectURI) + params.Set("response_type", "code") + params.Set("scope", strings.Join(antigravity.Scopes, " ")) + params.Set("state", state) + authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() RegisterOAuthSession(state, "antigravity") @@ -1628,41 +2015,104 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { time.Sleep(500 * time.Millisecond) } - tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI) - if errToken != nil { - log.Errorf("Failed to exchange token: %v", errToken) - SetOAuthSessionError(state, "Failed to exchange token") + httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) + form := url.Values{} + form.Set("code", authCode) + form.Set("client_id", antigravity.ClientID) + form.Set("client_secret", antigravity.ClientSecret) + form.Set("redirect_uri", redirectURI) + form.Set("grant_type", "authorization_code") + + req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) + if errNewRequest != nil { + log.Errorf("Failed to build token request: %v", errNewRequest) + SetOAuthSessionError(state, "Failed to build token request") return } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - accessToken := strings.TrimSpace(tokenResp.AccessToken) - if accessToken == "" { - log.Error("antigravity: token exchange returned empty access token") + resp, errDo := httpClient.Do(req) + if errDo != nil { + log.Errorf("Failed to execute token request: %v", errDo) SetOAuthSessionError(state, "Failed to exchange token") return } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity token exchange close error: %v", errClose) + } + }() - email, errInfo := authSvc.FetchUserInfo(ctx, accessToken) - if errInfo != nil { - log.Errorf("Failed to fetch user info: %v", errInfo) - SetOAuthSessionError(state, "Failed to fetch user info") + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, _ := io.ReadAll(resp.Body) + log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } - email = strings.TrimSpace(email) - if email == "" { - log.Error("antigravity: user info returned empty email") - SetOAuthSessionError(state, "Failed to fetch user info") + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` + } + if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { + log.Errorf("Failed to parse token response: %v", errDecode) + SetOAuthSessionError(state, "Failed to parse token response") return } + email := "" + if strings.TrimSpace(tokenResp.AccessToken) != "" { + infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) + if errInfoReq != nil { + log.Errorf("Failed to build user info request: %v", errInfoReq) + SetOAuthSessionError(state, "Failed to build user info request") + return + } + infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) + + infoResp, errInfo := httpClient.Do(infoReq) + if errInfo != nil { + log.Errorf("Failed to execute user info request: %v", errInfo) + SetOAuthSessionError(state, "Failed to execute user info request") + return + } + defer func() { + if errClose := infoResp.Body.Close(); errClose != nil { + log.Errorf("antigravity user info close error: %v", errClose) + } + }() + + if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices { + var infoPayload struct { + Email string `json:"email"` + } + if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil { + email = strings.TrimSpace(infoPayload.Email) + } + } else { + bodyBytes, _ := io.ReadAll(infoResp.Body) + log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) + SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) + return + } + } + projectID := "" - if accessToken != "" { - fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken) + tierID := "unknown" + tierName := "Unknown" + tierIsPaid := false + if strings.TrimSpace(tokenResp.AccessToken) != "" { + projectInfo, errProject := sdkAuth.FetchAntigravityProjectInfo(ctx, tokenResp.AccessToken, httpClient) if errProject != nil { - log.Warnf("antigravity: failed to fetch project ID: %v", errProject) + log.Warnf("antigravity: failed to fetch project info: %v", errProject) } else { - projectID = fetchedProjectID - log.Infof("antigravity: obtained project ID %s", projectID) + projectID = projectInfo.ProjectID + tierID = projectInfo.TierID + tierName = projectInfo.TierName + tierIsPaid = projectInfo.IsPaid + log.Infof("antigravity: obtained project ID %s, tier %s", projectID, tierID) } } @@ -1674,6 +2124,9 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { "expires_in": tokenResp.ExpiresIn, "timestamp": now.UnixMilli(), "expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), + "tier_id": tierID, + "tier_name": tierName, + "tier_is_paid": tierIsPaid, } if email != "" { metadata["email"] = email @@ -1682,7 +2135,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { metadata["project_id"] = projectID } - fileName := antigravity.CredentialFileName(email) + fileName := sanitizeAntigravityFileName(email) label := strings.TrimSpace(email) if label == "" { label = "antigravity" @@ -1933,13 +2386,30 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { identifier = fmt.Sprintf("%d", time.Now().UnixMilli()) tokenStorage.Email = identifier } + now := time.Now().UTC() + nextRefreshAfter := time.Time{} + if expiresAt, errParse := time.Parse(time.RFC3339, tokenStorage.Expire); errParse == nil { + nextRefreshAfter = expiresAt.Add(-36 * time.Hour) + } record := &coreauth.Auth{ - ID: fmt.Sprintf("iflow-%s.json", identifier), - Provider: "iflow", - FileName: fmt.Sprintf("iflow-%s.json", identifier), - Storage: tokenStorage, - Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey}, - Attributes: map[string]string{"api_key": tokenStorage.APIKey}, + ID: fmt.Sprintf("iflow-%s.json", identifier), + Provider: "iflow", + FileName: fmt.Sprintf("iflow-%s.json", identifier), + Storage: tokenStorage, + Metadata: map[string]any{ + "email": identifier, + "api_key": tokenStorage.APIKey, + "access_token": tokenStorage.AccessToken, + "refresh_token": tokenStorage.RefreshToken, + "expired": tokenStorage.Expire, + "type": "iflow", + "last_refresh": now.Format(time.RFC3339), + }, + Attributes: map[string]string{"api_key": tokenStorage.APIKey}, + CreatedAt: now, + UpdatedAt: now, + LastRefreshedAt: now, + NextRefreshAfter: nextRefreshAfter, } savedPath, errSave := h.saveTokenRecord(ctx, record) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index d83d583a20..066bf445fb 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -729,7 +729,11 @@ func (h *Handler) PutOAuthModelAlias(c *gin.Context) { entries = wrapper.Items } h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries) - h.persist(c) + if h.persist(c) { + if h.authManager != nil { + h.authManager.SetOAuthModelAlias(h.cfg.OAuthModelAlias) + } + } } func (h *Handler) PatchOAuthModelAlias(c *gin.Context) { @@ -773,14 +777,22 @@ func (h *Handler) PatchOAuthModelAlias(c *gin.Context) { h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias) } h.cfg.OAuthModelAlias[channel] = []config.OAuthModelAlias{} - h.persist(c) + if h.persist(c) { + if h.authManager != nil { + h.authManager.SetOAuthModelAlias(h.cfg.OAuthModelAlias) + } + } return } if h.cfg.OAuthModelAlias == nil { h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias) } h.cfg.OAuthModelAlias[channel] = normalized - h.persist(c) + if h.persist(c) { + if h.authManager != nil { + h.authManager.SetOAuthModelAlias(h.cfg.OAuthModelAlias) + } + } } func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) { diff --git a/internal/api/handlers/management/config_routing.go b/internal/api/handlers/management/config_routing.go new file mode 100644 index 0000000000..c5c7a3a741 --- /dev/null +++ b/internal/api/handlers/management/config_routing.go @@ -0,0 +1,100 @@ +package management + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +// normalizeRoutingMode normalizes the routing mode value. +// Supported values: "" (default, provider-based), "key-based" (model-only key). +func normalizeRoutingMode(mode string) (string, bool) { + normalized := strings.ToLower(strings.TrimSpace(mode)) + switch normalized { + case "", "provider-based", "provider": + return "provider-based", true + case "key-based", "key", "model-only": + return "key-based", true + default: + return "", false + } +} + +// GetRoutingMode returns the current routing mode. +func (h *Handler) GetRoutingMode(c *gin.Context) { + mode, ok := normalizeRoutingMode(h.cfg.Routing.Mode) + if !ok { + c.JSON(200, gin.H{"mode": strings.TrimSpace(h.cfg.Routing.Mode)}) + return + } + c.JSON(200, gin.H{"mode": mode}) +} + +// PutRoutingMode updates the routing mode. +func (h *Handler) PutRoutingMode(c *gin.Context) { + var body struct { + Value *string `json:"value"` + } + if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + normalized, ok := normalizeRoutingMode(*body.Value) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mode"}) + return + } + h.cfg.Routing.Mode = normalized + h.persist(c) +} + +// GetFallbackModels returns the fallback models configuration. +func (h *Handler) GetFallbackModels(c *gin.Context) { + models := h.cfg.Routing.FallbackModels + if models == nil { + models = make(map[string]string) + } + c.JSON(200, gin.H{"fallback-models": models}) +} + +// PutFallbackModels updates the fallback models configuration. +func (h *Handler) PutFallbackModels(c *gin.Context) { + var body struct { + Value map[string]string `json:"value"` + } + if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + if body.Value == nil { + body.Value = make(map[string]string) + } + h.cfg.Routing.FallbackModels = body.Value + h.persist(c) +} + +// GetFallbackChain returns the fallback chain configuration. +func (h *Handler) GetFallbackChain(c *gin.Context) { + chain := h.cfg.Routing.FallbackChain + if chain == nil { + chain = []string{} + } + c.JSON(200, gin.H{"fallback-chain": chain}) +} + +// PutFallbackChain updates the fallback chain configuration. +func (h *Handler) PutFallbackChain(c *gin.Context) { + var body struct { + Value []string `json:"value"` + } + if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) + return + } + if body.Value == nil { + body.Value = []string{} + } + h.cfg.Routing.FallbackChain = body.Value + h.persist(c) +} diff --git a/internal/api/handlers/management/config_routing_test.go b/internal/api/handlers/management/config_routing_test.go new file mode 100644 index 0000000000..1064e85da0 --- /dev/null +++ b/internal/api/handlers/management/config_routing_test.go @@ -0,0 +1,252 @@ +package management + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func setupTestRouter(h *Handler) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + return r +} + +func createTempConfigFile(t *testing.T) string { + t.Helper() + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + initialConfig := []byte("routing:\n strategy: round-robin\n") + if err := os.WriteFile(configPath, initialConfig, 0644); err != nil { + t.Fatalf("failed to create temp config: %v", err) + } + return configPath +} + +func TestGetRoutingMode(t *testing.T) { + tests := []struct { + name string + configMode string + expectedMode string + }{ + {"empty mode returns provider-based", "", "provider-based"}, + {"provider-based mode", "provider-based", "provider-based"}, + {"key-based mode", "key-based", "key-based"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + Routing: config.RoutingConfig{ + Mode: tt.configMode, + }, + } + h := &Handler{cfg: cfg} + r := setupTestRouter(h) + r.GET("/routing/mode", h.GetRoutingMode) + + req := httptest.NewRequest(http.MethodGet, "/routing/mode", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp["mode"] != tt.expectedMode { + t.Errorf("expected mode %q, got %q", tt.expectedMode, resp["mode"]) + } + }) + } +} + +func TestPutRoutingMode(t *testing.T) { + tests := []struct { + name string + inputValue string + expectedStatus int + expectedMode string + }{ + {"valid key-based", "key-based", http.StatusOK, "key-based"}, + {"valid provider-based", "provider-based", http.StatusOK, "provider-based"}, + {"alias key", "key", http.StatusOK, "key-based"}, + {"alias provider", "provider", http.StatusOK, "provider-based"}, + {"invalid mode", "invalid-mode", http.StatusBadRequest, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configPath := createTempConfigFile(t) + cfg := &config.Config{} + h := &Handler{cfg: cfg, configFilePath: configPath} + r := setupTestRouter(h) + r.PUT("/routing/mode", h.PutRoutingMode) + + body, _ := json.Marshal(map[string]string{"value": tt.inputValue}) + req := httptest.NewRequest(http.MethodPut, "/routing/mode", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + if tt.expectedStatus == http.StatusOK && cfg.Routing.Mode != tt.expectedMode { + t.Errorf("expected config mode %q, got %q", tt.expectedMode, cfg.Routing.Mode) + } + }) + } +} + +func TestGetFallbackModels(t *testing.T) { + tests := []struct { + name string + configModels map[string]string + expectedModels map[string]string + }{ + {"nil models returns empty map", nil, map[string]string{}}, + {"empty models returns empty map", map[string]string{}, map[string]string{}}, + {"with models", map[string]string{"model-a": "model-b"}, map[string]string{"model-a": "model-b"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + Routing: config.RoutingConfig{ + FallbackModels: tt.configModels, + }, + } + h := &Handler{cfg: cfg} + r := setupTestRouter(h) + r.GET("/fallback/models", h.GetFallbackModels) + + req := httptest.NewRequest(http.MethodGet, "/fallback/models", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string]map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + models := resp["fallback-models"] + if len(models) != len(tt.expectedModels) { + t.Errorf("expected %d models, got %d", len(tt.expectedModels), len(models)) + } + }) + } +} + +func TestPutFallbackModels(t *testing.T) { + configPath := createTempConfigFile(t) + cfg := &config.Config{} + h := &Handler{cfg: cfg, configFilePath: configPath} + r := setupTestRouter(h) + r.PUT("/fallback/models", h.PutFallbackModels) + + inputModels := map[string]string{"model-a": "model-b", "model-c": "model-d"} + body, _ := json.Marshal(map[string]interface{}{"value": inputModels}) + req := httptest.NewRequest(http.MethodPut, "/fallback/models", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if len(cfg.Routing.FallbackModels) != 2 { + t.Errorf("expected 2 models, got %d", len(cfg.Routing.FallbackModels)) + } + + if cfg.Routing.FallbackModels["model-a"] != "model-b" { + t.Errorf("expected model-a -> model-b, got %s", cfg.Routing.FallbackModels["model-a"]) + } +} + +func TestGetFallbackChain(t *testing.T) { + tests := []struct { + name string + configChain []string + expectedChain []string + }{ + {"nil chain returns empty array", nil, []string{}}, + {"empty chain returns empty array", []string{}, []string{}}, + {"with chain", []string{"model-a", "model-b"}, []string{"model-a", "model-b"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + Routing: config.RoutingConfig{ + FallbackChain: tt.configChain, + }, + } + h := &Handler{cfg: cfg} + r := setupTestRouter(h) + r.GET("/fallback/chain", h.GetFallbackChain) + + req := httptest.NewRequest(http.MethodGet, "/fallback/chain", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp map[string][]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + chain := resp["fallback-chain"] + if len(chain) != len(tt.expectedChain) { + t.Errorf("expected %d items, got %d", len(tt.expectedChain), len(chain)) + } + }) + } +} + +func TestPutFallbackChain(t *testing.T) { + configPath := createTempConfigFile(t) + cfg := &config.Config{} + h := &Handler{cfg: cfg, configFilePath: configPath} + r := setupTestRouter(h) + r.PUT("/fallback/chain", h.PutFallbackChain) + + inputChain := []string{"model-a", "model-b", "model-c"} + body, _ := json.Marshal(map[string]interface{}{"value": inputChain}) + req := httptest.NewRequest(http.MethodPut, "/fallback/chain", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if len(cfg.Routing.FallbackChain) != 3 { + t.Errorf("expected 3 items, got %d", len(cfg.Routing.FallbackChain)) + } + + if cfg.Routing.FallbackChain[0] != "model-a" { + t.Errorf("expected first item model-a, got %s", cfg.Routing.FallbackChain[0]) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 5c3274c9fc..958ced7117 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -616,6 +616,16 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy) mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy) + mgmt.GET("/routing/mode", s.mgmt.GetRoutingMode) + mgmt.PUT("/routing/mode", s.mgmt.PutRoutingMode) + mgmt.PATCH("/routing/mode", s.mgmt.PutRoutingMode) + + mgmt.GET("/fallback/models", s.mgmt.GetFallbackModels) + mgmt.PUT("/fallback/models", s.mgmt.PutFallbackModels) + + mgmt.GET("/fallback/chain", s.mgmt.GetFallbackChain) + mgmt.PUT("/fallback/chain", s.mgmt.PutFallbackChain) + mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys) mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys) mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey) @@ -654,6 +664,7 @@ func (s *Server) registerManagementRoutes() { mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile) mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus) mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields) + mgmt.POST("/auth-files/:id/refresh-tier", s.mgmt.RefreshTier) mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential) mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) @@ -667,6 +678,7 @@ func (s *Server) registerManagementRoutes() { mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken) + mgmt.POST("/request-cline-token", s.mgmt.RequestClineToken) mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } diff --git a/internal/auth/cline/cline_auth.go b/internal/auth/cline/cline_auth.go new file mode 100644 index 0000000000..9ee3e3c361 --- /dev/null +++ b/internal/auth/cline/cline_auth.go @@ -0,0 +1,167 @@ +// Package cline provides authentication and token management functionality +// for Cline AI services using WorkOS OAuth. +package cline + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // BaseURL is the base URL for the Cline API. + BaseURL = "https://api.cline.bot" + + // AuthTimeout is the timeout for OAuth authentication flow. + AuthTimeout = 10 * time.Minute +) + +// TokenResponse represents the response from Cline token endpoints. +type TokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresAt string `json:"expiresAt"` // Cline returns ISO 8601 timestamp string + Email string `json:"email"` +} + +// ClineAuth provides methods for handling the Cline WorkOS authentication flow. +type ClineAuth struct { + client *http.Client + cfg *config.Config +} + +// NewClineAuth creates a new instance of ClineAuth. +func NewClineAuth(cfg *config.Config) *ClineAuth { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + client.Timeout = 30 * time.Second + return &ClineAuth{ + client: client, + cfg: cfg, + } +} + +// GenerateAuthURL generates the Cline OAuth authorization URL. +// The state parameter is used for CSRF protection. +func (c *ClineAuth) GenerateAuthURL(state, callbackURL string) string { + // Cline uses WorkOS OAuth with the following parameters: + // client_type=extension&callback_url={cb}&redirect_uri={cb} + authURL := fmt.Sprintf("%s/api/v1/auth/authorize?client_type=extension&callback_url=%s&redirect_uri=%s&state=%s", + BaseURL, + callbackURL, + callbackURL, + state) + return authURL +} + +// ExchangeCode exchanges the authorization code for access and refresh tokens. +func (c *ClineAuth) ExchangeCode(ctx context.Context, code, redirectURI string) (*TokenResponse, error) { + payload := map[string]string{ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirectURI, + "client_type": "extension", + "provider": "workos", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("cline: failed to marshal token request: %w", err) + } + + tokenURL := BaseURL + "/api/v1/auth/token" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("cline: failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "Cline/3.0.0") + req.Header.Set("HTTP-Referer", "https://cline.bot") + req.Header.Set("X-Title", "Cline") + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("cline: token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("cline: failed to read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("cline: token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("cline: token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("cline: failed to parse token response: %w", err) + } + + return &tokenResp, nil +} + +// RefreshToken refreshes an expired access token using the refresh token. +func (c *ClineAuth) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { + payload := map[string]string{ + "grantType": "refresh_token", + "refreshToken": refreshToken, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("cline: failed to marshal refresh request: %w", err) + } + + refreshURL := BaseURL + "/api/v1/auth/refresh" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("cline: failed to create refresh request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "Cline/3.0.0") + req.Header.Set("HTTP-Referer", "https://cline.bot") + req.Header.Set("X-Title", "Cline") + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("cline: refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("cline: failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("cline: token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("cline: token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + } + + var tokenResp TokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("cline: failed to parse refresh response: %w", err) + } + + return &tokenResp, nil +} + +// ShouldRefresh checks if the token should be refreshed (expires in less than 5 minutes). +func ShouldRefresh(expiresAt int64) bool { + return time.Until(time.Unix(expiresAt, 0)) < 5*time.Minute +} diff --git a/internal/auth/cline/cline_token.go b/internal/auth/cline/cline_token.go new file mode 100644 index 0000000000..4f2029ec4d --- /dev/null +++ b/internal/auth/cline/cline_token.go @@ -0,0 +1,82 @@ +// Package cline provides authentication and token management functionality +// for Cline AI services. +package cline + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + log "github.com/sirupsen/logrus" +) + +// ClineTokenStorage stores token information for Cline authentication. +type ClineTokenStorage struct { + // AccessToken is the Cline access token (stored without workos: prefix). + AccessToken string `json:"accessToken"` + + // RefreshToken is the Cline refresh token. + RefreshToken string `json:"refreshToken"` + + // ExpiresAt is the Unix timestamp when the access token expires. + ExpiresAt int64 `json:"expiresAt"` + + // Email is the email address of the authenticated user. + Email string `json:"email"` + + // Type indicates the authentication provider type, always "cline" for this storage. + Type string `json:"type"` +} + +// SaveTokenToFile serializes the Cline token storage to a JSON file. +func (ts *ClineTokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "cline" + if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + if errClose := f.Close(); errClose != nil { + log.Errorf("failed to close file: %v", errClose) + } + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} + +// LoadTokenFromFile loads a Cline token from a JSON file. +func LoadTokenFromFile(authFilePath string) (*ClineTokenStorage, error) { + data, err := os.ReadFile(authFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var storage ClineTokenStorage + if err := json.Unmarshal(data, &storage); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &storage, nil +} + +// CredentialFileName returns the filename used to persist Cline credentials. +// Format: cline-{email}.json +func CredentialFileName(email string) string { + return fmt.Sprintf("cline-%s.json", email) +} + +// GetAuthHeaderValue returns the Authorization header value with workos: prefix. +// The token is stored without the prefix, but requests need it. +func (ts *ClineTokenStorage) GetAuthHeaderValue() string { + return "workos:" + ts.AccessToken +} diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index 279d7339d3..8fbe0bbfd5 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -65,15 +65,16 @@ func NewIFlowAuth(cfg *config.Config) *IFlowAuth { } // AuthorizationURL builds the authorization URL and matching redirect URI. +// Parameter order matches official iFlow CLI: loginMethod, type, redirect, state, client_id func (ia *IFlowAuth) AuthorizationURL(state string, port int) (authURL, redirectURI string) { redirectURI = fmt.Sprintf("http://localhost:%d/oauth2callback", port) - values := url.Values{} - values.Set("loginMethod", "phone") - values.Set("type", "phone") - values.Set("redirect", redirectURI) - values.Set("state", state) - values.Set("client_id", iFlowOAuthClientID) - authURL = fmt.Sprintf("%s?%s", iFlowOAuthAuthorizeEndpoint, values.Encode()) + + // Build URL with explicit parameter order to match iFlow CLI + params := fmt.Sprintf("loginMethod=phone&type=phone&redirect=%s&state=%s&client_id=%s", + url.QueryEscape(redirectURI), + url.QueryEscape(state), + iFlowOAuthClientID) + authURL = fmt.Sprintf("%s?%s", iFlowOAuthAuthorizeEndpoint, params) return authURL, redirectURI } @@ -145,6 +146,17 @@ func (ia *IFlowAuth) doTokenRequest(ctx context.Context, req *http.Request) (*IF return nil, fmt.Errorf("iflow token: decode response failed: %w", err) } + // Check for API-level errors (iflow returns HTTP 200 with success:false on errors) + if !tokenResp.Success && tokenResp.Message != "" { + log.Debugf("iflow token request failed: success=false code=%s message=%s", tokenResp.Code, tokenResp.Message) + return nil, fmt.Errorf("iflow token: API error (code %s): %s", tokenResp.Code, tokenResp.Message) + } + + if tokenResp.AccessToken == "" { + log.Debugf("iflow token: missing access token in response, body: %s", string(body)) + return nil, fmt.Errorf("iflow token: missing access token in response (body: %s)", strings.TrimSpace(string(body))) + } + data := &IFlowTokenData{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, @@ -153,11 +165,6 @@ func (ia *IFlowAuth) doTokenRequest(ctx context.Context, req *http.Request) (*IF Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), } - if tokenResp.AccessToken == "" { - log.Debug(string(body)) - return nil, fmt.Errorf("iflow token: missing access token in response") - } - info, errAPI := ia.FetchUserInfo(ctx, tokenResp.AccessToken) if errAPI != nil { return nil, fmt.Errorf("iflow token: fetch user info failed: %w", errAPI) @@ -261,6 +268,9 @@ func (ia *IFlowAuth) UpdateTokenStorage(storage *IFlowTokenStorage, data *IFlowT // IFlowTokenResponse models the OAuth token endpoint response. type IFlowTokenResponse struct { + Success bool `json:"success"` + Code string `json:"code"` + Message string `json:"message"` AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` diff --git a/internal/auth/kilocode/errors.go b/internal/auth/kilocode/errors.go new file mode 100644 index 0000000000..fada86008d --- /dev/null +++ b/internal/auth/kilocode/errors.go @@ -0,0 +1,128 @@ +package kilocode + +import ( + "errors" + "fmt" + "net/http" +) + +// AuthenticationError represents authentication-related errors for Kilocode. +type AuthenticationError struct { + // Type is the type of authentication error. + Type string `json:"type"` + // Message is a human-readable message describing the error. + Message string `json:"message"` + // Code is the HTTP status code associated with the error. + Code int `json:"code"` + // Cause is the underlying error that caused this authentication error. + Cause error `json:"-"` +} + +// Error returns a string representation of the authentication error. +func (e *AuthenticationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Type, e.Message) +} + +// Unwrap returns the underlying cause of the error. +func (e *AuthenticationError) Unwrap() error { + return e.Cause +} + +// Common authentication error types for Kilocode device flow. +var ( + // ErrDeviceCodeFailed represents an error when requesting the device code fails. + ErrDeviceCodeFailed = &AuthenticationError{ + Type: "device_code_failed", + Message: "Failed to request device code from Kilocode", + Code: http.StatusBadRequest, + } + + // ErrDeviceCodeExpired represents an error when the device code has expired. + ErrDeviceCodeExpired = &AuthenticationError{ + Type: "device_code_expired", + Message: "Device code has expired. Please try again.", + Code: http.StatusGone, + } + + // ErrAuthorizationPending represents a pending authorization state (not an error, used for polling). + ErrAuthorizationPending = &AuthenticationError{ + Type: "authorization_pending", + Message: "Authorization is pending. Waiting for user to authorize.", + Code: http.StatusAccepted, + } + + // ErrAccessDenied represents an error when the user denies authorization. + ErrAccessDenied = &AuthenticationError{ + Type: "access_denied", + Message: "User denied authorization", + Code: http.StatusForbidden, + } + + // ErrPollingTimeout represents an error when polling times out. + ErrPollingTimeout = &AuthenticationError{ + Type: "polling_timeout", + Message: "Timeout waiting for user authorization", + Code: http.StatusRequestTimeout, + } + + // ErrTokenExchangeFailed represents an error when token exchange fails. + ErrTokenExchangeFailed = &AuthenticationError{ + Type: "token_exchange_failed", + Message: "Failed to exchange device code for access token", + Code: http.StatusBadRequest, + } + + // ErrUserInfoFailed represents an error when fetching user info fails. + ErrUserInfoFailed = &AuthenticationError{ + Type: "user_info_failed", + Message: "Failed to fetch Kilocode user information", + Code: http.StatusBadRequest, + } +) + +// NewAuthenticationError creates a new authentication error with a cause based on a base error. +func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError { + return &AuthenticationError{ + Type: baseErr.Type, + Message: baseErr.Message, + Code: baseErr.Code, + Cause: cause, + } +} + +// IsAuthenticationError checks if an error is an authentication error. +func IsAuthenticationError(err error) bool { + var authenticationError *AuthenticationError + ok := errors.As(err, &authenticationError) + return ok +} + +// GetUserFriendlyMessage returns a user-friendly error message based on the error type. +func GetUserFriendlyMessage(err error) string { + var authErr *AuthenticationError + if errors.As(err, &authErr) { + switch authErr.Type { + case "device_code_failed": + return "Failed to start Kilocode authentication. Please check your network connection and try again." + case "device_code_expired": + return "The authentication code has expired. Please try again." + case "authorization_pending": + return "Waiting for you to authorize the application on Kilocode." + case "access_denied": + return "Authentication was cancelled or denied." + case "token_exchange_failed": + return "Failed to complete authentication. Please try again." + case "polling_timeout": + return "Authentication timed out. Please try again." + case "user_info_failed": + return "Failed to get your Kilocode account information. Please try again." + default: + return "Authentication failed. Please try again." + } + } + + return "An unexpected error occurred. Please try again." +} diff --git a/internal/auth/kilocode/kilocode_auth.go b/internal/auth/kilocode/kilocode_auth.go new file mode 100644 index 0000000000..bb8255c2a2 --- /dev/null +++ b/internal/auth/kilocode/kilocode_auth.go @@ -0,0 +1,368 @@ +// Package kilocode provides authentication and token management for Kilocode API. +// It handles the device flow for secure authentication with the Kilocode API. +package kilocode + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // kilocodeAPIBaseURL is the base URL for Kilocode API. + kilocodeAPIBaseURL = "https://api.kilo.ai" + // kilocodeDeviceCodeURL is the endpoint for requesting device codes. + kilocodeDeviceCodeURL = "https://api.kilo.ai/api/device-auth/codes" + // kilocodeVerifyURL is the URL where users verify their device codes. + kilocodeVerifyURL = "https://kilo.ai/device/verify" + // defaultPollInterval is the default interval for polling token endpoint. + defaultPollInterval = 3 * time.Second + // maxPollDuration is the maximum time to wait for user authorization. + maxPollDuration = 15 * time.Minute +) + +// DeviceCodeResponse represents Kilocode's device code response. +type DeviceCodeResponse struct { + // Code is the device verification code. + Code string `json:"code"` + // VerificationURL is the URL where the user should enter the code. + VerificationURL string `json:"verificationUrl"` + // ExpiresIn is the number of seconds until the device code expires. + ExpiresIn int `json:"expiresIn"` +} + +// PollResponse represents the polling response from Kilocode. +type PollResponse struct { + // Status indicates the current status: pending, approved, denied, expired. + Status string `json:"status"` + // Token is the access token (only present when status is "approved"). + Token string `json:"token,omitempty"` + // UserID is the user ID (only present when status is "approved"). + UserID string `json:"userId,omitempty"` + // UserEmail is the user email (only present when status is "approved"). + UserEmail string `json:"userEmail,omitempty"` +} + +// DeviceFlowClient handles the device flow for Kilocode. +type DeviceFlowClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewDeviceFlowClient creates a new device flow client. +func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &DeviceFlowClient{ + httpClient: client, + cfg: cfg, + } +} + +// RequestDeviceCode initiates the device flow by requesting a device code from Kilocode. +func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, kilocodeDeviceCodeURL, nil) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("kilocode device code: close body error: %v", errClose) + } + }() + + if !isHTTPSuccess(resp.StatusCode) { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var deviceCode DeviceCodeResponse + if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + + return &deviceCode, nil +} + +// PollForToken polls the token endpoint until the user authorizes or the device code expires. +func (c *DeviceFlowClient) PollForToken(ctx context.Context, code string) (*PollResponse, error) { + if code == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is empty")) + } + + pollURL := fmt.Sprintf("%s/%s", kilocodeDeviceCodeURL, url.PathEscape(code)) + deadline := time.Now().Add(maxPollDuration) + + ticker := time.NewTicker(defaultPollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err()) + case <-ticker.C: + if time.Now().After(deadline) { + return nil, ErrPollingTimeout + } + + pollResp, err := c.pollDeviceCode(ctx, pollURL) + if err != nil { + return nil, err + } + + switch pollResp.Status { + case "pending": + // Continue polling + continue + case "approved": + // Success - return the response + return pollResp, nil + case "denied": + return nil, ErrAccessDenied + case "expired": + return nil, ErrDeviceCodeExpired + default: + return nil, NewAuthenticationError(ErrTokenExchangeFailed, + fmt.Errorf("unknown status: %s", pollResp.Status)) + } + } + } +} + +// pollDeviceCode makes a single polling request to check the device code status. +func (c *DeviceFlowClient) pollDeviceCode(ctx context.Context, pollURL string) (*PollResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, pollURL, nil) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("kilocode token poll: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + // Handle different HTTP status codes + switch resp.StatusCode { + case http.StatusOK: + // Success - parse the response + var pollResp PollResponse + if err = json.Unmarshal(bodyBytes, &pollResp); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + return &pollResp, nil + case http.StatusAccepted: + // Still pending + return &PollResponse{Status: "pending"}, nil + case http.StatusForbidden: + // Access denied + return &PollResponse{Status: "denied"}, nil + case http.StatusGone: + // Code expired + return &PollResponse{Status: "expired"}, nil + default: + return nil, NewAuthenticationError(ErrTokenExchangeFailed, + fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } +} + +// KilocodeAuth handles Kilocode authentication flow. +// It provides methods for device flow authentication and token management. +type KilocodeAuth struct { + httpClient *http.Client + deviceClient *DeviceFlowClient + cfg *config.Config +} + +// NewKilocodeAuth creates a new KilocodeAuth service instance. +// It initializes an HTTP client with proxy settings from the provided configuration. +func NewKilocodeAuth(cfg *config.Config) *KilocodeAuth { + return &KilocodeAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + deviceClient: NewDeviceFlowClient(cfg), + cfg: cfg, + } +} + +// StartDeviceFlow initiates the device flow authentication. +// Returns the device code response containing the user code and verification URI. +func (k *KilocodeAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) { + return k.deviceClient.RequestDeviceCode(ctx) +} + +// WaitForAuthorization polls for user authorization and returns the auth bundle. +func (k *KilocodeAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KilocodeAuthBundle, error) { + if deviceCode == nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil")) + } + + pollResp, err := k.deviceClient.PollForToken(ctx, deviceCode.Code) + if err != nil { + return nil, err + } + + if pollResp.Status != "approved" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, + fmt.Errorf("unexpected status: %s", pollResp.Status)) + } + + if pollResp.Token == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty token in response")) + } + + return &KilocodeAuthBundle{ + Token: pollResp.Token, + UserID: pollResp.UserID, + UserEmail: pollResp.UserEmail, + }, nil +} + +// GetAPIEndpoint returns the Kilocode API endpoint URL for OpenRouter compatibility. +func (k *KilocodeAuth) GetAPIEndpoint() string { + return "https://kilo.ai/api/openrouter" +} + +// ValidateToken checks if a Kilocode access token is valid. +// Since Kilocode API only supports /chat/completions, we skip validation here. +// Token validity will be verified during actual API requests. +func (k *KilocodeAuth) ValidateToken(ctx context.Context, token string) (bool, error) { + if token == "" { + return false, nil + } + + // Kilocode API only supports /chat/completions endpoint + // We assume token is valid if it's not empty; actual validation happens during requests + return true, nil +} + +// CreateTokenStorage creates a new KilocodeTokenStorage from auth bundle. +func (k *KilocodeAuth) CreateTokenStorage(bundle *KilocodeAuthBundle) *KilocodeTokenStorage { + return &KilocodeTokenStorage{ + Token: bundle.Token, + UserID: bundle.UserID, + UserEmail: bundle.UserEmail, + Type: "kilocode", + } +} + +// LoadAndValidateToken loads a token from storage and validates it. +// Returns true if valid, false if invalid or expired. +func (k *KilocodeAuth) LoadAndValidateToken(ctx context.Context, storage *KilocodeTokenStorage) (bool, error) { + if storage == nil || storage.Token == "" { + return false, fmt.Errorf("no token available") + } + + // Mask token for logging + maskedToken := maskToken(storage.Token) + log.Debugf("kilocode: validating token %s for user %s", maskedToken, storage.UserID) + + valid, err := k.ValidateToken(ctx, storage.Token) + if err != nil { + log.Debugf("kilocode: token validation failed for %s: %v", maskedToken, err) + return false, err + } + + if !valid { + log.Debugf("kilocode: token %s is invalid", maskedToken) + return false, fmt.Errorf("token is invalid") + } + + log.Debugf("kilocode: token %s is valid", maskedToken) + return true, nil +} + +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 +} + +// FetchModels retrieves available models from the Kilocode API and filters for free models. +// This method fetches the list of AI models available from Kilocode and returns only +// those that are free (pricing.prompt == "0" && pricing.completion == "0"). +// +// Parameters: +// - ctx: The context for the request +// - token: The access token for authentication +// +// Returns: +// - []*registry.ModelInfo: The list of available free models converted to internal format +// - error: An error if the request fails +func (k *KilocodeAuth) FetchModels(ctx context.Context, token string) ([]*registry.ModelInfo, error) { + if token == "" { + return nil, fmt.Errorf("kilocode: access token is required") + } + + // Make request to Kilocode models endpoint + req, err := http.NewRequestWithContext(ctx, http.MethodGet, k.GetAPIEndpoint()+"/models", nil) + if err != nil { + return nil, fmt.Errorf("kilocode: failed to create models request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Accept", "application/json") + + resp, err := k.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("kilocode: failed to fetch models: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("kilocode fetch models: close body error: %v", errClose) + } + }() + + if !isHTTPSuccess(resp.StatusCode) { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("kilocode: models API returned status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Parse the API response + var apiResponse registry.KilocodeAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + return nil, fmt.Errorf("kilocode: failed to parse models response: %w", err) + } + + // Convert API models to internal format (filters for free models automatically) + models := registry.ConvertKilocodeAPIModels(apiResponse.Data) + + maskedToken := maskToken(token) + log.Debugf("kilocode: fetched %d free models with token %s", len(models), maskedToken) + + return models, nil +} + +// maskToken masks a token for safe logging by showing only first and last few characters. +func maskToken(token string) string { + if len(token) <= 8 { + return "***" + } + return token[:4] + "***" + token[len(token)-4:] +} diff --git a/internal/auth/kilocode/token.go b/internal/auth/kilocode/token.go new file mode 100644 index 0000000000..a91f4abfdd --- /dev/null +++ b/internal/auth/kilocode/token.go @@ -0,0 +1,67 @@ +// Package kilocode provides authentication and token management functionality +// for Kilocode AI services. It handles device flow token storage, +// serialization, and retrieval for maintaining authenticated sessions with the Kilocode API. +package kilocode + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +// KilocodeTokenStorage stores token information for Kilocode API authentication. +// It maintains compatibility with the existing auth system while adding Kilocode-specific fields +// for managing access tokens and user account information. +type KilocodeTokenStorage struct { + // Token is the access token used for authenticating API requests. + Token string `json:"token"` + // UserID is the Kilocode user ID associated with this token. + UserID string `json:"user_id"` + // UserEmail is the Kilocode user email associated with this token. + UserEmail string `json:"user_email"` + // Type indicates the authentication provider type, always "kilocode" for this storage. + Type string `json:"type"` +} + +// KilocodeAuthBundle bundles authentication data for storage. +type KilocodeAuthBundle struct { + // Token is the access token. + Token string + // UserID is the Kilocode user ID. + UserID string + // UserEmail is the Kilocode user email. + UserEmail string +} + +// SaveTokenToFile serializes the Kilocode token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path for persistent storage. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *KilocodeTokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "kilocode" + if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index 2a3407be49..6c8e7d0107 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -23,6 +23,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewKiroAuthenticator(), sdkAuth.NewGitHubCopilotAuthenticator(), sdkAuth.NewKiloAuthenticator(), + sdkAuth.NewClineAuthenticator(), ) return manager } diff --git a/internal/cmd/cline_login.go b/internal/cmd/cline_login.go new file mode 100644 index 0000000000..181636280e --- /dev/null +++ b/internal/cmd/cline_login.go @@ -0,0 +1,54 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" +) + +// DoClineLogin handles the Cline device flow using the shared authentication manager. +// It initiates the device-based authentication process for Cline AI services and saves +// the authentication tokens to the configured auth directory. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including browser behavior and prompts +func DoClineLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + promptFn := options.Prompt + if promptFn == nil { + promptFn = func(prompt string) (string, error) { + fmt.Print(prompt) + var value string + fmt.Scanln(&value) + return strings.TrimSpace(value), nil + } + } + + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "cline", cfg, authOpts) + if err != nil { + fmt.Printf("Cline authentication failed: %v\n", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + + fmt.Println("Cline authentication successful!") +} diff --git a/internal/cmd/kilocode_login.go b/internal/cmd/kilocode_login.go new file mode 100644 index 0000000000..969fbba25f --- /dev/null +++ b/internal/cmd/kilocode_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoKilocodeLogin triggers the device flow for Kilocode and saves tokens. +// It initiates the device flow authentication, displays the user code for the user to enter +// at Kilocode's verification URL, and waits for authorization before saving the tokens. +// +// Parameters: +// - cfg: The application configuration containing proxy and auth directory settings +// - options: Login options including browser behavior settings +func DoKilocodeLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + record, savedPath, err := manager.Login(context.Background(), "kilocode", cfg, authOpts) + if err != nil { + log.Errorf("Kilocode authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kilocode authentication successful!") +} diff --git a/internal/config/config.go b/internal/config/config.go index bb081c7846..4882509600 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -196,6 +196,22 @@ type RoutingConfig struct { // Strategy selects the credential selection strategy. // Supported values: "round-robin" (default), "fill-first". Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` + + // Mode configures the routing mode. + // Supported values: "" (default, provider-scoped), "key-based" (model-only key). + Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` + + // FallbackModels maps original model names to fallback model names. + // When all credentials for the original model fail with 429/401/5xx, + // the request is automatically retried with the fallback model. + FallbackModels map[string]string `yaml:"fallback-models,omitempty" json:"fallback-models,omitempty"` + + // FallbackChain is a general fallback chain for models not in FallbackModels. + // Models are tried in order when the original model fails. + FallbackChain []string `yaml:"fallback-chain,omitempty" json:"fallback-chain,omitempty"` + + // FallbackMaxDepth limits the number of fallback attempts (default: 3). + FallbackMaxDepth int `yaml:"fallback-max-depth,omitempty" json:"fallback-max-depth,omitempty"` } // OAuthModelAlias defines a model ID alias for a specific channel. @@ -773,7 +789,8 @@ func payloadRawString(value any) ([]byte, bool) { // SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases. // It trims whitespace, normalizes channel keys to lower-case, drops empty entries, -// allows multiple aliases per upstream name, and ensures aliases are unique within each channel. +// allows multiple source models to share the same alias, and ensures each name+alias +// combination is unique within each channel. // It also injects default aliases for channels that have built-in defaults (e.g., kiro) // when no user-configured aliases exist for those channels. func (cfg *Config) SanitizeOAuthModelAlias() { @@ -816,7 +833,9 @@ func (cfg *Config) SanitizeOAuthModelAlias() { out[channel] = nil continue } - seenAlias := make(map[string]struct{}, len(aliases)) + // Deduplicate by name+alias combination (not just alias) + // This allows multiple source models to share the same alias + seenNameAlias := make(map[string]struct{}, len(aliases)) clean := make([]OAuthModelAlias, 0, len(aliases)) for _, entry := range aliases { name := strings.TrimSpace(entry.Name) @@ -827,11 +846,12 @@ func (cfg *Config) SanitizeOAuthModelAlias() { if strings.EqualFold(name, alias) { continue } - aliasKey := strings.ToLower(alias) - if _, ok := seenAlias[aliasKey]; ok { + // Deduplicate by name+alias combination (case-insensitive) + nameAliasKey := strings.ToLower(name + "::" + alias) + if _, ok := seenNameAlias[nameAliasKey]; ok { continue } - seenAlias[aliasKey] = struct{}{} + seenNameAlias[nameAliasKey] = struct{}{} clean = append(clean, OAuthModelAlias{Name: name, Alias: alias, Fork: entry.Fork}) } if len(clean) > 0 { diff --git a/internal/config/routing_config_test.go b/internal/config/routing_config_test.go new file mode 100644 index 0000000000..3878c054ed --- /dev/null +++ b/internal/config/routing_config_test.go @@ -0,0 +1,35 @@ +package config + +import ( + "testing" + + "gopkg.in/yaml.v3" +) + +func TestRoutingConfigModeParsing(t *testing.T) { + yamlData := ` +routing: + mode: key-based +` + var cfg Config + if err := yaml.Unmarshal([]byte(yamlData), &cfg); err != nil { + t.Fatalf("failed to parse: %v", err) + } + if cfg.Routing.Mode != "key-based" { + t.Errorf("expected 'key-based', got %q", cfg.Routing.Mode) + } +} + +func TestRoutingConfigModeEmpty(t *testing.T) { + yamlData := ` +routing: + strategy: round-robin +` + var cfg Config + if err := yaml.Unmarshal([]byte(yamlData), &cfg); err != nil { + t.Fatalf("failed to parse: %v", err) + } + if cfg.Routing.Mode != "" { + t.Errorf("expected empty string, got %q", cfg.Routing.Mode) + } +} diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 9b7d31aab6..baf88a9451 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -30,4 +30,7 @@ const ( // Kilo represents the Kilo AI provider identifier. Kilo = "kilo" + + // Cline represents the Cline AI provider identifier. + Cline = "cline" ) diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index b94d7afe6d..9fca812c20 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -4,8 +4,11 @@ package logging import ( + "bytes" + "context" "errors" "fmt" + "io" "net/http" "runtime/debug" "strings" @@ -14,6 +17,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" ) // aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking. @@ -27,12 +31,67 @@ var aiAPIPrefixes = []string{ } const skipGinLogKey = "__gin_skip_request_logging__" +const requestBodyKey = "__gin_request_body__" +const providerAuthContextKey = "cliproxy.provider_auth" +const ginProviderAuthKey = "providerAuth" +const fallbackInfoContextKey = "cliproxy.fallback_info" +const ginFallbackInfoKey = "fallbackInfo" + +func getProviderAuthFromContext(c *gin.Context) (provider, authID, authLabel string) { + if c == nil { + return "", "", "" + } + + // First try to get from Gin context (set by conductor.go) + if v, exists := c.Get(ginProviderAuthKey); exists { + if authInfo, ok := v.(map[string]string); ok { + return authInfo["provider"], authInfo["auth_id"], authInfo["auth_label"] + } + } + + // Fallback to request context + if c.Request == nil { + return "", "", "" + } + ctx := c.Request.Context() + if ctx == nil { + return "", "", "" + } + if v, ok := ctx.Value(providerAuthContextKey).(map[string]string); ok { + return v["provider"], v["auth_id"], v["auth_label"] + } + return "", "", "" +} + +func getFallbackInfoFromContext(c *gin.Context) (requestedModel, actualModel string) { + if c == nil { + return "", "" + } + + if v, exists := c.Get(ginFallbackInfoKey); exists { + if info, ok := v.(map[string]string); ok { + return info["requested_model"], info["actual_model"] + } + } + + if c.Request == nil { + return "", "" + } + ctx := c.Request.Context() + if ctx == nil { + return "", "" + } + if v, ok := ctx.Value(fallbackInfoContextKey).(map[string]string); ok { + return v["requested_model"], v["actual_model"] + } + return "", "" +} // GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses // using logrus. It captures request details including method, path, status code, latency, -// client IP, and any error messages. Request ID is only added for AI API requests. +// client IP, model name, and auth key name. Request ID is only added for AI API requests. // -// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ... +// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ... | model (auth) // Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ... // // Returns: @@ -43,12 +102,20 @@ func GinLogrusLogger() gin.HandlerFunc { path := c.Request.URL.Path raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery) + var requestBody []byte + if isAIAPIPath(path) && c.Request.Body != nil { + requestBody, _ = io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(requestBody)) + c.Set(requestBodyKey, requestBody) + } + // Only generate request ID for AI API paths var requestID string if isAIAPIPath(path) { requestID = GenerateRequestID() SetGinRequestID(c, requestID) ctx := WithRequestID(c.Request.Context(), requestID) + ctx = context.WithValue(ctx, "gin", c) c.Request = c.Request.WithContext(ctx) } @@ -74,10 +141,66 @@ func GinLogrusLogger() gin.HandlerFunc { method := c.Request.Method errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String() + modelName := "" + if len(requestBody) == 0 { + if storedBody, exists := c.Get(requestBodyKey); exists { + if bodyBytes, ok := storedBody.([]byte); ok { + requestBody = bodyBytes + } + } + } + if len(requestBody) > 0 { + modelName = gjson.GetBytes(requestBody, "model").String() + modelName = strings.TrimSpace(modelName) + } + + authKeyName := "" + if apiKey, exists := c.Get("apiKey"); exists { + if keyStr, ok := apiKey.(string); ok { + authKeyName = keyStr + } + } + + provider, authID, authLabel := getProviderAuthFromContext(c) + requestedModel, actualModel := getFallbackInfoFromContext(c) + providerInfo := "" + if provider != "" { + displayAuth := authLabel + if displayAuth == "" { + displayAuth = authID + } + if displayAuth != "" { + providerInfo = fmt.Sprintf("%s:%s", provider, displayAuth) + } else { + providerInfo = provider + } + } + if requestID == "" { requestID = "--------" } + logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path) + + if isAIAPIPath(path) && (modelName != "" || providerInfo != "" || authKeyName != "") { + displayModelName := modelName + if requestedModel != "" && actualModel != "" && requestedModel != actualModel { + displayModelName = fmt.Sprintf("%s โ†’ %s", requestedModel, actualModel) + } + + if displayModelName != "" && providerInfo != "" { + logLine = logLine + " | " + fmt.Sprintf("%s | %s", displayModelName, providerInfo) + } else if displayModelName != "" && authKeyName != "" { + logLine = logLine + " | " + fmt.Sprintf("%s | %s", displayModelName, authKeyName) + } else if displayModelName != "" { + logLine = logLine + " | " + displayModelName + } else if providerInfo != "" { + logLine = logLine + " | " + providerInfo + } else if authKeyName != "" { + logLine = logLine + " | " + authKeyName + } + } + if errorMessage != "" { logLine = logLine + " | " + errorMessage } @@ -148,3 +271,23 @@ func shouldSkipGinRequestLogging(c *gin.Context) bool { flag, ok := val.(bool) return ok && flag } + +// GetRequestBody retrieves the request body from context or reads it from the request. +// This allows handlers to read the body multiple times. +func GetRequestBody(c *gin.Context) []byte { + if c == nil { + return nil + } + if body, exists := c.Get(requestBodyKey); exists { + if bodyBytes, ok := body.([]byte); ok { + return bodyBytes + } + } + if c.Request.Body != nil { + body, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + c.Set(requestBodyKey, body) + return body + } + return nil +} diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index c14f39d2fb..996591677a 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -12,10 +12,11 @@ import ( // for OAuth2 flows to prevent CSRF attacks. // // Returns: -// - string: A hexadecimal encoded random state string +// - string: A 64-character hexadecimal encoded random state string (32 bytes) // - error: An error if the random generation fails, nil otherwise func GenerateRandomState() (string, error) { - bytes := make([]byte, 16) + // Use 32 bytes to generate 64 hex characters, matching iFlow CLI's state format + bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("failed to generate random bytes: %w", err) } diff --git a/internal/registry/cline_models.go b/internal/registry/cline_models.go new file mode 100644 index 0000000000..36d3aec2f0 --- /dev/null +++ b/internal/registry/cline_models.go @@ -0,0 +1,20 @@ +// Package registry provides model definitions for various AI service providers. +package registry + +// GetClineModels returns the Cline model definitions +func GetClineModels() []*ModelInfo { + return []*ModelInfo{ + // --- Base Models --- + { + ID: "cline/auto", + Object: "model", + Created: 1732752000, + OwnedBy: "cline", + Type: "cline", + DisplayName: "Cline Auto", + Description: "Automatic model selection by Cline", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} diff --git a/internal/registry/kilocode_model_converter.go b/internal/registry/kilocode_model_converter.go new file mode 100644 index 0000000000..25edc9d485 --- /dev/null +++ b/internal/registry/kilocode_model_converter.go @@ -0,0 +1,376 @@ +// Package registry provides Kilocode model conversion utilities. +// This file handles converting dynamic Kilocode API model lists to the internal ModelInfo format, +// and filtering for free models based on pricing information. +package registry + +import ( + "strconv" + "strings" + "time" +) + +// KilocodeAPIModel represents a model from Kilocode API response. +// This structure mirrors the OpenRouter-compatible API format used by Kilocode. +type KilocodeAPIModel struct { + // ID is the unique identifier for the model (e.g., "devstral-2-2512") + ID string `json:"id"` + // Name is the human-readable name + Name string `json:"name"` + // Pricing contains cost information for prompt and completion tokens + Pricing struct { + // Prompt is the cost per prompt token (string format, e.g., "0" for free) + Prompt string `json:"prompt"` + // Completion is the cost per completion token (string format, e.g., "0" for free) + Completion string `json:"completion"` + } `json:"pricing"` + // ContextLength is the maximum context window size + ContextLength int `json:"context_length"` +} + +// KilocodeAPIResponse represents the full API response from Kilocode models endpoint. +type KilocodeAPIResponse struct { + // Data contains the list of available models + Data []*KilocodeAPIModel `json:"data"` +} + +// DefaultKilocodeThinkingSupport defines the default thinking configuration for Kilocode models. +// All Kilocode models support thinking with the following budget range. +var DefaultKilocodeThinkingSupport = &ThinkingSupport{ + Min: 1024, // Minimum thinking budget tokens + Max: 32000, // Maximum thinking budget tokens + ZeroAllowed: true, // Allow disabling thinking with 0 + DynamicAllowed: true, // Allow dynamic thinking budget (-1) +} + +// DefaultKilocodeContextLength is the default context window size for Kilocode models. +const DefaultKilocodeContextLength = 128000 + +// DefaultKilocodeMaxCompletionTokens is the default max completion tokens for Kilocode models. +const DefaultKilocodeMaxCompletionTokens = 32000 + +// ConvertKilocodeAPIModels converts Kilocode API models to internal ModelInfo format. +// It performs the following transformations: +// - Normalizes model ID (e.g., devstral-2-2512 โ†’ kilocode-devstral-2-2512) +// - Filters for free models only (pricing.prompt == "0" && pricing.completion == "0") +// - Adds default thinking support metadata +// - Sets context length from API or uses default if not provided +// +// Parameters: +// - kilocodeModels: List of models from Kilocode API response +// +// Returns: +// - []*ModelInfo: Converted model information list (free models only, filtered by allowed providers) +func ConvertKilocodeAPIModels(kilocodeModels []*KilocodeAPIModel) []*ModelInfo { + if len(kilocodeModels) == 0 { + return nil + } + + now := time.Now().Unix() + result := make([]*ModelInfo, 0, len(kilocodeModels)) + + for _, km := range kilocodeModels { + if km == nil { + continue + } + + if km.ID == "" { + continue + } + + if !isFreeModel(km) { + continue + } + + if !isAllowedKilocodeProvider(km.ID) { + continue + } + + normalizedID := normalizeKilocodeModelID(km.ID) + + info := &ModelInfo{ + ID: normalizedID, + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: generateKilocodeDisplayName(km.Name, normalizedID), + Description: generateKilocodeDescription(km.Name, normalizedID), + ContextLength: getKilocodeContextLength(km.ContextLength), + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + } + + result = append(result, info) + } + + return result +} + +// allowedKilocodeProviders defines which model providers are allowed to be listed. +var allowedKilocodeProviders = []string{ + "deepseek/", + "minimax/", + "openai/gpt-oss", + "tngtech/", + "upstage/", + "z-ai/", +} + +// isAllowedKilocodeProvider checks if a model ID belongs to an allowed provider. +func isAllowedKilocodeProvider(modelID string) bool { + idLower := strings.ToLower(modelID) + for _, prefix := range allowedKilocodeProviders { + if strings.HasPrefix(idLower, prefix) { + return true + } + } + return false +} + +// isFreeModel checks if a Kilocode model is free based on pricing information. +// A model is considered free if both prompt and completion costs are zero. +// Handles various pricing formats: "0", "0.0", "0.0000000", etc. +func isFreeModel(model *KilocodeAPIModel) bool { + if model == nil { + return false + } + + promptPrice, err1 := strconv.ParseFloat(strings.TrimSpace(model.Pricing.Prompt), 64) + completionPrice, err2 := strconv.ParseFloat(strings.TrimSpace(model.Pricing.Completion), 64) + + if err1 != nil || err2 != nil { + return false + } + + return promptPrice == 0 && completionPrice == 0 +} + +// normalizeKilocodeModelID converts Kilocode API model IDs to internal format. +// Transformation rules: +// - Adds "kilocode-" prefix if not present +// - Handles special cases and ensures consistent naming +// +// Examples: +// - "devstral-2-2512" โ†’ "kilocode-devstral-2-2512" +// - "trinity-large-preview" โ†’ "kilocode-trinity-large-preview" +// - "kilocode-mimo-v2-flash" โ†’ "kilocode-mimo-v2-flash" (unchanged) +func normalizeKilocodeModelID(modelID string) string { + if modelID == "" { + return "" + } + + // Trim whitespace + modelID = strings.TrimSpace(modelID) + + // Add kilocode- prefix if not present + if !strings.HasPrefix(modelID, "kilocode-") { + modelID = "kilocode-" + modelID + } + + return modelID +} + +// generateKilocodeDisplayName creates a human-readable display name. +// Uses the API-provided model name if available, otherwise generates from ID. +func generateKilocodeDisplayName(modelName, normalizedID string) string { + if modelName != "" && modelName != normalizedID { + return "Kilocode " + modelName + } + + // Generate from normalized ID by removing kilocode- prefix and formatting + displayID := strings.TrimPrefix(normalizedID, "kilocode-") + // Capitalize first letter of each word + words := strings.Split(displayID, "-") + for i, word := range words { + if len(word) > 0 { + words[i] = strings.ToUpper(word[:1]) + word[1:] + } + } + return "Kilocode " + strings.Join(words, " ") +} + +// generateKilocodeDescription creates a description for Kilocode models. +func generateKilocodeDescription(modelName, normalizedID string) string { + if modelName != "" && modelName != normalizedID { + return "Kilocode AI model: " + modelName + " (Free tier)" + } + + displayID := strings.TrimPrefix(normalizedID, "kilocode-") + return "Kilocode AI model: " + displayID + " (Free tier)" +} + +// getKilocodeContextLength returns the context length, using default if not provided. +func getKilocodeContextLength(contextLength int) int { + if contextLength > 0 { + return contextLength + } + return DefaultKilocodeContextLength +} + +// ResolveKilocodeModelAlias normalizes model names for Kilocode API. +// It strips the "kilocode-" prefix if present and passes through the model name. +// +// Model alias resolution (e.g., "kimi" โ†’ "moonshotai/kimi-k2.5:free") should be +// configured via openai-compatibility.models[] in config.yaml, NOT hardcoded here. +// +// Examples: +// - "kilocode-moonshotai/kimi-k2.5:free" โ†’ "moonshotai/kimi-k2.5:free" +// - "moonshotai/kimi-k2.5:free" โ†’ "moonshotai/kimi-k2.5:free" (unchanged) +// - "kimi" โ†’ "kimi" (unchanged - config alias handles this BEFORE executor) +func ResolveKilocodeModelAlias(alias string) string { + alias = strings.TrimSpace(alias) + if alias == "" { + return alias + } + + // Strip kilocode- prefix if present + return strings.TrimPrefix(alias, "kilocode-") +} + +// GetKilocodeModels returns a static list of free Kilocode models. +// The Kilocode API does not support the /models endpoint (returns 405 Method Not Allowed), +// so we maintain a static list of known free models. +// Only includes: deepseek, minimax, gpt-oss, chimera, upstage, z-ai +func GetKilocodeModels() []*ModelInfo { + now := int64(1738368000) // 2025-02-01 + return []*ModelInfo{ + // DeepSeek + { + ID: "kilocode-deepseek/deepseek-r1-0528:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode DeepSeek R1 0528 (Free)", + Description: "DeepSeek R1 0528 (Free tier)", + ContextLength: 163840, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + // MiniMax + { + ID: "kilocode-minimax/minimax-m2.1:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode MiniMax M2.1 (Free)", + Description: "MiniMax M2.1 (Free tier)", + ContextLength: 204800, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + { + ID: "kilocode-minimax/minimax-m2.5:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode MiniMax M2.5 (Free)", + Description: "MiniMax M2.5 (Free tier)", + ContextLength: 204800, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + // GPT-OSS + { + ID: "kilocode-openai/gpt-oss-20b:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode GPT-OSS 20B (Free)", + Description: "OpenAI GPT-OSS 20B (Free tier)", + ContextLength: 131072, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + { + ID: "kilocode-openai/gpt-oss-120b:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode GPT-OSS 120B (Free)", + Description: "OpenAI GPT-OSS 120B (Free tier)", + ContextLength: 131072, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + // Chimera (TNG Tech) + { + ID: "kilocode-tngtech/deepseek-r1t-chimera:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode DeepSeek R1T Chimera (Free)", + Description: "TNG DeepSeek R1T Chimera (Free tier)", + ContextLength: 163840, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + { + ID: "kilocode-tngtech/deepseek-r1t2-chimera:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode DeepSeek R1T2 Chimera (Free)", + Description: "TNG DeepSeek R1T2 Chimera (Free tier)", + ContextLength: 163840, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + { + ID: "kilocode-tngtech/tng-r1t-chimera:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode TNG R1T Chimera (Free)", + Description: "TNG R1T Chimera (Free tier)", + ContextLength: 163840, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + // Upstage + { + ID: "kilocode-upstage/solar-pro-3:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode Solar Pro 3 (Free)", + Description: "Upstage Solar Pro 3 (Free tier)", + ContextLength: 128000, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + // Z-AI (GLM) + { + ID: "kilocode-z-ai/glm-4.5-air:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode GLM 4.5 Air (Free)", + Description: "Z.AI GLM 4.5 Air (Free tier)", + ContextLength: 131072, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + { + ID: "kilocode-z-ai/glm-5:free", + Object: "model", + Created: now, + OwnedBy: "kilocode", + Type: "kilocode", + DisplayName: "Kilocode GLM 5 (Free)", + Description: "Z.AI GLM 5 (Free tier)", + ContextLength: 202800, + MaxCompletionTokens: DefaultKilocodeMaxCompletionTokens, + Thinking: cloneThinkingSupport(DefaultKilocodeThinkingSupport), + }, + } +} diff --git a/internal/registry/kilocode_model_converter_test.go b/internal/registry/kilocode_model_converter_test.go new file mode 100644 index 0000000000..7e5ca6054e --- /dev/null +++ b/internal/registry/kilocode_model_converter_test.go @@ -0,0 +1,40 @@ +package registry + +import ( + "testing" +) + +func TestResolveKilocodeModelAlias(t *testing.T) { + tests := []struct { + name string + alias string + expected string + }{ + // kilocode- prefix stripping + {"with kilocode prefix full format", "kilocode-moonshotai/kimi-k2.5:free", "moonshotai/kimi-k2.5:free"}, + {"with kilocode prefix simple", "kilocode-kimi", "kimi"}, + + // Already full format (passthrough) + {"already full format", "moonshotai/kimi-k2.5:free", "moonshotai/kimi-k2.5:free"}, + {"already full format glm", "z-ai/glm-4.7:free", "z-ai/glm-4.7:free"}, + + // Short names passthrough (config alias handles these) + {"kimi short passthrough", "kimi", "kimi"}, + {"glm short passthrough", "glm", "glm"}, + {"unknown model passthrough", "unknown-model", "unknown-model"}, + + // Edge cases + {"empty string", "", ""}, + {"whitespace only", " ", ""}, + {"whitespace around", " kimi ", "kimi"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ResolveKilocodeModelAlias(tt.alias) + if result != tt.expected { + t.Errorf("ResolveKilocodeModelAlias(%q) = %q, want %q", tt.alias, result, tt.expected) + } + }) + } +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index d7a6d75b3c..f26dff1d68 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -24,6 +24,7 @@ import ( // - kilo // - github-copilot // - amazonq +// - kilocode (alias for kilo) // - antigravity (returns static overrides only) func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { key := strings.ToLower(strings.TrimSpace(channel)) @@ -50,7 +51,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { return GetGitHubCopilotModels() case "kiro": return GetKiroModels() - case "kilo": + case "kilo", "kilocode": return GetKiloModels() case "amazonq": return GetAmazonQModels() diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go index b9b56677bd..b1fca73399 100644 --- a/internal/registry/model_definitions_static_data.go +++ b/internal/registry/model_definitions_static_data.go @@ -508,6 +508,21 @@ func GetGeminiCLIModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, }, + { + ID: "gemini-3.1-pro-preview", + Object: "model", + Created: 1765929600, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3.1-pro-preview", + Version: "3.1", + DisplayName: "Gemini 3.1 Pro Preview", + Description: "Preview release of Gemini 3.1 Pro with enhanced reasoning and multimodal capabilities", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, + }, } } diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index b1e23860cf..d891b92fb3 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -141,7 +141,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, URL: endpoint, Method: http.MethodPost, Headers: wsReq.Headers.Clone(), - Body: body.payload, + Body: bytes.Clone(body.payload), Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -156,14 +156,14 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, } recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone()) if len(wsResp.Body) > 0 { - appendAPIResponseChunk(ctx, e.cfg, wsResp.Body) + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body)) } if wsResp.Status < 200 || wsResp.Status >= 300 { return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)} } reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) var param any - out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m) + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, bytes.Clone(wsResp.Body), ¶m) resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()} return resp, nil } @@ -199,7 +199,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth URL: endpoint, Method: http.MethodPost, Headers: wsReq.Headers.Clone(), - Body: body.payload, + Body: bytes.Clone(body.payload), Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -225,7 +225,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } var body bytes.Buffer if len(firstEvent.Payload) > 0 { - appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload) + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload)) body.Write(firstEvent.Payload) } if firstEvent.Type == wsrelay.MessageTypeStreamEnd { @@ -244,7 +244,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth metadataLogged = true } if len(event.Payload) > 0 { - appendAPIResponseChunk(ctx, e.cfg, event.Payload) + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) body.Write(event.Payload) } if event.Type == wsrelay.MessageTypeStreamEnd { @@ -273,12 +273,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } case wsrelay.MessageTypeStreamChunk: if len(event.Payload) > 0 { - appendAPIResponseChunk(ctx, e.cfg, event.Payload) + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) filtered := FilterSSEUsageMetadata(event.Payload) if detail, ok := parseGeminiStreamUsage(filtered); ok { reporter.publish(ctx, detail) } - lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m) + lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, bytes.Clone(filtered), ¶m) for i := range lines { out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))} } @@ -398,7 +398,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream) - payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream) + payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream) payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { return nil, translatedPayload{}, err diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 43891019db..1213478450 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -287,7 +287,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -315,6 +315,14 @@ attemptLoop: return resp, err } + log.WithFields(log.Fields{ + "auth_id": auth.ID, + "provider": e.Identifier(), + "model": baseModel, + "url": httpReq.URL.String(), + "method": httpReq.Method, + }).Infof("external HTTP request: %s %s", httpReq.Method, httpReq.URL.String()) + httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { recordAPIResponseError(ctx, e.cfg, errDo) @@ -429,7 +437,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -821,7 +829,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -1022,7 +1030,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut respCtx := context.WithValue(ctx, "alt", opts.Alt) // Prepare payload once (doesn't depend on baseURL) - payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -1036,11 +1044,14 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) - var authID, authLabel, authType, authValue string + var authID, authLabel, authType, authValue, authTier string if auth != nil { authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() + if tierID, ok := auth.Metadata["tier_id"].(string); ok { + authTier = tierID + } } var lastStatus int @@ -1083,6 +1094,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, + Tier: authTier, }) httpResp, errDo := httpClient.Do(httpReq) @@ -1394,6 +1406,11 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau return auth, errUnmarshal } + // Preserve tier info before refresh + tierID, _ := auth.Metadata["tier_id"].(string) + tierName, _ := auth.Metadata["tier_name"].(string) + tierIsPaid, _ := auth.Metadata["tier_is_paid"].(bool) + if auth.Metadata == nil { auth.Metadata = make(map[string]any) } @@ -1408,7 +1425,31 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau auth.Metadata["type"] = antigravityAuthType if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil { log.Warnf("antigravity executor: ensure project id failed: %v", errProject) + log.Infof("antigravity executor: blocking auth %s for 30 minutes due to project id failure", auth.ID) + if auth.ModelStates == nil { + auth.ModelStates = make(map[string]*cliproxyauth.ModelState) + } + auth.ModelStates[""] = &cliproxyauth.ModelState{ + Status: cliproxyauth.StatusDisabled, + Unavailable: true, + NextRetryAfter: time.Now().Add(30 * time.Minute), + UpdatedAt: time.Now(), + LastError: &cliproxyauth.Error{Code: "project_id_failed", Message: errProject.Error()}, + StatusMessage: "blocked due to project id failure", + } + } + + // Restore preserved tier info + if tierID != "" { + auth.Metadata["tier_id"] = tierID + } + if tierName != "" { + auth.Metadata["tier_name"] = tierName } + if tierIsPaid { + auth.Metadata["tier_is_paid"] = tierIsPaid + } + return auth, nil } @@ -1481,7 +1522,12 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau } } payload = geminiToAntigravity(modelName, payload, projectID) - payload, _ = sjson.SetBytes(payload, "model", modelName) + resolvedModel := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if resolvedModel == "" { + resolvedModel = modelName + } + payload, _ = sjson.SetBytes(payload, "model", resolvedModel) + modelName = resolvedModel useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro") payloadStr := string(payload) @@ -1528,11 +1574,14 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau httpReq.Host = host } - var authID, authLabel, authType, authValue string + var authID, authLabel, authType, authValue, authTier string if auth != nil { authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() + if tierID, ok := auth.Metadata["tier_id"].(string); ok { + authTier = tierID + } } var payloadLog []byte if e.cfg != nil && e.cfg.RequestLog { @@ -1548,6 +1597,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, + Tier: authTier, }) return httpReq, nil @@ -1733,7 +1783,25 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string { } func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte { - template, _ := sjson.Set(string(payload), "model", modelName) + requestType := gjson.GetBytes(payload, "requestType").String() + if strings.TrimSpace(requestType) == "" { + if gjson.GetBytes(payload, "request.tools.0.googleSearch").Exists() { + requestType = "web_search" + } else { + requestType = "agent" + } + } + resolvedModel := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if requestType == "web_search" { + if resolvedModel == "" { + resolvedModel = "gemini-2.5-flash" + } + } + if resolvedModel == "" { + resolvedModel = modelName + } + + template, _ := sjson.Set(string(payload), "model", resolvedModel) template, _ = sjson.Set(template, "userAgent", "antigravity") isImageModel := strings.Contains(modelName, "image") @@ -1765,6 +1833,46 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw) template, _ = sjson.Delete(template, "toolConfig") } + if strings.Contains(modelName, "claude") { + template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") + } + + // Clean tool parameters schema for all models (both Claude and Gemini) + // This handles unsupported keywords like anyOf, oneOf, $ref, complex type arrays, etc. + gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool { + tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool { + // Check both parametersJsonSchema and parameters fields + var paramsRaw string + var paramsPath string + if funcDecl.Get("parametersJsonSchema").Exists() { + paramsRaw = funcDecl.Get("parametersJsonSchema").Raw + paramsPath = fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()) + } else if funcDecl.Get("parameters").Exists() { + paramsRaw = funcDecl.Get("parameters").Raw + paramsPath = fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()) + } + + if paramsRaw != "" { + // Clean the schema to be compatible with Gemini API + cleanedSchema := util.CleanJSONSchemaForAntigravity(paramsRaw) + // Set to parameters field (Gemini API expects "parameters", not "parametersJsonSchema") + template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), cleanedSchema) + // Remove $schema if present + template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int())) + // Remove parametersJsonSchema if it was the source + if paramsPath != fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()) { + template, _ = sjson.Delete(template, paramsPath) + } + } + return true + }) + return true + }) + + if !strings.Contains(modelName, "claude") { + template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens") + } + return []byte(template) } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7d0ddcf2d2..d2ea32b65b 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -179,6 +179,14 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r AuthValue: authValue, }) + log.WithFields(log.Fields{ + "auth_id": authID, + "provider": e.Identifier(), + "model": baseModel, + "url": url, + "method": http.MethodPost, + }).Infof("external HTTP request: POST %s", url) + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpResp, err := httpClient.Do(httpReq) if err != nil { diff --git a/internal/runtime/executor/cline_executor.go b/internal/runtime/executor/cline_executor.go new file mode 100644 index 0000000000..11beae3fd5 --- /dev/null +++ b/internal/runtime/executor/cline_executor.go @@ -0,0 +1,657 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "runtime" + "strconv" + "strings" + "time" + + clineauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cline" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + clineVersion = "3.0.0" + clineBaseURL = "https://api.cline.bot/api/v1" + clineModelsEndpoint = "/ai/cline/models" + clineChatEndpoint = "/chat/completions" +) + +func clineTokenAuthValue(token string) string { + t := strings.TrimSpace(token) + if t == "" { + return "" + } + if strings.HasPrefix(t, "workos:") { + return "Bearer " + t + } + return "Bearer workos:" + t +} + +// ClineExecutor handles requests to Cline API. +type ClineExecutor struct { + cfg *config.Config +} + +// NewClineExecutor creates a new Cline executor instance. +func NewClineExecutor(cfg *config.Config) *ClineExecutor { + return &ClineExecutor{cfg: cfg} +} + +// Identifier returns the unique identifier for this executor. +func (e *ClineExecutor) Identifier() string { return "cline" } + +// PrepareRequest prepares the HTTP request before execution. +func (e *ClineExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if req == nil { + return nil + } + accessToken, err := e.ensureFreshAccessToken(req.Context(), auth) + if err != nil { + return err + } + if strings.TrimSpace(accessToken) == "" { + return fmt.Errorf("cline: missing access token") + } + + req.Header.Set("Authorization", clineTokenAuthValue(accessToken)) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(req, attrs) + return nil +} + +// HttpRequest executes a raw HTTP request. +func (e *ClineExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("cline executor: request is nil") + } + if ctx == nil { + ctx = req.Context() + } + httpReq := req.WithContext(ctx) + if err := e.PrepareRequest(httpReq, auth); err != nil { + return nil, err + } + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + return httpClient.Do(httpReq) +} + +// Execute performs a non-streaming request. +func (e *ClineExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + + reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.trackFailure(ctx, &err) + + accessToken, err := e.ensureFreshAccessToken(ctx, auth) + if err != nil { + return resp, err + } + if accessToken == "" { + return resp, fmt.Errorf("cline: missing access token") + } + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + endpoint := clineChatEndpoint + + originalPayloadSource := req.Payload + if len(opts.OriginalRequest) > 0 { + originalPayloadSource = opts.OriginalRequest + } + originalPayload := originalPayloadSource + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream) + translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream) + requestedModel := payloadRequestedModel(opts, req.Model) + translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) + translated = applyClineOpenRouterParity(translated, false) + + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return resp, err + } + + url := clineBaseURL + endpoint + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated)) + if err != nil { + return resp, err + } + applyClineHeaders(httpReq, accessToken, false) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translated, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer httpResp.Body.Close() + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + + body, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, body) + reporter.publish(ctx, parseOpenAIUsage(body)) + reporter.ensurePublished(ctx) + + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} + +// ExecuteStream performs a streaming request. +func (e *ClineExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + + reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.trackFailure(ctx, &err) + + accessToken, err := e.ensureFreshAccessToken(ctx, auth) + if err != nil { + return nil, err + } + if accessToken == "" { + return nil, fmt.Errorf("cline: missing access token") + } + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + endpoint := clineChatEndpoint + + originalPayloadSource := req.Payload + if len(opts.OriginalRequest) > 0 { + originalPayloadSource = opts.OriginalRequest + } + originalPayload := originalPayloadSource + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) + translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + requestedModel := payloadRequestedModel(opts, req.Model) + translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) + translated = applyClineOpenRouterParity(translated, true) + + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return nil, err + } + + url := clineBaseURL + endpoint + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated)) + if err != nil { + return nil, err + } + applyClineHeaders(httpReq, accessToken, true) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translated, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + httpResp.Body.Close() + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return nil, err + } + + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + defer httpResp.Body.Close() + + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 52_428_800) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseOpenAIStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + reporter.ensurePublished(ctx) + }() + + return &cliproxyexecutor.StreamResult{ + Headers: httpResp.Header.Clone(), + Chunks: out, + }, nil +} + +// Refresh validates the Cline token. +func (e *ClineExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, fmt.Errorf("missing auth") + } + return auth, nil +} + +// CountTokens returns the token count for the given request. +func (e *ClineExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, fmt.Errorf("cline: count tokens not supported") +} + +// clineAccessToken extracts access token from auth. +func clineAccessToken(auth *cliproxyauth.Auth) string { + if auth == nil { + return "" + } + + // Check metadata first, then attributes + if auth.Metadata != nil { + if token, ok := auth.Metadata["accessToken"].(string); ok && token != "" { + return token + } + if token, ok := auth.Metadata["access_token"].(string); ok && token != "" { + return token + } + if token, ok := auth.Metadata["token"].(string); ok && token != "" { + return token + } + } + + if auth.Attributes != nil { + if token := auth.Attributes["accessToken"]; token != "" { + return token + } + if token := auth.Attributes["access_token"]; token != "" { + return token + } + if token := auth.Attributes["token"]; token != "" { + return token + } + } + + return "" +} + +func clineRefreshToken(auth *cliproxyauth.Auth) string { + if auth == nil { + return "" + } + if auth.Metadata != nil { + if token, ok := auth.Metadata["refreshToken"].(string); ok && strings.TrimSpace(token) != "" { + return strings.TrimSpace(token) + } + if token, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(token) != "" { + return strings.TrimSpace(token) + } + } + if auth.Attributes != nil { + if token := strings.TrimSpace(auth.Attributes["refreshToken"]); token != "" { + return token + } + if token := strings.TrimSpace(auth.Attributes["refresh_token"]); token != "" { + return token + } + } + return "" +} + +func (e *ClineExecutor) ensureFreshAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) { + accessToken := clineAccessToken(auth) + if strings.TrimSpace(accessToken) == "" { + return "", fmt.Errorf("cline: missing access token") + } + + refreshToken := clineRefreshToken(auth) + if refreshToken == "" { + return accessToken, nil + } + + authSvc := clineauth.NewClineAuth(e.cfg) + refreshed, err := authSvc.RefreshToken(ctx, refreshToken) + if err != nil { + log.Warnf("cline: token refresh failed, fallback to current token: %v", err) + return accessToken, nil + } + if refreshed == nil || strings.TrimSpace(refreshed.AccessToken) == "" { + return accessToken, nil + } + + newAccessToken := strings.TrimSpace(refreshed.AccessToken) + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["accessToken"] = newAccessToken + auth.Metadata["access_token"] = newAccessToken + + if strings.TrimSpace(refreshed.RefreshToken) != "" { + newRefresh := strings.TrimSpace(refreshed.RefreshToken) + auth.Metadata["refreshToken"] = newRefresh + auth.Metadata["refresh_token"] = newRefresh + } + + if strings.TrimSpace(refreshed.ExpiresAt) != "" { + if t, parseErr := time.Parse(time.RFC3339Nano, refreshed.ExpiresAt); parseErr == nil { + auth.Metadata["expiresAt"] = t.Unix() + auth.Metadata["expires_at"] = t.Format(time.RFC3339) + } else if t, parseErr2 := time.Parse(time.RFC3339, refreshed.ExpiresAt); parseErr2 == nil { + auth.Metadata["expiresAt"] = t.Unix() + auth.Metadata["expires_at"] = t.Format(time.RFC3339) + } + } + + return newAccessToken, nil +} + +// applyClineHeaders sets the standard Cline headers. +func applyClineHeaders(r *http.Request, token string, stream bool) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", clineTokenAuthValue(token)) + r.Header.Set("HTTP-Referer", "https://cline.bot") + r.Header.Set("X-Title", "Cline") + r.Header.Set("X-Task-ID", "") + r.Header.Set("X-CLIENT-TYPE", "cli") + r.Header.Set("X-CORE-VERSION", clineVersion) + r.Header.Set("X-IS-MULTIROOT", "false") + r.Header.Set("X-CLIENT-VERSION", clineVersion) + r.Header.Set("X-PLATFORM", runtime.GOOS) + r.Header.Set("X-PLATFORM-VERSION", runtime.Version()) + r.Header.Set("User-Agent", "Cline/"+clineVersion) + if stream { + r.Header.Set("Accept", "text/event-stream") + r.Header.Set("Cache-Control", "no-cache") + } else { + r.Header.Set("Accept", "application/json") + } +} + +func applyClineOpenRouterParity(payload []byte, stream bool) []byte { + if len(payload) == 0 { + return payload + } + + out := payload + if stream { + if updated, err := sjson.SetRawBytes(out, "stream_options", []byte(`{"include_usage":true}`)); err == nil { + out = updated + } + if updated, err := sjson.SetBytes(out, "include_reasoning", true); err == nil { + out = updated + } + } else { + if updated, err := sjson.DeleteBytes(out, "stream_options"); err == nil { + out = updated + } + if updated, err := sjson.SetBytes(out, "include_reasoning", true); err == nil { + out = updated + } + } + + modelID := strings.TrimSpace(gjson.GetBytes(out, "model").String()) + if modelID == "" { + return out + } + + if strings.Contains(modelID, "kwaipilot/kat-coder-pro") { + trimmedModel := strings.TrimSuffix(modelID, ":free") + if updated, err := sjson.SetBytes(out, "model", trimmedModel); err == nil { + out = updated + } + if updated, err := sjson.SetRawBytes(out, "provider", []byte(`{"sort":"throughput"}`)); err == nil { + out = updated + } + } + + return out +} + +// ClineModel represents a model from Cline API. +type ClineModel struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + MaxTokens int `json:"max_tokens"` + ContextLen int `json:"context_length"` + Pricing struct { + Prompt string `json:"prompt"` + Completion string `json:"completion"` + InputCacheRead string `json:"input_cache_read"` + WebSearch string `json:"web_search"` + } `json:"pricing"` +} + +func clineIsFreeModel(m ClineModel) bool { + promptRaw := strings.TrimSpace(m.Pricing.Prompt) + completionRaw := strings.TrimSpace(m.Pricing.Completion) + if promptRaw == "" || completionRaw == "" { + return false + } + promptPrice, errPrompt := strconv.ParseFloat(promptRaw, 64) + completionPrice, errCompletion := strconv.ParseFloat(completionRaw, 64) + if errPrompt != nil || errCompletion != nil { + return false + } + return promptPrice == 0 && completionPrice == 0 +} + +// FetchClineModels fetches models from Cline API. +// The model list endpoint does not require authentication. +func FetchClineModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { + log.Debugf("cline: fetching dynamic models from API") + + httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, clineBaseURL+clineModelsEndpoint, nil) + if err != nil { + log.Warnf("cline: failed to create model fetch request: %v", err) + return nil + } + + req.Header.Set("User-Agent", "Cline/"+clineVersion) + req.Header.Set("HTTP-Referer", "https://cline.bot") + req.Header.Set("X-Title", "Cline") + req.Header.Set("X-CLIENT-TYPE", "cli") + req.Header.Set("X-CORE-VERSION", clineVersion) + req.Header.Set("X-IS-MULTIROOT", "false") + req.Header.Set("X-CLIENT-VERSION", clineVersion) + req.Header.Set("X-PLATFORM", runtime.GOOS) + req.Header.Set("X-PLATFORM-VERSION", runtime.Version()) + + resp, err := httpClient.Do(req) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Warnf("cline: fetch models canceled: %v", err) + } else { + log.Warnf("cline: fetch models failed: %v", err) + } + return nil + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Warnf("cline: failed to read models response: %v", err) + return nil + } + + if resp.StatusCode != http.StatusOK { + log.Warnf("cline: fetch models failed: status %d, body: %s", resp.StatusCode, string(body)) + return nil + } + + // Parse models response + var modelsResponse struct { + Data []ClineModel `json:"data"` + } + if err := json.Unmarshal(body, &modelsResponse); err != nil { + log.Warnf("cline: failed to parse models response: %v", err) + return nil + } + + // Also try gjson parsing as fallback + if len(modelsResponse.Data) == 0 { + result := gjson.GetBytes(body, "data") + if !result.Exists() { + // Try root if data field is missing + result = gjson.ParseBytes(body) + if !result.IsArray() { + log.Debugf("cline: response body: %s", string(body)) + log.Warn("cline: invalid API response format (expected array or data field with array)") + return nil + } + } + result.ForEach(func(key, value gjson.Result) bool { + id := value.Get("id").String() + if id == "" { + return true + } + modelsResponse.Data = append(modelsResponse.Data, ClineModel{ + ID: id, + Name: value.Get("name").String(), + ContextLen: int(value.Get("context_length").Int()), + MaxTokens: int(value.Get("max_tokens").Int()), + Pricing: struct { + Prompt string `json:"prompt"` + Completion string `json:"completion"` + InputCacheRead string `json:"input_cache_read"` + WebSearch string `json:"web_search"` + }{ + Prompt: value.Get("pricing.prompt").String(), + Completion: value.Get("pricing.completion").String(), + InputCacheRead: value.Get("pricing.input_cache_read").String(), + WebSearch: value.Get("pricing.web_search").String(), + }, + }) + return true + }) + } + + now := time.Now().Unix() + var dynamicModels []*registry.ModelInfo + count := 0 + + for _, m := range modelsResponse.Data { + if m.ID == "" { + continue + } + if !clineIsFreeModel(m) { + continue + } + contextLen := m.ContextLen + if contextLen == 0 { + contextLen = 200000 // Default context length + } + maxTokens := m.MaxTokens + if maxTokens == 0 { + maxTokens = 64000 // Default max tokens + } + displayName := m.Name + if displayName == "" { + displayName = m.ID + } + + dynamicModels = append(dynamicModels, ®istry.ModelInfo{ + ID: m.ID, + DisplayName: displayName, + Description: m.Description, + ContextLength: contextLen, + MaxCompletionTokens: maxTokens, + OwnedBy: "cline", + Type: "cline", + Object: "model", + Created: now, + }) + count++ + } + + log.Infof("cline: fetched %d free models from API", count) + return dynamicModels +} diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 1be245b702..f69399951c 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -124,7 +124,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -230,7 +230,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth lastStatus = httpResp.StatusCode lastBody = append([]byte(nil), data...) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), data) if httpResp.StatusCode == 429 { if idx+1 < len(models) { log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) @@ -278,7 +278,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -367,7 +367,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut appendAPIResponseChunk(ctx, e.cfg, data) lastStatus = httpResp.StatusCode lastBody = append([]byte(nil), data...) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), data) if httpResp.StatusCode == 429 { if idx+1 < len(models) { log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) @@ -406,7 +406,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } } - segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone([]byte("[DONE]")), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } @@ -433,7 +433,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } - segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m) + segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone([]byte("[DONE]")), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } @@ -485,7 +485,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. // The loop variable attemptModel is only used as the concrete model id sent to the upstream // Gemini CLI endpoint when iterating fallback variants. for range models { - payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 7c25b8935f..6f6b354318 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -122,7 +122,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -192,7 +192,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) err = statusErr{code: httpResp.StatusCode, msg: string(b)} return resp, err } @@ -229,7 +229,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -290,7 +290,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("gemini executor: close response body error: %v", errClose) } @@ -345,7 +345,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -409,7 +409,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } appendAPIResponseChunk(ctx, e.cfg, data) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data)) + logDetailedAPIError(ctx, e.Identifier(), url, resp.StatusCode, resp.Header.Get("Content-Type"), data) return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)} } diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 7ad1c6186b..673ef83c40 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -324,7 +324,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -396,7 +396,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) err = statusErr{code: httpResp.StatusCode, msg: string(b)} return resp, err } @@ -439,7 +439,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -511,7 +511,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) err = statusErr{code: httpResp.StatusCode, msg: string(b)} return resp, err } @@ -544,7 +544,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -610,7 +610,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("vertex executor: close response body error: %v", errClose) } @@ -668,7 +668,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -734,7 +734,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("vertex executor: close response body error: %v", errClose) } @@ -783,7 +783,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -846,7 +846,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} } data, errRead := io.ReadAll(httpResp.Body) @@ -867,7 +867,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) if err != nil { @@ -930,7 +930,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} } data, errRead := io.ReadAll(httpResp.Body) diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 65a0b8f81e..f0f84aa839 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net/http" + "runtime" "strings" "time" @@ -28,7 +29,8 @@ import ( const ( iflowDefaultEndpoint = "/chat/completions" - iflowUserAgent = "iFlow-Cli" + // iflowUserAgentPrefix matches the official iFlow CLI format: iFlowCLI/0.5.14 + iflowUserAgentPrefix = "iFlowCLI/0.5.14" ) // IFlowExecutor executes OpenAI-compatible chat completions against the iFlow API using API keys derived from OAuth. @@ -97,7 +99,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, _ = sjson.SetBytes(body, "model", baseModel) body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier()) @@ -150,7 +152,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + log.Debugf("iflow request error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) err = statusErr{code: httpResp.StatusCode, msg: string(b)} return resp, err } @@ -200,7 +202,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier()) @@ -256,8 +258,10 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au log.Errorf("iflow executor: close response body error: %v", errClose) } appendAPIResponseChunk(ctx, e.cfg, data) - logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - err = statusErr{code: httpResp.StatusCode, msg: string(data)} + bodyStr := string(data) + summary := summarizeErrorBody(httpResp.Header.Get("Content-Type"), data) + log.Errorf("iflow streaming error: status %d, summary: %s, full body: %s", httpResp.StatusCode, summary, bodyStr) + err = statusErr{code: httpResp.StatusCode, msg: bodyStr} return nil, err } @@ -301,7 +305,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) enc, err := tokenizerForModel(baseModel) if err != nil { @@ -436,10 +440,15 @@ func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyaut auth.Metadata["api_key"] = tokenData.APIKey } auth.Metadata["expired"] = tokenData.Expire + auth.Metadata["expires_at"] = tokenData.Expire auth.Metadata["type"] = "iflow" auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) - // Log the new access token (masked) after successful refresh + if expiresAt, err := time.Parse(time.RFC3339, tokenData.Expire); err == nil { + auth.NextRefreshAfter = expiresAt.Add(-36 * time.Hour) + log.Debugf("iflow executor: set NextRefreshAfter to %v", auth.NextRefreshAfter.Format(time.RFC3339)) + } + log.Debugf("iflow executor: token refresh successful, new: %s", util.HideAPIKey(tokenData.AccessToken)) if auth.Attributes == nil { @@ -455,7 +464,10 @@ func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyaut func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) { r.Header.Set("Content-Type", "application/json") r.Header.Set("Authorization", "Bearer "+apiKey) - r.Header.Set("User-Agent", iflowUserAgent) + + // Build User-Agent matching official iFlow CLI: iFlowCLI/0.5.14 (linux; amd64) + userAgent := buildIFlowUserAgent() + r.Header.Set("User-Agent", userAgent) // Generate session-id sessionID := "session-" + generateUUID() @@ -465,7 +477,8 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) { timestamp := time.Now().UnixMilli() r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp)) - signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey) + // Signature uses the same User-Agent string for HMAC calculation + signature := createIFlowSignature(userAgent, sessionID, timestamp, apiKey) if signature != "" { r.Header.Set("x-iflow-signature", signature) } @@ -477,6 +490,22 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) { } } +// buildIFlowUserAgent constructs a User-Agent string matching the official iFlow CLI format. +// Example: iFlowCLI/0.5.14 (linux; amd64) +func buildIFlowUserAgent() string { + // Map Go's runtime.GOARCH to common architecture names + arch := runtime.GOARCH + switch arch { + case "amd64": + arch = "x64" + case "arm64": + arch = "arm64" + case "386": + arch = "x86" + } + return fmt.Sprintf("%s (%s; %s)", iflowUserAgentPrefix, runtime.GOOS, arch) +} + // createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests. // The signature payload format is: userAgent:sessionId:timestamp func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string { diff --git a/internal/runtime/executor/kilo_executor.go b/internal/runtime/executor/kilo_executor.go index 34f620230f..5c3db89aab 100644 --- a/internal/runtime/executor/kilo_executor.go +++ b/internal/runtime/executor/kilo_executor.go @@ -22,6 +22,11 @@ import ( "github.com/tidwall/gjson" ) +const ( + kiloVersion = "3.26.0" + kiloTesterHeader = "X-Kilocode-Tester" +) + // KiloExecutor handles requests to Kilo API. type KiloExecutor struct { cfg *config.Config @@ -106,12 +111,7 @@ func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req if err != nil { return resp, err } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if orgID != "" { - httpReq.Header.Set("X-Kilocode-OrganizationID", orgID) - } - httpReq.Header.Set("User-Agent", "cli-proxy-kilo") + applyKiloHeaders(httpReq, accessToken, orgID, false) var attrs map[string]string if auth != nil { attrs = auth.Attributes @@ -203,14 +203,7 @@ func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut if err != nil { return nil, err } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - if orgID != "" { - httpReq.Header.Set("X-Kilocode-OrganizationID", orgID) - } - httpReq.Header.Set("User-Agent", "cli-proxy-kilo") - httpReq.Header.Set("Accept", "text/event-stream") - httpReq.Header.Set("Cache-Control", "no-cache") + applyKiloHeaders(httpReq, accessToken, orgID, true) var attrs map[string]string if auth != nil { @@ -315,6 +308,8 @@ func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) { if auth.Metadata != nil { if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" { accessToken = token + } else if token, ok := auth.Metadata["token"].(string); ok && token != "" { + accessToken = token } else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" { accessToken = token } @@ -329,6 +324,8 @@ func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) { if accessToken == "" && auth.Attributes != nil { if token := auth.Attributes["kilocodeToken"]; token != "" { accessToken = token + } else if token := auth.Attributes["token"]; token != "" { + accessToken = token } else if token := auth.Attributes["access_token"]; token != "" { accessToken = token } @@ -458,3 +455,23 @@ func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.C return allModels } + +func applyKiloHeaders(r *http.Request, token, orgID string, stream bool) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", "Bearer "+token) + if orgID != "" { + r.Header.Set("X-Kilocode-OrganizationID", orgID) + } + r.Header.Set("HTTP-Referer", "https://kilocode.ai") + r.Header.Set("X-Title", "Kilo Code") + r.Header.Set("X-KiloCode-Version", kiloVersion) + r.Header.Set("User-Agent", "Kilo-Code/"+kiloVersion) + r.Header.Set(kiloTesterHeader, "SUPPRESS") + r.Header.Set("X-KiloCode-EditorName", "Visual Studio Code 1.96.0") + if stream { + r.Header.Set("Accept", "text/event-stream") + r.Header.Set("Cache-Control", "no-cache") + } else { + r.Header.Set("Accept", "application/json") + } +} diff --git a/internal/runtime/executor/kilocode_executor.go b/internal/runtime/executor/kilocode_executor.go new file mode 100644 index 0000000000..c46eb4cc60 --- /dev/null +++ b/internal/runtime/executor/kilocode_executor.go @@ -0,0 +1,366 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" +) + +const ( + // Kilocode API base URL - must match VS Code extension format + // VS Code extension uses: getKiloUrlFromToken("https://api.kilo.ai/api/", token) + "openrouter/" + kilocodeBaseURL = "https://api.kilo.ai/api/openrouter" + kilocodeChatPath = "/chat/completions" + kilocodeAuthType = "kilocode" + // Kilocode VS Code extension version - used for API compatibility + kilocodeVersion = "3.26.0" +) + +// KilocodeExecutor handles requests to the Kilocode API. +type KilocodeExecutor struct { + cfg *config.Config +} + +// normalizeKilocodeModelForAPI strips "kilocode-" prefix and normalizes model names for API calls. +// Preserves ":free" suffix which Kilocode API requires for free model access. +func normalizeKilocodeModelForAPI(model string) string { + resolved := registry.ResolveKilocodeModelAlias(model) + normalized := strings.TrimPrefix(resolved, "kilocode-") + + freeSuffix := "" + if strings.HasSuffix(normalized, ":free") { + freeSuffix = ":free" + normalized = strings.TrimSuffix(normalized, ":free") + } + + if strings.HasPrefix(normalized, "glm-4-") { + normalized = strings.Replace(normalized, "glm-4-", "glm-4.", 1) + } + + if strings.HasPrefix(normalized, "kimi-k2-") { + normalized = strings.Replace(normalized, "kimi-k2-", "kimi-k2.", 1) + } + + normalized = normalized + freeSuffix + + log.Debugf("[DEBUG] normalizeKilocodeModelForAPI: input=%s -> output=%s", model, normalized) + return normalized +} + +// NewKilocodeExecutor constructs a new executor instance. +func NewKilocodeExecutor(cfg *config.Config) *KilocodeExecutor { + return &KilocodeExecutor{ + cfg: cfg, + } +} + +// Identifier implements ProviderExecutor. +func (e *KilocodeExecutor) Identifier() string { return kilocodeAuthType } + +// PrepareRequest implements ProviderExecutor. +func (e *KilocodeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if req == nil { + return nil + } + + token := metaStringValue(auth.Metadata, "token") + if token == "" { + return statusErr{code: http.StatusUnauthorized, msg: "missing kilocode token"} + } + + e.applyHeaders(req, token) + return nil +} + +// HttpRequest injects Kilocode credentials into the request and executes it. +func (e *KilocodeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("kilocode executor: request is nil") + } + if ctx == nil { + ctx = req.Context() + } + httpReq := req.WithContext(ctx) + if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil { + return nil, errPrepare + } + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + return httpClient.Do(httpReq) +} + +// Execute handles non-streaming requests to Kilocode. +func (e *KilocodeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + token := metaStringValue(auth.Metadata, "token") + if token == "" { + return resp, statusErr{code: http.StatusUnauthorized, msg: "missing kilocode token"} + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + log.Infof("[KILOCODE-EXEC] Execute: req.Model=%s", req.Model) + normalizedModel := normalizeKilocodeModelForAPI(req.Model) + log.Infof("[KILOCODE-EXEC] Execute: normalizedModel=%s", normalizedModel) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + originalPayload := bytes.Clone(req.Payload) + if len(opts.OriginalRequest) > 0 { + originalPayload = bytes.Clone(opts.OriginalRequest) + } + originalTranslated := sdktranslator.TranslateRequest(from, to, normalizedModel, originalPayload, false) + body := sdktranslator.TranslateRequest(from, to, normalizedModel, bytes.Clone(req.Payload), false) + requestedModel := payloadRequestedModel(opts, normalizedModel) + body = applyPayloadConfigWithRoot(e.cfg, normalizedModel, to.String(), "", body, originalTranslated, requestedModel) + body, _ = sjson.SetBytes(body, "stream", false) + body, _ = sjson.SetBytes(body, "model", normalizedModel) + + url := kilocodeBaseURL + kilocodeChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return resp, err + } + e.applyHeaders(httpReq, token) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("kilocode executor: close response body error: %v", errClose) + } + }() + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if !isHTTPSuccess(httpResp.StatusCode) { + data, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("kilocode executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return resp, err + } + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + + detail := parseOpenAIUsage(data) + if detail.TotalTokens > 0 { + reporter.publish(ctx, detail) + } + + var param any + converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(converted)} + reporter.ensurePublished(ctx) + return resp, nil +} + +// ExecuteStream handles streaming requests to Kilocode. +func (e *KilocodeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + token := metaStringValue(auth.Metadata, "token") + if token == "" { + return nil, statusErr{code: http.StatusUnauthorized, msg: "missing kilocode token"} + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + normalizedModel := normalizeKilocodeModelForAPI(req.Model) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + originalPayload := bytes.Clone(req.Payload) + if len(opts.OriginalRequest) > 0 { + originalPayload = bytes.Clone(opts.OriginalRequest) + } + originalTranslated := sdktranslator.TranslateRequest(from, to, normalizedModel, originalPayload, false) + body := sdktranslator.TranslateRequest(from, to, normalizedModel, bytes.Clone(req.Payload), true) + requestedModel := payloadRequestedModel(opts, normalizedModel) + body = applyPayloadConfigWithRoot(e.cfg, normalizedModel, to.String(), "", body, originalTranslated, requestedModel) + body, _ = sjson.SetBytes(body, "stream", true) + body, _ = sjson.SetBytes(body, "model", normalizedModel) + // Enable stream options for usage stats in stream + body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) + + url := kilocodeBaseURL + kilocodeChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + e.applyHeaders(httpReq, token) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if !isHTTPSuccess(httpResp.StatusCode) { + data, readErr := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("kilocode executor: close response body error: %v", errClose) + } + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + return nil, readErr + } + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("kilocode executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return nil, err + } + + out := make(chan cliproxyexecutor.StreamChunk) + + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("kilocode executor: close response body error: %v", errClose) + } + }() + + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, maxScannerBufferSize) + var param any + + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + + // Skip empty lines (SSE keepalive) + if len(line) == 0 { + continue + } + + // Skip non-data lines (SSE comments like ": OPENROUTER PROCESSING", event types, etc.) + // This prevents JSON parse errors when OpenRouter sends keepalive comments + if !bytes.HasPrefix(line, dataTag) { + continue + } + + // Parse SSE data + data := bytes.TrimSpace(line[5:]) + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if detail, ok := parseOpenAIStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } else { + reporter.ensurePublished(ctx) + } + }() + + return &cliproxyexecutor.StreamResult{ + Headers: httpResp.Header.Clone(), + Chunks: out, + }, nil +} + +// CountTokens is not supported for Kilocode. +func (e *KilocodeExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for kilocode"} +} + +// Refresh validates the Kilocode token is still working. +// Kilocode API only supports /chat/completions endpoint, so we skip validation +// and return the auth as-is. Token validation will happen naturally during actual requests. +func (e *KilocodeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + token := metaStringValue(auth.Metadata, "token") + if token == "" { + return auth, nil + } + + // Kilocode API only supports /chat/completions, so we skip token validation here + // Token validity will be checked during actual API requests + return auth, nil +} + +const kilocodeTesterHeader = "X-Kilocode-Tester" + +func (e *KilocodeExecutor) applyHeaders(r *http.Request, token string) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", "Bearer "+token) + r.Header.Set("Accept", "application/json") + r.Header.Set("HTTP-Referer", "https://kilocode.ai") + r.Header.Set("X-Title", "Kilo Code") + r.Header.Set("X-KiloCode-Version", kilocodeVersion) + r.Header.Set("User-Agent", "Kilo-Code/"+kilocodeVersion) + r.Header.Set(kilocodeTesterHeader, "SUPPRESS") + r.Header.Set("X-KiloCode-EditorName", "Visual Studio Code 1.96.0") +} diff --git a/internal/runtime/executor/logging_helpers.go b/internal/runtime/executor/logging_helpers.go index ae2aee3ffd..126f8bbd2e 100644 --- a/internal/runtime/executor/logging_helpers.go +++ b/internal/runtime/executor/logging_helpers.go @@ -14,6 +14,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -35,6 +36,7 @@ type upstreamRequestLog struct { AuthLabel string AuthType string AuthValue string + Tier string } type upstreamAttempt struct { @@ -80,7 +82,7 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ writeHeaders(builder, info.Headers) builder.WriteString("\nBody:\n") if len(info.Body) > 0 { - builder.WriteString(string(info.Body)) + builder.WriteString(string(bytes.Clone(info.Body))) } else { builder.WriteString("") } @@ -152,7 +154,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt if cfg == nil || !cfg.RequestLog { return } - data := bytes.TrimSpace(chunk) + data := bytes.TrimSpace(bytes.Clone(chunk)) if len(data) == 0 { return } @@ -296,6 +298,9 @@ func formatAuthInfo(info upstreamRequestLog) string { if trimmed := strings.TrimSpace(info.AuthLabel); trimmed != "" { parts = append(parts, fmt.Sprintf("label=%s", trimmed)) } + if trimmed := strings.TrimSpace(info.Tier); trimmed != "" { + parts = append(parts, fmt.Sprintf("tier=%s", trimmed)) + } authType := strings.ToLower(strings.TrimSpace(info.AuthType)) authValue := strings.TrimSpace(info.AuthValue) @@ -389,3 +394,39 @@ func logWithRequestID(ctx context.Context) *log.Entry { } return log.WithField("request_id", requestID) } + +// logDetailedAPIError logs detailed error information for API errors at Warn/Error level. +// This function logs the full error body, URL, status code, and provider information. +// 4xx errors are logged at Warn level, 5xx errors at Error level. +func logDetailedAPIError(ctx context.Context, provider string, url string, statusCode int, contentType string, body []byte) { + entry := logWithRequestID(ctx) + + // 4xx๋Š” Warn, 5xx๋Š” Error + logFn := entry.Warnf + if statusCode >= 500 { + logFn = entry.Errorf + } + + // ์ „์ฒด ์—๋Ÿฌ ๋ฐ”๋”” ๋กœ๊น… (๋‹จ, ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋ƒ„) + bodyStr := string(body) + if len(bodyStr) > 4096 { + bodyStr = bodyStr[:4096] + "...[truncated]" + } + + // Extract auth info from context for logging + providerDisplay := provider + if ctxProvider, _, authLabel := cliproxyauth.GetProviderAuthFromContext(ctx); ctxProvider != "" { + displayAuth := authLabel + if displayAuth == "" { + if _, authID, _ := cliproxyauth.GetProviderAuthFromContext(ctx); authID != "" { + displayAuth = authID + } + } + if displayAuth != "" { + providerDisplay = fmt.Sprintf("%s:%s", provider, displayAuth) + } + } + + logFn("[%s] API error - URL: %s, Status: %d, Content-Type: %s, Response: %s", + providerDisplay, url, statusCode, contentType, bodyStr) +} diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index d28b36251a..1c3e25b272 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -94,7 +94,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream) - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream) + translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream) requestedModel := payloadRequestedModel(opts, req.Model) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) if opts.Alt == "responses/compact" { @@ -156,7 +156,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) err = statusErr{code: httpResp.StatusCode, msg: string(b)} return resp, err } @@ -196,7 +196,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) requestedModel := payloadRequestedModel(opts, req.Model) translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel) @@ -250,7 +250,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + logDetailedAPIError(ctx, e.Identifier(), url, httpResp.StatusCode, httpResp.Header.Get("Content-Type"), b) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("openai compat executor: close response body error: %v", errClose) } @@ -305,7 +305,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau from := opts.SourceFormat to := sdktranslator.FromString("openai") - translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) modelForCounting := baseModel diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index 8998eb236b..c524167e67 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -21,6 +21,23 @@ var ( httpClientCacheMutex sync.RWMutex ) +// Default timeout constants for HTTP client transport +const ( + // defaultDialTimeout is the timeout for establishing TCP connections + defaultDialTimeout = 30 * time.Second + // defaultKeepAlive is the TCP keep-alive interval + defaultKeepAlive = 30 * time.Second + // defaultTLSHandshakeTimeout is the timeout for TLS handshake + defaultTLSHandshakeTimeout = 10 * time.Second + // defaultResponseHeaderTimeout is the timeout for receiving response headers + // This timeout only applies AFTER the request is sent - it does NOT affect streaming body reads + defaultResponseHeaderTimeout = 60 * time.Second + // defaultIdleConnTimeout is how long idle connections stay in the pool + defaultIdleConnTimeout = 90 * time.Second + // defaultExpectContinueTimeout is the timeout for 100-continue responses + defaultExpectContinueTimeout = 1 * time.Second +) + // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // 1. Use auth.ProxyURL if configured (highest priority) // 2. Use cfg.ProxyURL if auth proxy is not configured @@ -28,11 +45,16 @@ var ( // // This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse. // +// IMPORTANT: For streaming responses (SSE, AI model outputs), Client.Timeout is NOT set. +// Instead, we use Transport-level timeouts (ResponseHeaderTimeout, DialTimeout) which +// only apply to connection establishment and header reception, NOT to body reading. +// This prevents "context deadline exceeded" errors during long-running streaming responses. +// // Parameters: // - ctx: The context containing optional RoundTripper // - cfg: The application configuration // - auth: The authentication information -// - timeout: The client timeout (0 means no timeout) +// - timeout: The client timeout (0 means streaming-safe mode with no body read timeout) // // Returns: // - *http.Client: An HTTP client with configured proxy or transport @@ -55,7 +77,6 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip httpClientCacheMutex.RLock() if cachedClient, ok := httpClientCache[cacheKey]; ok { httpClientCacheMutex.RUnlock() - // Return a wrapper with the requested timeout but shared transport if timeout > 0 { return &http.Client{ Transport: cachedClient.Transport, @@ -66,13 +87,12 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip } httpClientCacheMutex.RUnlock() - // Create new client httpClient := &http.Client{} + if timeout > 0 { httpClient.Timeout = timeout } - // If we have a proxy URL configured, set up the transport if proxyURL != "" { transport := buildProxyTransport(proxyURL) if transport != nil { @@ -87,9 +107,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL) } - // Priority 3: Use RoundTripper from context (typically from RoundTripperFor) if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil { httpClient.Transport = rt + } else { + httpClient.Transport = buildDefaultTransport() } // Cache the client for no-proxy case @@ -123,9 +144,7 @@ func buildProxyTransport(proxyURL string) *http.Transport { var transport *http.Transport - // Handle different proxy schemes if parsedURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication var proxyAuth *proxy.Auth if parsedURL.User != nil { username := parsedURL.User.Username() @@ -137,15 +156,33 @@ func buildProxyTransport(proxyURL string) *http.Transport { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return nil } - // Set up a custom transport using the SOCKS5 dialer transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, + TLSHandshakeTimeout: defaultTLSHandshakeTimeout, + ResponseHeaderTimeout: defaultResponseHeaderTimeout, + IdleConnTimeout: defaultIdleConnTimeout, + ExpectContinueTimeout: defaultExpectContinueTimeout, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + ForceAttemptHTTP2: true, } } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy - transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} + transport = &http.Transport{ + Proxy: http.ProxyURL(parsedURL), + DialContext: (&net.Dialer{ + Timeout: defaultDialTimeout, + KeepAlive: defaultKeepAlive, + }).DialContext, + TLSHandshakeTimeout: defaultTLSHandshakeTimeout, + ResponseHeaderTimeout: defaultResponseHeaderTimeout, + IdleConnTimeout: defaultIdleConnTimeout, + ExpectContinueTimeout: defaultExpectContinueTimeout, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + ForceAttemptHTTP2: true, + } } else { log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) return nil @@ -153,3 +190,21 @@ func buildProxyTransport(proxyURL string) *http.Transport { return transport } + +// buildDefaultTransport creates an HTTP transport with streaming-safe timeout settings. +// ResponseHeaderTimeout protects against unresponsive servers without affecting body reads. +func buildDefaultTransport() *http.Transport { + return &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: defaultDialTimeout, + KeepAlive: defaultKeepAlive, + }).DialContext, + TLSHandshakeTimeout: defaultTLSHandshakeTimeout, + ResponseHeaderTimeout: defaultResponseHeaderTimeout, + IdleConnTimeout: defaultIdleConnTimeout, + ExpectContinueTimeout: defaultExpectContinueTimeout, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + ForceAttemptHTTP2: true, + } +} diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index e7957d2918..3fdec3d754 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -239,7 +239,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) body, _ = sjson.SetBytes(body, "model", baseModel) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) @@ -342,7 +342,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } originalPayload := originalPayloadSource originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true) body, _ = sjson.SetBytes(body, "model", baseModel) body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) @@ -442,7 +442,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, from := opts.SourceFormat to := sdktranslator.FromString("openai") - body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) modelName := gjson.GetBytes(body, "model").String() if strings.TrimSpace(modelName) == "" { diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index b8a0fcaee5..c0bdc18d73 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -405,12 +405,18 @@ func extractClaudeConfig(body []byte) ThinkingConfig { // // Priority: thinkingLevel is checked first (Gemini 3 format), then thinkingBudget (Gemini 2.5 format). // This allows newer Gemini 3 level-based configs to take precedence. +// +// Note: If both thinkingLevel and thinkingBudget are present, only thinkingLevel is used. +// This prevents the 400 error: "thinking_budget and thinking_level are not supported together" func extractGeminiConfig(body []byte, provider string) ThinkingConfig { prefix := "generationConfig.thinkingConfig" if provider == "gemini-cli" || provider == "antigravity" { prefix = "request.generationConfig.thinkingConfig" } + //levelExists := gjson.GetBytes(body, prefix+".thinkingLevel").Exists() + //budgetExists := gjson.GetBytes(body, prefix+".thinkingBudget").Exists() + // Check thinkingLevel first (Gemini 3 format takes precedence) level := gjson.GetBytes(body, prefix+".thinkingLevel") if !level.Exists() { diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 3c834f6f21..3c867f0373 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -9,6 +9,8 @@ package claude import ( "bytes" "context" + "encoding/base64" + "encoding/json" "fmt" "strings" "sync/atomic" @@ -42,6 +44,11 @@ type Params struct { // Signature caching support CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching + + // Web search support + WebSearchQuery string + WebSearchResults []map[string]any + WebSearchEmitted bool } // toolUseIDCounter provides a process-wide unique counter for tool use identifiers. @@ -276,6 +283,15 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq params.FinishReason = finishReasonResult.String() } + if q, results := extractWebSearchFromAntigravity(rawJSON); q != "" || len(results) > 0 { + if q != "" { + params.WebSearchQuery = q + } + if len(results) > 0 { + params.WebSearchResults = results + } + } + if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { params.HasUsageMetadata = true params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int() @@ -292,6 +308,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq } if params.HasUsageMetadata && params.HasFinishReason { + appendWebSearchBlocks(params, &output) appendFinalEvents(params, &output, false) } @@ -359,6 +376,114 @@ func resolveStopReason(params *Params) string { return "end_turn" } +func buildEncryptedContent(url, title string) string { + payload := map[string]string{"url": url, "title": title} + encoded, err := json.Marshal(payload) + if err != nil { + return "" + } + return base64.StdEncoding.EncodeToString(encoded) +} + +func extractWebSearchFromAntigravity(rawJSON []byte) (string, []map[string]any) { + candidate := gjson.GetBytes(rawJSON, "response.candidates.0") + if !candidate.Exists() { + candidate = gjson.GetBytes(rawJSON, "candidates.0") + } + if !candidate.Exists() { + return "", nil + } + + query := candidate.Get("groundingMetadata.webSearchQueries.0").String() + + chunks := candidate.Get("groundingChunks") + if !chunks.Exists() { + chunks = candidate.Get("groundingMetadata.groundingChunks") + } + if !chunks.Exists() || !chunks.IsArray() { + return query, nil + } + + results := make([]map[string]any, 0, len(chunks.Array())) + for _, chunk := range chunks.Array() { + web := chunk.Get("web") + if !web.Exists() { + continue + } + url := web.Get("uri").String() + if url == "" { + url = web.Get("url").String() + } + title := web.Get("title").String() + if title == "" { + title = web.Get("domain").String() + } + if url == "" && title == "" { + continue + } + item := map[string]any{ + "type": "web_search_result", + "title": title, + "url": url, + "encrypted_content": buildEncryptedContent(url, title), + "page_age": nil, + } + results = append(results, item) + } + + if len(results) == 0 { + return query, nil + } + return query, results +} + +func appendWebSearchBlocks(params *Params, output *string) { + if params.WebSearchEmitted { + return + } + if params.WebSearchQuery == "" && len(params.WebSearchResults) == 0 { + return + } + + if params.ResponseType != 0 { + *output = *output + "event: content_block_stop\n" + *output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex) + *output = *output + "\n\n\n" + params.ResponseType = 0 + params.ResponseIndex++ + } + + toolUseID := fmt.Sprintf("srvtoolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)) + serverTool := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"server_tool_use","id":"","name":"web_search","input":{}}}`, params.ResponseIndex) + serverTool, _ = sjson.Set(serverTool, "content_block.id", toolUseID) + if params.WebSearchQuery != "" { + serverTool, _ = sjson.Set(serverTool, "content_block.input.query", params.WebSearchQuery) + } + *output = *output + "event: content_block_start\n" + *output = *output + fmt.Sprintf("data: %s\n\n\n", serverTool) + *output = *output + "event: content_block_stop\n" + *output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex) + *output = *output + "\n\n\n" + params.ResponseIndex++ + + resultBlock := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"web_search_tool_result","tool_use_id":"","content":[]}}`, params.ResponseIndex) + resultBlock, _ = sjson.Set(resultBlock, "content_block.tool_use_id", toolUseID) + if len(params.WebSearchResults) > 0 { + if raw, err := json.Marshal(params.WebSearchResults); err == nil { + resultBlock, _ = sjson.SetRaw(resultBlock, "content_block.content", string(raw)) + } + } + *output = *output + "event: content_block_start\n" + *output = *output + fmt.Sprintf("data: %s\n\n\n", resultBlock) + *output = *output + "event: content_block_stop\n" + *output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex) + *output = *output + "\n\n\n" + params.ResponseIndex++ + + params.HasContent = true + params.WebSearchEmitted = true +} + // ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. // // Parameters: @@ -491,6 +616,25 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or flushThinking() flushText() + if query, results := extractWebSearchFromAntigravity(rawJSON); query != "" || len(results) > 0 { + ensureContentArray() + toolUseID := fmt.Sprintf("srvtoolu_%d", time.Now().UnixNano()) + serverTool := `{"type":"server_tool_use","id":"","name":"web_search","input":{}}` + serverTool, _ = sjson.Set(serverTool, "id", toolUseID) + if query != "" { + serverTool, _ = sjson.Set(serverTool, "input.query", query) + } + responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", serverTool) + + resultBlock := `{"type":"web_search_tool_result","tool_use_id":"","content":[]}` + resultBlock, _ = sjson.Set(resultBlock, "tool_use_id", toolUseID) + if len(results) > 0 { + if raw, err := json.Marshal(results); err == nil { + resultBlock, _ = sjson.SetRaw(resultBlock, "content", string(raw)) + } + } + responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", resultBlock) + } stopReason := "end_turn" if hasToolCall { diff --git a/internal/translator/kiro/claude/kiro_claude_response.go b/internal/translator/kiro/claude/kiro_claude_response.go index 17ffdde239..8c813810b0 100644 --- a/internal/translator/kiro/claude/kiro_claude_response.go +++ b/internal/translator/kiro/claude/kiro_claude_response.go @@ -93,7 +93,7 @@ func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, u } response := map[string]interface{}{ - "id": "msg_" + uuid.New().String()[:24], + "id": "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24], "type": "message", "role": "assistant", "model": model, diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go index c86b6e023e..6de95ab1eb 100644 --- a/internal/translator/kiro/claude/kiro_claude_stream.go +++ b/internal/translator/kiro/claude/kiro_claude_stream.go @@ -5,17 +5,17 @@ package claude import ( "encoding/json" + "strings" "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" ) -// BuildClaudeMessageStartEvent creates the message_start SSE event func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte { event := map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ - "id": "msg_" + uuid.New().String()[:24], + "id": "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24], "type": "message", "role": "assistant", "content": []interface{}{}, diff --git a/internal/translator/kiro/claude/kiro_claude_tools.go b/internal/translator/kiro/claude/kiro_claude_tools.go index d00c74932c..6c3561910b 100644 --- a/internal/translator/kiro/claude/kiro_claude_tools.go +++ b/internal/translator/kiro/claude/kiro_claude_tools.go @@ -7,7 +7,6 @@ import ( "regexp" "strings" - "github.com/google/uuid" kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" log "github.com/sirupsen/logrus" ) @@ -102,8 +101,7 @@ func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, continue } - // Generate unique tool ID - toolUseID := "toolu_" + uuid.New().String()[:12] + toolUseID := kirocommon.GenerateToolUseID() // Check for duplicates using name+input as key dedupeKey := toolName + ":" + repairedJSON @@ -389,7 +387,11 @@ func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseSt tu = nested } - toolUseID := kirocommon.GetString(tu, "toolUseId") + toolUseID := kirocommon.SanitizeToolUseID(kirocommon.GetString(tu, "toolUseId")) + if toolUseID == "" { + log.Warnf("kiro: skipping tool use with empty/invalid toolUseId") + return nil, nil + } toolName := kirocommon.GetString(tu, "name") isStop := false if stop, ok := tu["stop"].(bool); ok { diff --git a/internal/translator/kiro/common/utils.go b/internal/translator/kiro/common/utils.go index f5f5788ab2..74881532d1 100644 --- a/internal/translator/kiro/common/utils.go +++ b/internal/translator/kiro/common/utils.go @@ -1,6 +1,13 @@ // Package common provides shared constants and utilities for Kiro translator. package common +import ( + "strings" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" +) + // GetString safely extracts a string from a map. // Returns empty string if the key doesn't exist or the value is not a string. func GetString(m map[string]interface{}, key string) string { @@ -13,4 +20,42 @@ func GetString(m map[string]interface{}, key string) string { // GetStringValue is an alias for GetString for backward compatibility. func GetStringValue(m map[string]interface{}, key string) string { return GetString(m, key) -} \ No newline at end of file +} + +// SanitizeToolUseID ensures tool_use.id matches Claude API pattern ^[a-zA-Z0-9_-]+$ +// Returns sanitized ID or generates new one if input is invalid. +func SanitizeToolUseID(id string) string { + if id == "" { + return "" + } + + var sanitized strings.Builder + sanitized.Grow(len(id)) + + for _, r := range id { + if (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-' { + sanitized.WriteRune(r) + } + } + + result := sanitized.String() + + if len(result) < 8 { + log.Warnf("kiro: tool_use.id '%s' sanitized to '%s' (too short), generating new ID", id, result) + return GenerateToolUseID() + } + + if result != id { + log.Warnf("kiro: tool_use.id sanitized: '%s' -> '%s'", id, result) + } + + return result +} + +// GenerateToolUseID creates a valid tool_use.id without hyphens +func GenerateToolUseID() string { + return "toolu_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:12] +} diff --git a/internal/translator/kiro/common/utils_test.go b/internal/translator/kiro/common/utils_test.go new file mode 100644 index 0000000000..7ed987b476 --- /dev/null +++ b/internal/translator/kiro/common/utils_test.go @@ -0,0 +1,133 @@ +package common + +import ( + "strings" + "testing" +) + +func TestSanitizeToolUseID(t *testing.T) { + tests := []struct { + name string + input string + wantLen int + }{ + { + name: "valid alphanumeric with hyphen", + input: "toolu_abc123-def456", + wantLen: 19, + }, + { + name: "UUID with hyphens (hyphens are valid)", + input: "e9577a7d-809c-4e3f", + wantLen: 18, + }, + { + name: "invalid characters removed", + input: "tool@use#id$123", + wantLen: 12, + }, + { + name: "empty string", + input: "", + wantLen: 0, + }, + { + name: "too short after sanitization generates new ID", + input: "abc", + wantLen: 18, + }, + { + name: "special characters only generates new ID", + input: "@#$%^&*()", + wantLen: 18, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeToolUseID(tt.input) + + if len(got) != tt.wantLen { + t.Errorf("SanitizeToolUseID() length = %v, want %v (got: %s)", len(got), tt.wantLen, got) + } + + for _, r := range got { + if !((r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-') { + t.Errorf("SanitizeToolUseID() contains invalid character: %c in %s", r, got) + } + } + }) + } +} + +func TestGenerateToolUseID(t *testing.T) { + ids := make(map[string]bool) + + for i := 0; i < 100; i++ { + id := GenerateToolUseID() + + if !strings.HasPrefix(id, "toolu_") { + t.Errorf("GenerateToolUseID() doesn't start with 'toolu_': %s", id) + } + + if len(id) != 18 { + t.Errorf("GenerateToolUseID() length = %v, want 18 (got: %s)", len(id), id) + } + + if strings.Contains(id, "-") { + t.Errorf("GenerateToolUseID() contains hyphen: %s", id) + } + + if ids[id] { + t.Errorf("GenerateToolUseID() generated duplicate: %s", id) + } + ids[id] = true + + for _, r := range id { + if !((r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-') { + t.Errorf("GenerateToolUseID() contains invalid character: %c in %s", r, id) + } + } + } +} + +func TestSanitizeToolUseID_ClaudeAPICompliance(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "UUID slice with hyphen (hyphens are valid in pattern)", + input: "e9577a7d-809", + }, + { + name: "Full UUID (hyphens are valid in pattern)", + input: "550e8400-e29b-41d4-a716-446655440000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeToolUseID(tt.input) + + if len(got) < 8 { + t.Errorf("SanitizeToolUseID() too short: %s (len=%d)", got, len(got)) + } + + for _, r := range got { + if !((r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-') { + t.Errorf("SanitizeToolUseID() contains invalid character: %c in %s", r, got) + } + } + }) + } +} diff --git a/internal/translator/kiro/openai/kiro_openai_request_test.go b/internal/translator/kiro/openai/kiro_openai_request_test.go index 22953bbc27..be20317fc1 100644 --- a/internal/translator/kiro/openai/kiro_openai_request_test.go +++ b/internal/translator/kiro/openai/kiro_openai_request_test.go @@ -13,7 +13,7 @@ func TestToolResultsAttachedToCurrentMessage(t *testing.T) { // Sequence: user -> assistant (with tool_calls) -> tool (result) -> user // The last user message should have the tool results attached input := []byte(`{ - "model": "kiro-claude-opus-4-5-agentic", + "model": "kiro-claude-sonnet-4-5-agentic", "messages": [ {"role": "user", "content": "Hello, can you read a file for me?"}, { @@ -78,7 +78,7 @@ func TestToolResultsInHistoryUserMessage(t *testing.T) { // Sequence: user -> assistant (with tool_calls) -> tool (result) -> user -> assistant -> user // The first user after tool should have tool results in history input := []byte(`{ - "model": "kiro-claude-opus-4-5-agentic", + "model": "kiro-claude-sonnet-4-5-agentic", "messages": [ {"role": "user", "content": "Hello"}, { @@ -146,7 +146,7 @@ func TestToolResultsInHistoryUserMessage(t *testing.T) { // TestToolResultsWithMultipleToolCalls verifies handling of multiple tool calls func TestToolResultsWithMultipleToolCalls(t *testing.T) { input := []byte(`{ - "model": "kiro-claude-opus-4-5-agentic", + "model": "kiro-claude-sonnet-4-5-agentic", "messages": [ {"role": "user", "content": "Read two files for me"}, { @@ -222,7 +222,7 @@ func TestToolResultsWithMultipleToolCalls(t *testing.T) { // the conversation ends with tool results (no following user message) func TestToolResultsAtEndOfConversation(t *testing.T) { input := []byte(`{ - "model": "kiro-claude-opus-4-5-agentic", + "model": "kiro-claude-sonnet-4-5-agentic", "messages": [ {"role": "user", "content": "Read a file"}, { @@ -280,7 +280,7 @@ func TestToolResultsFollowedByAssistant(t *testing.T) { // assistant: "I've read them" // user: "What did they say?" input := []byte(`{ - "model": "kiro-claude-opus-4-5-agentic", + "model": "kiro-claude-sonnet-4-5-agentic", "messages": [ {"role": "user", "content": "Read two files for me"}, { @@ -362,7 +362,7 @@ func TestToolResultsFollowedByAssistant(t *testing.T) { // TestAssistantEndsConversation verifies handling when assistant is the last message func TestAssistantEndsConversation(t *testing.T) { input := []byte(`{ - "model": "kiro-claude-opus-4-5-agentic", + "model": "kiro-claude-sonnet-4-5-agentic", "messages": [ {"role": "user", "content": "Hello"}, { diff --git a/internal/usage/logger_plugin.go b/internal/usage/logger_plugin.go index e4371e8d39..4a31d6fdb0 100644 --- a/internal/usage/logger_plugin.go +++ b/internal/usage/logger_plugin.go @@ -76,6 +76,7 @@ type RequestStatistics struct { // apiStats holds aggregated metrics for a single API key. type apiStats struct { TotalRequests int64 + FailureCount int64 TotalTokens int64 Models map[string]*modelStats } @@ -83,6 +84,7 @@ type apiStats struct { // modelStats holds aggregated metrics for a specific model within an API. type modelStats struct { TotalRequests int64 + FailureCount int64 TotalTokens int64 Details []RequestDetail } @@ -123,6 +125,7 @@ type StatisticsSnapshot struct { // APISnapshot summarises metrics for a single API key. type APISnapshot struct { TotalRequests int64 `json:"total_requests"` + FailureCount int64 `json:"failure_count"` TotalTokens int64 `json:"total_tokens"` Models map[string]ModelSnapshot `json:"models"` } @@ -130,6 +133,7 @@ type APISnapshot struct { // ModelSnapshot summarises metrics for a specific model. type ModelSnapshot struct { TotalRequests int64 `json:"total_requests"` + FailureCount int64 `json:"failure_count"` TotalTokens int64 `json:"total_tokens"` Details []RequestDetail `json:"details"` } @@ -212,6 +216,9 @@ func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record) func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail RequestDetail) { stats.TotalRequests++ + if detail.Failed { + stats.FailureCount++ + } stats.TotalTokens += detail.Tokens.TotalTokens modelStatsValue, ok := stats.Models[model] if !ok { @@ -219,6 +226,9 @@ func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail stats.Models[model] = modelStatsValue } modelStatsValue.TotalRequests++ + if detail.Failed { + modelStatsValue.FailureCount++ + } modelStatsValue.TotalTokens += detail.Tokens.TotalTokens modelStatsValue.Details = append(modelStatsValue.Details, detail) } @@ -242,6 +252,7 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot { for apiName, stats := range s.apis { apiSnapshot := APISnapshot{ TotalRequests: stats.TotalRequests, + FailureCount: stats.FailureCount, TotalTokens: stats.TotalTokens, Models: make(map[string]ModelSnapshot, len(stats.Models)), } @@ -250,6 +261,7 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot { copy(requestDetails, modelStatsValue.Details) apiSnapshot.Models[modelName] = ModelSnapshot{ TotalRequests: modelStatsValue.TotalRequests, + FailureCount: modelStatsValue.FailureCount, TotalTokens: modelStatsValue.TotalTokens, Details: requestDetails, } diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index ea96118b5e..14fe8f9396 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -109,8 +109,9 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e Status: status, Disabled: disabled, Attributes: map[string]string{ - "source": full, - "path": full, + "source": full, + "path": full, + "auth_kind": "oauth", }, ProxyURL: proxyURL, Metadata: metadata, diff --git a/management.html b/management.html new file mode 100644 index 0000000000..86b416e7f9 --- /dev/null +++ b/management.html @@ -0,0 +1,45 @@ + + + + + + + CLI Proxy API Management Center + + + + +
+ + diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index f099116da9..92ad6eaf08 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -475,18 +475,14 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel - payload := rawJSON - if len(payload) == 0 { - payload = nil - } req := coreexecutor.Request{ Model: normalizedModel, - Payload: payload, + Payload: cloneBytes(rawJSON), } opts := coreexecutor.Options{ Stream: false, Alt: alt, - OriginalRequest: rawJSON, + OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } opts.Metadata = reqMeta @@ -507,9 +503,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } if !PassthroughHeadersEnabled(h.Cfg) { - return resp.Payload, nil, nil + return cloneBytes(resp.Payload), nil, nil } - return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil + return cloneBytes(resp.Payload), FilterUpstreamHeaders(resp.Headers), nil } // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. @@ -521,18 +517,14 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel - payload := rawJSON - if len(payload) == 0 { - payload = nil - } req := coreexecutor.Request{ Model: normalizedModel, - Payload: payload, + Payload: cloneBytes(rawJSON), } opts := coreexecutor.Options{ Stream: false, Alt: alt, - OriginalRequest: rawJSON, + OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } opts.Metadata = reqMeta @@ -553,9 +545,9 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} } if !PassthroughHeadersEnabled(h.Cfg) { - return resp.Payload, nil, nil + return cloneBytes(resp.Payload), nil, nil } - return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil + return cloneBytes(resp.Payload), FilterUpstreamHeaders(resp.Headers), nil } // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. @@ -571,18 +563,14 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl } reqMeta := requestExecutionMetadata(ctx) reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel - payload := rawJSON - if len(payload) == 0 { - payload = nil - } req := coreexecutor.Request{ Model: normalizedModel, - Payload: payload, + Payload: cloneBytes(rawJSON), } opts := coreexecutor.Options{ Stream: true, Alt: alt, - OriginalRequest: rawJSON, + OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } opts.Metadata = reqMeta @@ -870,7 +858,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro var previous []byte if existing, exists := c.Get("API_RESPONSE"); exists { if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { - previous = existingBytes + previous = bytes.Clone(existingBytes) } } appendAPIResponse(c, body) diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 6ed31d6d72..2b4ceab2ce 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -2,13 +2,15 @@ package auth import ( "context" + "encoding/json" "fmt" + "io" "net" "net/http" + "net/url" "strings" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -17,6 +19,28 @@ import ( log "github.com/sirupsen/logrus" ) +// AntigravityProjectInfo contains project ID and subscription tier info +type AntigravityProjectInfo struct { + ProjectID string + TierID string // "ultra", "pro", "standard", "free", or "unknown" + TierName string // Display name from API (e.g., "Gemini Code Assist Pro") + IsPaid bool // true if tier is "pro" or "ultra" +} + +const ( + antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + antigravityClientSecret = "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" + antigravityCallbackPort = 51121 +) + +var antigravityScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", +} + // AntigravityAuthenticator implements OAuth login for the antigravity provider. type AntigravityAuthenticator struct{} @@ -43,12 +67,12 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o opts = &LoginOptions{} } - callbackPort := antigravity.CallbackPort + callbackPort := antigravityCallbackPort if opts.CallbackPort > 0 { callbackPort = opts.CallbackPort } - authSvc := antigravity.NewAntigravityAuth(cfg, nil) + httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) state, err := misc.GenerateRandomState() if err != nil { @@ -66,7 +90,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o }() redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port) - authURL := authSvc.BuildAuthURL(state, redirectURI) + authURL := buildAntigravityAuthURL(redirectURI, state) if !opts.NoBrowser { fmt.Println("Opening browser for antigravity authentication") @@ -147,34 +171,33 @@ waitForCallback: return nil, fmt.Errorf("antigravity: missing authorization code") } - tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI) + tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient) if errToken != nil { return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken) } - accessToken := strings.TrimSpace(tokenResp.AccessToken) - if accessToken == "" { - return nil, fmt.Errorf("antigravity: token exchange returned empty access token") - } - - email, errInfo := authSvc.FetchUserInfo(ctx, accessToken) - if errInfo != nil { - return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo) - } - email = strings.TrimSpace(email) - if email == "" { - return nil, fmt.Errorf("antigravity: empty email returned from user info") + email := "" + if tokenResp.AccessToken != "" { + if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" { + email = strings.TrimSpace(info.Email) + } } // Fetch project ID via loadCodeAssist (same approach as Gemini CLI) projectID := "" - if accessToken != "" { - fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken) + tierID := "unknown" + tierName := "Unknown" + tierIsPaid := false + if tokenResp.AccessToken != "" { + projectInfo, errProject := FetchAntigravityProjectInfo(ctx, tokenResp.AccessToken, httpClient) if errProject != nil { - log.Warnf("antigravity: failed to fetch project ID: %v", errProject) + log.Warnf("antigravity: failed to fetch project info: %v", errProject) } else { - projectID = fetchedProjectID - log.Infof("antigravity: obtained project ID %s", projectID) + projectID = projectInfo.ProjectID + tierID = projectInfo.TierID + tierName = projectInfo.TierName + tierIsPaid = projectInfo.IsPaid + log.Infof("antigravity: obtained project ID %s, tier %s", projectID, tierID) } } @@ -186,6 +209,9 @@ waitForCallback: "expires_in": tokenResp.ExpiresIn, "timestamp": now.UnixMilli(), "expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), + "tier_id": tierID, + "tier_name": tierName, + "tier_is_paid": tierIsPaid, } if email != "" { metadata["email"] = email @@ -194,7 +220,7 @@ waitForCallback: metadata["project_id"] = projectID } - fileName := antigravity.CredentialFileName(email) + fileName := sanitizeAntigravityFileName(email) label := email if label == "" { label = "antigravity" @@ -221,7 +247,7 @@ type callbackResult struct { func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) { if port <= 0 { - port = antigravity.CallbackPort + port = antigravityCallbackPort } addr := fmt.Sprintf(":%d", port) listener, err := net.Listen("tcp", addr) @@ -257,9 +283,354 @@ func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbac return srv, port, resultCh, nil } +type antigravityTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` +} + +func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) { + data := url.Values{} + data.Set("code", code) + data.Set("client_id", antigravityClientID) + data.Set("client_secret", antigravityClientSecret) + data.Set("redirect_uri", redirectURI) + data.Set("grant_type", "authorization_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, errDo := httpClient.Do(req) + if errDo != nil { + return nil, errDo + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity token exchange: close body error: %v", errClose) + } + }() + + var token antigravityTokenResponse + if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil { + return nil, errDecode + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode) + } + return &token, nil +} + +type antigravityUserInfo struct { + Email string `json:"email"` +} + +func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) { + if strings.TrimSpace(accessToken) == "" { + return &antigravityUserInfo{}, nil + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, errDo := httpClient.Do(req) + if errDo != nil { + return nil, errDo + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity userinfo: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return &antigravityUserInfo{}, nil + } + var info antigravityUserInfo + if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil { + return nil, errDecode + } + return &info, nil +} + +func buildAntigravityAuthURL(redirectURI, state string) string { + params := url.Values{} + params.Set("access_type", "offline") + params.Set("client_id", antigravityClientID) + params.Set("prompt", "consent") + params.Set("redirect_uri", redirectURI) + params.Set("response_type", "code") + params.Set("scope", strings.Join(antigravityScopes, " ")) + params.Set("state", state) + return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() +} + +func sanitizeAntigravityFileName(email string) string { + if strings.TrimSpace(email) == "" { + return "antigravity.json" + } + replacer := strings.NewReplacer("@", "_", ".", "_") + return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email)) +} + +func extractTierInfo(resp map[string]any) (tierID, tierName string, isPaid bool) { + var effectiveTier map[string]any + if pt, ok := resp["paidTier"].(map[string]any); ok && pt != nil { + effectiveTier = pt + } else if ct, ok := resp["currentTier"].(map[string]any); ok { + effectiveTier = ct + } + + if effectiveTier == nil { + return "unknown", "Unknown", false + } + + id, _ := effectiveTier["id"].(string) + name, _ := effectiveTier["name"].(string) + + idLower := strings.ToLower(id) + nameLower := strings.ToLower(name) + + // Check tier by ID first, then by name patterns + switch { + case strings.Contains(idLower, "ultra"): + return "ultra", name, true + case strings.Contains(idLower, "pro"): + return "pro", name, true + case strings.Contains(idLower, "standard"), strings.Contains(idLower, "free"): + return "free", name, false + // Check by tier name patterns when ID doesn't match + case strings.Contains(nameLower, "google one ai pro"): + // "Gemini Code Assist in Google One AI Pro" -> Pro tier + return "pro", name, true + case strings.Contains(nameLower, "for individuals"): + // "Gemini Code Assist for individuals" -> Free tier + return "free", name, false + default: + return id, name, false + } +} + +// Antigravity API constants for project discovery +const ( + antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com" + antigravityAPIVersion = "v1internal" + antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1" + antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" + antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` +) + // FetchAntigravityProjectID exposes project discovery for external callers. func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) { - cfg := &config.Config{} - authSvc := antigravity.NewAntigravityAuth(cfg, httpClient) - return authSvc.FetchProjectID(ctx, accessToken) + info, err := FetchAntigravityProjectInfo(ctx, accessToken, httpClient) + if err != nil { + return "", err + } + return info.ProjectID, nil +} + +// FetchAntigravityProjectInfo fetches project ID and tier info from the Antigravity API. +func FetchAntigravityProjectInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*AntigravityProjectInfo, error) { + loadReqBody := map[string]any{ + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(loadReqBody) + if errMarshal != nil { + return nil, fmt.Errorf("marshal request body: %w", errMarshal) + } + + endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) + req.Header.Set("Client-Metadata", antigravityClientMetadata) + + resp, errDo := httpClient.Do(req) + if errDo != nil { + return nil, fmt.Errorf("execute request: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose) + } + }() + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errRead != nil { + return nil, fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes))) + } + + var loadResp map[string]any + if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil { + return nil, fmt.Errorf("decode response: %w", errDecode) + } + + tierID, tierName, isPaid := extractTierInfo(loadResp) + + projectID := "" + if id, ok := loadResp["cloudaicompanionProject"].(string); ok { + projectID = strings.TrimSpace(id) + } + if projectID == "" { + if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok { + if id, okID := projectMap["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + } + } + + if projectID == "" { + onboardTierID := "legacy-tier" + if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers { + for _, rawTier := range tiers { + tier, okTier := rawTier.(map[string]any) + if !okTier { + continue + } + if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault { + if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" { + onboardTierID = strings.TrimSpace(id) + break + } + } + } + } + + projectID, err = antigravityOnboardUser(ctx, accessToken, onboardTierID, httpClient) + if err != nil { + return nil, err + } + } + + return &AntigravityProjectInfo{ + ProjectID: projectID, + TierID: tierID, + TierName: tierName, + IsPaid: isPaid, + }, nil +} + +// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion. +// It returns an empty string when the operation times out or completes without a project ID. +func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) { + if httpClient == nil { + httpClient = http.DefaultClient + } + fmt.Println("Antigravity: onboarding user...", tierID) + requestBody := map[string]any{ + "tierId": tierID, + "metadata": map[string]string{ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + } + + rawBody, errMarshal := json.Marshal(requestBody) + if errMarshal != nil { + return "", fmt.Errorf("marshal request body: %w", errMarshal) + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Debugf("Polling attempt %d/%d", attempt, maxAttempts) + + reqCtx := ctx + var cancel context.CancelFunc + if reqCtx == nil { + reqCtx = context.Background() + } + reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second) + + endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion) + req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody))) + if errRequest != nil { + cancel() + return "", fmt.Errorf("create request: %w", errRequest) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityAPIClient) + req.Header.Set("Client-Metadata", antigravityClientMetadata) + + resp, errDo := httpClient.Do(req) + if errDo != nil { + cancel() + return "", fmt.Errorf("execute request: %w", errDo) + } + + bodyBytes, errRead := io.ReadAll(resp.Body) + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("close body error: %v", errClose) + } + cancel() + + if errRead != nil { + return "", fmt.Errorf("read response: %w", errRead) + } + + if resp.StatusCode == http.StatusOK { + var data map[string]any + if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil { + return "", fmt.Errorf("decode response: %w", errDecode) + } + + if done, okDone := data["done"].(bool); okDone && done { + projectID := "" + if responseData, okResp := data["response"].(map[string]any); okResp { + switch projectValue := responseData["cloudaicompanionProject"].(type) { + case map[string]any: + if id, okID := projectValue["id"].(string); okID { + projectID = strings.TrimSpace(id) + } + case string: + projectID = strings.TrimSpace(projectValue) + } + } + + if projectID != "" { + log.Infof("Successfully fetched project_id: %s", projectID) + return projectID, nil + } + + return "", fmt.Errorf("no project_id in response") + } + + time.Sleep(2 * time.Second) + continue + } + + responsePreview := strings.TrimSpace(string(bodyBytes)) + if len(responsePreview) > 500 { + responsePreview = responsePreview[:500] + } + + responseErr := responsePreview + if len(responseErr) > 200 { + responseErr = responseErr[:200] + } + return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr) + } + + return "", nil } diff --git a/sdk/auth/cline.go b/sdk/auth/cline.go new file mode 100644 index 0000000000..139cdf39f6 --- /dev/null +++ b/sdk/auth/cline.go @@ -0,0 +1,343 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/cline" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +func extractFirstJSONObject(input []byte) []byte { + start := -1 + depth := 0 + inString := false + escapeNext := false + + for i, b := range input { + if start == -1 { + if b == '{' { + start = i + depth = 1 + } + continue + } + + if inString { + if escapeNext { + escapeNext = false + continue + } + if b == '\\' { + escapeNext = true + continue + } + if b == '"' { + inString = false + } + continue + } + + if b == '"' { + inString = true + continue + } + + if b == '{' { + depth++ + continue + } + + if b == '}' { + depth-- + if depth == 0 { + return input[start : i+1] + } + } + } + + if start != -1 { + return input[start:] + } + + return nil +} + +const defaultClineCallbackPort = 1455 + +type ClineAuthenticator struct { + CallbackPort int +} + +func NewClineAuthenticator() *ClineAuthenticator { + return &ClineAuthenticator{CallbackPort: defaultClineCallbackPort} +} + +func (a *ClineAuthenticator) Provider() string { + return "cline" +} + +func (a *ClineAuthenticator) RefreshLead() *time.Duration { + d := 5 * time.Minute + return &d +} + +func (a *ClineAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if ctx == nil { + ctx = context.Background() + } + if opts == nil { + opts = &LoginOptions{} + } + + callbackPort := a.CallbackPort + if opts.CallbackPort > 0 { + callbackPort = opts.CallbackPort + } + + state, err := misc.GenerateRandomState() + if err != nil { + return nil, fmt.Errorf("cline state generation failed: %w", err) + } + + callbackURL := fmt.Sprintf("http://localhost:%d/callback", callbackPort) + authSvc := cline.NewClineAuth(cfg) + authURL := authSvc.GenerateAuthURL(state, callbackURL) + + if !opts.NoBrowser { + fmt.Println("Opening browser for Cline authentication") + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } else if err = browser.OpenURL(authURL); err != nil { + log.Warnf("Failed to open browser automatically: %v", err) + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + } else { + util.PrintSSHTunnelInstructions(callbackPort) + fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) + } + + fmt.Println("Waiting for Cline authentication callback...") + result, err := waitForClineCallback(ctx, callbackPort, opts.Prompt) + if err != nil { + return nil, err + } + + if result.Error != "" { + if result.ErrorDescription != "" { + return nil, fmt.Errorf("cline oauth error: %s (%s)", result.Error, result.ErrorDescription) + } + return nil, fmt.Errorf("cline oauth error: %s", result.Error) + } + // Cline may not return state in callback, only validate if both are present + if result.State != "" && state != "" && result.State != state { + return nil, fmt.Errorf("cline authentication failed: state mismatch") + } + + // Cline returns the token directly in the code parameter as base64-encoded JSON + // Try to parse it directly first, fall back to exchange if needed + var tokenResp *cline.TokenResponse + codeStr := result.Code + + // Try multiple base64 decoding strategies + decodeStrategies := []func(string) ([]byte, error){ + base64.URLEncoding.DecodeString, + base64.RawURLEncoding.DecodeString, + base64.StdEncoding.DecodeString, + base64.RawStdEncoding.DecodeString, + } + + for _, decode := range decodeStrategies { + if decoded, decodeErr := decode(codeStr); decodeErr == nil { + var directToken cline.TokenResponse + parseErr := json.Unmarshal(decoded, &directToken) + if parseErr != nil { + if jsonOnly := extractFirstJSONObject(decoded); len(jsonOnly) > 0 { + parseErr = json.Unmarshal(jsonOnly, &directToken) + } + } + if parseErr == nil && directToken.AccessToken != "" { + tokenResp = &directToken + break + } + log.Debugf("cline: base64 decode succeeded but JSON parse failed: %v", parseErr) + } + } + + // Fall back to token exchange if direct parsing didn't work + if tokenResp == nil { + var err error + tokenResp, err = authSvc.ExchangeCode(ctx, result.Code, callbackURL) + if err != nil { + return nil, fmt.Errorf("cline token exchange failed: %w", err) + } + } + + if tokenResp == nil { + return nil, fmt.Errorf("cline authentication failed: no token response") + } + + email := strings.TrimSpace(tokenResp.Email) + if email == "" { + return nil, fmt.Errorf("cline authentication failed: missing account email") + } + + // Parse expiresAt from string to int64 + var expiresAtInt int64 + if tokenResp.ExpiresAt != "" { + if t, err := time.Parse(time.RFC3339Nano, tokenResp.ExpiresAt); err == nil { + expiresAtInt = t.Unix() + } else if t, err := time.Parse(time.RFC3339, tokenResp.ExpiresAt); err == nil { + expiresAtInt = t.Unix() + } else { + log.Debugf("cline: failed to parse expiresAt: %v", err) + } + } + + ts := &cline.ClineTokenStorage{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAtInt, + Email: email, + Type: "cline", + } + + fileName := cline.CredentialFileName(email) + metadata := map[string]any{ + "email": email, + "fileName": fileName, + "expires_at": expiresAtInt, + } + + fmt.Printf("Cline authentication successful for %s\n", email) + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Storage: ts, + Metadata: metadata, + }, nil +} + +type clineOAuthResult struct { + Code string + State string + Error string + ErrorDescription string +} + +func waitForClineCallback(ctx context.Context, callbackPort int, prompt func(prompt string) (string, error)) (*clineOAuthResult, error) { + if ctx == nil { + ctx = context.Background() + } + + resultCh := make(chan *clineOAuthResult, 1) + errCh := make(chan error, 1) + + mux := http.NewServeMux() + server := &http.Server{ + Addr: ":" + strconv.Itoa(callbackPort), + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + } + + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + res := &clineOAuthResult{ + Code: strings.TrimSpace(q.Get("code")), + State: strings.TrimSpace(q.Get("state")), + Error: strings.TrimSpace(q.Get("error")), + ErrorDescription: strings.TrimSpace(q.Get("error_description")), + } + + select { + case resultCh <- res: + default: + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte("

Cline login complete

You can close this window and return to CLI.

")) + }) + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errCh <- fmt.Errorf("cline callback server failed: %w", err) + } + }() + + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + log.Warnf("cline callback server shutdown error: %v", err) + } + }() + + var manualTimer *time.Timer + var manualTimerC <-chan time.Time + if prompt != nil { + manualTimer = time.NewTimer(15 * time.Second) + manualTimerC = manualTimer.C + defer manualTimer.Stop() + } + + timeout := cline.AuthTimeout + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 && remaining < timeout { + timeout = remaining + } + } + timeoutTimer := time.NewTimer(timeout) + defer timeoutTimer.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timeoutTimer.C: + return nil, fmt.Errorf("cline callback wait timeout after %s", timeout.String()) + case err := <-errCh: + return nil, err + case res := <-resultCh: + return res, nil + case <-manualTimerC: + manualTimerC = nil + input, err := prompt("Paste the Cline callback URL (or press Enter to keep waiting): ") + if err != nil { + return nil, err + } + parsed, err := misc.ParseOAuthCallback(input) + if err != nil { + return nil, err + } + if parsed == nil { + continue + } + return &clineOAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + ErrorDescription: parsed.ErrorDescription, + }, nil + } + } +} diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 794ea51326..b1167f9d19 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -238,7 +238,6 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, status = cliproxyauth.StatusDisabled } - // Calculate NextRefreshAfter from expires_at (20 minutes before expiry) var nextRefreshAfter time.Time if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" { if expiresAt, err := time.Parse(time.RFC3339, expiresAtStr); err == nil { @@ -246,6 +245,19 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, } } + if nextRefreshAfter.IsZero() { + if expiredStr, ok := metadata["expired"].(string); ok && expiredStr != "" { + if expiresAt, err := time.Parse(time.RFC3339, expiredStr); err == nil { + refreshLead := 24 * time.Hour + if provider == "iflow" { + nextRefreshAfter = expiresAt.Add(-refreshLead) + } else { + nextRefreshAfter = expiresAt.Add(-20 * time.Minute) + } + } + } + } + auth := &cliproxyauth.Auth{ ID: id, Provider: provider, diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index a695311db2..eb1fadb7cc 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -26,7 +26,8 @@ func (a *IFlowAuthenticator) Provider() string { return "iflow" } // RefreshLead indicates how soon before expiry a refresh should be attempted. func (a *IFlowAuthenticator) RefreshLead() *time.Duration { - return new(24 * time.Hour) + d := 36 * time.Hour + return &d } // Login performs the OAuth code flow using a local callback server. @@ -167,6 +168,12 @@ waitForCallback: } fileName := fmt.Sprintf("iflow-%s-%d.json", email, time.Now().Unix()) + + expiresAt, err := time.Parse(time.RFC3339, tokenStorage.Expire) + if err != nil { + expiresAt = time.Now().Add(7 * 24 * time.Hour) + } + metadata := map[string]any{ "email": email, "api_key": tokenStorage.APIKey, @@ -175,16 +182,63 @@ waitForCallback: "expired": tokenStorage.Expire, } + now := time.Now() + fmt.Println("iFlow authentication successful") return &coreauth.Auth{ - ID: fileName, - Provider: a.Provider(), - FileName: fileName, - Storage: tokenStorage, - Metadata: metadata, + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Storage: tokenStorage, + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + NextRefreshAfter: expiresAt.Add(-36 * time.Hour), Attributes: map[string]string{ "api_key": tokenStorage.APIKey, }, }, nil } + +func (a *IFlowAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) { + if auth == nil || auth.Metadata == nil { + return nil, fmt.Errorf("iflow: invalid auth record") + } + + refreshToken, ok := auth.Metadata["refresh_token"].(string) + if !ok || refreshToken == "" { + return nil, fmt.Errorf("iflow: refresh token not found") + } + + authSvc := iflow.NewIFlowAuth(cfg) + + tokenData, err := authSvc.RefreshTokens(ctx, refreshToken) + if err != nil { + return nil, fmt.Errorf("iflow: token refresh failed: %w", err) + } + + expiresAt, err := time.Parse(time.RFC3339, tokenData.Expire) + if err != nil { + expiresAt = time.Now().Add(7 * 24 * time.Hour) + } + + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expired"] = tokenData.Expire + updated.Metadata["api_key"] = tokenData.APIKey + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) + updated.NextRefreshAfter = expiresAt.Add(-36 * time.Hour) + + if tokenData.APIKey != "" { + updated.Attributes["api_key"] = tokenData.APIKey + } + + log.Infof("iflow: token refreshed successfully for %s", auth.ID) + + return updated, nil +} diff --git a/sdk/auth/kilo.go b/sdk/auth/kilo.go index 7e98f7c4b7..5e1dd6ed8d 100644 --- a/sdk/auth/kilo.go +++ b/sdk/auth/kilo.go @@ -6,8 +6,10 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" ) // KiloAuthenticator implements the login flow for Kilo AI accounts. @@ -39,16 +41,25 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts } kilocodeAuth := kilo.NewKiloAuth() - + fmt.Println("Initiating Kilo device authentication...") resp, err := kilocodeAuth.InitiateDeviceFlow(ctx) if err != nil { return nil, fmt.Errorf("failed to initiate device flow: %w", err) } - fmt.Printf("Please visit: %s\n", resp.VerificationURL) - fmt.Printf("And enter code: %s\n", resp.Code) - + fmt.Printf("\nTo authenticate, please visit: %s\n", resp.VerificationURL) + fmt.Printf("And enter the code: %s\n\n", resp.Code) + + // Try to open the browser automatically + if !opts.NoBrowser { + if browser.IsAvailable() { + if errOpen := browser.OpenURL(resp.VerificationURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + } + fmt.Println("Waiting for authorization...") status, err := kilocodeAuth.PollForToken(ctx, resp.Code) if err != nil { @@ -68,7 +79,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts for i, org := range profile.Orgs { fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID) } - + if opts.Prompt != nil { input, err := opts.Prompt("Enter the number of the organization: ") if err != nil { @@ -108,7 +119,7 @@ func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts metadata := map[string]any{ "email": status.UserEmail, "organization_id": orgID, - "model": defaults.Model, + "model": defaults.Model, } return &coreauth.Auth{ diff --git a/sdk/auth/kilocode.go b/sdk/auth/kilocode.go new file mode 100644 index 0000000000..8cb70bebe4 --- /dev/null +++ b/sdk/auth/kilocode.go @@ -0,0 +1,105 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilocode" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// KilocodeAuthenticator implements the device flow login for Kilocode. +type KilocodeAuthenticator struct{} + +// NewKilocodeAuthenticator constructs a new Kilocode authenticator. +func NewKilocodeAuthenticator() Authenticator { + return &KilocodeAuthenticator{} +} + +// Provider returns the provider key for kilocode. +func (KilocodeAuthenticator) Provider() string { + return "kilocode" +} + +// RefreshLead returns nil since Kilocode tokens don't expire traditionally. +func (KilocodeAuthenticator) RefreshLead() *time.Duration { + return nil +} + +// Login initiates the device flow authentication for Kilocode. +func (a KilocodeAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if opts == nil { + opts = &LoginOptions{} + } + + authSvc := kilocode.NewKilocodeAuth(cfg) + + // Start the device flow + fmt.Println("Starting Kilocode authentication...") + deviceCode, err := authSvc.StartDeviceFlow(ctx) + if err != nil { + return nil, fmt.Errorf("kilocode: failed to start device flow: %w", err) + } + + // Display the user code and verification URL + fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURL) + fmt.Printf("And enter the code: %s\n\n", deviceCode.Code) + + // Try to open the browser automatically + if !opts.NoBrowser { + if browser.IsAvailable() { + if errOpen := browser.OpenURL(deviceCode.VerificationURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + } + + fmt.Println("Waiting for Kilocode authorization...") + fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn) + + // Wait for user authorization + authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode) + if err != nil { + errMsg := kilocode.GetUserFriendlyMessage(err) + return nil, fmt.Errorf("kilocode: %s", errMsg) + } + + // Create the token storage + tokenStorage := authSvc.CreateTokenStorage(authBundle) + + // Build metadata with token information for the executor + metadata := map[string]any{ + "type": "kilocode", + "user_id": authBundle.UserID, + "email": authBundle.UserEmail, + "token": authBundle.Token, + "timestamp": time.Now().UnixMilli(), + } + + fileName := fmt.Sprintf("kilocode-%s.json", authBundle.UserID) + label := authBundle.UserEmail + if label == "" { + label = authBundle.UserID + } + + fmt.Printf("\nKilocode authentication successful for user: %s\n", label) + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: label, + Storage: tokenStorage, + Metadata: metadata, + Attributes: map[string]string{ + "auth_kind": "oauth", + }, + }, nil +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index ecf8e820af..b1fd4b75a6 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -17,6 +17,8 @@ func init() { registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() }) registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() }) registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() }) + registerRefreshLead("kilocode", func() Authenticator { return NewKilocodeAuthenticator() }) + registerRefreshLead("cline", func() Authenticator { return NewClineAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/auth/.tldrignore b/sdk/cliproxy/auth/.tldrignore new file mode 100644 index 0000000000..e01df83cb2 --- /dev/null +++ b/sdk/cliproxy/auth/.tldrignore @@ -0,0 +1,84 @@ +# TLDR ignore patterns (gitignore syntax) +# Auto-generated - review and customize for your project +# Docs: https://git-scm.com/docs/gitignore + +# =================== +# Dependencies +# =================== +node_modules/ +.venv/ +venv/ +env/ +__pycache__/ +.tox/ +.nox/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +vendor/ +Pods/ + +# =================== +# Build outputs +# =================== +dist/ +build/ +out/ +target/ +*.egg-info/ +*.whl +*.pyc +*.pyo + +# =================== +# Binary/large files +# =================== +*.so +*.dylib +*.dll +*.exe +*.bin +*.o +*.a +*.lib + +# =================== +# IDE/editors +# =================== +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# =================== +# Security (always exclude) +# =================== +.env +.env.* +*.pem +*.key +*.p12 +*.pfx +credentials.* +secrets.* + +# =================== +# Version control +# =================== +.git/ +.hg/ +.svn/ + +# =================== +# OS files +# =================== +.DS_Store +Thumbs.db + +# =================== +# Project-specific +# Add your custom patterns below +# =================== +# large_test_fixtures/ +# data/ diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 6e89adcbc4..d4aab510c4 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + "github.com/gin-gonic/gin" "github.com/google/uuid" internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" @@ -24,6 +25,62 @@ import ( log "github.com/sirupsen/logrus" ) +const providerAuthContextKey = "cliproxy.provider_auth" +const GinProviderAuthKey = "providerAuth" +const fallbackInfoContextKey = "cliproxy.fallback_info" +const GinFallbackInfoKey = "fallbackInfo" + +func SetProviderAuthInContext(ctx context.Context, provider, authID, authLabel string) context.Context { + authInfo := map[string]string{ + "provider": provider, + "auth_id": authID, + "auth_label": authLabel, + } + + if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil { + ginCtx.Set(GinProviderAuthKey, authInfo) + } + + return context.WithValue(ctx, providerAuthContextKey, authInfo) +} + +func GetProviderAuthFromContext(ctx context.Context) (provider, authID, authLabel string) { + if ctx == nil { + return "", "", "" + } + if v, ok := ctx.Value(providerAuthContextKey).(map[string]string); ok { + return v["provider"], v["auth_id"], v["auth_label"] + } + return "", "", "" +} + +func SetFallbackInfoInContext(ctx context.Context, requestedModel, actualModel string) context.Context { + if requestedModel == "" || actualModel == "" || requestedModel == actualModel { + return ctx + } + + fallbackInfo := map[string]string{ + "requested_model": requestedModel, + "actual_model": actualModel, + } + + if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil { + ginCtx.Set(GinFallbackInfoKey, fallbackInfo) + } + + return context.WithValue(ctx, fallbackInfoContextKey, fallbackInfo) +} + +func GetFallbackInfoFromContext(ctx context.Context) (requestedModel, actualModel string) { + if ctx == nil { + return "", "" + } + if v, ok := ctx.Value(fallbackInfoContextKey).(map[string]string); ok { + return v["requested_model"], v["actual_model"] + } + return "", "" +} + // ProviderExecutor defines the contract required by Manager to execute provider calls. type ProviderExecutor interface { // Identifier returns the provider key handled by this executor. @@ -156,6 +213,15 @@ type Manager struct { // Optional HTTP RoundTripper provider injected by host. rtProvider RoundTripperProvider + // fallbackModels stores model fallback mappings (original -> fallback). + fallbackModels atomic.Value + + // fallbackChain stores the general fallback chain for models not in fallbackModels. + fallbackChain atomic.Value + + // fallbackMaxDepth limits the number of fallback attempts. + fallbackMaxDepth atomic.Int32 + // Auto refresh state refreshCancel context.CancelFunc refreshSemaphore chan struct{} @@ -404,6 +470,64 @@ func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration, maxR m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds()) } +func (m *Manager) SetFallbackModels(models map[string]string) { + if m == nil { + return + } + if models == nil { + models = make(map[string]string) + } + m.fallbackModels.Store(models) +} + +func (m *Manager) getFallbackModel(originalModel string) (string, bool) { + if m == nil { + return "", false + } + models, ok := m.fallbackModels.Load().(map[string]string) + if !ok || models == nil { + return "", false + } + fallback, exists := models[originalModel] + return fallback, exists && fallback != "" +} + +func (m *Manager) SetFallbackChain(chain []string, maxDepth int) { + if m == nil { + return + } + if chain == nil { + chain = []string{} + } + m.fallbackChain.Store(chain) + if maxDepth <= 0 { + maxDepth = 3 + } + m.fallbackMaxDepth.Store(int32(maxDepth)) +} + +func (m *Manager) getFallbackChain() []string { + if m == nil { + return nil + } + chain, ok := m.fallbackChain.Load().([]string) + if !ok { + return nil + } + return chain +} + +func (m *Manager) getFallbackMaxDepth() int { + if m == nil { + return 3 + } + depth := m.fallbackMaxDepth.Load() + if depth <= 0 { + return 3 + } + return int(depth) +} + // RegisterExecutor registers a provider executor with the manager. func (m *Manager) RegisterExecutor(executor ProviderExecutor) { if executor == nil { @@ -628,6 +752,8 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req debugLogAuthSelection(entry, auth, provider, req.Model) publishSelectedAuthMetadata(opts.Metadata, auth.ID) + // Set provider auth info in context for gin logger + SetProviderAuthInContext(ctx, provider, auth.ID, auth.Label) tried[auth.ID] = struct{}{} execCtx := ctx if rt := m.roundTripperFor(auth); rt != nil { @@ -638,6 +764,10 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + // Store actual model name in context for logging + if execReq.Model != routeModel { + execCtx = SetFallbackInfoInContext(execCtx, routeModel, execReq.Model) + } resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -691,15 +821,25 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, publishSelectedAuthMetadata(opts.Metadata, auth.ID) tried[auth.ID] = struct{}{} + // Set provider auth info in context for gin logger + SetProviderAuthInContext(ctx, provider, auth.ID, auth.Label) execCtx := ctx if rt := m.roundTripperFor(auth); rt != nil { execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) + execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + // Store actual model name in context for logging + if execReq.Model != routeModel { + execCtx = SetFallbackInfoInContext(execCtx, routeModel, execReq.Model) + } execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + // Store actual model name in context for logging + if execReq.Model != routeModel { + execCtx = SetFallbackInfoInContext(execCtx, routeModel, execReq.Model) + } resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -753,6 +893,8 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string publishSelectedAuthMetadata(opts.Metadata, auth.ID) tried[auth.ID] = struct{}{} + // Set provider auth info in context for gin logger + SetProviderAuthInContext(ctx, provider, auth.ID, auth.Label) execCtx := ctx if rt := m.roundTripperFor(auth); rt != nil { execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) @@ -762,6 +904,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + // Store actual model name in context for logging + if execReq.Model != routeModel { + execCtx = SetFallbackInfoInContext(execCtx, routeModel, execReq.Model) + } streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { @@ -1136,6 +1282,31 @@ func (m *Manager) normalizeProviders(providers []string) []string { return result } +func (m *Manager) rotateProviders(model string, providers []string) []string { + if len(providers) == 0 { + return nil + } + + m.mu.Lock() + offset := m.providerOffsets[model] + m.providerOffsets[model] = (offset + 1) % len(providers) + m.mu.Unlock() + + if len(providers) > 0 { + offset %= len(providers) + } + if offset < 0 { + offset = 0 + } + if offset == 0 { + return providers + } + rotated := make([]string, 0, len(providers)) + rotated = append(rotated, providers[offset:]...) + rotated = append(rotated, providers[:offset]...) + return rotated +} + func (m *Manager) retrySettings() (int, int, time.Duration) { if m == nil { return 0, 0, 0 diff --git a/sdk/cliproxy/auth/oauth_model_alias.go b/sdk/cliproxy/auth/oauth_model_alias.go index 8563aac463..8be6bba6eb 100644 --- a/sdk/cliproxy/auth/oauth_model_alias.go +++ b/sdk/cliproxy/auth/oauth_model_alias.go @@ -5,6 +5,7 @@ import ( internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + log "github.com/sirupsen/logrus" ) type modelAliasEntry interface { @@ -73,10 +74,14 @@ func (m *Manager) SetOAuthModelAlias(aliases map[string][]internalconfig.OAuthMo // applyOAuthModelAlias resolves the upstream model from OAuth model alias. // If an alias exists, the returned model is the upstream model. func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string { + channel := modelAliasChannel(auth) + log.Debugf("[DEBUG] applyOAuthModelAlias: provider=%s model=%s channel=%s auth_kind=%v", auth.Provider, requestedModel, channel, auth.Attributes) upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel) if upstreamModel == "" { + log.Debugf("[DEBUG] applyOAuthModelAlias: no alias found, returning original model=%s", requestedModel) return requestedModel } + log.Debugf("[DEBUG] applyOAuthModelAlias: resolved %s -> %s", requestedModel, upstreamModel) return upstreamModel } @@ -147,6 +152,7 @@ func resolveUpstreamModelFromAliasTable(m *Manager, auth *Auth, requestedModel, return "" } if channel == "" { + log.Debugf("[DEBUG] resolveUpstreamModelFromAliasTable: empty channel for provider=%s", auth.Provider) return "" } @@ -163,12 +169,19 @@ func resolveUpstreamModelFromAliasTable(m *Manager, auth *Auth, requestedModel, raw := m.oauthModelAlias.Load() table, _ := raw.(*oauthModelAliasTable) if table == nil || table.reverse == nil { + log.Debugf("[DEBUG] resolveUpstreamModelFromAliasTable: no alias table loaded") return "" } rev := table.reverse[channel] if rev == nil { + var availableChannels []string + for k := range table.reverse { + availableChannels = append(availableChannels, k) + } + log.Debugf("[DEBUG] resolveUpstreamModelFromAliasTable: no entries for channel=%s, available=%v", channel, availableChannels) return "" } + log.Debugf("[DEBUG] resolveUpstreamModelFromAliasTable: channel=%s has %d aliases, looking for candidates=%v", channel, len(rev), candidates) for _, candidate := range candidates { key := strings.ToLower(strings.TrimSpace(candidate)) @@ -221,7 +234,7 @@ func modelAliasChannel(auth *Auth) string { // and auth kind. Returns empty string if the provider/authKind combination doesn't support // OAuth model alias (e.g., API key authentication). // -// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi. +// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi, kilo, kilocode. func OAuthModelAliasChannel(provider, authKind string) string { provider = strings.ToLower(strings.TrimSpace(provider)) authKind = strings.ToLower(strings.TrimSpace(authKind)) @@ -245,7 +258,7 @@ func OAuthModelAliasChannel(provider, authKind string) string { return "" } return "codex" - case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi": + case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "cline", "github-copilot", "kimi", "kilo", "kilocode": return provider default: return "" diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index cf79e17337..3cbca01b59 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -22,6 +22,7 @@ type RoundRobinSelector struct { mu sync.Mutex cursors map[string]int maxKeys int + Mode string // "key-based" or empty for default behavior } // FillFirstSelector selects the first available credential (deterministic ordering). diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 0e6d14213b..b0faa149c6 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -208,15 +208,17 @@ func (b *Builder) Build() (*Service, error) { } strategy := "" + mode := "" if b.cfg != nil { strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy)) + mode = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Mode)) } var selector coreauth.Selector switch strategy { case "fill-first", "fillfirst", "ff": selector = &coreauth.FillFirstSelector{} default: - selector = &coreauth.RoundRobinSelector{} + selector = &coreauth.RoundRobinSelector{Mode: mode} } coreManager = coreauth.NewManager(tokenStore, selector, nil) @@ -225,6 +227,8 @@ func (b *Builder) Build() (*Service, error) { coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider()) coreManager.SetConfig(b.cfg) coreManager.SetOAuthModelAlias(b.cfg.OAuthModelAlias) + coreManager.SetFallbackModels(b.cfg.Routing.FallbackModels) + coreManager.SetFallbackChain(b.cfg.Routing.FallbackChain, b.cfg.Routing.FallbackMaxDepth) service := &Service{ cfg: b.cfg, diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 0867a4fea8..42361a1d7c 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -13,6 +13,7 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" + kilocodeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilocode" kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" @@ -119,6 +120,7 @@ func newDefaultAuthManager() *sdkAuth.Manager { sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewQwenAuthenticator(), + sdkAuth.NewClineAuthenticator(), ) } @@ -434,10 +436,14 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg)) case "kiro": s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg)) + case "cline": + s.coreManager.RegisterExecutor(executor.NewClineExecutor(s.cfg)) case "kilo": s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg)) case "github-copilot": s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg)) + case "kilocode": + s.coreManager.RegisterExecutor(executor.NewKilocodeExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -575,9 +581,11 @@ func (s *Service) Run(ctx context.Context) error { var watcherWrapper *WatcherWrapper reloadCallback := func(newCfg *config.Config) { previousStrategy := "" + previousMode := "" s.cfgMu.RLock() if s.cfg != nil { previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy)) + previousMode = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Mode)) } s.cfgMu.RUnlock() @@ -591,6 +599,7 @@ func (s *Service) Run(ctx context.Context) error { } nextStrategy := strings.ToLower(strings.TrimSpace(newCfg.Routing.Strategy)) + nextMode := strings.ToLower(strings.TrimSpace(newCfg.Routing.Mode)) normalizeStrategy := func(strategy string) string { switch strategy { case "fill-first", "fillfirst", "ff": @@ -601,15 +610,16 @@ func (s *Service) Run(ctx context.Context) error { } previousStrategy = normalizeStrategy(previousStrategy) nextStrategy = normalizeStrategy(nextStrategy) - if s.coreManager != nil && previousStrategy != nextStrategy { + if s.coreManager != nil && (previousStrategy != nextStrategy || previousMode != nextMode) { var selector coreauth.Selector switch nextStrategy { case "fill-first": selector = &coreauth.FillFirstSelector{} default: - selector = &coreauth.RoundRobinSelector{} + selector = &coreauth.RoundRobinSelector{Mode: nextMode} } s.coreManager.SetSelector(selector) + log.Infof("routing strategy updated to %s (mode: %s)", nextStrategy, nextMode) } s.applyRetryConfig(newCfg) @@ -623,6 +633,8 @@ func (s *Service) Run(ctx context.Context) error { if s.coreManager != nil { s.coreManager.SetConfig(newCfg) s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias) + s.coreManager.SetFallbackModels(newCfg.Routing.FallbackModels) + s.coreManager.SetFallbackChain(newCfg.Routing.FallbackChain, newCfg.Routing.FallbackMaxDepth) } s.rebindExecutors() } @@ -870,7 +882,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = applyExcludedModels(models, excluded) case "kimi": models = registry.GetKimiModels() - models = applyExcludedModels(models, excluded) + models = applyExcludedModels(models, excluded) case "github-copilot": ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() @@ -879,7 +891,12 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "kiro": models = s.fetchKiroModels(a) models = applyExcludedModels(models, excluded) - case "kilo": + case "cline": + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + models = executor.FetchClineModels(ctx, a, s.cfg) + cancel() + models = applyExcludedModels(models, excluded) + case "kilo", "kilocode": models = executor.FetchKiloModels(context.Background(), a, s.cfg) models = applyExcludedModels(models, excluded) default: @@ -1399,12 +1416,15 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models } channel := coreauth.OAuthModelAliasChannel(provider, authKind) if channel == "" || len(cfg.OAuthModelAlias) == 0 { + log.Debugf("applyOAuthModelAlias: no channel or aliases (provider=%s, authKind=%s, channel=%s)", provider, authKind, channel) return models } aliases := cfg.OAuthModelAlias[channel] if len(aliases) == 0 { + log.Debugf("applyOAuthModelAlias: no aliases for channel=%s", channel) return models } + log.Debugf("applyOAuthModelAlias: processing %d aliases for channel=%s with %d models", len(aliases), channel, len(models)) type aliasEntry struct { alias string @@ -1484,6 +1504,7 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models } out = append(out, &clone) addedAlias = true + log.Debugf("applyOAuthModelAlias: created alias model id=%s from target=%s", mappedID, id) } if !keepOriginal && !addedAlias { @@ -1545,6 +1566,83 @@ func (s *Service) fetchKiroModels(a *coreauth.Auth) []*ModelInfo { return models } +// fetchKilocodeModels attempts to dynamically fetch Kilocode models from the API. +// If dynamic fetch fails, it falls back to static registry.GetKilocodeModels(). +func (s *Service) fetchKilocodeModels(a *coreauth.Auth) []*ModelInfo { + if a == nil { + log.Debug("kilocode: auth is nil, using static models") + return registry.GetKilocodeModels() + } + + token := s.extractKilocodeToken(a) + if token == "" { + log.Debug("kilocode: no valid token in auth, using static models") + return registry.GetKilocodeModels() + } + + // Create KilocodeAuth instance + kAuth := kilocodeauth.NewKilocodeAuth(s.cfg) + if kAuth == nil { + log.Warn("kilocode: failed to create KilocodeAuth instance, using static models") + return registry.GetKilocodeModels() + } + + // Use timeout context for API call + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // Attempt to fetch dynamic models + models, err := kAuth.FetchModels(ctx, token) + if err != nil { + log.Warnf("kilocode: failed to fetch dynamic models: %v, using static models", err) + return registry.GetKilocodeModels() + } + + if len(models) == 0 { + log.Debug("kilocode: API returned no models, using static models") + return registry.GetKilocodeModels() + } + + log.Infof("kilocode: successfully fetched %d free models from API", len(models)) + return models +} + +// extractKilocodeToken extracts Kilocode access token from auth attributes and metadata. +// It supports both config-based tokens (stored in Attributes) and file-based tokens (stored in Metadata). +func (s *Service) extractKilocodeToken(a *coreauth.Auth) string { + if a == nil { + return "" + } + + var token string + + // Priority 1: Try to get from Attributes (config.yaml source) + if a.Attributes != nil { + token = strings.TrimSpace(a.Attributes["token"]) + if token == "" { + token = strings.TrimSpace(a.Attributes["access_token"]) + } + } + + // Priority 2: If not found in Attributes, try Metadata (JSON file source) + if token == "" && a.Metadata != nil { + if tokenVal, ok := a.Metadata["token"]; ok { + if tokenStr, isStr := tokenVal.(string); isStr { + token = strings.TrimSpace(tokenStr) + } + } + if token == "" { + if accessTokenVal, ok := a.Metadata["access_token"]; ok { + if accessTokenStr, isStr := accessTokenVal.(string); isStr { + token = strings.TrimSpace(accessTokenStr) + } + } + } + } + + return token +} + // extractKiroTokenData extracts KiroTokenData from auth attributes and metadata. // It supports both config-based tokens (stored in Attributes) and file-based tokens (stored in Metadata). func (s *Service) extractKiroTokenData(a *coreauth.Auth) *kiroauth.KiroTokenData { @@ -1666,6 +1764,50 @@ func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo { result := make([]*ModelInfo, 0, len(models)*2) result = append(result, models...) + // [์ƒˆ๋กœ ์ถ”๊ฐ€] KiroExecutor๊ฐ€ ์ง€์›ํ•˜๋Š” ๊ฐ€์ƒ Friendly ID๋“ค์„ ๋ช…์‹œ์ ์œผ๋กœ ์ถ”๊ฐ€ + // ์ด๋ฅผ ํ†ตํ•ด ์‚ฌ์šฉ์ž๊ฐ€ OAuthModelAlias์—์„œ ์ด ์ด๋ฆ„๋“ค์„ ํƒ€๊ฒŸ์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•จ + virtualModels := []struct { + ID string + DisplayName string + }{ + {"kiro-claude-sonnet-4-5", "Kiro Claude Sonnet 4.5"}, + {"kiro-claude-sonnet-4", "Kiro Claude Sonnet 4"}, + {"kiro-claude-haiku-4-5", "Kiro Claude Haiku 4.5"}, + {"kiro-claude-sonnet-4-5-agentic", "Kiro Claude Sonnet 4.5 (Agentic)"}, + {"kiro-claude-sonnet-4-agentic", "Kiro Claude Sonnet 4 (Agentic)"}, + {"kiro-claude-haiku-4-5-agentic", "Kiro Claude Haiku 4.5 (Agentic)"}, + } + + seen := make(map[string]bool) + for _, m := range models { + seen[m.ID] = true + } + + // ๊ฐ€์ƒ ๋ชจ๋ธ ์ค‘ ์•„์ง ๋“ฑ๋ก๋˜์ง€ ์•Š์€ ๊ฒƒ๋งŒ ์ถ”๊ฐ€ + addedVirtuals := 0 + for _, vm := range virtualModels { + if !seen[vm.ID] { + virtual := &ModelInfo{ + ID: vm.ID, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "aws", + Type: "kiro", + DisplayName: vm.DisplayName, + Description: "Virtual model compatible with Kiro Executor", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + } + result = append(result, virtual) + seen[vm.ID] = true + addedVirtuals++ + } + } + if addedVirtuals > 0 { + log.Debugf("generateKiroAgenticVariants: added %d virtual models", addedVirtuals) + } + for _, m := range models { if m == nil { continue @@ -1681,9 +1823,15 @@ func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo { continue } + // Skip if agentic variant already exists (from virtual models) + agenticID := m.ID + "-agentic" + if seen[agenticID] { + continue + } + // Create agentic variant agentic := &ModelInfo{ - ID: m.ID + "-agentic", + ID: agenticID, Object: m.Object, Created: m.Created, OwnedBy: m.OwnedBy, @@ -1705,6 +1853,7 @@ func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo { } result = append(result, agentic) + seen[agenticID] = true } return result diff --git a/test/config_migration_test.go b/test/config_migration_test.go new file mode 100644 index 0000000000..2ed8788277 --- /dev/null +++ b/test/config_migration_test.go @@ -0,0 +1,195 @@ +package test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestLegacyConfigMigration(t *testing.T) { + t.Run("onlyLegacyFields", func(t *testing.T) { + path := writeConfig(t, ` +port: 8080 +generative-language-api-key: + - "legacy-gemini-1" +openai-compatibility: + - name: "legacy-provider" + base-url: "https://example.com" + api-keys: + - "legacy-openai-1" +amp-upstream-url: "https://amp.example.com" +amp-upstream-api-key: "amp-legacy-key" +amp-restrict-management-to-localhost: false +amp-model-mappings: + - from: "old-model" + to: "new-model" +`) + cfg, err := config.LoadConfig(path) + if err != nil { + t.Fatalf("load legacy config: %v", err) + } + if got := len(cfg.GeminiKey); got != 1 || cfg.GeminiKey[0].APIKey != "legacy-gemini-1" { + t.Fatalf("gemini migration mismatch: %+v", cfg.GeminiKey) + } + if got := len(cfg.OpenAICompatibility); got != 1 { + t.Fatalf("expected 1 openai-compat provider, got %d", got) + } + if entries := cfg.OpenAICompatibility[0].APIKeyEntries; len(entries) != 1 || entries[0].APIKey != "legacy-openai-1" { + t.Fatalf("openai-compat migration mismatch: %+v", entries) + } + if cfg.AmpCode.UpstreamURL != "https://amp.example.com" || cfg.AmpCode.UpstreamAPIKey != "amp-legacy-key" { + t.Fatalf("amp migration failed: %+v", cfg.AmpCode) + } + if cfg.AmpCode.RestrictManagementToLocalhost { + t.Fatalf("expected amp restriction to be false after migration") + } + if got := len(cfg.AmpCode.ModelMappings); got != 1 || cfg.AmpCode.ModelMappings[0].From != "old-model" { + t.Fatalf("amp mappings migration mismatch: %+v", cfg.AmpCode.ModelMappings) + } + updated := readFile(t, path) + if strings.Contains(updated, "generative-language-api-key") { + t.Fatalf("legacy gemini key still present:\n%s", updated) + } + if strings.Contains(updated, "amp-upstream-url") || strings.Contains(updated, "amp-restrict-management-to-localhost") { + t.Fatalf("legacy amp keys still present:\n%s", updated) + } + if strings.Contains(updated, "\n api-keys:") { + t.Fatalf("legacy openai compat keys still present:\n%s", updated) + } + }) + + t.Run("mixedLegacyAndNewFields", func(t *testing.T) { + path := writeConfig(t, ` +gemini-api-key: + - api-key: "new-gemini" +generative-language-api-key: + - "new-gemini" + - "legacy-gemini-only" +openai-compatibility: + - name: "mixed-provider" + base-url: "https://mixed.example.com" + api-key-entries: + - api-key: "new-entry" + api-keys: + - "legacy-entry" + - "new-entry" +`) + cfg, err := config.LoadConfig(path) + if err != nil { + t.Fatalf("load mixed config: %v", err) + } + if got := len(cfg.GeminiKey); got != 2 { + t.Fatalf("expected 2 gemini entries, got %d: %+v", got, cfg.GeminiKey) + } + seen := make(map[string]struct{}, len(cfg.GeminiKey)) + for _, entry := range cfg.GeminiKey { + if _, exists := seen[entry.APIKey]; exists { + t.Fatalf("duplicate gemini key %q after migration", entry.APIKey) + } + seen[entry.APIKey] = struct{}{} + } + provider := cfg.OpenAICompatibility[0] + if got := len(provider.APIKeyEntries); got != 2 { + t.Fatalf("expected 2 openai entries, got %d: %+v", got, provider.APIKeyEntries) + } + entrySeen := make(map[string]struct{}, len(provider.APIKeyEntries)) + for _, entry := range provider.APIKeyEntries { + if _, ok := entrySeen[entry.APIKey]; ok { + t.Fatalf("duplicate openai key %q after migration", entry.APIKey) + } + entrySeen[entry.APIKey] = struct{}{} + } + }) + + t.Run("onlyNewFields", func(t *testing.T) { + path := writeConfig(t, ` +gemini-api-key: + - api-key: "new-only" +openai-compatibility: + - name: "new-only-provider" + base-url: "https://new-only.example.com" + api-key-entries: + - api-key: "new-only-entry" +ampcode: + upstream-url: "https://amp.new" + upstream-api-key: "new-amp-key" + restrict-management-to-localhost: true + model-mappings: + - from: "a" + to: "b" +`) + cfg, err := config.LoadConfig(path) + if err != nil { + t.Fatalf("load new config: %v", err) + } + if len(cfg.GeminiKey) != 1 || cfg.GeminiKey[0].APIKey != "new-only" { + t.Fatalf("unexpected gemini entries: %+v", cfg.GeminiKey) + } + if len(cfg.OpenAICompatibility) != 1 || len(cfg.OpenAICompatibility[0].APIKeyEntries) != 1 { + t.Fatalf("unexpected openai compat entries: %+v", cfg.OpenAICompatibility) + } + if cfg.AmpCode.UpstreamURL != "https://amp.new" || cfg.AmpCode.UpstreamAPIKey != "new-amp-key" { + t.Fatalf("unexpected amp config: %+v", cfg.AmpCode) + } + }) + + t.Run("duplicateNamesDifferentBase", func(t *testing.T) { + path := writeConfig(t, ` +openai-compatibility: + - name: "dup-provider" + base-url: "https://provider-a" + api-keys: + - "key-a" + - name: "dup-provider" + base-url: "https://provider-b" + api-keys: + - "key-b" +`) + cfg, err := config.LoadConfig(path) + if err != nil { + t.Fatalf("load duplicate config: %v", err) + } + if len(cfg.OpenAICompatibility) != 2 { + t.Fatalf("expected 2 providers, got %d", len(cfg.OpenAICompatibility)) + } + for _, entry := range cfg.OpenAICompatibility { + if len(entry.APIKeyEntries) != 1 { + t.Fatalf("expected 1 key entry per provider: %+v", entry) + } + switch entry.BaseURL { + case "https://provider-a": + if entry.APIKeyEntries[0].APIKey != "key-a" { + t.Fatalf("provider-a key mismatch: %+v", entry.APIKeyEntries) + } + case "https://provider-b": + if entry.APIKeyEntries[0].APIKey != "key-b" { + t.Fatalf("provider-b key mismatch: %+v", entry.APIKeyEntries) + } + default: + t.Fatalf("unexpected provider base url: %s", entry.BaseURL) + } + } + }) +} + +func writeConfig(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(strings.TrimSpace(content)+"\n"), 0o644); err != nil { + t.Fatalf("write temp config: %v", err) + } + return path +} + +func readFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read temp config: %v", err) + } + return string(data) +}