diff --git a/.bun-version b/.bun-version new file mode 100644 index 000000000..0c00f6108 --- /dev/null +++ b/.bun-version @@ -0,0 +1 @@ +1.3.10 diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 6e06f3664..cf36c4f9d 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -13,6 +13,7 @@ permissions: jobs: dev-build-deploy: runs-on: ubuntu-latest + timeout-minutes: 90 # 跳过由GitHub Actions创建的提交,避免死循环 if: github.event.pusher.name != 'github-actions[bot]' && !contains(github.event.head_commit.message, '[skip ci]') steps: @@ -61,9 +62,22 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 + with: + bun-version-file: .bun-version + + - name: Cache Bun package cache + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- - name: Install dependencies, type check, and format code + timeout-minutes: 15 run: | + # No lockfile is committed in this repository, so use a non-frozen install. bun install bun run typecheck bun run format diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index e90516235..b9ae96a87 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -17,6 +17,7 @@ jobs: code-quality: runs-on: ubuntu-latest name: Code Quality Check + timeout-minutes: 15 steps: - name: 📥 Checkout repository @@ -24,13 +25,26 @@ jobs: - name: 📦 Setup Bun uses: oven-sh/setup-bun@v2 + with: + bun-version-file: .bun-version - name: 🟢 Setup Node.js uses: actions/setup-node@v4 with: node-version: '20' + - name: Cache Bun package cache + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- + - name: 📦 Install dependencies + timeout-minutes: 8 + # No lockfile is committed in this repository, so use a non-frozen install. run: bun install - name: 🔍 Type check @@ -53,6 +67,7 @@ jobs: build-check: runs-on: ubuntu-latest name: Docker Build Test + timeout-minutes: 45 steps: - name: 📥 Checkout repository diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a495efb80..5f86dbf09 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -30,6 +30,7 @@ permissions: jobs: release-pipeline: runs-on: ubuntu-latest + timeout-minutes: 120 # 跳过由GitHub Actions创建的提交,避免死循环 (仅对push事件生效) if: | github.event_name == 'workflow_dispatch' || @@ -197,10 +198,24 @@ jobs: - name: Setup Bun if: steps.check.outputs.needs_bump == 'true' || github.event_name == 'workflow_dispatch' uses: oven-sh/setup-bun@v2 + with: + bun-version-file: .bun-version + + - name: Cache Bun package cache + if: steps.check.outputs.needs_bump == 'true' || github.event_name == 'workflow_dispatch' + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- - name: Install dependencies, type check, and format code if: steps.check.outputs.needs_bump == 'true' || github.event_name == 'workflow_dispatch' + timeout-minutes: 15 run: | + # No lockfile is committed in this repository, so use a non-frozen install. bun install bun run typecheck bun run format diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e32c45ada..67f9f199b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,6 +16,7 @@ jobs: quality: name: 📋 Code Quality runs-on: ubuntu-latest + timeout-minutes: 15 steps: - name: Checkout code @@ -24,10 +25,21 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 with: - bun-version: latest + bun-version-file: .bun-version + + - name: Cache Bun package cache + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- - name: Install dependencies - run: bun install --frozen-lockfile + timeout-minutes: 8 + # No lockfile is committed in this repository, so use a non-frozen install. + run: bun install - name: Run linting run: bun run lint @@ -42,6 +54,7 @@ jobs: unit-tests: name: ⚡ Unit Tests runs-on: ubuntu-latest + timeout-minutes: 15 steps: - name: Checkout code @@ -49,9 +62,22 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 + with: + bun-version-file: .bun-version + + - name: Cache Bun package cache + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- - name: Install dependencies - run: bun install --frozen-lockfile + timeout-minutes: 8 + # No lockfile is committed in this repository, so use a non-frozen install. + run: bun install - name: Run unit tests run: bun run test -- tests/unit/ --passWithNoTests @@ -60,6 +86,7 @@ jobs: integration-tests: name: 🔗 Integration Tests runs-on: ubuntu-latest + timeout-minutes: 25 services: postgres: @@ -93,6 +120,8 @@ jobs: AUTO_MIGRATE: true ENABLE_RATE_LIMIT: true SESSION_TTL: 300 + VITEST_STATEFUL_MAX_WORKERS: 2 + VITEST_STATEFUL_MAX_CONCURRENCY: 3 steps: - name: Checkout code @@ -100,24 +129,37 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 + with: + bun-version-file: .bun-version + + - name: Cache Bun package cache + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- - name: Install dependencies - run: bun install --frozen-lockfile + timeout-minutes: 8 + # No lockfile is committed in this repository, so use a non-frozen install. + run: bun install - name: Run database migrations run: bun run db:migrate - name: Run integration tests run: > - bunx vitest run - tests/integration/usage-ledger.test.ts - tests/integration/my-usage-imported-ledger.test.ts + bun x vitest run + --config tests/configs/integration.config.ts --passWithNoTests # ==================== API 测试(需要运行服务)==================== api-tests: name: 🌐 API Tests runs-on: ubuntu-latest + timeout-minutes: 35 services: postgres: @@ -152,6 +194,8 @@ jobs: PORT: 13500 ENABLE_RATE_LIMIT: true SESSION_TTL: 300 + VITEST_STATEFUL_MAX_WORKERS: 1 + VITEST_STATEFUL_MAX_CONCURRENCY: 3 steps: - name: Checkout code @@ -159,14 +203,36 @@ jobs: - name: Setup Bun uses: oven-sh/setup-bun@v2 + with: + bun-version-file: .bun-version + + - name: Cache Bun package cache + uses: actions/cache@v4 + with: + path: ~/.bun/install/cache + # This repo intentionally does not track Bun lockfiles; rotate cache on package/runtime inputs. + key: ${{ runner.os }}-bun-${{ hashFiles('package.json', '.bun-version') }} + restore-keys: | + ${{ runner.os }}-bun- - name: Install dependencies - run: bun install --frozen-lockfile + timeout-minutes: 8 + # No lockfile is committed in this repository, so use a non-frozen install. + run: bun install - name: Run database migrations run: bun run db:migrate + - name: Cache Next.js build cache + uses: actions/cache@v4 + with: + path: ${{ github.workspace }}/.next/cache + key: ${{ runner.os }}-nextjs-${{ hashFiles('package.json', '.bun-version') }}-${{ hashFiles('src/**/*.js', 'src/**/*.jsx', 'src/**/*.ts', 'src/**/*.tsx', 'next.config.*', 'tsconfig.json', 'postcss.config.*') }} + restore-keys: | + ${{ runner.os }}-nextjs-${{ hashFiles('package.json', '.bun-version') }}- + - name: Build application + timeout-minutes: 15 run: bun run build - name: Start server (background) diff --git a/Dockerfile b/Dockerfile index 8576fbc54..f356cccc2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ # syntax=docker/dockerfile:1 FROM oven/bun:debian AS deps WORKDIR /app -COPY package.json bun.lockb* ./ -RUN bun install --frozen-lockfile +COPY package.json ./ +RUN bun install FROM oven/bun:debian AS builder WORKDIR /app diff --git a/messages/en/settings/statusPage.json b/messages/en/settings/statusPage.json index cb9becde4..5eae1b710 100644 --- a/messages/en/settings/statusPage.json +++ b/messages/en/settings/statusPage.json @@ -10,6 +10,7 @@ "intervalOption": "{minutes} min", "slug": "Slug", "slugTooltip": "URL identifier for this group's dedicated status page. Leave empty to auto-generate from the public display name. Lowercase letters, digits, and hyphens only.", + "duplicateSlug": "Slug {slug} is used by multiple public groups. Please make each group slug unique before saving.", "copy": "Copy", "sortOrder": "Sort Order", "sortOrderTooltip": "Display order on the public status page. Lower values appear first (same convention as provider priority). Equal values fall back to group name.", diff --git a/messages/ja/settings/statusPage.json b/messages/ja/settings/statusPage.json index 99a5ff89a..f41ea87b7 100644 --- a/messages/ja/settings/statusPage.json +++ b/messages/ja/settings/statusPage.json @@ -10,6 +10,7 @@ "intervalOption": "{minutes} 分", "slug": "Slug", "slugTooltip": "グループ専用ステータスページの URL 識別子。空欄の場合は公開表示名から自動生成されます。小文字、数字、ハイフンのみ使用可能です。", + "duplicateSlug": "Slug {slug} は複数の公開グループで使われています。保存前に各グループの slug を一意にしてください。", "copy": "説明文", "sortOrder": "並び順", "sortOrderTooltip": "公開ステータスページでの表示順。値が小さいほど先頭に表示されます(プロバイダー優先度と同じ規則)。同値の場合はグループ名順となります。", diff --git a/messages/ru/settings/statusPage.json b/messages/ru/settings/statusPage.json index 86d5cd63a..d7bb45386 100644 --- a/messages/ru/settings/statusPage.json +++ b/messages/ru/settings/statusPage.json @@ -10,6 +10,7 @@ "intervalOption": "{minutes} мин", "slug": "Slug", "slugTooltip": "Идентификатор URL для выделенной страницы статуса этой группы. Если оставить пустым, будет создан автоматически из публичного имени. Разрешены только строчные буквы, цифры и дефисы.", + "duplicateSlug": "Slug {slug} используется несколькими публичными группами. Перед сохранением сделайте slug каждой группы уникальным.", "copy": "Пояснение", "sortOrder": "Порядок", "sortOrderTooltip": "Порядок отображения на публичной странице статуса. Меньшие значения — выше (как у приоритета провайдеров). При равных значениях порядок определяется по имени группы.", diff --git a/messages/zh-CN/settings/statusPage.json b/messages/zh-CN/settings/statusPage.json index 83f9d5fb7..8019df3b0 100644 --- a/messages/zh-CN/settings/statusPage.json +++ b/messages/zh-CN/settings/statusPage.json @@ -10,6 +10,7 @@ "intervalOption": "{minutes} 分钟", "slug": "Slug", "slugTooltip": "用于分组独立状态页的 URL 标识。留空将基于对外显示名自动生成。只能包含小写字母、数字和连字符。", + "duplicateSlug": "Slug {slug} 被多个公开分组使用。请先修改为互不重复的 slug 再保存。", "copy": "说明文案", "sortOrder": "排序", "sortOrderTooltip": "控制公开状态页中分组的显示顺序。数值越小越靠前(与供应商优先级一致)。相同数值按分组名排序。", diff --git a/messages/zh-TW/settings/statusPage.json b/messages/zh-TW/settings/statusPage.json index 107bae522..2c8ef1ec3 100644 --- a/messages/zh-TW/settings/statusPage.json +++ b/messages/zh-TW/settings/statusPage.json @@ -10,6 +10,7 @@ "intervalOption": "{minutes} 分鐘", "slug": "Slug", "slugTooltip": "用於分組獨立狀態頁的 URL 識別碼。留空將根據對外顯示名自動產生。只能包含小寫字母、數字與連字號。", + "duplicateSlug": "Slug {slug} 被多個公開分組使用。請先修改為互不重複的 slug 再儲存。", "copy": "說明文案", "sortOrder": "排序", "sortOrderTooltip": "控制公開狀態頁中分組的顯示順序。數值越小越靠前(與供應商優先級一致)。相同數值按分組名排序。", diff --git a/package.json b/package.json index 8a5ffff7d..c033fb255 100644 --- a/package.json +++ b/package.json @@ -122,7 +122,7 @@ "@types/react": "^19", "@types/react-dom": "^19", "@types/react-syntax-highlighter": "^15", - "@typescript/native-preview": "7.0.0-dev.20260321.1", + "@typescript/native-preview": "7.0.0-dev.20260425.1", "@vitest/coverage-v8": "^4", "@vitest/ui": "^4", "bun-types": "^1", diff --git a/src/app/[locale]/internal/dashboard/big-screen/page.tsx b/src/app/[locale]/internal/dashboard/big-screen/page.tsx index a5e7145b9..8366aa72b 100644 --- a/src/app/[locale]/internal/dashboard/big-screen/page.tsx +++ b/src/app/[locale]/internal/dashboard/big-screen/page.tsx @@ -38,6 +38,7 @@ import { import useSWR from "swr"; import { getDashboardRealtimeData } from "@/actions/dashboard-realtime"; import { type Locale, localeLabels, locales } from "@/i18n/config"; +import { normalizePathnameForLocaleNavigation } from "@/i18n/pathname"; import { usePathname, useRouter } from "@/i18n/routing"; import { CURRENCY_CONFIG, type CurrencyCode } from "@/lib/utils/currency"; @@ -726,7 +727,7 @@ export default function BigScreenPage() { const nextIndex = (currentIndex + 1) % locales.length; const nextLocale = locales[nextIndex]; - router.push(pathname || "/dashboard", { locale: nextLocale }); + router.push(normalizePathnameForLocaleNavigation(pathname), { locale: nextLocale }); }; const theme = THEMES[themeMode as keyof typeof THEMES]; diff --git a/src/app/[locale]/layout.tsx b/src/app/[locale]/layout.tsx index 302cc7c81..30904d22b 100644 --- a/src/app/[locale]/layout.tsx +++ b/src/app/[locale]/layout.tsx @@ -3,7 +3,7 @@ import "../globals.css"; import { headers } from "next/headers"; import { notFound } from "next/navigation"; import { NextIntlClientProvider } from "next-intl"; -import { getMessages } from "next-intl/server"; +import { getMessages, setRequestLocale } from "next-intl/server"; import { Footer } from "@/components/customs/footer"; import { Toaster } from "@/components/ui/sonner"; import { type Locale, locales } from "@/i18n/config"; @@ -82,6 +82,9 @@ export default async function RootLayout({ notFound(); } + // 将路由段 locale 固定到 next-intl 请求上下文,避免后续导航回落到默认语言。 + setRequestLocale(locale); + // Load translation messages const messages = await getMessages({ locale }); const timeZone = isPublicStatusRequest @@ -93,7 +96,7 @@ export default async function RootLayout({ return ( - +
{children}
diff --git a/src/app/[locale]/login/redirect-safety.ts b/src/app/[locale]/login/redirect-safety.ts index 641ea8a6a..55ca38b04 100644 --- a/src/app/[locale]/login/redirect-safety.ts +++ b/src/app/[locale]/login/redirect-safety.ts @@ -1,3 +1,5 @@ +import { normalizePathnameForLocaleNavigation } from "@/i18n/pathname"; + const DEFAULT_REDIRECT_PATH = "/dashboard"; const PROTOCOL_LIKE_PATTERN = /^[a-zA-Z][a-zA-Z\d+.-]*:/; @@ -25,7 +27,7 @@ export function sanitizeRedirectPath(from: string): string { return DEFAULT_REDIRECT_PATH; } - return candidate; + return normalizePathnameForLocaleNavigation(candidate, DEFAULT_REDIRECT_PATH); } export function resolveLoginRedirectTarget(redirectTo: unknown, from: string): string { diff --git a/src/app/[locale]/settings/status-page/_components/public-status-settings-form.tsx b/src/app/[locale]/settings/status-page/_components/public-status-settings-form.tsx index 4118dad2f..7bd17a57c 100644 --- a/src/app/[locale]/settings/status-page/_components/public-status-settings-form.tsx +++ b/src/app/[locale]/settings/status-page/_components/public-status-settings-form.tsx @@ -3,7 +3,7 @@ import { ChevronDown, ChevronRight, ExternalLink, Info, Save } from "lucide-react"; import { useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; -import { useMemo, useState, useTransition } from "react"; +import { useMemo, useRef, useState, useTransition } from "react"; import { toast } from "sonner"; import { type SavePublicStatusSettingsInput, @@ -31,8 +31,13 @@ import { getProviderTypeTranslationKey, getUserFacingProviderTypes, } from "@/lib/provider-type-utils"; -import type { PublicStatusModelConfig } from "@/lib/public-status/config"; +import { + normalizePublicGroupSlug, + type PublicStatusModelConfig, + slugifyPublicGroup, +} from "@/lib/public-status/config"; import { PUBLIC_STATUS_INTERVAL_OPTIONS } from "@/lib/public-status/constants"; +import { cn } from "@/lib/utils"; import type { ProviderType } from "@/types/provider"; import { normalizePublicStatusModels, @@ -59,14 +64,37 @@ function getPublishableGroupCount(groups: PublicStatusSettingsFormGroup[]): numb return groups.filter((group) => group.enabled && group.publicModels.length > 0).length; } -function slugifyGroupName(input: string): string { - const trimmed = (input || "").trim().toLowerCase(); - if (!trimmed) return ""; - return trimmed - .replace(/[^a-z0-9\s-]+/g, "") - .replace(/\s+/g, "-") - .replace(/-+/g, "-") - .replace(/^-+|-+$/g, ""); +interface DuplicateSlugErrorState { + slug: string; + groupNames: string[]; +} + +function findDuplicateSlugError( + groups: PublicStatusSettingsFormGroup[] +): DuplicateSlugErrorState | null { + const groupNamesBySlug = new Map(); + + for (const group of groups) { + if (!group.enabled || normalizePublicStatusModels(group.publicModels).length === 0) { + continue; + } + + const normalizedSlug = normalizePublicGroupSlug(group.groupName, group.publicGroupSlug); + const groupNames = groupNamesBySlug.get(normalizedSlug); + if (groupNames) { + groupNames.push(group.groupName); + } else { + groupNamesBySlug.set(normalizedSlug, [group.groupName]); + } + } + + for (const [slug, groupNames] of groupNamesBySlug) { + if (groupNames.length > 1) { + return { slug, groupNames }; + } + } + + return null; } function InfoTip({ text }: { text: string }) { @@ -111,7 +139,11 @@ export function PublicStatusSettingsForm({ const [collapsedGroups, setCollapsedGroups] = useState>(() => Object.fromEntries(initialGroups.map((group) => [group.groupName, !group.enabled])) ); + const [duplicateSlugError, setDuplicateSlugError] = useState( + null + ); const [isPending, startTransition] = useTransition(); + const slugInputRefs = useRef(new Map()); const enabledGroupCount = useMemo(() => getPublishableGroupCount(groups), [groups]); const previewHref = "/status"; @@ -137,6 +169,27 @@ export function PublicStatusSettingsForm({ }; const handleSave = () => { + const nextDuplicateSlugError = findDuplicateSlugError(groups); + if (nextDuplicateSlugError) { + setDuplicateSlugError(nextDuplicateSlugError); + setCollapsedGroups((current) => ({ + ...current, + ...Object.fromEntries( + nextDuplicateSlugError.groupNames.map((groupName) => [groupName, false]) + ), + })); + toast.error(t("statusPage.form.duplicateSlug", { slug: nextDuplicateSlugError.slug })); + + window.requestAnimationFrame(() => { + const firstInput = slugInputRefs.current.get(nextDuplicateSlugError.groupNames[0]); + firstInput?.scrollIntoView({ behavior: "smooth", block: "center" }); + firstInput?.focus(); + }); + return; + } + + setDuplicateSlugError(null); + const payload: SavePublicStatusSettingsInput = { publicStatusWindowHours: Number(windowHours), publicStatusAggregationIntervalMinutes: Number(aggregationIntervalMinutes), @@ -249,9 +302,17 @@ export function PublicStatusSettingsForm({ {groups.map((group, index) => { const isCollapsed = collapsedGroups[group.groupName] ?? false; const selectedModelKeys = group.publicModels.map((model) => model.modelKey); + const isSlugConflict = + duplicateSlugError?.groupNames.includes(group.groupName) ?? false; return ( - +
{ + if (element) { + slugInputRefs.current.set(group.groupName, element); + } else { + slugInputRefs.current.delete(group.groupName); + } + }} value={group.publicGroupSlug} - onChange={(event) => + onChange={(event) => { + setDuplicateSlugError(null); updateGroup(index, { publicGroupSlug: event.target.value, - }) - } - placeholder={slugifyGroupName(group.displayName || group.groupName)} + }); + }} + placeholder={slugifyPublicGroup(group.displayName || group.groupName)} disabled={isPending} + aria-invalid={isSlugConflict || undefined} + className={cn( + isSlugConflict && + "border-destructive bg-destructive/5 focus-visible:border-destructive focus-visible:ring-destructive/30" + )} /> + {isSlugConflict ? ( +

+ {t("statusPage.form.duplicateSlug", { + slug: duplicateSlugError?.slug ?? "", + })} +

+ ) : null}
diff --git a/src/app/[locale]/settings/status-page/loader.ts b/src/app/[locale]/settings/status-page/loader.ts index 1eda4630e..de9613e1f 100644 --- a/src/app/[locale]/settings/status-page/loader.ts +++ b/src/app/[locale]/settings/status-page/loader.ts @@ -1,5 +1,9 @@ import { bootstrapProviderGroupsFromProviders } from "@/lib/provider-groups/bootstrap"; -import { parsePublicStatusDescription } from "@/lib/public-status/config"; +import { + createUniquePublicGroupSlug, + normalizePublicGroupSlug, + parsePublicStatusDescription, +} from "@/lib/public-status/config"; import { getSystemSettings } from "@/repository/system-config"; import type { PublicStatusSettingsFormGroup } from "./_components/public-status-settings-form"; @@ -10,18 +14,32 @@ export async function loadStatusPageSettings(): Promise<{ }> { const settings = await getSystemSettings(); const { groups } = await bootstrapProviderGroupsFromProviders(); + const parsedGroups = groups.map((group) => ({ + group, + parsed: parsePublicStatusDescription(group.description), + })); + const usedDefaultSlugs = new Set(); + for (const { group, parsed } of parsedGroups) { + if (parsed.publicStatus?.publicGroupSlug) { + usedDefaultSlugs.add( + normalizePublicGroupSlug(group.name, parsed.publicStatus.publicGroupSlug) + ); + } + } return { initialWindowHours: settings.publicStatusWindowHours, initialAggregationIntervalMinutes: settings.publicStatusAggregationIntervalMinutes, - initialGroups: groups.map((group) => { - const parsed = parsePublicStatusDescription(group.description); + initialGroups: parsedGroups.map(({ group, parsed }) => { + const publicGroupSlug = + parsed.publicStatus?.publicGroupSlug ?? + createUniquePublicGroupSlug(group.name, usedDefaultSlugs); return { groupName: group.name, enabled: (parsed.publicStatus?.publicModels.length ?? 0) > 0, displayName: parsed.publicStatus?.displayName ?? "", - publicGroupSlug: parsed.publicStatus?.publicGroupSlug ?? "", + publicGroupSlug, explanatoryCopy: parsed.publicStatus?.explanatoryCopy ?? "", sortOrder: parsed.publicStatus?.sortOrder ?? 0, publicModels: parsed.publicStatus?.publicModels ?? [], diff --git a/src/app/page.tsx b/src/app/page.tsx index 75e5ec785..4abdc759b 100644 --- a/src/app/page.tsx +++ b/src/app/page.tsx @@ -1,6 +1,11 @@ -import { defaultLocale } from "@/i18n/config"; +import { cookies } from "next/headers"; +import { defaultLocale, localeCookieName } from "@/i18n/config"; +import { getLocaleFromValue } from "@/i18n/pathname"; import { redirect } from "@/i18n/routing"; -export default function RootPage() { - redirect({ href: "/dashboard", locale: defaultLocale }); +export default async function RootPage() { + const cookieStore = await cookies(); + const locale = getLocaleFromValue(cookieStore.get(localeCookieName)?.value) || defaultLocale; + + redirect({ href: "/dashboard", locale }); } diff --git a/src/app/v1/_lib/proxy/client-abort-listener.ts b/src/app/v1/_lib/proxy/client-abort-listener.ts new file mode 100644 index 000000000..d47758907 --- /dev/null +++ b/src/app/v1/_lib/proxy/client-abort-listener.ts @@ -0,0 +1,25 @@ +export function bindClientAbortListener( + signal: AbortSignal | null | undefined, + onAbort: () => void +): () => void { + if (!signal) { + return () => {}; + } + + if (signal.aborted) { + onAbort(); + return () => {}; + } + + let cleaned = false; + signal.addEventListener("abort", onAbort, { once: true }); + + return () => { + if (cleaned) { + return; + } + cleaned = true; + // 正常完成时也要解绑,避免 listener 闭包继续持有 session 与请求体。 + signal.removeEventListener("abort", onAbort); + }; +} diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index 1e97622db..8b6f383f1 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -26,6 +26,7 @@ import { getPreferredProviderEndpoints, } from "@/lib/provider-endpoints/endpoint-selector"; import { getGlobalAgentPool, getProxyAgentForProvider } from "@/lib/proxy-agent"; +import { RateLimitService } from "@/lib/rate-limit/service"; import { SessionManager } from "@/lib/session-manager"; import { detectUpstreamErrorFromSseOrJsonText, @@ -49,6 +50,7 @@ import { GEMINI_PROTOCOL } from "../gemini/protocol"; import { HeaderProcessor, resolveAnthropicAuthHeaders } from "../headers"; import { buildProxyUrl } from "../url"; import { rectifyBillingHeader } from "./billing-header-rectifier"; +import { bindClientAbortListener } from "./client-abort-listener"; import { deriveClientSafeUpstreamErrorMessage } from "./client-error-message"; import { isStandardProxyEndpointPath } from "./endpoint-family-catalog"; import { resolveEndpointPolicy, shouldEnforceStrictEndpointPoolPolicy } from "./endpoint-policy"; @@ -1077,7 +1079,7 @@ export class ProxyForwarder { }); } - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); attemptCount = maxAttemptsPerProvider; } else { endpointCandidates.push({ endpointId: null, baseUrl: currentProvider.url }); @@ -1140,7 +1142,7 @@ export class ProxyForwarder { vendorId: currentProvider.providerVendorId, providerType: currentProvider.providerType, }); - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); attemptCount = maxAttemptsPerProvider; } @@ -1708,7 +1710,7 @@ export class ProxyForwarder { const env = getEnvConfig(); // 无论是否计入熔断器,都要加入 failedProviderIds(避免重复选择同一供应商) - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); if (env.ENABLE_CIRCUIT_BREAKER_ON_NETWORK_ERRORS) { logger.warn( @@ -1806,7 +1808,7 @@ export class ProxyForwarder { } // 重试耗尽:加入失败列表并切换供应商 - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // ⭐ 跳出内层循环,进入供应商切换逻辑 } @@ -1878,7 +1880,7 @@ export class ProxyForwarder { } } - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // 跳出内层循环,进入供应商切换逻辑 } @@ -1927,7 +1929,7 @@ export class ProxyForwarder { currentProvider.providerVendorId, currentProvider.providerType ); - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; } @@ -2023,7 +2025,7 @@ export class ProxyForwarder { } // 加入失败列表并切换供应商 - failedProviderIds.push(currentProvider.id); + ProxyForwarder.markProviderFailed(session, failedProviderIds, currentProvider.id); break; // 跳出内层循环,进入供应商切换逻辑 } } @@ -3397,6 +3399,7 @@ export class ProxyForwarder { let lastError: Error | null = null; let lastErrorCategory: ErrorCategory | null = null; const attempts = new Set(); + const failedProviderIds: number[] = []; let resolveResult: ((result: { response?: Response; error?: Error }) => void) | null = null; const resultPromise = new Promise<{ response?: Response; error?: Error }>((resolve) => { @@ -3444,6 +3447,7 @@ export class ProxyForwarder { attemptNumber: attempt.sequence, modelRedirect: getAttemptModelRedirect(attempt), }); + ProxyForwarder.markProviderFailed(session, failedProviderIds, attempt.provider.id); } try { attempt.responseController?.abort(new Error(reason)); @@ -3511,21 +3515,24 @@ export class ProxyForwarder { } launchingAlternative = (async () => { - const alternativeProvider = await ProxyForwarder.selectAlternative( - session, - Array.from(launchedProviderIds) - ); - if (!alternativeProvider) { - noMoreProviders = true; - // No alternative providers available — let in-flight attempt(s) continue. - // If all attempts already completed, settle with last error. - if (attempts.size === 0) { - await finishIfExhausted(); + while (!settled && !winnerCommitted && !noMoreProviders) { + const alternativeProvider = await ProxyForwarder.selectAlternative( + session, + Array.from(launchedProviderIds) + ); + if (!alternativeProvider) { + noMoreProviders = true; + // No alternative providers available — let in-flight attempt(s) continue. + // If all attempts already completed, settle with last error. + if (attempts.size === 0) { + await finishIfExhausted(); + } + return; } - return; - } - await startAttempt(alternativeProvider, false); + const launched = await startAttempt(alternativeProvider, false); + if (launched) return; + } })() .catch(async (error) => { const normalizedError = error instanceof Error ? error : new Error(String(error)); @@ -3767,6 +3774,7 @@ export class ProxyForwarder { attempt.thresholdTimer = null; } attempts.delete(attempt); + ProxyForwarder.markProviderFailed(session, failedProviderIds, attempt.provider.id); if (errorCategory === ErrorCategory.PROVIDER_ERROR && statusCode !== 404) { await recordFailure(attempt.provider.id, error); @@ -3916,11 +3924,40 @@ export class ProxyForwarder { settleSuccess(response); }; - const startAttempt = async (provider: Provider, useOriginalSession: boolean) => { - if (settled || winnerCommitted || launchedProviderIds.has(provider.id)) return; + const startAttempt = async ( + provider: Provider, + useOriginalSession: boolean + ): Promise => { + if (settled || winnerCommitted || noMoreProviders || launchedProviderIds.has(provider.id)) { + return false; + } launchedProviderIds.add(provider.id); + if (!useOriginalSession && session.sessionId) { + const limit = provider.limitConcurrentSessions || 0; + const checkResult = await RateLimitService.checkAndTrackProviderSession( + provider.id, + session.sessionId, + limit + ); + + if (!checkResult.allowed) { + ProxyForwarder.markProviderFailed(session, failedProviderIds, provider.id); + session.addProviderToChain(provider, { + reason: "concurrent_limit_failed", + circuitState: getCircuitState(provider.id), + attemptNumber: launchedProviderCount + 1, + errorMessage: checkResult.reason || "并发限制已达到", + }); + return false; + } + + if (checkResult.referenced) { + session.recordProviderSessionRef(provider.id); + } + } + let endpointSelection: { endpointId: number | null; baseUrl: string; @@ -3931,9 +3968,9 @@ export class ProxyForwarder { } catch (endpointError) { lastError = endpointError as Error; lastErrorCategory = null; - await launchAlternative(); + ProxyForwarder.markProviderFailed(session, failedProviderIds, provider.id); await finishIfExhausted(); - return; + return false; } launchedProviderCount += 1; @@ -3984,41 +4021,43 @@ export class ProxyForwarder { armAttemptThreshold(attempt); runAttempt(attempt); + return true; }; - if (session.clientAbortSignal) { - session.clientAbortSignal.addEventListener( - "abort", - () => { - if (settled || winnerCommitted) return; - noMoreProviders = true; - lastError = new ProxyError("Request aborted by client", 499, undefined, true); - lastErrorCategory = ErrorCategory.CLIENT_ABORT; - for (const attempt of Array.from(attempts)) { - if (!attempt.settled) { - session.addProviderToChain(attempt.provider, { - ...attempt.endpointAudit, - reason: "client_abort", - attemptNumber: attempt.sequence, - errorMessage: "Client aborted request", - modelRedirect: getAttemptModelRedirect(attempt), - }); - } - } - abortAllAttempts(undefined, "client_abort"); - void finishIfExhausted(); - }, - { once: true } - ); - } + const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { + if (settled || winnerCommitted) return; + noMoreProviders = true; + lastError = new ProxyError("Request aborted by client", 499, undefined, true); + lastErrorCategory = ErrorCategory.CLIENT_ABORT; + for (const attempt of Array.from(attempts)) { + if (!attempt.settled) { + session.addProviderToChain(attempt.provider, { + ...attempt.endpointAudit, + reason: "client_abort", + attemptNumber: attempt.sequence, + errorMessage: "Client aborted request", + modelRedirect: getAttemptModelRedirect(attempt), + }); + } + } + abortAllAttempts(undefined, "client_abort"); + void finishIfExhausted(); + }); - await startAttempt(initialProvider, true); - await finishIfExhausted(); - const result = await resultPromise; - if (result.error) { - throw result.error; + try { + const initialLaunched = await startAttempt(initialProvider, true); + if (!initialLaunched) { + await launchAlternative(); + } + await finishIfExhausted(); + const result = await resultPromise; + if (result.error) { + throw result.error; + } + return result.response as Response; + } finally { + cleanupClientAbortListener(); } - return result.response as Response; } private static async resolveStreamingHedgeEndpoint( @@ -4250,6 +4289,32 @@ export class ProxyForwarder { await SessionManager.clearSessionProvider(session.sessionId); } + private static markProviderFailed( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ): void { + if (failedProviderIds.includes(providerId)) { + return; + } + + failedProviderIds.push(providerId); + + if (!session.sessionId) { + return; + } + + const providerSessionRefConsumer = ( + session as { consumeProviderSessionRef?: (providerId: number) => boolean } + ).consumeProviderSessionRef; + + if (!providerSessionRefConsumer?.call(session, providerId)) { + return; + } + + void RateLimitService.releaseProviderSession(providerId, session.sessionId); + } + private static buildAllProvidersUnavailableError(finalError?: Error | null): ProxyError { const safeClientMessageCandidate = finalError instanceof ProxyError && diff --git a/src/app/v1/_lib/proxy/provider-selector.ts b/src/app/v1/_lib/proxy/provider-selector.ts index 8bdd4ef5f..8b9fde2b9 100644 --- a/src/app/v1/_lib/proxy/provider-selector.ts +++ b/src/app/v1/_lib/proxy/provider-selector.ts @@ -295,6 +295,10 @@ export class ProxyProviderResolver { } // === 成功 === + if (checkResult.referenced) { + session.recordProviderSessionRef(session.provider.id); + } + logger.debug("ProviderSelector: Session tracked atomically", { sessionId: session.sessionId, providerName: session.provider.name, diff --git a/src/app/v1/_lib/proxy/response-handler.ts b/src/app/v1/_lib/proxy/response-handler.ts index b7ac734cc..e2dd19edc 100644 --- a/src/app/v1/_lib/proxy/response-handler.ts +++ b/src/app/v1/_lib/proxy/response-handler.ts @@ -40,6 +40,7 @@ import type { LongContextPricingSpecialSetting } from "@/types/special-settings" import { GeminiAdapter } from "../gemini/adapter"; import type { GeminiResponse } from "../gemini/types"; import { extractActualResponseModelForProvider } from "./actual-response-model"; +import { bindClientAbortListener } from "./client-abort-listener"; import { isClientAbortError, isTransportError } from "./errors"; import type { ProxySession } from "./session"; import { consumeDeferredStreamingFinalization } from "./stream-finalization"; @@ -1073,6 +1074,10 @@ export class ProxyResponseHandler { // 使用 AsyncTaskManager 管理后台处理任务 const taskId = `non-stream-${messageContext?.id || `unknown-${Date.now()}`}`; const abortController = new AbortController(); + const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { + AsyncTaskManager.cancel(taskId); + abortController.abort(); + }); const processingPromise = (async () => { const finalizeNonStreamAbort = async (): Promise => { @@ -1502,6 +1507,7 @@ export class ProxyResponseHandler { }); } } finally { + cleanupClientAbortListener(); releaseSessionAgent(session); AsyncTaskManager.cleanup(taskId); } @@ -1526,14 +1532,6 @@ export class ProxyResponseHandler { }); }); - // 客户端断开时取消任务 - if (session.clientAbortSignal) { - session.clientAbortSignal.addEventListener("abort", () => { - AsyncTaskManager.cancel(taskId); - abortController.abort(); - }); - } - void persistNonStreamAfterSnapshot(finalResponse).catch((error) => { logger.error("[ResponseHandler] Failed to persist non-stream after snapshot", { error }); }); @@ -2128,6 +2126,26 @@ export class ProxyResponseHandler { // ⭐ 提升 idleTimeoutId 到外部作用域,以便客户端断开时能清除 let idleTimeoutId: NodeJS.Timeout | null = null; + const cleanupClientAbortListener = bindClientAbortListener(session.clientAbortSignal, () => { + logger.debug("ResponseHandler: Client disconnected, cleaning up", { + taskId, + providerId: provider.id, + messageId: messageContext.id, + }); + + // 客户端断开时清除 idle timeout,避免任务已取消后仍误触发。 + if (idleTimeoutId) { + clearTimeout(idleTimeoutId); + idleTimeoutId = null; + logger.debug("ResponseHandler: Idle timeout cleared due to client disconnect", { + taskId, + providerId: provider.id, + }); + } + + AsyncTaskManager.cancel(taskId); + abortController.abort(); + }); const processingPromise = (async () => { const reader = internalStream.getReader(); @@ -2757,6 +2775,7 @@ export class ProxyResponseHandler { } } finally { // 确保资源释放 + cleanupClientAbortListener(); clearIdleTimer(); // ⭐ 清除静默期计时器(防止泄漏) try { reader.releaseLock(); @@ -2791,34 +2810,6 @@ export class ProxyResponseHandler { }); }); - // 客户端断开时取消任务并清除 idle timer - if (session.clientAbortSignal) { - session.clientAbortSignal.addEventListener("abort", () => { - logger.debug("ResponseHandler: Client disconnected, cleaning up", { - taskId, - providerId: provider.id, - messageId: messageContext.id, - }); - - // ⭐ 1. 清除 idle timeout(避免误触发) - if (idleTimeoutId) { - clearTimeout(idleTimeoutId); - idleTimeoutId = null; - logger.debug("ResponseHandler: Idle timeout cleared due to client disconnect", { - taskId, - providerId: provider.id, - }); - } - - // 2. 取消后台任务 - AsyncTaskManager.cancel(taskId); - abortController.abort(); - - // 注意:不需要 streamController.error()(客户端已断开) - // 注意:不需要 responseController.abort()(上游会自然结束) - }); - } - // ⭐ 修复 Bun 运行时的 Transfer-Encoding 重复问题 // 清理上游的传输 headers,让 Response API 自动管理 const finalStreamHeaders = cleanResponseHeaders(response.headers); diff --git a/src/app/v1/_lib/proxy/session.ts b/src/app/v1/_lib/proxy/session.ts index 5e892c1c7..692e695d2 100644 --- a/src/app/v1/_lib/proxy/session.ts +++ b/src/app/v1/_lib/proxy/session.ts @@ -176,6 +176,10 @@ export class ProxySession { */ private providersSnapshot: Provider[] | null = null; + // 本请求已通过 Provider 并发检查获得的引用。 + // 失败切换 provider 时只能释放这里记录过的引用,避免 hedge/fallback 释放未 acquire 的 Redis 计数。 + private providerSessionRefs = new Set(); + private constructor(init: { startTime: number; method: string; @@ -313,6 +317,25 @@ export class ProxySession { } } + recordProviderSessionRef(providerId: number): void { + if (!this.providerSessionRefs) { + this.providerSessionRefs = new Set(); + } + + if (Number.isInteger(providerId) && providerId > 0) { + this.providerSessionRefs.add(providerId); + } + } + + consumeProviderSessionRef(providerId: number): boolean { + if (!this.providerSessionRefs?.has(providerId)) { + return false; + } + + this.providerSessionRefs.delete(providerId); + return true; + } + setCacheTtlResolved(ttl: CacheTtlResolved | null): void { this.cacheTtlResolved = ttl; } diff --git a/src/components/ui/language-switcher.tsx b/src/components/ui/language-switcher.tsx index 60c74ade7..1f93b3bf3 100644 --- a/src/components/ui/language-switcher.tsx +++ b/src/components/ui/language-switcher.tsx @@ -11,6 +11,7 @@ import { DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { type Locale, localeLabels, locales } from "@/i18n/config"; +import { normalizePathnameForLocaleNavigation } from "@/i18n/pathname"; import { usePathname, useRouter } from "@/i18n/routing"; import { cn } from "@/lib/utils/index"; @@ -40,7 +41,7 @@ export function LanguageSwitcher({ className, size = "sm" }: LanguageSwitcherPro setIsTransitioning(true); try { - router.push(pathname || "/dashboard", { locale: newLocale }); + router.push(normalizePathnameForLocaleNavigation(pathname), { locale: newLocale }); } catch (error) { console.error("Failed to switch locale:", error); setIsTransitioning(false); diff --git a/src/i18n/config.ts b/src/i18n/config.ts index 158139bc6..4769f1a05 100644 --- a/src/i18n/config.ts +++ b/src/i18n/config.ts @@ -12,6 +12,9 @@ export type Locale = (typeof locales)[number]; // Default locale (Chinese Simplified) export const defaultLocale: Locale = "zh-CN"; +// Locale cookie shared by next-intl middleware and app-level routing helpers +export const localeCookieName = "NEXT_LOCALE"; + // Locale labels for language switcher UI export const localeLabels: Record = { "zh-CN": "简体中文", diff --git a/src/i18n/pathname.ts b/src/i18n/pathname.ts new file mode 100644 index 000000000..fb9007cfc --- /dev/null +++ b/src/i18n/pathname.ts @@ -0,0 +1,66 @@ +import { type Locale, locales } from "./config"; + +const DEFAULT_INTERNAL_PATH = "/dashboard"; +const PROTOCOL_LIKE_PATTERN = /^[a-zA-Z][a-zA-Z\d+.-]*:/; + +function isLocale(value: string): value is Locale { + return locales.some((locale) => locale === value); +} + +function normalizeFallback(fallback: string): string { + const candidate = fallback.trim(); + + if (!candidate || !candidate.startsWith("/") || candidate.startsWith("//")) { + return DEFAULT_INTERNAL_PATH; + } + + return candidate === "/" ? DEFAULT_INTERNAL_PATH : candidate; +} + +export function getLocaleFromValue(value: string | null | undefined): Locale | null { + if (!value) return null; + + const candidate = value.trim(); + return isLocale(candidate) ? candidate : null; +} + +export function normalizePathnameForLocaleNavigation( + pathname: string | null | undefined, + fallback = DEFAULT_INTERNAL_PATH +): string { + const safeFallback = normalizeFallback(fallback); + const candidate = pathname?.trim() ?? ""; + + if (!candidate || !candidate.startsWith("/") || candidate.startsWith("//")) { + return safeFallback; + } + + if (PROTOCOL_LIKE_PATTERN.test(candidate) || PROTOCOL_LIKE_PATTERN.test(candidate.slice(1))) { + return safeFallback; + } + + const suffixStart = candidate.search(/[?#]/); + let path = suffixStart === -1 ? candidate : candidate.slice(0, suffixStart); + const suffix = suffixStart === -1 ? "" : candidate.slice(suffixStart); + + while (true) { + const localeMatch = path.match(/^\/([^/]+)(?=\/|$)/); + const locale = localeMatch?.[1]; + + if (!locale || !isLocale(locale)) { + break; + } + + path = path.slice(locale.length + 1) || "/"; + } + + if (path === "/") { + return `${safeFallback}${suffix}`; + } + + if (!path.startsWith("/") || path.startsWith("//") || PROTOCOL_LIKE_PATTERN.test(path.slice(1))) { + return safeFallback; + } + + return `${path}${suffix}`; +} diff --git a/src/i18n/routing.ts b/src/i18n/routing.ts index e191568e8..d3ac3c320 100644 --- a/src/i18n/routing.ts +++ b/src/i18n/routing.ts @@ -5,7 +5,7 @@ import { createNavigation } from "next-intl/navigation"; import { defineRouting } from "next-intl/routing"; -import { defaultLocale, locales } from "./config"; +import { defaultLocale, localeCookieName, locales } from "./config"; // Define routing configuration for next-intl export const routing = defineRouting({ @@ -23,7 +23,7 @@ export const routing = defineRouting({ // Locale cookie configuration localeCookie: { - name: "NEXT_LOCALE", + name: localeCookieName, // Cookie expires in 1 year maxAge: 365 * 24 * 60 * 60, // Available across the entire site diff --git a/src/lib/public-status/config-publisher.ts b/src/lib/public-status/config-publisher.ts index 6e8567238..24333619b 100644 --- a/src/lib/public-status/config-publisher.ts +++ b/src/lib/public-status/config-publisher.ts @@ -69,7 +69,8 @@ export async function publishCurrentPublicStatusConfigProjection(input: { providerGroups.map((group) => ({ groupName: group.name, ...parsePublicStatusDescription(group.description), - })) + })), + { duplicateSlugStrategy: "suffix" } ); const latestPrices = await findLatestPricesByModels( enabledGroups.flatMap((group) => getPublicStatusModelKeys(group.publicModels)) diff --git a/src/lib/public-status/config.ts b/src/lib/public-status/config.ts index d4ffadd13..556385e35 100644 --- a/src/lib/public-status/config.ts +++ b/src/lib/public-status/config.ts @@ -53,6 +53,13 @@ interface LegacyPublicStatusGroupConfigInput { } const CONFIG_CACHE_TTL_MS = 60 * 1000; +const PUBLIC_STATUS_SLUG_MAX_LENGTH = 64; +const PUBLIC_STATUS_SLUG_SUFFIX_LENGTH = 6; +const PUBLIC_STATUS_SLUG_FALLBACK_PREFIX = "group"; + +interface CollectEnabledPublicStatusGroupsOptions { + duplicateSlugStrategy?: "throw" | "suffix"; +} let cachedConfiguredGroups: EnabledPublicStatusGroup[] | null = null; let cachedConfiguredGroupsAt = 0; @@ -139,20 +146,96 @@ export function getPublicStatusModelKeys(publicModels: PublicStatusModelConfig[] return publicModels.map((model) => model.modelKey); } -export function slugifyPublicGroup(input: string): string { +function createStablePublicGroupSlugSuffix(input: string): string { + let hash = 0x811c9dc5; + for (const character of input) { + hash ^= character.codePointAt(0) ?? 0; + hash = Math.imul(hash, 0x01000193) >>> 0; + } + + return hash + .toString(36) + .padStart(PUBLIC_STATUS_SLUG_SUFFIX_LENGTH, "0") + .slice(0, PUBLIC_STATUS_SLUG_SUFFIX_LENGTH); +} + +function appendStablePublicGroupSlugSuffix(base: string, suffix: string): string { + const prefixLength = Math.max(1, PUBLIC_STATUS_SLUG_MAX_LENGTH - suffix.length - 1); + const prefix = + base.slice(0, prefixLength).replace(/-+$/g, "") || PUBLIC_STATUS_SLUG_FALLBACK_PREFIX; + return `${prefix}-${suffix}`; +} + +function slugifyPublicGroupAscii(input: string): string { return input .trim() .toLowerCase() .replace(/[^a-z0-9]+/g, "-") .replace(/^-+|-+$/g, "") - .slice(0, 64); + .slice(0, PUBLIC_STATUS_SLUG_MAX_LENGTH); +} + +export function slugifyPublicGroup(input: string): string { + const trimmed = input.trim().toLowerCase(); + if (!trimmed) { + return ""; + } + + const asciiSlug = slugifyPublicGroupAscii(trimmed); + const hasNonAsciiCharacters = Array.from(trimmed).some( + (character) => (character.codePointAt(0) ?? 0) > 0x7f + ); + if (!hasNonAsciiCharacters) { + return asciiSlug; + } + + const suffix = createStablePublicGroupSlugSuffix(trimmed); + if (!asciiSlug) { + return `${PUBLIC_STATUS_SLUG_FALLBACK_PREFIX}-${suffix}`; + } + + return appendStablePublicGroupSlugSuffix(asciiSlug, suffix); } -function normalizePublicGroupSlug(groupName: string, publicGroupSlug?: string): string { +export function normalizePublicGroupSlug(groupName: string, publicGroupSlug?: string): string { const normalized = slugifyPublicGroup(publicGroupSlug?.trim() || groupName); return normalized || slugifyPublicGroup(groupName); } +function createAvailablePublicGroupSlug( + baseSlug: string, + groupName: string, + usedSlugs: Set +): string { + let counter = 1; + let candidate = baseSlug; + while (usedSlugs.has(candidate)) { + const suffixSource = counter === 1 ? groupName : `${groupName}-${counter}`; + candidate = appendStablePublicGroupSlugSuffix( + baseSlug || PUBLIC_STATUS_SLUG_FALLBACK_PREFIX, + createStablePublicGroupSlugSuffix(suffixSource) + ); + counter += 1; + } + + return candidate; +} + +export function createUniquePublicGroupSlug(groupName: string, usedSlugs: Set): string { + const baseSlug = normalizePublicGroupSlug(groupName); + const uniqueSlug = createAvailablePublicGroupSlug(baseSlug, groupName, usedSlugs); + usedSlugs.add(uniqueSlug); + return uniqueSlug; +} + +function createCollisionPublicGroupSlug( + baseSlug: string, + groupName: string, + usedSlugs: Set +): string { + return createAvailablePublicGroupSlug(baseSlug, groupName, usedSlugs); +} + export function parsePublicStatusDescription( description: string | null | undefined ): ParsedPublicStatusDescription { @@ -267,38 +350,54 @@ export function serializePublicStatusDescription( } export function collectEnabledPublicStatusGroups( - groups: PublicStatusConfiguredGroupInput[] + groups: PublicStatusConfiguredGroupInput[], + options: CollectEnabledPublicStatusGroupsOptions = {} ): EnabledPublicStatusGroup[] { const seenGroupNamesBySlug = new Map(); + const usedSlugs = new Set(); return groups - .map((group) => { + .flatMap((group) => { const publicModels = sanitizePublicModels(group.publicStatus?.publicModels); - const publicGroupSlug = normalizePublicGroupSlug( + if (publicModels.length === 0) { + return []; + } + + const normalizedPublicGroupSlug = normalizePublicGroupSlug( group.groupName, group.publicStatus?.publicGroupSlug ); + let publicGroupSlug = normalizedPublicGroupSlug; const existingGroupName = seenGroupNamesBySlug.get(publicGroupSlug); if (existingGroupName) { - throw new DuplicatePublicStatusGroupSlugError(publicGroupSlug, [ - existingGroupName, + if (options.duplicateSlugStrategy !== "suffix") { + throw new DuplicatePublicStatusGroupSlugError(publicGroupSlug, [ + existingGroupName, + group.groupName, + ]); + } + publicGroupSlug = createCollisionPublicGroupSlug( + normalizedPublicGroupSlug, group.groupName, - ]); + usedSlugs + ); } seenGroupNamesBySlug.set(publicGroupSlug, group.groupName); - - return { - groupName: group.groupName, - displayName: group.publicStatus?.displayName?.trim() || group.groupName, - publicGroupSlug, - explanatoryCopy: group.publicStatus?.explanatoryCopy?.trim() || null, - sortOrder: group.publicStatus?.sortOrder ?? 0, - publicModels, - }; + usedSlugs.add(publicGroupSlug); + + return [ + { + groupName: group.groupName, + displayName: group.publicStatus?.displayName?.trim() || group.groupName, + publicGroupSlug, + explanatoryCopy: group.publicStatus?.explanatoryCopy?.trim() || null, + sortOrder: group.publicStatus?.sortOrder ?? 0, + publicModels, + }, + ]; }) - .filter((group) => group.publicModels.length > 0) .sort( (left, right) => left.sortOrder - right.sortOrder || left.displayName.localeCompare(right.displayName) diff --git a/src/lib/rate-limit/service.ts b/src/lib/rate-limit/service.ts index 3de962b46..1e7c7f304 100644 --- a/src/lib/rate-limit/service.ts +++ b/src/lib/rate-limit/service.ts @@ -77,6 +77,7 @@ import { CHECK_AND_TRACK_SESSION, GET_COST_5H_ROLLING_WINDOW, GET_COST_DAILY_ROLLING_WINDOW, + RELEASE_PROVIDER_SESSION, TRACK_COST_5H_ROLLING_WINDOW, TRACK_COST_DAILY_ROLLING_WINDOW, } from "@/lib/redis/lua-scripts"; @@ -804,43 +805,52 @@ export class RateLimitService { * @param providerId - Provider ID * @param sessionId - Session ID * @param limit - 并发限制 - * @returns { allowed, count, tracked } - 是否允许、当前并发数、是否已追踪 + * @returns { allowed, count, tracked, referenced } - 是否允许、当前并发数、是否新追踪、是否获得释放引用 */ static async checkAndTrackProviderSession( providerId: number, sessionId: string, limit: number - ): Promise<{ allowed: boolean; count: number; tracked: boolean; reason?: string }> { + ): Promise<{ + allowed: boolean; + count: number; + tracked: boolean; + referenced: boolean; + reason?: string; + }> { if (limit <= 0) { - return { allowed: true, count: 0, tracked: false }; + return { allowed: true, count: 0, tracked: false, referenced: false }; } if (!RateLimitService.redis || RateLimitService.redis.status !== "ready") { logger.warn("[RateLimit] Redis not ready, Fail Open"); - return { allowed: true, count: 0, tracked: false }; + return { allowed: true, count: 0, tracked: false, referenced: false }; } try { const key = `provider:${providerId}:active_sessions`; + const refKey = `provider:${providerId}:active_session_refs`; const now = Date.now(); const result = (await RateLimitService.redis.eval( CHECK_AND_TRACK_SESSION, - 1, // KEYS count + 2, // KEYS count key, // KEYS[1] + refKey, // KEYS[2] sessionId, // ARGV[1] limit.toString(), // ARGV[2] now.toString(), // ARGV[3] SESSION_TTL_MS.toString() // ARGV[4] - )) as [number, number, number]; + )) as [number, number, number, number]; - const [allowed, count, tracked] = result; + const [allowed, count, tracked, referenced] = result; if (allowed === 0) { return { allowed: false, count, tracked: false, + referenced: false, reason: `供应商并发 Session 上限已达到(${count}/${limit})`, }; } @@ -849,10 +859,53 @@ export class RateLimitService { allowed: true, count, tracked: tracked === 1, // Lua 返回 1 表示新追踪,0 表示已存在 + referenced: referenced === 1, }; } catch (error) { logger.error("[RateLimit] Atomic check-and-track failed:", error); - return { allowed: true, count: 0, tracked: false }; // Fail Open + return { allowed: true, count: 0, tracked: false, referenced: false }; // Fail Open + } + } + + /** + * Release a provider-level active session when a selected provider is abandoned. + * + * Provider concurrency is tracked before forwarding so fallback decisions can be atomic. + * If the provider later fails, the session must be removed immediately instead of waiting + * for TTL cleanup; otherwise outage storms inflate provider active_sessions ZSETs. + */ + static async releaseProviderSession(providerId: number, sessionId: string): Promise { + if (!Number.isInteger(providerId) || providerId <= 0 || sessionId.trim().length === 0) { + return; + } + + const redis = RateLimitService.redis; + if (!redis || redis.status !== "ready") { + return; + } + + const key = `provider:${providerId}:active_sessions`; + const refKey = `provider:${providerId}:active_session_refs`; + try { + const [removed, remainingRefs] = (await redis.eval( + RELEASE_PROVIDER_SESSION, + 2, + key, + refKey, + sessionId + )) as [number, number]; + logger.debug("[RateLimit] Released provider session", { + providerId, + sessionId, + removed, + remainingRefs, + }); + } catch (error) { + logger.error("[RateLimit] Failed to release provider session", { + providerId, + sessionId, + error, + }); } } diff --git a/src/lib/redis/lua-scripts.ts b/src/lib/redis/lua-scripts.ts index 402b702d0..7513e4119 100644 --- a/src/lib/redis/lua-scripts.ts +++ b/src/lib/redis/lua-scripts.ts @@ -14,18 +14,21 @@ * 4. If not exceeded, track new session (atomic operation) * * KEYS[1]: provider:${providerId}:active_sessions + * KEYS[2]: provider:${providerId}:active_session_refs * ARGV[1]: sessionId * ARGV[2]: limit (concurrency limit) * ARGV[3]: now (current timestamp, ms) * ARGV[4]: ttlMs (optional, cleanup window in ms, default 300000) * * Return: - * - {1, count, 1} - allowed (new tracking), returns new count and tracked=1 - * - {1, count, 0} - allowed (already tracked), returns current count and tracked=0 - * - {0, count, 0} - rejected (limit reached), returns current count and tracked=0 + * - {1, count, 1, 1} - allowed (new tracking), returns new count, tracked=1, referenced=1 + * - {1, count, 0, 1} - allowed (already tracked with refs), returns count, tracked=0, referenced=1 + * - {1, count, 0, 0} - allowed (legacy tracked without refs), returns count, tracked=0, referenced=0 + * - {0, count, 0, 0} - rejected (limit reached), returns current count and tracked=0 */ export const CHECK_AND_TRACK_SESSION = ` local provider_key = KEYS[1] +local ref_key = KEYS[2] local session_id = ARGV[1] local limit = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) @@ -38,37 +41,86 @@ end -- 1. Cleanup expired sessions (TTL window ago) local cutoff = now - ttl +local expired_sessions = redis.call('ZRANGEBYSCORE', provider_key, '-inf', cutoff) redis.call('ZREMRANGEBYSCORE', provider_key, '-inf', cutoff) +for _, expired_session_id in ipairs(expired_sessions) do + redis.call('HDEL', ref_key, expired_session_id) +end -- 2. Check if session is already tracked local is_tracked = redis.call('ZSCORE', provider_key, session_id) +-- Direct cleanup paths may remove the ZSET member before this script sees the session again. +-- When the member is absent, discard any stale reference hash value before acquiring a new ref. +if not is_tracked then + redis.call('HDEL', ref_key, session_id) +end + +local existing_refs = tonumber(redis.call('HGET', ref_key, session_id) or '0') + -- 3. Get current concurrency count local current_count = redis.call('ZCARD', provider_key) -- 4. Check limit (exclude already tracked session) if limit > 0 and not is_tracked and current_count >= limit then - return {0, current_count, 0} -- {allowed=false, current_count, tracked=0} + return {0, current_count, 0, 0} -- {allowed=false, current_count, tracked=0, referenced=0} end -- 5. Track session (ZADD updates timestamp for existing members) redis.call('ZADD', provider_key, now, session_id) +local referenced = 0 +if not is_tracked or existing_refs > 0 then + redis.call('HINCRBY', ref_key, session_id, 1) + referenced = 1 +end + -- 6. Set TTL based on session TTL (at least 1h to cover active sessions) local ttl_seconds = math.floor(ttl / 1000) local expire_ttl = math.max(3600, ttl_seconds) redis.call('EXPIRE', provider_key, expire_ttl) +redis.call('EXPIRE', ref_key, expire_ttl) -- 7. Return success if is_tracked then -- Already tracked, count unchanged - return {1, current_count, 0} -- {allowed=true, count, tracked=0} + return {1, current_count, 0, referenced} -- {allowed=true, count, tracked=0, referenced} else -- New tracking, count +1 - return {1, current_count + 1, 1} -- {allowed=true, new_count, tracked=1} + return {1, current_count + 1, 1, referenced} -- {allowed=true, new_count, tracked=1, referenced=1} end `; +/** + * Release provider-level active session membership with per-session references. + * + * KEYS[1]: provider:${providerId}:active_sessions + * KEYS[2]: provider:${providerId}:active_session_refs + * ARGV[1]: sessionId + * + * Return: {removed, remainingRefs} + */ +export const RELEASE_PROVIDER_SESSION = ` +local provider_key = KEYS[1] +local ref_key = KEYS[2] +local session_id = ARGV[1] + +local current_refs = tonumber(redis.call('HGET', ref_key, session_id) or '0') +if current_refs <= 0 then + return {0, 0} +end + +local remaining_refs = current_refs - 1 +if remaining_refs > 0 then + redis.call('HSET', ref_key, session_id, remaining_refs) + return {0, remaining_refs} +end + +redis.call('HDEL', ref_key, session_id) +local removed = redis.call('ZREM', provider_key, session_id) +return {removed, remaining_refs} +`; + /** * Key/User 并发:原子性检查 + 追踪(修复竞态条件) * diff --git a/src/lib/session-manager.ts b/src/lib/session-manager.ts index bac9d11bd..1acafdd1b 100644 --- a/src/lib/session-manager.ts +++ b/src/lib/session-manager.ts @@ -2433,6 +2433,7 @@ export class SessionManager { if (providerId) { pipeline.zrem(`provider:${providerId}:active_sessions`, sessionId); + pipeline.hdel(`provider:${providerId}:active_session_refs`, sessionId); } if (keyId) { diff --git a/src/lib/session-tracker.ts b/src/lib/session-tracker.ts index 8690d2128..dd278a521 100644 --- a/src/lib/session-tracker.ts +++ b/src/lib/session-tracker.ts @@ -6,6 +6,13 @@ import { } from "@/lib/redis/active-session-keys"; import { getRedisClient } from "./redis"; +const PROVIDER_ACTIVE_SESSIONS_PATTERN = /^provider:(\d+):active_sessions$/; + +function getProviderActiveSessionRefsKey(activeSessionsKey: string): string | null { + const match = PROVIDER_ACTIVE_SESSIONS_PATTERN.exec(activeSessionsKey); + return match ? `provider:${match[1]}:active_session_refs` : null; +} + /** * Session 追踪器 - 统一管理活跃 Session 集合 * @@ -141,8 +148,11 @@ export class SessionTracker { pipeline.zadd(globalKey, now, sessionId); // 添加到 provider 级集合(ZSET) - pipeline.zadd(`provider:${providerId}:active_sessions`, now, sessionId); - pipeline.expire(`provider:${providerId}:active_sessions`, 3600); + const providerZSetKey = `provider:${providerId}:active_sessions`; + const providerRefKey = `provider:${providerId}:active_session_refs`; + pipeline.zadd(providerZSetKey, now, sessionId); + pipeline.expire(providerZSetKey, 3600); + pipeline.expire(providerRefKey, 3600); const results = await pipeline.exec(); @@ -190,25 +200,42 @@ export class SessionTracker { const pipeline = redis.pipeline(); const ttlSeconds = SessionTracker.SESSION_TTL_SECONDS; const providerZSetKey = `provider:${providerId}:active_sessions`; + const providerRefKey = `provider:${providerId}:active_session_refs`; const globalKey = getGlobalActiveSessionsKey(); const keyZSetKey = getKeyActiveSessionsKey(keyId); + let commandIndex = 0; + let cleanupExpiredSessionsResultIndex: number | null = null; pipeline.zadd(globalKey, now, sessionId); + commandIndex++; pipeline.zadd(keyZSetKey, now, sessionId); + commandIndex++; pipeline.zadd(providerZSetKey, now, sessionId); + commandIndex++; // Use dynamic TTL based on session TTL (at least 1h to cover active sessions) pipeline.expire(providerZSetKey, Math.max(3600, ttlSeconds)); + commandIndex++; + pipeline.expire(providerRefKey, Math.max(3600, ttlSeconds)); + commandIndex++; if (userId !== undefined) { pipeline.zadd(getUserActiveSessionsKey(userId), now, sessionId); + commandIndex++; } pipeline.expire(`session:${sessionId}:provider`, ttlSeconds); + commandIndex++; pipeline.expire(`session:${sessionId}:key`, ttlSeconds); + commandIndex++; pipeline.setex(`session:${sessionId}:last_seen`, ttlSeconds, now.toString()); + commandIndex++; if (Math.random() < SessionTracker.CLEANUP_PROBABILITY) { const cutoffMs = now - SessionTracker.SESSION_TTL_MS; + cleanupExpiredSessionsResultIndex = commandIndex; + pipeline.zrangebyscore(providerZSetKey, "-inf", cutoffMs); + commandIndex++; pipeline.zremrangebyscore(providerZSetKey, "-inf", cutoffMs); + commandIndex++; } const results = await pipeline.exec(); @@ -227,6 +254,18 @@ export class SessionTracker { } } + if (cleanupExpiredSessionsResultIndex !== null && results) { + const expiredResult = results[cleanupExpiredSessionsResultIndex]; + if (!expiredResult?.[0] && Array.isArray(expiredResult?.[1])) { + const expiredSessionIds = expiredResult[1].filter( + (value): value is string => typeof value === "string" && value.length > 0 + ); + if (expiredSessionIds.length > 0) { + await redis.hdel(providerRefKey, ...expiredSessionIds); + } + } + } + logger.trace("SessionTracker: Refreshed session", { sessionId }); } catch (error) { logger.error("SessionTracker: Failed to refresh session", { error }); @@ -397,6 +436,7 @@ export class SessionTracker { for (const providerId of providerIds) { const key = `provider:${providerId}:active_sessions`; // 清理过期 session + cleanupPipeline.zrangebyscore(key, "-inf", cutoffMs); cleanupPipeline.zremrangebyscore(key, "-inf", cutoffMs); // 获取剩余 session IDs cleanupPipeline.zrange(key, 0, -1); @@ -410,11 +450,22 @@ export class SessionTracker { // 收集需要验证的 session IDs const providerSessionMap = new Map(); const allSessionIds: string[] = []; + const expiredProviderSessions = new Map(); for (let i = 0; i < providerIds.length; i++) { const providerId = providerIds[i]; - // 每个 provider 有 2 个命令(zremrangebyscore + zrange) - const zrangeResult = cleanupResults[i * 2 + 1]; + // 每个 provider 有 3 个命令(zrangebyscore + zremrangebyscore + zrange) + const expiredResult = cleanupResults[i * 3]; + const zrangeResult = cleanupResults[i * 3 + 2]; + + if (expiredResult && expiredResult[0] === null && Array.isArray(expiredResult[1])) { + expiredProviderSessions.set( + providerId, + expiredResult[1].filter( + (value): value is string => typeof value === "string" && value.length > 0 + ) + ); + } if (zrangeResult && zrangeResult[0] === null) { const sessionIds = zrangeResult[1] as string[]; @@ -425,6 +476,21 @@ export class SessionTracker { } } + const refCleanupPipeline = redis.pipeline(); + let hasRefCleanup = false; + for (const [providerId, expiredSessionIds] of expiredProviderSessions) { + if (expiredSessionIds.length > 0) { + refCleanupPipeline.hdel( + `provider:${providerId}:active_session_refs`, + ...expiredSessionIds + ); + hasRefCleanup = true; + } + } + if (hasRefCleanup) { + await refCleanupPipeline.exec(); + } + // 如果没有 session,直接返回 if (allSessionIds.length === 0) { return result; @@ -533,7 +599,14 @@ export class SessionTracker { const cutoffMs = now - SessionTracker.SESSION_TTL_MS; // 1. 清理过期 session(5 分钟前) + const providerRefKey = getProviderActiveSessionRefsKey(key); + const expiredSessionIds = providerRefKey + ? await redis.zrangebyscore(key, "-inf", cutoffMs) + : []; await redis.zremrangebyscore(key, "-inf", cutoffMs); + if (providerRefKey && expiredSessionIds.length > 0) { + await redis.hdel(providerRefKey, ...expiredSessionIds); + } // 2. 获取剩余的 session ID const sessionIds = await redis.zrange(key, 0, -1); diff --git a/src/lib/utils/upstream-error-detection.ts b/src/lib/utils/upstream-error-detection.ts index f8231a4ae..49e2be3d0 100644 --- a/src/lib/utils/upstream-error-detection.ts +++ b/src/lib/utils/upstream-error-detection.ts @@ -95,77 +95,77 @@ const ERROR_STATUS_MATCHERS: Array<{ statusCode: number; matcherId: string; re: { statusCode: 429, matcherId: "rate_limit", - re: /(?:\bHTTP\/\d(?:\.\d)?\s+429\b|\b429\s+too\s+many\s+requests\b|\btoo\s+many\s+requests\b|\brate\s*limit(?:ed|ing)?\b|\bthrottl(?:e|ed|ing)\b|\bretry-after\b|\bRESOURCE_EXHAUSTED\b|\bRequestLimitExceeded\b|\bThrottling(?:Exception)?\b|\bError\s*1015\b|超出频率|请求过于频繁|限流|稍后重试)/iu, + re: /(?:\bHTTP\/\d(?:\.\d)?\s+429(?![\p{L}\p{N}_]|\.\d)|(? { + test("unprefixed protected routes use NEXT_LOCALE for the login redirect", async () => { + const response = await fetchRedirect("/dashboard", `${localeCookieName}=en`); + + expect(response.status).toBeGreaterThanOrEqual(300); + expect(response.status).toBeLessThan(400); + expect(response.headers.get("location")).toContain("/en/login?from=%2Fdashboard"); + }); + + test("repeated locale prefixes do not leak into the login from parameter", async () => { + const response = await fetchRedirect("/en/en/dashboard"); + const location = response.headers.get("location"); + + expect(response.status).toBeGreaterThanOrEqual(300); + expect(response.status).toBeLessThan(400); + expect(location).toContain("/en/login?from=%2Fdashboard"); + expect(location).not.toContain("from=%2Fen%2Fdashboard"); + }); +}); diff --git a/tests/integration/public-status/config-publish.test.ts b/tests/integration/public-status/config-publish.test.ts index 7ef8c9830..8c24a21f1 100644 --- a/tests/integration/public-status/config-publish.test.ts +++ b/tests/integration/public-status/config-publish.test.ts @@ -11,6 +11,7 @@ const mockPublishCurrentPublicStatusConfigProjection = vi.hoisted(() => vi.fn()) const mockSchedulePublicStatusRebuild = vi.hoisted(() => vi.fn()); const mockInvalidateSystemSettingsCache = vi.hoisted(() => vi.fn()); const mockRevalidatePath = vi.hoisted(() => vi.fn()); +const mockLoggerInfo = vi.hoisted(() => vi.fn()); const mockLoggerError = vi.hoisted(() => vi.fn()); const mockLoggerWarn = vi.hoisted(() => vi.fn()); const mockDbTransaction = vi.hoisted(() => @@ -64,6 +65,7 @@ vi.mock("next-intl/server", () => ({ vi.mock("@/lib/logger", () => ({ logger: { + info: mockLoggerInfo, error: mockLoggerError, warn: mockLoggerWarn, }, diff --git a/tests/unit/auth/login-redirect-safety.test.ts b/tests/unit/auth/login-redirect-safety.test.ts index 2496f441f..a88c02bb6 100644 --- a/tests/unit/auth/login-redirect-safety.test.ts +++ b/tests/unit/auth/login-redirect-safety.test.ts @@ -30,6 +30,22 @@ describe("sanitizeRedirectPath", () => { expect(sanitizeRedirectPath("/settings?tab=general")).toBe("/settings?tab=general"); }); + it("strips a leading locale before passing the path to locale-aware navigation", () => { + expect(sanitizeRedirectPath("/en/dashboard")).toBe("/dashboard"); + }); + + it("strips repeated leading locales before passing the path to locale-aware navigation", () => { + expect(sanitizeRedirectPath("/en/en/dashboard?tab=logs")).toBe("/dashboard?tab=logs"); + }); + + it("preserves hash fragments when stripping leading locales", () => { + expect(sanitizeRedirectPath("/en/dashboard#section")).toBe("/dashboard#section"); + }); + + it("uses dashboard fallback when a localized redirect points to the locale root", () => { + expect(sanitizeRedirectPath("/zh-CN")).toBe("/dashboard"); + }); + it("rejects protocol-like path payload", () => { expect(sanitizeRedirectPath("/https://evil.example/path")).toBe("/dashboard"); }); diff --git a/tests/unit/i18n/locale-layout-request-locale.test.tsx b/tests/unit/i18n/locale-layout-request-locale.test.tsx new file mode 100644 index 000000000..0f0d4e73c --- /dev/null +++ b/tests/unit/i18n/locale-layout-request-locale.test.tsx @@ -0,0 +1,106 @@ +import type { ReactElement, ReactNode } from "react"; +import { describe, expect, test, vi } from "vitest"; + +const nextIntlMocks = vi.hoisted(() => ({ + provider: vi.fn(({ children }: { children: ReactNode }) => children), + getMessages: vi.fn(async () => ({ dashboard: { nav: { dashboard: "Dashboard" } } })), + setRequestLocale: vi.fn(), +})); + +vi.mock("next-intl", () => ({ + NextIntlClientProvider: nextIntlMocks.provider, +})); + +vi.mock("next-intl/server", () => ({ + getMessages: nextIntlMocks.getMessages, + setRequestLocale: nextIntlMocks.setRequestLocale, +})); + +vi.mock("next/headers", () => ({ + headers: vi.fn(async () => ({ + get: vi.fn(() => null), + })), +})); + +vi.mock("next/navigation", () => ({ + notFound: vi.fn(() => { + throw new Error("notFound"); + }), +})); + +vi.mock("@/components/customs/footer", () => ({ + Footer: () => null, +})); + +vi.mock("@/components/ui/sonner", () => ({ + Toaster: () => null, +})); + +vi.mock("@/lib/layout-site-metadata", () => ({ + resolveDefaultLayoutTimeZone: vi.fn(async () => "UTC"), + resolveDefaultSiteMetadataSource: vi.fn(async () => null), +})); + +vi.mock("@/lib/public-status/layout-metadata", () => ({ + resolveLayoutTimeZone: vi.fn(async () => "UTC"), + resolveSiteMetadataSource: vi.fn(async () => null), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + error: vi.fn(), + }, +})); + +vi.mock("@/app/providers", () => ({ + AppProviders: ({ children }: { children: ReactNode }) => children, +})); + +vi.mock("@/app/globals.css", () => ({})); + +function findProviderElement(node: ReactNode): ReactElement | null { + if (!node || typeof node !== "object") { + return null; + } + + if (!("props" in node)) { + return null; + } + + const element = node as ReactElement<{ children?: ReactNode }>; + + if (element.type === nextIntlMocks.provider) { + return element; + } + + const children = element.props.children; + if (Array.isArray(children)) { + for (const child of children) { + const match = findProviderElement(child); + if (match) return match; + } + return null; + } + + return findProviderElement(children); +} + +describe("locale root layout", () => { + test("pins next-intl request locale and provider locale to the route segment", async () => { + const { default: RootLayout } = await import("@/app/[locale]/layout"); + + const tree = await RootLayout({ + children:
, + params: Promise.resolve({ locale: "en" }), + }); + + expect.soft(nextIntlMocks.setRequestLocale).toHaveBeenCalledWith("en"); + expect(nextIntlMocks.getMessages).toHaveBeenCalledWith({ locale: "en" }); + + const provider = findProviderElement(tree); + expect.soft(provider?.props).toMatchObject({ + locale: "en", + timeZone: "UTC", + }); + }); +}); diff --git a/tests/unit/i18n/locale-pathname.test.ts b/tests/unit/i18n/locale-pathname.test.ts new file mode 100644 index 000000000..4bf9ddb51 --- /dev/null +++ b/tests/unit/i18n/locale-pathname.test.ts @@ -0,0 +1,36 @@ +import { describe, expect, it } from "vitest"; +import { normalizePathnameForLocaleNavigation } from "@/i18n/pathname"; + +describe("normalizePathnameForLocaleNavigation", () => { + it("keeps an already internal pathname unchanged", () => { + expect(normalizePathnameForLocaleNavigation("/dashboard/providers")).toBe( + "/dashboard/providers" + ); + }); + + it("strips a single leading locale", () => { + expect(normalizePathnameForLocaleNavigation("/en/dashboard")).toBe("/dashboard"); + }); + + it("strips repeated leading locales that would otherwise create /en/en", () => { + expect(normalizePathnameForLocaleNavigation("/en/en/dashboard")).toBe("/dashboard"); + }); + + it("preserves query string and hash after stripping locales", () => { + expect(normalizePathnameForLocaleNavigation("/zh-CN/en/dashboard?tab=logs#row-1")).toBe( + "/dashboard?tab=logs#row-1" + ); + }); + + it("uses the fallback for locale roots", () => { + expect(normalizePathnameForLocaleNavigation("/ja")).toBe("/dashboard"); + expect(normalizePathnameForLocaleNavigation("/ru/")).toBe("/dashboard"); + }); + + it("rejects non-internal pathnames", () => { + expect(normalizePathnameForLocaleNavigation("https://example.com/dashboard")).toBe( + "/dashboard" + ); + expect(normalizePathnameForLocaleNavigation("//example.com/dashboard")).toBe("/dashboard"); + }); +}); diff --git a/tests/unit/lib/rate-limit/provider-session-release.test.ts b/tests/unit/lib/rate-limit/provider-session-release.test.ts new file mode 100644 index 000000000..217083be0 --- /dev/null +++ b/tests/unit/lib/rate-limit/provider-session-release.test.ts @@ -0,0 +1,97 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +type RedisClientMock = { + status: string; + eval: (...args: unknown[]) => Promise<[number, number]>; +}; + +let redisClientRef: RedisClientMock | null; +let evalMock: ReturnType Promise<[number, number]>>>; + +vi.mock("server-only", () => ({})); + +vi.mock("@/lib/redis", () => ({ + getRedisClient: () => redisClientRef, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + }, +})); + +describe("RateLimitService.releaseProviderSession", () => { + beforeEach(() => { + vi.clearAllMocks(); + evalMock = vi.fn(async () => [1, 0]); + redisClientRef = { + status: "ready", + eval: evalMock, + }; + }); + + it("应通过引用计数脚本释放失败请求的 provider session", async () => { + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + expect(evalMock).toHaveBeenCalledTimes(1); + expect(evalMock).toHaveBeenCalledWith( + expect.any(String), + 2, + "provider:42:active_sessions", + "provider:42:active_session_refs", + "sess_failed" + ); + }); + + it("仍有并发引用时不应直接 ZREM active session", async () => { + evalMock.mockResolvedValueOnce([0, 1]); + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + expect(evalMock).toHaveBeenCalledTimes(1); + }); + + it("Redis 不可用或未 ready 时应静默跳过", async () => { + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + redisClientRef = null; + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + redisClientRef = { status: "connecting", eval: evalMock }; + await RateLimitService.releaseProviderSession(42, "sess_failed"); + + expect(evalMock).not.toHaveBeenCalled(); + }); + + it("非法 providerId 或空 sessionId 不应触发 Redis 命令", async () => { + const { RateLimitService } = await import("@/lib/rate-limit/service"); + + await RateLimitService.releaseProviderSession(0, "sess_failed"); + await RateLimitService.releaseProviderSession(-1, "sess_failed"); + await RateLimitService.releaseProviderSession(42, " "); + + expect(evalMock).not.toHaveBeenCalled(); + }); + + it("释放失败时应记录日志但不向请求链路抛错", async () => { + const error = new Error("redis down"); + evalMock.mockRejectedValueOnce(error); + const { RateLimitService } = await import("@/lib/rate-limit/service"); + const { logger } = await import("@/lib/logger"); + + await expect( + RateLimitService.releaseProviderSession(42, "sess_failed") + ).resolves.toBeUndefined(); + + expect(logger.error).toHaveBeenCalledWith("[RateLimit] Failed to release provider session", { + providerId: 42, + sessionId: "sess_failed", + error, + }); + }); +}); diff --git a/tests/unit/lib/rate-limit/service-extra.test.ts b/tests/unit/lib/rate-limit/service-extra.test.ts index 235a22780..46d411f77 100644 --- a/tests/unit/lib/rate-limit/service-extra.test.ts +++ b/tests/unit/lib/rate-limit/service-extra.test.ts @@ -145,7 +145,7 @@ describe("RateLimitService - other quota paths", () => { const { RateLimitService } = await import("@/lib/rate-limit"); const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 0); - expect(result).toEqual({ allowed: true, count: 0, tracked: false }); + expect(result).toEqual({ allowed: true, count: 0, tracked: false, referenced: false }); }); it("checkAndTrackProviderSession:Redis 非 ready 时应 Fail Open", async () => { @@ -153,13 +153,13 @@ describe("RateLimitService - other quota paths", () => { redisClientRef.status = "end"; const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); - expect(result).toEqual({ allowed: true, count: 0, tracked: false }); + expect(result).toEqual({ allowed: true, count: 0, tracked: false, referenced: false }); }); it("checkAndTrackProviderSession:达到上限时应返回 not allowed", async () => { const { RateLimitService } = await import("@/lib/rate-limit"); - redisClientRef.eval.mockResolvedValueOnce([0, 2, 0]); + redisClientRef.eval.mockResolvedValueOnce([0, 2, 0, 0]); const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); expect(result.allowed).toBe(false); expect(result.reason).toContain("供应商并发 Session 上限已达到(2/2)"); @@ -168,27 +168,37 @@ describe("RateLimitService - other quota paths", () => { it("checkAndTrackProviderSession:未达到上限时应返回 allowed 且可标记 tracked", async () => { const { RateLimitService } = await import("@/lib/rate-limit"); - redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]); + redisClientRef.eval.mockResolvedValueOnce([1, 1, 1, 1]); const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); - expect(result).toEqual({ allowed: true, count: 1, tracked: true }); + expect(result).toEqual({ allowed: true, count: 1, tracked: true, referenced: true }); + }); + + it("checkAndTrackProviderSession:旧 membership 无引用计数时不应返回 release 引用", async () => { + const { RateLimitService } = await import("@/lib/rate-limit"); + + redisClientRef.eval.mockResolvedValueOnce([1, 1, 0, 0]); + const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); + expect(result).toEqual({ allowed: true, count: 1, tracked: false, referenced: false }); }); it("checkAndTrackProviderSession: should pass SESSION_TTL_MS as ARGV[4] to Lua script", async () => { const { RateLimitService } = await import("@/lib/rate-limit"); - redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]); + redisClientRef.eval.mockResolvedValueOnce([1, 1, 1, 1]); await RateLimitService.checkAndTrackProviderSession(9, "sess", 2); // Verify eval was called with the correct args including ARGV[4] = SESSION_TTL_MS expect(redisClientRef.eval).toHaveBeenCalledTimes(1); const evalCall = redisClientRef.eval.mock.calls[0]; - // evalCall: [script, numkeys, key, sessionId, limit, now, ttlMs] - // Indices: 0 1 2 3 4 5 6 - expect(evalCall.length).toBe(7); // script + 1 key + 5 ARGV - - // ARGV[4] (index 6) should be SESSION_TTL_MS derived from env (default 300s = 300000ms) - const ttlMsArg = evalCall[6]; + // evalCall: [script, numkeys, activeKey, refKey, sessionId, limit, now, ttlMs] + // Indices: 0 1 2 3 4 5 6 7 + expect(evalCall.length).toBe(8); // script + 2 keys + 4 ARGV + expect(evalCall[2]).toBe("provider:9:active_sessions"); + expect(evalCall[3]).toBe("provider:9:active_session_refs"); + + // ARGV[4] (index 7) should be SESSION_TTL_MS derived from env (default 300s = 300000ms) + const ttlMsArg = evalCall[7]; expect(ttlMsArg).toBe("300000"); }); diff --git a/tests/unit/lib/session-manager-terminate-session.test.ts b/tests/unit/lib/session-manager-terminate-session.test.ts index f4de279ac..f61889538 100644 --- a/tests/unit/lib/session-manager-terminate-session.test.ts +++ b/tests/unit/lib/session-manager-terminate-session.test.ts @@ -27,6 +27,7 @@ describe("SessionManager.terminateSession", () => { pipelineRef = { del: vi.fn(() => pipelineRef), zrem: vi.fn(() => pipelineRef), + hdel: vi.fn(() => pipelineRef), exec: vi.fn(async () => [[null, 1]]), }; diff --git a/tests/unit/lib/session-tracker-cleanup.test.ts b/tests/unit/lib/session-tracker-cleanup.test.ts index 554c6723e..e13f718ab 100644 --- a/tests/unit/lib/session-tracker-cleanup.test.ts +++ b/tests/unit/lib/session-tracker-cleanup.test.ts @@ -25,10 +25,18 @@ const makePipeline = () => { pipelineCalls.push(["zremrangebyscore", ...args]); return pipeline; }), + zrangebyscore: vi.fn((...args: unknown[]) => { + pipelineCalls.push(["zrangebyscore", ...args]); + return pipeline; + }), zrange: vi.fn((...args: unknown[]) => { pipelineCalls.push(["zrange", ...args]); return pipeline; }), + hdel: vi.fn((...args: unknown[]) => { + pipelineCalls.push(["hdel", ...args]); + return pipeline; + }), exists: vi.fn((...args: unknown[]) => { pipelineCalls.push(["exists", ...args]); return pipeline; @@ -72,6 +80,8 @@ describe("SessionTracker - TTL and cleanup", () => { exists: vi.fn(async () => 1), type: vi.fn(async () => "zset"), del: vi.fn(async () => 1), + hdel: vi.fn(async () => 0), + zrangebyscore: vi.fn(async () => []), zremrangebyscore: vi.fn(async () => 0), zrange: vi.fn(async () => []), pipeline: vi.fn(() => makePipeline()), diff --git a/tests/unit/lib/upstream-error-detection-status.test.ts b/tests/unit/lib/upstream-error-detection-status.test.ts new file mode 100644 index 000000000..e34cf75a5 --- /dev/null +++ b/tests/unit/lib/upstream-error-detection-status.test.ts @@ -0,0 +1,130 @@ +import { describe, expect, it } from "vitest"; +import { inferUpstreamErrorStatusCodeFromText } from "@/lib/utils/upstream-error-detection"; + +const httpStatusCases = [ + { statusCode: 429, matcherId: "rate_limit" }, + { statusCode: 402, matcherId: "payment_required" }, + { statusCode: 401, matcherId: "unauthorized" }, + { statusCode: 403, matcherId: "forbidden" }, + { statusCode: 404, matcherId: "not_found" }, + { statusCode: 413, matcherId: "payload_too_large" }, + { statusCode: 415, matcherId: "unsupported_media_type" }, + { statusCode: 409, matcherId: "conflict" }, + { statusCode: 422, matcherId: "unprocessable_entity" }, + { statusCode: 408, matcherId: "request_timeout" }, + { statusCode: 451, matcherId: "legal_restriction" }, + { statusCode: 503, matcherId: "service_unavailable" }, + { statusCode: 504, matcherId: "gateway_timeout" }, + { statusCode: 500, matcherId: "internal_server_error" }, + { statusCode: 400, matcherId: "bad_request" }, +] as const; + +const cloudflareErrorCases = [ + { code: 1015, statusCode: 429, matcherId: "rate_limit" }, + { code: 1020, statusCode: 403, matcherId: "forbidden" }, + { code: 521, statusCode: 503, matcherId: "service_unavailable" }, + { code: 522, statusCode: 504, matcherId: "gateway_timeout" }, + { code: 524, statusCode: 504, matcherId: "gateway_timeout" }, +] as const; + +describe("inferUpstreamErrorStatusCodeFromText numeric boundaries", () => { + it.each(httpStatusCases)("keeps matching a standalone HTTP $statusCode status token", ({ + statusCode, + matcherId, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`HTTP/1.1 ${statusCode}`)).toEqual({ + statusCode, + matcherId, + }); + }); + + it.each( + httpStatusCases + )("does not treat HTTP $statusCode followed by a decimal fraction as a status token", ({ + statusCode, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`HTTP/1.1 ${statusCode}.12`)).toBeNull(); + }); + + it.each( + httpStatusCases + )("does not treat HTTP $statusCode embedded in a longer number as a status token", ({ + statusCode, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`HTTP/1.1 ${statusCode}12`)).toBeNull(); + }); + + it.each( + httpStatusCases + )("does not treat HTTP $statusCode followed by a letter as a status token", ({ statusCode }) => { + expect(inferUpstreamErrorStatusCodeFromText(`HTTP/1.1 ${statusCode}abc`)).toBeNull(); + }); + + it.each(httpStatusCases)("keeps matching HTTP $statusCode followed by sentence punctuation", ({ + statusCode, + matcherId, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`HTTP/1.1 ${statusCode}.`)).toEqual({ + statusCode, + matcherId, + }); + }); + + it.each(cloudflareErrorCases)("keeps matching a standalone Cloudflare Error $code token", ({ + code, + statusCode, + matcherId, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`Error ${code}`)).toEqual({ + statusCode, + matcherId, + }); + }); + + it.each( + cloudflareErrorCases + )("does not treat Cloudflare Error $code followed by a decimal fraction as a code token", ({ + code, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`Error ${code}.7`)).toBeNull(); + }); + + it.each( + cloudflareErrorCases + )("does not treat Cloudflare Error $code embedded in a longer number as a code token", ({ + code, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`Error ${code}7`)).toBeNull(); + }); + + it.each( + cloudflareErrorCases + )("does not treat Cloudflare Error $code followed by a letter as a code token", ({ code }) => { + expect(inferUpstreamErrorStatusCodeFromText(`Error ${code}x`)).toBeNull(); + }); + + it.each( + cloudflareErrorCases + )("keeps matching Cloudflare Error $code followed by sentence punctuation", ({ + code, + statusCode, + matcherId, + }) => { + expect(inferUpstreamErrorStatusCodeFromText(`Error ${code}.`)).toEqual({ + statusCode, + matcherId, + }); + }); + + it("does not infer service_unavailable from an AWS request id containing 503", () => { + const text = "request id: 202604250550399959"; + + expect(inferUpstreamErrorStatusCodeFromText(text)).toBeNull(); + }); + + it("does not infer any status from a decimal price sample", () => { + const text = "需要预扣费额度:¥0.352942"; + + expect(inferUpstreamErrorStatusCodeFromText(text)).toBeNull(); + }); +}); diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index 942234e60..0f41408b9 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -20,6 +20,13 @@ const mocks = vi.hoisted(() => ({ recordEndpointFailure: vi.fn(async () => {}), isVendorTypeCircuitOpen: vi.fn(async () => false), recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}), + checkAndTrackProviderSession: vi.fn(async () => ({ + allowed: true, + count: 1, + tracked: true, + referenced: true, + })), + releaseProviderSession: vi.fn(async (_providerId: number, _sessionId: string) => {}), categorizeErrorAsync: vi.fn(async () => 0), getErrorDetectionResultAsync: vi.fn(async () => ({ matched: false })), getCachedSystemSettings: vi.fn(async () => ({ @@ -73,6 +80,13 @@ vi.mock("@/lib/vendor-type-circuit-breaker", () => ({ recordVendorTypeAllEndpointsTimeout: mocks.recordVendorTypeAllEndpointsTimeout, })); +vi.mock("@/lib/rate-limit/service", () => ({ + RateLimitService: { + checkAndTrackProviderSession: mocks.checkAndTrackProviderSession, + releaseProviderSession: mocks.releaseProviderSession, + }, +})); + vi.mock("@/lib/session-manager", () => ({ SessionManager: { updateSessionBindingSmart: mocks.updateSessionBindingSmart, @@ -222,6 +236,11 @@ function createSession(clientAbortSignal: AbortSignal | null = null): ProxySessi return session as ProxySession; } +function setProviderWithSessionRef(session: ProxySession, provider: Provider): void { + session.setProvider(provider); + session.recordProviderSessionRef(provider.id); +} + function createStreamingResponse(params: { label: string; firstChunkDelayMs: number; @@ -310,6 +329,12 @@ function withThinkingBlocks(session: ProxySession): void { describe("ProxyForwarder - first-byte hedge scheduling", () => { beforeEach(() => { vi.clearAllMocks(); + mocks.checkAndTrackProviderSession.mockResolvedValue({ + allowed: true, + count: 1, + tracked: true, + referenced: true, + }); }); test("shadow session redirect should not overwrite initial provider redirect and winner should keep its own redirect", () => { @@ -506,7 +531,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const session = createSession(); session.request.model = requestedModel; session.request.message.model = requestedModel; - session.setProvider(fireworks); + setProviderWithSessionRef(session, fireworks); session.addProviderToChain(fireworks, { reason: "initial_selection" }); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(minimax); @@ -575,6 +600,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { redirectedModel: fireworksRedirect, billingModel: requestedModel, }); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(fireworks.id, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -817,7 +843,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -874,6 +900,100 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { true, null ); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); + } finally { + vi.useRealTimers(); + } + }); + + test("hedge skips provider when concurrent session acquire is rejected", async () => { + vi.useFakeTimers(); + + try { + const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const provider2 = createProvider({ + id: 2, + name: "p2", + firstByteTimeoutStreamingMs: 100, + limitConcurrentSessions: 1, + }); + const provider3 = createProvider({ id: 3, name: "p3", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(); + setProviderWithSessionRef(session, provider1); + + mocks.pickRandomProviderWithExclusion + .mockResolvedValueOnce(provider2) + .mockResolvedValueOnce(provider3); + mocks.checkAndTrackProviderSession + .mockResolvedValueOnce({ + allowed: false, + count: 1, + tracked: false, + referenced: false, + reason: "供应商并发 Session 上限已达到(1/1)", + }) + .mockResolvedValueOnce({ allowed: true, count: 1, tracked: true, referenced: true }); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + + const controller1 = new AbortController(); + const controller3 = new AbortController(); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller1; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p1", + firstChunkDelayMs: 220, + controller: controller1, + }); + }); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller3; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p3", + firstChunkDelayMs: 40, + controller: controller3, + }); + }); + + const responsePromise = ProxyForwarder.send(session); + + await vi.advanceTimersByTimeAsync(100); + expect(doForward).toHaveBeenCalledTimes(2); + expect(doForward).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ id: 2 }), + expect.anything(), + expect.anything(), + expect.anything(), + expect.anything() + ); + + await vi.advanceTimersByTimeAsync(50); + const response = await responsePromise; + + expect(await response.text()).toContain('"provider":"p3"'); + expect(session.provider?.id).toBe(3); + expect(mocks.checkAndTrackProviderSession).toHaveBeenNthCalledWith(1, 2, "sess-hedge", 1); + expect(mocks.checkAndTrackProviderSession).toHaveBeenNthCalledWith(2, 3, "sess-hedge", 0); + expect(session.getProviderChain()).toEqual( + expect.arrayContaining([ + expect.objectContaining({ id: 2, reason: "concurrent_limit_failed" }), + expect.objectContaining({ id: 3, reason: "hedge_winner" }), + ]) + ); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); + expect(mocks.releaseProviderSession).not.toHaveBeenCalledWith(2, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -887,7 +1007,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); session.setHighConcurrencyModeEnabled(true); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -953,7 +1073,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { firstByteTimeoutStreamingMs: 100, }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -1022,7 +1142,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -1071,6 +1191,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { expect(mocks.recordFailure).not.toHaveBeenCalled(); expect(mocks.recordSuccess).not.toHaveBeenCalled(); expect(session.provider?.id).toBe(1); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(2, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -1084,7 +1205,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); const provider3 = createProvider({ id: 3, name: "p3", firstByteTimeoutStreamingMs: 100 }); const session = createSession(); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion .mockResolvedValueOnce(provider2) @@ -1148,6 +1269,9 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { expect(mocks.recordFailure).not.toHaveBeenCalled(); expect(mocks.recordSuccess).not.toHaveBeenCalled(); expect(session.provider?.id).toBe(3); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(2, "sess-hedge"); + expect(mocks.releaseProviderSession).not.toHaveBeenCalledWith(3, "sess-hedge"); } finally { vi.useRealTimers(); } @@ -1182,7 +1306,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { const session = createSession(clientAbortController.signal); session.request.model = requestedModel; session.request.message.model = requestedModel; - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); @@ -1322,7 +1446,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { }); const session = createSession(); session.requestUrl = new URL("https://example.com/v1/messages"); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); mocks.getPreferredProviderEndpoints.mockRejectedValueOnce(new Error("Redis connection lost")); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(null); @@ -1745,7 +1869,7 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { }); const session = createSession(); session.requestUrl = new URL("https://example.com/v1/messages"); - session.setProvider(provider1); + setProviderWithSessionRef(session, provider1); // Provider 1's strict endpoint resolution will fail mocks.getPreferredProviderEndpoints.mockRejectedValueOnce( @@ -1790,8 +1914,64 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { ); expect(winnerEntry).toBeDefined(); expect(winnerEntry!.reason).toBe("request_success"); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(1, "sess-hedge"); } finally { vi.useRealTimers(); } }); + + test("removes streaming hedge client abort listener after winner response is returned", async () => { + const clientAbortController = new AbortController(); + const addSpy = vi.spyOn(clientAbortController.signal, "addEventListener"); + const removeSpy = vi.spyOn(clientAbortController.signal, "removeEventListener"); + const provider = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(clientAbortController.signal); + setProviderWithSessionRef(session, provider); + session.forwardedRequestBody = "x".repeat(512 * 1024); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + const upstreamController = new AbortController(); + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = upstreamController; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p1", + firstChunkDelayMs: 0, + controller: upstreamController, + }); + }); + + const response = await ProxyForwarder.send(session); + expect(await response.text()).toContain('"provider":"p1"'); + + const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort"); + expect(abortAddCalls).toHaveLength(1); + expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]); + }); + + test("pre-aborted client signal should settle hedge without launching upstream attempt", async () => { + const clientAbortController = new AbortController(); + clientAbortController.abort(new Error("client_cancelled")); + const addSpy = vi.spyOn(clientAbortController.signal, "addEventListener"); + const provider = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(clientAbortController.signal); + setProviderWithSessionRef(session, provider); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + + await expect(ProxyForwarder.send(session)).rejects.toMatchObject({ statusCode: 499 }); + expect(doForward).not.toHaveBeenCalled(); + expect(addSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0); + }); }); diff --git a/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts new file mode 100644 index 000000000..98bf78997 --- /dev/null +++ b/tests/unit/proxy/proxy-forwarder-provider-session-release.test.ts @@ -0,0 +1,112 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { ProxySession } from "@/app/v1/_lib/proxy/session"; + +const mocks = vi.hoisted(() => ({ + releaseProviderSession: vi.fn(async (_providerId: number, _sessionId: string) => {}), +})); + +vi.mock("@/lib/rate-limit/service", () => ({ + RateLimitService: { + releaseProviderSession: mocks.releaseProviderSession, + }, +})); + +vi.mock("@/lib/rate-limit", () => ({ + RateLimitService: { + releaseProviderSession: mocks.releaseProviderSession, + }, +})); + +describe("ProxyForwarder provider failure session release", () => { + beforeEach(() => { + mocks.releaseProviderSession.mockClear(); + }); + + it("标记供应商失败时仅释放本请求已获取的 provider session ref", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => void; + }; + const consumeProviderSessionRef = vi.fn(() => true); + const session = { + sessionId: "sess_failed", + consumeProviderSessionRef, + } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(consumeProviderSessionRef).toHaveBeenCalledWith(42); + expect(mocks.releaseProviderSession).toHaveBeenCalledWith(42, "sess_failed"); + }); + + it("未获取 provider session ref 的 fallback/hedge provider 不应释放 Redis membership", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => void; + }; + const consumeProviderSessionRef = vi.fn(() => false); + const session = { + sessionId: "sess_failed", + consumeProviderSessionRef, + } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(consumeProviderSessionRef).toHaveBeenCalledWith(42); + expect(mocks.releaseProviderSession).not.toHaveBeenCalled(); + }); + + it("重复标记同一供应商时只释放一次,避免 hedge 路径重复 ZREM", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => void; + }; + const consumeProviderSessionRef = vi.fn(() => true); + const session = { + sessionId: "sess_failed", + consumeProviderSessionRef, + } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(consumeProviderSessionRef).toHaveBeenCalledTimes(1); + expect(mocks.releaseProviderSession).toHaveBeenCalledTimes(1); + }); + + it("没有 sessionId 时只记录失败供应商,不触发 Redis 释放", async () => { + const { ProxyForwarder } = await import("@/app/v1/_lib/proxy/forwarder"); + const forwarderInternals = ProxyForwarder as unknown as { + markProviderFailed: ( + session: ProxySession, + failedProviderIds: number[], + providerId: number + ) => void; + }; + const session = { sessionId: null } as unknown as ProxySession; + const failedProviderIds: number[] = []; + + forwarderInternals.markProviderFailed(session, failedProviderIds, 42); + + expect(failedProviderIds).toEqual([42]); + expect(mocks.releaseProviderSession).not.toHaveBeenCalled(); + }); +}); diff --git a/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts b/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts new file mode 100644 index 000000000..c5dda43b9 --- /dev/null +++ b/tests/unit/proxy/response-handler-abort-listener-cleanup.test.ts @@ -0,0 +1,283 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy"; +import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler"; +import type { ProxySession } from "@/app/v1/_lib/proxy/session"; +import type { Provider } from "@/types/provider"; + +const testState = vi.hoisted(() => ({ + asyncTasks: [] as Promise[], + cancelTask: vi.fn(), + cleanupTask: vi.fn(), +})); + +vi.mock("@/app/v1/_lib/proxy/response-fixer", () => ({ + ResponseFixer: { + process: async (_session: unknown, response: Response) => response, + }, +})); + +vi.mock("@/lib/async-task-manager", () => ({ + AsyncTaskManager: { + register: (_taskId: string, promise: Promise) => { + testState.asyncTasks.push(promise); + return new AbortController(); + }, + cleanup: testState.cleanupTask, + cancel: testState.cancelTask, + }, +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + trace: vi.fn(), + }, +})); + +vi.mock("@/lib/price-sync/cloud-price-updater", () => ({ + requestCloudPriceTableSync: vi.fn(), +})); + +vi.mock("@/lib/proxy-status-tracker", () => ({ + ProxyStatusTracker: { + getInstance: () => ({ + endRequest: vi.fn(), + }), + }, +})); + +vi.mock("@/lib/rate-limit", () => ({ + RateLimitService: { + trackCost: vi.fn(), + trackUserDailyCost: vi.fn(), + decrementLeaseBudget: vi.fn(), + }, +})); + +vi.mock("@/lib/redis/live-chain-store", () => ({ + deleteLiveChain: vi.fn(), +})); + +vi.mock("@/lib/session-manager", () => ({ + SessionManager: { + clearSessionProvider: vi.fn(), + storeSessionResponse: vi.fn(), + updateSessionUsage: vi.fn(), + storeSessionRequestPhaseSnapshot: vi.fn(), + storeSessionResponsePhaseSnapshot: vi.fn(), + storeSessionUpstreamRequestMeta: vi.fn(), + storeSessionSpecialSettings: vi.fn(), + storeSessionRequestHeaders: vi.fn(), + storeSessionResponseHeaders: vi.fn(), + storeSessionUpstreamResponseMeta: vi.fn(), + }, +})); + +vi.mock("@/lib/session-tracker", () => ({ + SessionTracker: { + refreshSession: vi.fn(), + }, +})); + +vi.mock("@/lib/circuit-breaker", () => ({ + recordFailure: vi.fn(), +})); + +vi.mock("@/lib/endpoint-circuit-breaker", () => ({ + recordEndpointFailure: vi.fn(), + recordEndpointSuccess: vi.fn(), + resetEndpointCircuit: vi.fn(), +})); + +vi.mock("@/repository/message", () => ({ + updateMessageRequestCostWithBreakdown: vi.fn(), + updateMessageRequestDetails: vi.fn(), + updateMessageRequestDuration: vi.fn(), +})); + +async function drainAsyncTasks(): Promise { + while (testState.asyncTasks.length > 0) { + const tasks = testState.asyncTasks.splice(0); + await Promise.allSettled(tasks); + await new Promise((resolve) => setTimeout(resolve, 0)); + } +} + +function makeProvider(overrides: Partial = {}): Provider { + return { + id: 99, + name: "test-provider", + providerType: "openai", + baseUrl: "https://api.test.invalid", + priority: 1, + weight: 1, + costMultiplier: 1, + groupTag: "default", + isEnabled: true, + models: [], + createdAt: new Date(), + updatedAt: new Date(), + streamingIdleTimeoutMs: 0, + ...overrides, + } as Provider; +} + +function makeSession(clientAbortSignal: AbortSignal | null, stream: boolean): ProxySession { + const endpointPolicy = resolveEndpointPolicy("/v1/chat/completions"); + const provider = makeProvider(); + const session = { + request: { + model: "gpt-5.4", + log: "", + message: { + model: "gpt-5.4", + stream, + messages: [{ role: "user", content: "hello" }], + }, + }, + startTime: Date.now(), + method: "POST", + requestUrl: new URL("http://localhost/v1/chat/completions"), + headers: new Headers(), + headerLog: "", + userAgent: null, + context: {}, + clientAbortSignal, + forwardedRequestBody: "", + userName: "test-user", + authState: { + success: true, + user: { id: 1, name: "test-user" }, + key: { id: 2, name: "test-key" }, + apiKey: "test-key", + }, + provider, + messageContext: { + id: 123, + user: { id: 1, name: "test-user" }, + key: { id: 2, name: "test-key" }, + isSystemPrompt: false, + requireAuth: true, + createdAt: new Date(), + }, + sessionId: null, + requestSequence: 1, + originalFormat: "openai", + providerType: "openai", + originalModelName: "gpt-5.4", + originalUrlPathname: "/v1/chat/completions", + providerChain: [], + cacheTtlResolved: null, + context1mApplied: false, + specialSettings: [], + cachedPriceData: undefined, + cachedBillingModelSource: undefined, + endpointPolicy, + isHeaderModified: () => false, + getEndpointPolicy: () => endpointPolicy, + getContext1mApplied: () => false, + getGroupCostMultiplier: () => 1, + getOriginalModel: () => "gpt-5.4", + getCurrentModel: () => "gpt-5.4", + getProviderChain: () => [], + getSpecialSettings: () => [], + shouldPersistSessionDebugArtifacts: () => false, + shouldTrackSessionObservability: () => false, + getResolvedPricingByBillingSource: async () => null, + recordTtfb: vi.fn(), + ttfbMs: null, + addProviderToChain: vi.fn(), + clearResponseTimeout: vi.fn(), + releaseAgent: vi.fn(), + }; + + return session as unknown as ProxySession; +} + +describe("ProxyResponseHandler client abort listener cleanup", () => { + beforeEach(() => { + testState.asyncTasks = []; + testState.cancelTask.mockClear(); + testState.cleanupTask.mockClear(); + vi.restoreAllMocks(); + }); + + it("removes non-stream client abort listener after response processing completes", async () => { + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, "addEventListener"); + const removeSpy = vi.spyOn(controller.signal, "removeEventListener"); + const session = makeSession(controller.signal, false); + const upstreamResponse = new Response( + JSON.stringify({ + choices: [{ message: { content: "ok" } }], + }), + { + headers: { "content-type": "application/json" }, + } + ); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort"); + expect(abortAddCalls).toHaveLength(1); + expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]); + }); + + it("removes stream client abort listener after stream processing completes", async () => { + const controller = new AbortController(); + const addSpy = vi.spyOn(controller.signal, "addEventListener"); + const removeSpy = vi.spyOn(controller.signal, "removeEventListener"); + const session = makeSession(controller.signal, true); + const upstreamResponse = new Response( + 'data: {"choices":[{"delta":{"content":"ok"}}]}\n\ndata: [DONE]\n\n', + { + headers: { "content-type": "text/event-stream" }, + } + ); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + const abortAddCalls = addSpy.mock.calls.filter(([type]) => type === "abort"); + expect(abortAddCalls).toHaveLength(1); + expect(removeSpy).toHaveBeenCalledWith("abort", abortAddCalls[0][1]); + }); + + it("uses no-op cleanup when client abort signal is null", async () => { + const session = makeSession(null, false); + const upstreamResponse = new Response(JSON.stringify({ choices: [] }), { + headers: { "content-type": "application/json" }, + }); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + expect(testState.cancelTask).not.toHaveBeenCalled(); + }); + + it("invokes cancel synchronously when client signal is already aborted", async () => { + const controller = new AbortController(); + controller.abort(); + const addSpy = vi.spyOn(controller.signal, "addEventListener"); + const removeSpy = vi.spyOn(controller.signal, "removeEventListener"); + const session = makeSession(controller.signal, false); + const upstreamResponse = new Response(JSON.stringify({ choices: [] }), { + headers: { "content-type": "application/json" }, + }); + + const response = await ProxyResponseHandler.dispatch(session, upstreamResponse); + await response.text(); + await drainAsyncTasks(); + + expect(addSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0); + expect(removeSpy.mock.calls.filter(([type]) => type === "abort")).toHaveLength(0); + expect(testState.cancelTask).toHaveBeenCalled(); + }); +}); diff --git a/tests/unit/public-status/config-publisher.test.ts b/tests/unit/public-status/config-publisher.test.ts index 63b0cc84a..0987bf0cd 100644 --- a/tests/unit/public-status/config-publisher.test.ts +++ b/tests/unit/public-status/config-publisher.test.ts @@ -116,7 +116,7 @@ describe("public-status config publisher", () => { }), }) ); - }, 20_000); + }, 40_000); it("uses shared model-prefix matching for vendor icons without changing request type badges", async () => { mockFindAllProviderGroups.mockResolvedValue([ @@ -170,7 +170,7 @@ describe("public-status config publisher", () => { }), }) ); - }); + }, 40_000); it("uses model price metadata to derive public labels and vendor icons", async () => { mockFindAllProviderGroups.mockResolvedValue([ @@ -224,7 +224,7 @@ describe("public-status config publisher", () => { }), }) ); - }); + }, 40_000); it("publishes internal snapshot sourceGroupName for default group while public snapshot keeps custom slug", async () => { mockFindAllProviderGroups.mockResolvedValue([ @@ -275,5 +275,55 @@ describe("public-status config publisher", () => { }), }) ); - }, 20_000); + }, 40_000); + + it("publishes a Redis config projection when stored legacy group slugs collide", async () => { + mockFindAllProviderGroups.mockResolvedValue([ + { + id: 10, + name: "cc特价", + description: JSON.stringify({ + version: 2, + publicStatus: { + displayName: "CC Special", + publicGroupSlug: "cc", + publicModels: [{ modelKey: "gpt-4.1" }], + }, + }), + }, + { + id: 11, + name: "cc逆向", + description: JSON.stringify({ + version: 2, + publicStatus: { + displayName: "CC Reverse", + publicGroupSlug: "cc", + publicModels: [{ modelKey: "gpt-4.1" }], + }, + }), + }, + ]); + + const mod = await import("@/lib/public-status/config-publisher"); + const result = await mod.publishCurrentPublicStatusConfigProjection({ + reason: "test", + configVersion: "cfg-test", + }); + + expect(result.written).toBe(true); + expect(mockPublishPublicStatusConfigSnapshot).toHaveBeenCalledWith( + expect.objectContaining({ + snapshot: expect.objectContaining({ + groups: expect.arrayContaining([ + expect.objectContaining({ slug: "cc", displayName: "CC Special" }), + expect.objectContaining({ + slug: expect.stringMatching(/^cc-[a-z0-9]{6}$/), + displayName: "CC Reverse", + }), + ]), + }), + }) + ); + }, 40_000); }); diff --git a/tests/unit/public-status/public-path.test.ts b/tests/unit/public-status/public-path.test.ts index 14613de24..95b0655af 100644 --- a/tests/unit/public-status/public-path.test.ts +++ b/tests/unit/public-status/public-path.test.ts @@ -1,5 +1,6 @@ import { NextRequest } from "next/server"; import { describe, expect, it, vi } from "vitest"; +import { localeCookieName } from "@/i18n/config"; const mockIntlMiddleware = vi.hoisted(() => vi.fn((request: NextRequest) => { @@ -56,6 +57,41 @@ describe("public status proxy path", () => { expect(location).toContain("from=%2Fdashboard"); }); + it("prefers NEXT_LOCALE when redirecting an unprefixed protected route", async () => { + const { default: proxyHandler } = await import("@/proxy"); + const request = new NextRequest("http://localhost/dashboard"); + request.cookies.set(localeCookieName, "en"); + const response = proxyHandler(request); + const location = response.headers.get("location"); + + expect(location).toContain("/en/login"); + expect(location).toContain("from=%2Fdashboard"); + }); + + it("falls back safely when NEXT_LOCALE cookie is malformed", async () => { + const { default: proxyHandler } = await import("@/proxy"); + const request = new NextRequest("http://localhost/dashboard"); + request.headers.set("cookie", `${localeCookieName}=%E0%A4%A`); + + expect(() => proxyHandler(request)).not.toThrow(); + + const response = proxyHandler(request); + const location = response.headers.get("location"); + + expect(location).toContain("/zh-CN/login"); + expect(location).toContain("from=%2Fdashboard"); + }); + + it("normalizes repeated locale prefixes in the login from parameter", async () => { + const { default: proxyHandler } = await import("@/proxy"); + const response = proxyHandler(new NextRequest("http://localhost/en/en/dashboard")); + const location = response.headers.get("location"); + + expect(location).toContain("/en/login"); + expect(location).toContain("from=%2Fdashboard"); + expect(location).not.toContain("from=%2Fen%2Fdashboard"); + }); + it("redirects locale root to login with a dashboard fallback", async () => { const { default: proxyHandler } = await import("@/proxy"); const response = proxyHandler(new NextRequest("http://localhost/en")); diff --git a/tests/unit/public-status/public-status-config.test.ts b/tests/unit/public-status/public-status-config.test.ts index 43a9610dc..df85870d3 100644 --- a/tests/unit/public-status/public-status-config.test.ts +++ b/tests/unit/public-status/public-status-config.test.ts @@ -17,6 +17,7 @@ interface PublicStatusConfigModule { }; serializePublicStatusDescription(input: unknown): string | null; collectEnabledPublicStatusGroups(input: unknown): unknown[]; + slugifyPublicGroup(input: string): string; } describe("public-status config", () => { @@ -142,4 +143,65 @@ describe("public-status config", () => { }, }); }); + + it("keeps non-English group names from collapsing into duplicate empty or ASCII-only slugs", async () => { + const mod = await importPublicStatusModule( + "@/lib/public-status/config" + ); + + expect(mod.slugifyPublicGroup("中文分组")).toMatch(/^group-[a-z0-9]{6}$/); + + const groups = mod.collectEnabledPublicStatusGroups([ + { + groupName: "cc特价", + note: null, + publicStatus: { + publicModels: [{ modelKey: "gpt-4.1" }], + }, + }, + { + groupName: "cc逆向", + note: null, + publicStatus: { + publicModels: [{ modelKey: "gpt-4.1" }], + }, + }, + ]) as Array<{ publicGroupSlug: string }>; + + expect(groups).toHaveLength(2); + expect(groups.map((group) => group.publicGroupSlug)).toEqual([ + expect.stringMatching(/^cc-[a-z0-9]{6}$/), + expect.stringMatching(/^cc-[a-z0-9]{6}$/), + ]); + expect(new Set(groups.map((group) => group.publicGroupSlug)).size).toBe(2); + }); + + it("throws by default when enabled groups share the same normalized custom slug", async () => { + const mod = await importPublicStatusModule( + "@/lib/public-status/config" + ); + + expect(() => + mod.collectEnabledPublicStatusGroups([ + { + groupName: "openai-primary", + note: null, + publicStatus: { + publicGroupSlug: "Open AI", + publicModels: [{ modelKey: "gpt-4.1" }], + }, + }, + { + groupName: "openai-fallback", + note: null, + publicStatus: { + publicGroupSlug: "open-ai", + publicModels: [{ modelKey: "gpt-4.1" }], + }, + }, + ]) + ).toThrowError( + 'Duplicate normalized publicGroupSlug "open-ai" for groups: openai-primary, openai-fallback' + ); + }); }); diff --git a/tests/unit/repository/error-rules-default-numeric-boundaries.test.ts b/tests/unit/repository/error-rules-default-numeric-boundaries.test.ts new file mode 100644 index 000000000..9a7eb4201 --- /dev/null +++ b/tests/unit/repository/error-rules-default-numeric-boundaries.test.ts @@ -0,0 +1,131 @@ +import { describe, expect, test, vi } from "vitest"; + +process.env.DSN = ""; +process.env.AUTO_CLEANUP_TEST_DATA = "false"; + +type CapturedDefaultRule = { + pattern: string; + matchType: "contains" | "exact" | "regex"; + category: string; +}; + +type MockTransaction = { + query: { + errorRules: { + findMany: () => Promise; + }; + }; + delete: () => { + where: () => Promise; + }; + insert: () => { + values: (rule: CapturedDefaultRule) => { + onConflictDoNothing: () => { + returning: () => Promise>; + }; + }; + }; + update: () => { + set: () => { + where: () => Promise; + }; + }; +}; + +const capturedInsertedRules: CapturedDefaultRule[] = []; + +vi.mock("drizzle-orm", () => ({ + desc: vi.fn((...args: unknown[]) => ({ args, op: "desc" })), + eq: vi.fn((...args: unknown[]) => ({ args, op: "eq" })), + inArray: vi.fn((...args: unknown[]) => ({ args, op: "inArray" })), +})); + +vi.mock("@/drizzle/schema", () => ({ + errorRules: { + id: "error_rules.id", + pattern: "error_rules.pattern", + isDefault: "error_rules.is_default", + }, +})); + +vi.mock("@/drizzle/db", () => ({ + db: { + transaction: vi.fn(async (fn: (tx: MockTransaction) => Promise) => { + const tx: MockTransaction = { + query: { + errorRules: { + findMany: vi.fn(async () => []), + }, + }, + delete: vi.fn(() => ({ + where: vi.fn(async () => []), + })), + insert: vi.fn(() => ({ + values: (rule: CapturedDefaultRule) => { + capturedInsertedRules.push(rule); + return { + onConflictDoNothing: () => ({ + returning: vi.fn(async () => [{ id: 1 }]), + }), + }; + }, + })), + update: vi.fn(() => ({ + set: vi.fn(() => ({ + where: vi.fn(async () => []), + })), + })), + }; + + await fn(tx); + }), + }, +})); + +vi.mock("@/lib/emit-event", () => ({ + emitErrorRulesUpdated: vi.fn(async () => {}), +})); + +vi.mock("@/lib/logger", () => ({ + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + trace: vi.fn(), + }, +})); + +async function loadDefaultRules(): Promise { + capturedInsertedRules.length = 0; + vi.resetModules(); + + const { syncDefaultErrorRules } = await import("@/repository/error-rules"); + await syncDefaultErrorRules(); + + return [...capturedInsertedRules]; +} + +function matchesRule(rule: CapturedDefaultRule, sample: string): boolean { + if (rule.matchType === "exact") return sample === rule.pattern; + if (rule.matchType === "contains") + return sample.toLowerCase().includes(rule.pattern.toLowerCase()); + + return new RegExp(rule.pattern, "i").test(sample); +} + +describe("syncDefaultErrorRules numeric boundaries", () => { + test("default rules do not match numeric substrings in request ids or prices", async () => { + const defaultRules = await loadDefaultRules(); + const samples = ["request id: 202604250550399959", "需要预扣费额度:¥0.352942"]; + + expect(defaultRules.length).toBeGreaterThan(0); + + const accidentalMatches = defaultRules.flatMap((rule) => { + return samples + .filter((sample) => matchesRule(rule, sample)) + .map((sample) => ({ category: rule.category, pattern: rule.pattern, sample })); + }); + + expect(accidentalMatches).toEqual([]); + }); +}); diff --git a/tests/unit/settings/status-page/public-status-settings-form.test.tsx b/tests/unit/settings/status-page/public-status-settings-form.test.tsx index 694701169..a265d72ab 100644 --- a/tests/unit/settings/status-page/public-status-settings-form.test.tsx +++ b/tests/unit/settings/status-page/public-status-settings-form.test.tsx @@ -6,7 +6,11 @@ import type { ReactNode } from "react"; import { act } from "react"; import { createRoot } from "react-dom/client"; import { beforeEach, describe, expect, it, vi } from "vitest"; -import { PublicStatusSettingsForm } from "@/app/[locale]/settings/status-page/_components/public-status-settings-form"; +import { + PublicStatusSettingsForm, + type PublicStatusSettingsFormGroup, +} from "@/app/[locale]/settings/status-page/_components/public-status-settings-form"; +import { toast } from "sonner"; const mockRefresh = vi.hoisted(() => vi.fn()); const mockSavePublicStatusSettings = vi.hoisted(() => vi.fn()); @@ -374,4 +378,234 @@ describe("public-status settings form", () => { unmount(); }); + + it("blocks submit and highlights slug inputs when enabled groups share the same slug", async () => { + const scrollIntoView = vi.fn(); + Object.defineProperty(HTMLElement.prototype, "scrollIntoView", { + configurable: true, + value: scrollIntoView, + }); + const requestAnimationFrame = vi + .spyOn(window, "requestAnimationFrame") + .mockImplementation((callback) => { + callback(0); + return 0; + }); + + const { container, unmount } = render( + + ); + + const submitButton = Array.from(container.querySelectorAll("button")).find((button) => + button.textContent?.includes("statusPage.form.save") + ); + expect(submitButton).toBeTruthy(); + + await act(async () => { + submitButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + await Promise.resolve(); + }); + + expect(mockSavePublicStatusSettings).not.toHaveBeenCalled(); + expect(toast.error).toHaveBeenCalledWith("statusPage.form.duplicateSlug"); + const invalidSlugInputs = container.querySelectorAll('[aria-invalid="true"]'); + expect(invalidSlugInputs).toHaveLength(2); + expect(scrollIntoView).toHaveBeenCalledWith({ behavior: "smooth", block: "center" }); + expect(document.activeElement).toBe(invalidSlugInputs[0]); + + unmount(); + requestAnimationFrame.mockRestore(); + }); + + it("expands collapsed conflicting groups before focusing the first duplicate slug input", async () => { + const scrollIntoView = vi.fn(); + Object.defineProperty(HTMLElement.prototype, "scrollIntoView", { + configurable: true, + value: scrollIntoView, + }); + let frameCallback: FrameRequestCallback | undefined; + const requestAnimationFrame = vi + .spyOn(window, "requestAnimationFrame") + .mockImplementation((callback) => { + frameCallback = callback; + return 1; + }); + + const { container, unmount } = render( + + ); + + const getInputByValue = (value: string) => + Array.from(container.querySelectorAll("input")).find((input) => input.value === value); + const getGroupToggleButton = (groupName: string) => + Array.from(container.querySelectorAll("button")).find((button) => + button.textContent?.includes(groupName) + ); + + await act(async () => { + getGroupToggleButton("openai-primary")?.dispatchEvent( + new MouseEvent("click", { bubbles: true }) + ); + getGroupToggleButton("openai-fallback")?.dispatchEvent( + new MouseEvent("click", { bubbles: true }) + ); + }); + + expect(getInputByValue("Open AI")).toBeUndefined(); + expect(getInputByValue("open-ai")).toBeUndefined(); + + const submitButton = Array.from(container.querySelectorAll("button")).find((button) => + button.textContent?.includes("statusPage.form.save") + ); + expect(submitButton).toBeTruthy(); + + await act(async () => { + submitButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + await Promise.resolve(); + }); + + expect(mockSavePublicStatusSettings).not.toHaveBeenCalled(); + expect(toast.error).toHaveBeenCalledWith("statusPage.form.duplicateSlug"); + + const invalidSlugInputs = Array.from( + container.querySelectorAll('[aria-invalid="true"]') + ); + expect(invalidSlugInputs).toHaveLength(2); + expect(invalidSlugInputs.map((input) => input.value)).toEqual(["Open AI", "open-ai"]); + + await act(async () => { + frameCallback?.(0); + }); + + expect(scrollIntoView).toHaveBeenCalledWith({ behavior: "smooth", block: "center" }); + expect(document.activeElement).toBe(invalidSlugInputs[0]); + + unmount(); + requestAnimationFrame.mockRestore(); + }); + + it("uses backend slug fallback semantics and highlights every conflicting group", async () => { + const scrollIntoView = vi.fn(); + Object.defineProperty(HTMLElement.prototype, "scrollIntoView", { + configurable: true, + value: scrollIntoView, + }); + const requestAnimationFrame = vi + .spyOn(window, "requestAnimationFrame") + .mockImplementation((callback) => { + callback(0); + return 0; + }); + + const conflictingGroups: PublicStatusSettingsFormGroup[] = [ + { + groupName: "Open AI", + enabled: true, + displayName: "Open AI", + publicGroupSlug: "!!!", + explanatoryCopy: "Primary public models", + sortOrder: 0, + publicModels: [{ modelKey: "gpt-4.1" }], + }, + { + groupName: "open-ai", + enabled: true, + displayName: "open-ai", + publicGroupSlug: "???", + explanatoryCopy: "Fallback public models", + sortOrder: 1, + publicModels: [{ modelKey: "gpt-4.1" }], + }, + { + groupName: "open ai", + enabled: true, + displayName: "open ai", + publicGroupSlug: "...", + explanatoryCopy: "Tertiary public models", + sortOrder: 2, + publicModels: [{ modelKey: "gpt-4.1" }], + }, + ]; + + const { container, unmount } = render( + + ); + + const submitButton = Array.from(container.querySelectorAll("button")).find((button) => + button.textContent?.includes("statusPage.form.save") + ); + expect(submitButton).toBeTruthy(); + + await act(async () => { + submitButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + await Promise.resolve(); + }); + + expect(mockSavePublicStatusSettings).not.toHaveBeenCalled(); + expect(toast.error).toHaveBeenCalledWith("statusPage.form.duplicateSlug"); + const invalidSlugInputs = Array.from( + container.querySelectorAll('[aria-invalid="true"]') + ); + expect(invalidSlugInputs).toHaveLength(3); + expect(invalidSlugInputs.map((input) => input.value)).toEqual(["!!!", "???", "..."]); + expect(scrollIntoView).toHaveBeenCalledWith({ behavior: "smooth", block: "center" }); + expect(document.activeElement).toBe(invalidSlugInputs[0]); + + unmount(); + requestAnimationFrame.mockRestore(); + }); }); diff --git a/tests/unit/settings/status-page/status-page-loader.test.tsx b/tests/unit/settings/status-page/status-page-loader.test.tsx index 47d9f99e8..8d71ba132 100644 --- a/tests/unit/settings/status-page/status-page-loader.test.tsx +++ b/tests/unit/settings/status-page/status-page-loader.test.tsx @@ -105,4 +105,70 @@ describe("status-page loader", () => { ], }); }); + + it("hydrates unique default slugs for non-English provider groups", async () => { + mockGetSystemSettings.mockResolvedValue({ + publicStatusWindowHours: 24, + publicStatusAggregationIntervalMinutes: 5, + }); + mockBootstrapProviderGroupsFromProviders.mockResolvedValue({ + groups: [ + { + id: 3, + name: "cc特价", + description: null, + }, + { + id: 4, + name: "cc逆向", + description: null, + }, + ], + groupCounts: new Map(), + }); + + const mod = await import("@/app/[locale]/settings/status-page/loader"); + const result = await mod.loadStatusPageSettings(); + + expect(result.initialGroups.map((group) => group.publicGroupSlug)).toEqual([ + expect.stringMatching(/^cc-[a-z0-9]{6}$/), + expect.stringMatching(/^cc-[a-z0-9]{6}$/), + ]); + expect(new Set(result.initialGroups.map((group) => group.publicGroupSlug)).size).toBe(2); + }); + + it("does not generate a default slug that collides with a later custom slug", async () => { + mockGetSystemSettings.mockResolvedValue({ + publicStatusWindowHours: 24, + publicStatusAggregationIntervalMinutes: 5, + }); + mockBootstrapProviderGroupsFromProviders.mockResolvedValue({ + groups: [ + { + id: 5, + name: "alpha", + description: null, + }, + { + id: 6, + name: "custom-alpha", + description: JSON.stringify({ + version: 2, + publicStatus: { + publicGroupSlug: "alpha", + }, + }), + }, + ], + groupCounts: new Map(), + }); + + const mod = await import("@/app/[locale]/settings/status-page/loader"); + const result = await mod.loadStatusPageSettings(); + + expect(result.initialGroups.map((group) => group.publicGroupSlug)).toEqual([ + expect.stringMatching(/^alpha-[a-z0-9]{6}$/), + "alpha", + ]); + }); }); diff --git a/tests/vitest.base.ts b/tests/vitest.base.ts index 9cf5d658d..5f7db290c 100644 --- a/tests/vitest.base.ts +++ b/tests/vitest.base.ts @@ -30,6 +30,28 @@ const resolveSnapshotPath = (testPath: string, snapExtension: string) => { return testPath.replace(/\.test\.([tj]sx?)$/, `${snapExtension}.$1`); }; +export function parsePositiveInt(value: string | undefined, fallback: number): number { + if (!value) return fallback; + const parsed = Number.parseInt(value.trim(), 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback; +} + +export function parseWorkerLimit( + value: string | undefined, + fallback: number | string +): number | string { + if (!value) return fallback; + const trimmed = value.trim(); + if (/^\d+%$/.test(trimmed)) return trimmed; + const parsed = Number.parseInt(trimmed, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback; +} + +export function parseBoolean(value: string | undefined, fallback: boolean): boolean { + if (!value) return fallback; + return !["0", "false", "no", "off"].includes(value.trim().toLowerCase()); +} + const defaultTestExclude = [ "node_modules", ".next", @@ -103,6 +125,10 @@ interface TestRunnerConfigOptions { testFiles: string[]; testTimeout?: number; hookTimeout?: number; + maxWorkers?: number | string; + maxConcurrency?: number; + fileParallelism?: boolean; + pool?: "threads" | "forks" | "vmThreads" | "vmForks"; extraExclude?: string[]; api?: { host?: string; @@ -113,6 +139,8 @@ interface TestRunnerConfigOptions { export function createTestRunnerConfig(opts: TestRunnerConfigOptions) { const baseExclude = ["node_modules", ".next", "dist", "build", "coverage", "**/*.d.ts"]; + const maxWorkers = + opts.maxWorkers ?? parseWorkerLimit(process.env.VITEST_STATEFUL_MAX_WORKERS, 2); return defineConfig({ test: { @@ -122,8 +150,14 @@ export function createTestRunnerConfig(opts: TestRunnerConfigOptions) { ...(opts.api ? { api: opts.api, open: false } : {}), testTimeout: opts.testTimeout ?? 10000, hookTimeout: opts.hookTimeout ?? 10000, - maxConcurrency: 5, - pool: "threads", + teardownTimeout: parsePositiveInt(process.env.VITEST_TEARDOWN_TIMEOUT_MS, 15000), + slowTestThreshold: parsePositiveInt(process.env.VITEST_SLOW_TEST_THRESHOLD_MS, 1000), + maxConcurrency: + opts.maxConcurrency ?? parsePositiveInt(process.env.VITEST_STATEFUL_MAX_CONCURRENCY, 3), + pool: opts.pool ?? "threads", + maxWorkers, + fileParallelism: + opts.fileParallelism ?? parseBoolean(process.env.VITEST_FILE_PARALLELISM, true), include: opts.testFiles, exclude: [...baseExclude, ...(opts.extraExclude ?? [])], reporters: ["verbose"], diff --git a/vitest.config.ts b/vitest.config.ts index 59b30d224..520c48fcd 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -1,14 +1,14 @@ +import { availableParallelism } from "node:os"; import { defineConfig } from "vitest/config"; -import { sharedResolve } from "./tests/vitest.base"; +import { parsePositiveInt, parseWorkerLimit, sharedResolve } from "./tests/vitest.base"; const isIntegrationFileFilterRequested = process.argv.some((arg) => /tests[\\/]+integration(?:[\\/].+\.(test|spec)\.[cm]?[jt]sx?|[\\/]?$)/.test(arg) ); -function parsePositiveInt(value: string | undefined, fallback: number): number { - if (!value) return fallback; - const parsed = Number.parseInt(value, 10); - return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback; +function defaultMaxWorkers(): number { + const workerBudget = Math.floor(availableParallelism() * 0.75); + return Math.min(8, Math.max(2, workerBudget)); } export default defineConfig({ @@ -92,13 +92,15 @@ export default defineConfig({ // ==================== 超时配置 ==================== testTimeout: 10000, // 单个测试超时 10 秒 hookTimeout: 10000, // 钩子函数超时 10 秒 + teardownTimeout: parsePositiveInt(process.env.VITEST_TEARDOWN_TIMEOUT_MS, 15000), + slowTestThreshold: parsePositiveInt(process.env.VITEST_SLOW_TEST_THRESHOLD_MS, 1000), // ==================== 并发配置 ==================== maxConcurrency: 5, // 最大并发测试数 pool: "threads", // 使用线程池(推荐) - // 高核机器/Windows 下 threads worker 过多可能触发 EMFILE / 资源争用导致用例超时。 - // 允许通过环境变量覆盖:VITEST_MAX_WORKERS=... - maxWorkers: parsePositiveInt(process.env.VITEST_MAX_WORKERS, 8), + // 依据可用 CPU 自动调节,但上限保持 8,避免高核机器过度并行拖垮长尾测试。 + // 允许通过环境变量覆盖:VITEST_MAX_WORKERS=8 或 VITEST_MAX_WORKERS=75%。 + maxWorkers: parseWorkerLimit(process.env.VITEST_MAX_WORKERS, defaultMaxWorkers()), // ==================== 文件匹配 ==================== include: [