diff --git a/src/app/api/card/[username]/route.test.ts b/src/app/api/card/[username]/route.test.ts index cc8adb95..3793be9f 100644 --- a/src/app/api/card/[username]/route.test.ts +++ b/src/app/api/card/[username]/route.test.ts @@ -1,3 +1,4 @@ +import { NextRequest } from "next/server"; import { describe, expect, it, vi } from "vitest"; vi.mock("@/lib/cardDataFetcher", () => ({ @@ -23,7 +24,7 @@ describe("GET /api/card/[username] cache headers", () => { }); const { GET } = await import("./route"); - const req = new Request("http://localhost/api/card/alice"); + const req = new NextRequest("http://localhost/api/card/alice"); const response = await GET(req, { params: Promise.resolve({ username: "alice" }) }); expect(response.headers.get("Cache-Control")).toBe("public, s-maxage=1800, stale-while-revalidate=3600"); @@ -34,7 +35,7 @@ describe("GET /api/card/[username] cache headers", () => { vi.mocked(fetchCardData).mockResolvedValueOnce(null); const { GET } = await import("./route"); - const req = new Request("http://localhost/api/card/ghost"); + const req = new NextRequest("http://localhost/api/card/ghost"); const response = await GET(req, { params: Promise.resolve({ username: "ghost" }) }); expect(response.status).toBe(404); @@ -46,7 +47,7 @@ describe("GET /api/card/[username] cache headers", () => { vi.mocked(fetchCardData).mockRejectedValueOnce(new Error("API Error")); const { GET } = await import("./route"); - const req = new Request("http://localhost/api/card/erroruser"); + const req = new NextRequest("http://localhost/api/card/erroruser"); const response = await GET(req, { params: Promise.resolve({ username: "erroruser" }) }); expect(response.status).toBe(503); @@ -66,7 +67,7 @@ describe("GET /api/card/[username] error responses", () => { } const { GET } = await import("./route"); - const req = new Request(`http://localhost/api/card/${username}`); + const req = new NextRequest(`http://localhost/api/card/${username}`); await GET(req, { params: Promise.resolve({ username }) }); expect(renderErrorCardResponse).toHaveBeenCalledWith(expect.objectContaining({ @@ -92,7 +93,7 @@ describe("GET /api/card/[username] rate limiting", () => { const { fetchCardData } = await import("@/lib/cardDataFetcher"); const { renderErrorCardResponse } = await import("@/lib/cardRenderer"); - const req1 = new Request("http://localhost/api/card/testuser", { + const req1 = new NextRequest("http://localhost/api/card/testuser", { headers: { "x-forwarded-for": "127.0.0.1", }, diff --git a/src/app/api/card/[username]/route.ts b/src/app/api/card/[username]/route.ts index f8cd2111..3a2d8eff 100644 --- a/src/app/api/card/[username]/route.ts +++ b/src/app/api/card/[username]/route.ts @@ -1,3 +1,4 @@ +import { NextRequest } from "next/server"; import { RateLimiter } from "@/lib/rateLimit"; import { fetchCardData } from "@/lib/cardDataFetcher"; import { parseCardQueryParams, renderCardResponse, renderErrorCardResponse } from "@/lib/cardRenderer"; @@ -10,7 +11,7 @@ const SUCCESS_CACHE = "public, s-maxage=1800, stale-while-revalidate=3600"; const ERROR_CACHE = "public, s-maxage=60, stale-while-revalidate=120"; export async function GET( - request: Request, + request: NextRequest, { params }: { params: Promise<{ username: string }> } ): Promise { const { username } = await params; @@ -19,7 +20,7 @@ export async function GET( const allowedOrigin = process.env.APP_URL || "http://localhost:3000"; const fontUrl = `${allowedOrigin}/fonts/NotoSans-Regular.ttf`; - const ip = request.headers.get("x-forwarded-for") ?? "unknown"; + const ip = request.headers.get("x-real-ip") ?? request.headers.get("x-forwarded-for")?.split(",")[0] ?? "unknown"; const rateLimitResult = rateLimiter.check(ip); if (!rateLimitResult.success) { diff --git a/src/app/api/og/[username]/route.tsx b/src/app/api/og/[username]/route.tsx index bfc09e8b..64b08f46 100644 --- a/src/app/api/og/[username]/route.tsx +++ b/src/app/api/og/[username]/route.tsx @@ -17,8 +17,7 @@ export async function GET( ) { const { username } = await params; - const forwarded = request.headers.get("x-forwarded-for"); - const ip = forwarded ? forwarded.split(",").at(-1)?.trim() ?? "unknown" : "unknown"; + const ip = request.headers.get("x-real-ip") ?? request.headers.get("x-forwarded-for")?.split(",")[0] ?? "unknown"; const rateLimitResult = rateLimiter.check(ip); if (!rateLimitResult.success) {