diff --git a/.agents/skills/paykit-architecture/SKILL.md b/.agents/skills/paykit-architecture/SKILL.md index 5903d067..b5082deb 100644 --- a/.agents/skills/paykit-architecture/SKILL.md +++ b/.agents/skills/paykit-architecture/SKILL.md @@ -1,6 +1,6 @@ --- name: paykit-architecture -description: Use before architectural, API design, provider integration, billing lifecycle, database model, or product-scope decisions in PayKit. +description: Not use automatically. --- # PayKit Architecture diff --git a/apps/web/content/docs/providers/polar.mdx b/apps/web/content/docs/providers/polar.mdx index 4a653336..287199f3 100644 --- a/apps/web/content/docs/providers/polar.mdx +++ b/apps/web/content/docs/providers/polar.mdx @@ -36,7 +36,7 @@ POLAR_ACCESS_TOKEN=polar_oat_... POLAR_WEBHOOK_SECRET=... ``` -- `POLAR_ACCESS_TOKEN`: create one in [Polar Settings](https://polar.sh/settings) under **Access Tokens**. The token needs the following scopes: `products:read`, `products:write`, `customers:read`, `customers:write`, `customer_sessions:write`, `subscriptions:read`, `subscriptions:write`, `checkouts:write`, `organizations:read`, `organizations:write`. +- `POLAR_ACCESS_TOKEN`: create one in [Polar Settings](https://polar.sh/settings) under **Access Tokens**. The token needs the following scopes: `products:read`, `products:write`, `customers:read`, `customers:write`, `customer_sessions:write`, `subscriptions:read`, `subscriptions:write`, `checkouts:write`, `organizations:read`, `organizations:write`, `webhooks:read`, `webhooks:write`. - `POLAR_WEBHOOK_SECRET`: generated when you create a webhook endpoint. See the section below. ## Webhook setup @@ -60,16 +60,16 @@ You can also select all events. PayKit silently ignores any events it doesn't ne ## Local development -Use the Polar CLI to forward webhook events to your local server: +Use the PayKit CLI to create or update the Polar webhook endpoint and forward events to your local server: ```bash -polar listen http://localhost:3000/paykit/webhook +paykitjs listen --forward-to http://localhost:3000 ``` -The CLI prints a webhook signing secret at startup. Use that as `POLAR_WEBHOOK_SECRET` in your local `.env`. +The CLI prints the provider webhook URL and webhook signing secret at startup. Use that signing secret as `POLAR_WEBHOOK_SECRET` in your local `.env`. - Install the Polar CLI with `curl -fsSL https://polar.sh/install.sh | bash`. You'll need to run `polar login` once to authenticate. + `paykitjs listen` uses your configured `POLAR_ACCESS_TOKEN`, so the token must include `webhooks:read` and `webhooks:write` in addition to the normal PayKit provider scopes. ## Product syncing diff --git a/e2e/core/cancel/cancel-then-upgrade.test.ts b/e2e/core/cancel/cancel-then-upgrade.test.ts index b429c66c..f32c843c 100644 --- a/e2e/core/cancel/cancel-then-upgrade.test.ts +++ b/e2e/core/cancel/cancel-then-upgrade.test.ts @@ -7,93 +7,97 @@ import { expectProduct, expectSingleActivePlanInGroup, expectSingleScheduledPlanInGroup, + harness, subscribeCustomer, type TestPayKit, } from "../../test-utils"; -describe("cancel-then-upgrade: pro → free (scheduled) → ultra (upgrade)", () => { - let t: TestPayKit; - let customerId: string; +describe.skipIf(!harness.capabilities.repeatedHostedCheckout)( + "cancel-then-upgrade: pro → free (scheduled) → ultra (upgrade)", + () => { + let t: TestPayKit; + let customerId: string; - beforeAll(async () => { - t = await createTestPayKit(); - const customer = await createTestCustomerWithPM({ - t, - customer: { - id: "test_cancel_upgrade", - email: "cancel-upgrade@test.com", - name: "Cancel Then Upgrade Test", - }, - }); - customerId = customer.customerId; + beforeAll(async () => { + t = await createTestPayKit(); + const customer = await createTestCustomerWithPM({ + t, + customer: { + id: "test_cancel_upgrade", + email: "cancel-upgrade@test.com", + name: "Cancel Then Upgrade Test", + }, + }); + customerId = customer.customerId; - // Setup: subscribe to Pro, then schedule downgrade to Free - await subscribeCustomer({ t, customerId, planId: "pro" }); + // Setup: subscribe to Pro, then schedule downgrade to Free + await subscribeCustomer({ t, customerId, planId: "pro" }); - await subscribeCustomer({ t, customerId, planId: "free" }); - }); + await subscribeCustomer({ t, customerId, planId: "free" }); + }); - afterAll(async () => { - await t?.cleanup(); - }); + afterAll(async () => { + await t?.cleanup(); + }); - it("upgrading while cancellation is pending cancels the downgrade and activates the new plan", async () => { - try { - // Verify precondition - await expectProduct({ - database: t.database, - customerId, - planId: "pro", - expected: { status: "active", canceled: true }, - }); - await expectSingleActivePlanInGroup({ - database: t.database, - customerId, - group: "base", - planId: "pro", - }); - await expectSingleScheduledPlanInGroup({ - database: t.database, - customerId, - group: "base", - planId: "free", - }); + it("upgrading while cancellation is pending cancels the downgrade and activates the new plan", async () => { + try { + // Verify precondition + await expectProduct({ + database: t.database, + customerId, + planId: "pro", + expected: { status: "active", canceled: true }, + }); + await expectSingleActivePlanInGroup({ + database: t.database, + customerId, + group: "base", + planId: "pro", + }); + await expectSingleScheduledPlanInGroup({ + database: t.database, + customerId, + group: "base", + planId: "free", + }); - // Action: upgrade to Ultra - await subscribeCustomer({ t, customerId, planId: "ultra" }); + // Action: upgrade to Ultra + await subscribeCustomer({ t, customerId, planId: "ultra" }); - // Ultra is active - await expectProduct({ - database: t.database, - customerId, - planId: "ultra", - expected: { - status: "active", - hasPeriodEnd: true, - }, - }); - await expectSingleActivePlanInGroup({ - database: t.database, - customerId, - group: "base", - planId: "ultra", - }); + // Ultra is active + await expectProduct({ + database: t.database, + customerId, + planId: "ultra", + expected: { + status: "active", + hasPeriodEnd: true, + }, + }); + await expectSingleActivePlanInGroup({ + database: t.database, + customerId, + group: "base", + planId: "ultra", + }); - // Pro is ended - await expectProduct({ - database: t.database, - customerId, - planId: "pro", - expected: { status: "ended" }, - }); + // Pro is ended + await expectProduct({ + database: t.database, + customerId, + planId: "pro", + expected: { status: "ended" }, + }); - // TODO: scheduled Free should be deleted on upgrade, but the subscribe - // flow computes "switch" instead of "upgrade" when the current subscription - // has cancel_at_period_end=true. This is a known PayKit issue. - // await expectProductNotPresent(t.database, customerId, "free"); - } catch (error) { - await dumpStateOnFailure(t.database, t.dbPath); - throw error; - } - }); -}); + // TODO: scheduled Free should be deleted on upgrade, but the subscribe + // flow computes "switch" instead of "upgrade" when the current subscription + // has cancel_at_period_end=true. This is a known PayKit issue. + // await expectProductNotPresent(t.database, customerId, "free"); + } catch (error) { + await dumpStateOnFailure(t.database, t.dbPath); + throw error; + } + }); + }, +); diff --git a/e2e/core/checkout/resubscribe-after-cancel.test.ts b/e2e/core/checkout/resubscribe-after-cancel.test.ts index 05f497e6..16e00e2f 100644 --- a/e2e/core/checkout/resubscribe-after-cancel.test.ts +++ b/e2e/core/checkout/resubscribe-after-cancel.test.ts @@ -9,120 +9,124 @@ import { dumpStateOnFailure, expectProduct, expectSingleActivePlanInGroup, + harness, type TestPayKit, waitForWebhook, } from "../../test-utils"; -describe("resubscribe-after-cancel: checkout after full cancellation", () => { - let t: TestPayKit; - let customerId: string; +describe.skipIf(!harness.capabilities.testClocks)( + "resubscribe-after-cancel: checkout after full cancellation", + () => { + let t: TestPayKit; + let customerId: string; - beforeAll(async () => { - t = await createTestPayKit(); - const customer = await createTestCustomerWithPM({ - t, - customer: { - id: "test_resub", - email: "resub@test.com", - name: "Resubscribe Test", - }, - }); - customerId = customer.customerId; - - // Setup: subscribe Pro → cancel to Free → advance clock (full cancellation) - await t.paykit.subscribe({ - customerId, - planId: "pro", - successUrl: "https://example.com/success", - }); + beforeAll(async () => { + t = await createTestPayKit(); + const customer = await createTestCustomerWithPM({ + t, + customer: { + id: "test_resub", + email: "resub@test.com", + name: "Resubscribe Test", + }, + }); + customerId = customer.customerId; - await t.paykit.subscribe({ - customerId, - planId: "free", - successUrl: "https://example.com/success", - }); + // Setup: subscribe Pro → cancel to Free → advance clock (full cancellation) + await t.paykit.subscribe({ + customerId, + planId: "pro", + successUrl: "https://example.com/success", + }); - // Advance past period end so subscription fully cancels - const subRows = await t.database - .select({ currentPeriodEndAt: subscription.currentPeriodEndAt }) - .from(subscription) - .where(eq(subscription.customerId, customerId)) - .orderBy(desc(subscription.updatedAt)) - .limit(1); - const periodEnd = new Date(subRows[0]!.currentPeriodEndAt as unknown as string); - const advanceTo = new Date(periodEnd.getTime() + 86_400_000); - await advanceTestClock({ - t, - customerId, - frozenTime: advanceTo, - }); + await t.paykit.subscribe({ + customerId, + planId: "free", + successUrl: "https://example.com/success", + }); - // Wait for Free to activate - for (let i = 0; i < 60; i++) { - const rows = await t.database - .select({ status: subscription.status }) + // Advance past period end so subscription fully cancels + const subRows = await t.database + .select({ currentPeriodEndAt: subscription.currentPeriodEndAt }) .from(subscription) - .innerJoin(product, eq(product.internalId, subscription.productInternalId)) - .where( - and( - eq(subscription.customerId, customerId), - eq(product.id, "free"), - eq(subscription.status, "active"), - ), - ); - if (rows.length > 0) break; - if (i === 59) throw new Error("Free never activated in setup"); - await new Promise((resolve) => setTimeout(resolve, 2000)); - } + .where(eq(subscription.customerId, customerId)) + .orderBy(desc(subscription.updatedAt)) + .limit(1); + const periodEnd = new Date(subRows[0]!.currentPeriodEndAt as unknown as string); + const advanceTo = new Date(periodEnd.getTime() + 86_400_000); + await advanceTestClock({ + t, + customerId, + frozenTime: advanceTo, + }); - // Clear stale payment method (Stripe removes it on full cancellation) - await t.database.delete(paymentMethod).where(eq(paymentMethod.customerId, customerId)); - }); + // Wait for Free to activate + for (let i = 0; i < 60; i++) { + const rows = await t.database + .select({ status: subscription.status }) + .from(subscription) + .innerJoin(product, eq(product.internalId, subscription.productInternalId)) + .where( + and( + eq(subscription.customerId, customerId), + eq(product.id, "free"), + eq(subscription.status, "active"), + ), + ); + if (rows.length > 0) break; + if (i === 59) throw new Error("Free never activated in setup"); + await new Promise((resolve) => setTimeout(resolve, 2000)); + } - afterAll(async () => { - await t?.cleanup(); - }); + // Clear stale payment method (Stripe removes it on full cancellation) + await t.database.delete(paymentMethod).where(eq(paymentMethod.customerId, customerId)); + }); - it("resubscribing after full cancellation requires checkout", async () => { - try { - const beforeCheckout = new Date(); + afterAll(async () => { + await t?.cleanup(); + }); - const result = await t.paykit.subscribe({ - customerId, - planId: "pro", - successUrl: "https://example.com/success", - }); + it("resubscribing after full cancellation requires checkout", async () => { + try { + const beforeCheckout = new Date(); - // Should require checkout (payment method was cleared) - if (!result.paymentUrl) { - throw new Error("Expected checkout URL but got direct subscription"); - } + const result = await t.paykit.subscribe({ + customerId, + planId: "pro", + successUrl: "https://example.com/success", + }); - await t.harness.completeCheckout(result.paymentUrl); + // Should require checkout (payment method was cleared) + if (!result.paymentUrl) { + throw new Error("Expected checkout URL but got direct subscription"); + } - await waitForWebhook({ - database: t.database, - eventType: "checkout.completed", - after: beforeCheckout, - timeout: 120_000, - }); + await t.harness.completeCheckout(result.paymentUrl); - // Pro is active again - await expectProduct({ - database: t.database, - customerId, - planId: "pro", - expected: { status: "active", hasPeriodEnd: true }, - }); - await expectSingleActivePlanInGroup({ - database: t.database, - customerId, - group: "base", - planId: "pro", - }); - } catch (error) { - await dumpStateOnFailure(t.database, t.dbPath); - throw error; - } - }); -}); + await waitForWebhook({ + database: t.database, + eventType: "checkout.completed", + after: beforeCheckout, + timeout: 120_000, + }); + + // Pro is active again + await expectProduct({ + database: t.database, + customerId, + planId: "pro", + expected: { status: "active", hasPeriodEnd: true }, + }); + await expectSingleActivePlanInGroup({ + database: t.database, + customerId, + group: "base", + planId: "pro", + }); + } catch (error) { + await dumpStateOnFailure(t.database, t.dbPath); + throw error; + } + }); + }, +); diff --git a/e2e/core/checkout/subscribe-paid-checkout.test.ts b/e2e/core/checkout/subscribe-paid-checkout.test.ts index 17ea1d66..52ea1ccc 100644 --- a/e2e/core/checkout/subscribe-paid-checkout.test.ts +++ b/e2e/core/checkout/subscribe-paid-checkout.test.ts @@ -1,3 +1,5 @@ +import { eq } from "drizzle-orm"; +import { product, subscription } from "paykitjs/database"; import { afterAll, beforeAll, describe, it } from "vitest"; import { @@ -8,7 +10,6 @@ import { expectSingleActivePlanInGroup, expectSubscription, type TestPayKit, - waitForWebhook, } from "../../test-utils"; describe("subscribe-paid-checkout: free → pro via checkout", () => { @@ -50,12 +51,7 @@ describe("subscribe-paid-checkout: free → pro via checkout", () => { await t.harness.completeCheckout(result.paymentUrl); - await waitForWebhook({ - database: t.database, - eventType: "checkout.completed", - after: beforeCheckout, - timeout: 120_000, - }); + await waitForActiveSubscription(t, customerId, 120_000, beforeCheckout, "pro"); // Pro is active await expectProduct({ @@ -91,3 +87,44 @@ describe("subscribe-paid-checkout: free → pro via checkout", () => { } }); }); + +async function waitForActiveSubscription( + t: TestPayKit, + customerId: string, + timeout: number, + after: Date, + planId: string, +): Promise { + const start = Date.now(); + + while (Date.now() - start < timeout) { + const rows = await t.database + .select({ + currentPeriodEndAt: subscription.currentPeriodEndAt, + planId: product.id, + providerData: subscription.providerData, + startedAt: subscription.startedAt, + status: subscription.status, + }) + .from(subscription) + .innerJoin(product, eq(product.internalId, subscription.productInternalId)) + .where(eq(subscription.customerId, customerId)); + + if ( + rows.some( + (row) => + row.status === "active" && + row.planId === planId && + row.currentPeriodEndAt !== null && + row.startedAt > after && + row.providerData !== null, + ) + ) { + return; + } + + await new Promise((resolve) => setTimeout(resolve, 500)); + } + + throw new Error("Timed out waiting for active checkout subscription"); +} diff --git a/e2e/core/entitlements/stacked-metered.test.ts b/e2e/core/entitlements/stacked-metered.test.ts index 3279a1e4..2723a523 100644 --- a/e2e/core/entitlements/stacked-metered.test.ts +++ b/e2e/core/entitlements/stacked-metered.test.ts @@ -5,80 +5,84 @@ import { createTestPayKit, dumpStateOnFailure, expectExactMeteredBalance, + harness, subscribeCustomer, type TestPayKit, } from "../../test-utils"; -describe("stacked-metered: cross-group entitlement aggregation", () => { - let t: TestPayKit; - let customerId: string; +describe.skipIf(!harness.capabilities.repeatedHostedCheckout)( + "stacked-metered: cross-group entitlement aggregation", + () => { + let t: TestPayKit; + let customerId: string; - beforeAll(async () => { - t = await createTestPayKit(); - const customer = await createTestCustomerWithPM({ - t, - customer: { - id: "test_stacked_metered", - email: "stacked-metered@test.com", - name: "Stacked Metered Test", - }, - }); - customerId = customer.customerId; + beforeAll(async () => { + t = await createTestPayKit(); + const customer = await createTestCustomerWithPM({ + t, + customer: { + id: "test_stacked_metered", + email: "stacked-metered@test.com", + name: "Stacked Metered Test", + }, + }); + customerId = customer.customerId; - // Pro (base group): 500 messages/month - await subscribeCustomer({ t, customerId, planId: "pro" }); + // Pro (base group): 500 messages/month + await subscribeCustomer({ t, customerId, planId: "pro" }); - // Extra Messages (addons group): 200 messages/month - await subscribeCustomer({ t, customerId, planId: "extra_messages" }); - }); + // Extra Messages (addons group): 200 messages/month + await subscribeCustomer({ t, customerId, planId: "extra_messages" }); + }); - afterAll(async () => { - await t?.cleanup(); - }); + afterAll(async () => { + await t?.cleanup(); + }); - it("check returns aggregated balance across groups", async () => { - try { - await expectExactMeteredBalance({ - paykit: t.paykit, - customerId, - featureId: "messages", - limit: 700, - remaining: 700, - }); - } catch (error) { - await dumpStateOnFailure(t.database, t.dbPath); - throw error; - } - }); + it("check returns aggregated balance across groups", async () => { + try { + await expectExactMeteredBalance({ + paykit: t.paykit, + customerId, + featureId: "messages", + limit: 700, + remaining: 700, + }); + } catch (error) { + await dumpStateOnFailure(t.database, t.dbPath); + throw error; + } + }); - it("report succeeds when amount exceeds any single row but fits the total", async () => { - try { - // 600 exceeds Pro's 500 and addon's 200 individually, but 700 total covers it - const report = await t.paykit.report({ customerId, featureId: "messages", amount: 600 }); - expect(report.success).toBe(true); - expect(report.balance!.remaining).toBe(100); + it("report succeeds when amount exceeds any single row but fits the total", async () => { + try { + // 600 exceeds Pro's 500 and addon's 200 individually, but 700 total covers it + const report = await t.paykit.report({ customerId, featureId: "messages", amount: 600 }); + expect(report.success).toBe(true); + expect(report.balance!.remaining).toBe(100); - await expectExactMeteredBalance({ - paykit: t.paykit, - customerId, - featureId: "messages", - limit: 700, - remaining: 100, - }); - } catch (error) { - await dumpStateOnFailure(t.database, t.dbPath); - throw error; - } - }); + await expectExactMeteredBalance({ + paykit: t.paykit, + customerId, + featureId: "messages", + limit: 700, + remaining: 100, + }); + } catch (error) { + await dumpStateOnFailure(t.database, t.dbPath); + throw error; + } + }); - it("report fails when amount exceeds the combined balance", async () => { - try { - // Only 100 remaining after previous test - const report = await t.paykit.report({ customerId, featureId: "messages", amount: 200 }); - expect(report.success).toBe(false); - } catch (error) { - await dumpStateOnFailure(t.database, t.dbPath); - throw error; - } - }); -}); + it("report fails when amount exceeds the combined balance", async () => { + try { + // Only 100 remaining after previous test + const report = await t.paykit.report({ customerId, featureId: "messages", amount: 200 }); + expect(report.success).toBe(false); + } catch (error) { + await dumpStateOnFailure(t.database, t.dbPath); + throw error; + } + }); + }, +); diff --git a/e2e/core/lifecycle/subscription.test.ts b/e2e/core/lifecycle/subscription.test.ts index e66b4d40..f54e890f 100644 --- a/e2e/core/lifecycle/subscription.test.ts +++ b/e2e/core/lifecycle/subscription.test.ts @@ -16,11 +16,11 @@ import { expectNoScheduledPlanInGroup, expectSingleActivePlanInGroup, expectSingleScheduledPlanInGroup, + harness, type TestPayKit, - waitForWebhook, } from "../../test-utils"; -describe("subscription lifecycle", () => { +describe.skipIf(!harness.capabilities.repeatedHostedCheckout)("subscription lifecycle", () => { let t: TestPayKit; let customerId: string; @@ -171,11 +171,7 @@ describe("subscription lifecycle", () => { await t.harness.completeCheckout(result.paymentUrl!); - await waitForWebhook({ - database: t.database, - eventType: "checkout.completed", - timeout: 120_000, - }); + await waitForSubscriptionActive({ customerId, planId: "pro", timeout: 120_000 }); const products = await getCustomerProducts(); const pro = products.find((p) => p.plan_id === "pro" && p.status === "active"); @@ -302,6 +298,10 @@ describe("subscription lifecycle", () => { }); }); + if (!harness.capabilities.testClocks) { + return; + } + // ─── Step 6: Downgrade Ultra → Free + advance clock ─── await step("downgrade ultra → free + advance clock", async () => { await t.paykit.subscribe({ @@ -373,8 +373,6 @@ describe("subscription lifecycle", () => { await t.database.delete(paymentMethod).where(eq(paymentMethod.customerId, customerId)); await step("upgrade free → pro", async () => { - const beforeCheckout = new Date(); - const result = await t.paykit.subscribe({ customerId, planId: "pro", @@ -385,12 +383,7 @@ describe("subscription lifecycle", () => { // Customer must go through checkout again. expect(result.paymentUrl).not.toBeNull(); await t.harness.completeCheckout(result.paymentUrl!); - await waitForWebhook({ - database: t.database, - eventType: "checkout.completed", - after: beforeCheckout, - timeout: 120_000, - }); + await waitForSubscriptionActive({ customerId, planId: "pro", timeout: 120_000 }); const products = await getCustomerProducts(); const activePro = products.find((p) => p.plan_id === "pro" && p.status === "active"); diff --git a/e2e/core/subscribe/downgrade-scheduled.test.ts b/e2e/core/subscribe/downgrade-scheduled.test.ts index 2b8a8c12..645cb91b 100644 --- a/e2e/core/subscribe/downgrade-scheduled.test.ts +++ b/e2e/core/subscribe/downgrade-scheduled.test.ts @@ -7,67 +7,71 @@ import { expectProduct, expectSingleActivePlanInGroup, expectSingleScheduledPlanInGroup, + harness, subscribeCustomer, type TestPayKit, } from "../../test-utils"; -describe("downgrade-scheduled: ultra → pro", () => { - let t: TestPayKit; - let customerId: string; +describe.skipIf(!harness.capabilities.repeatedHostedCheckout)( + "downgrade-scheduled: ultra → pro", + () => { + let t: TestPayKit; + let customerId: string; - beforeAll(async () => { - t = await createTestPayKit(); - const customer = await createTestCustomerWithPM({ - t, - customer: { - id: "test_downgrade", - email: "downgrade@test.com", - name: "Downgrade Test", - }, - }); - customerId = customer.customerId; + beforeAll(async () => { + t = await createTestPayKit(); + const customer = await createTestCustomerWithPM({ + t, + customer: { + id: "test_downgrade", + email: "downgrade@test.com", + name: "Downgrade Test", + }, + }); + customerId = customer.customerId; - // Setup: subscribe to Pro then upgrade to Ultra - await subscribeCustomer({ t, customerId, planId: "pro" }); + // Setup: subscribe to Pro then upgrade to Ultra + await subscribeCustomer({ t, customerId, planId: "pro" }); - await subscribeCustomer({ t, customerId, planId: "ultra" }); - }); + await subscribeCustomer({ t, customerId, planId: "ultra" }); + }); - afterAll(async () => { - await t?.cleanup(); - }); + afterAll(async () => { + await t?.cleanup(); + }); - it("downgrading to a lower tier schedules the change at period end", async () => { - try { - await subscribeCustomer({ t, customerId, planId: "pro" }); + it("downgrading to a lower tier schedules the change at period end", async () => { + try { + await subscribeCustomer({ t, customerId, planId: "pro" }); - // Ultra is still active but marked as canceled - await expectProduct({ - database: t.database, - customerId, - planId: "ultra", - expected: { - status: "active", - canceled: true, - }, - }); - await expectSingleActivePlanInGroup({ - database: t.database, - customerId, - group: "base", - planId: "ultra", - }); + // Ultra is still active but marked as canceled + await expectProduct({ + database: t.database, + customerId, + planId: "ultra", + expected: { + status: "active", + canceled: true, + }, + }); + await expectSingleActivePlanInGroup({ + database: t.database, + customerId, + group: "base", + planId: "ultra", + }); - // Pro is scheduled for activation at period end - await expectSingleScheduledPlanInGroup({ - database: t.database, - customerId, - group: "base", - planId: "pro", - }); - } catch (error) { - await dumpStateOnFailure(t.database, t.dbPath); - throw error; - } - }); -}); + // Pro is scheduled for activation at period end + await expectSingleScheduledPlanInGroup({ + database: t.database, + customerId, + group: "base", + planId: "pro", + }); + } catch (error) { + await dumpStateOnFailure(t.database, t.dbPath); + throw error; + } + }); + }, +); diff --git a/e2e/core/webhook/duplicate-webhook.test.ts b/e2e/core/webhook/duplicate-webhook.test.ts index a614bba3..029272d6 100644 --- a/e2e/core/webhook/duplicate-webhook.test.ts +++ b/e2e/core/webhook/duplicate-webhook.test.ts @@ -9,11 +9,12 @@ import { expectExactMeteredBalance, expectNoScheduledPlanInGroup, expectProduct, - expectSingleActivePlanInGroup, + harness, replayWebhookRequest, subscribeCustomer, type TestPayKit, waitForForwardedWebhookRequest, + waitForSingleActivePlanInGroup, waitForWebhook, } from "../../test-utils"; @@ -56,11 +57,13 @@ describe("duplicate-webhook: same event delivered twice", () => { requests: t.webhookRequests, }); - await waitForWebhook({ - after: beforeSubscribe, - database: t.database, - eventType: "invoice.updated", - }); + if (harness.capabilities.invoiceWebhooks) { + await waitForWebhook({ + after: beforeSubscribe, + database: t.database, + eventType: "invoice.updated", + }); + } const webhookCountBeforeRows = await t.database .select({ count: count() }) @@ -83,7 +86,7 @@ describe("duplicate-webhook: same event delivered twice", () => { planId: "pro", expected: { status: "active" }, }); - await expectSingleActivePlanInGroup({ + await waitForSingleActivePlanInGroup({ database: t.database, customerId, group: "base", diff --git a/e2e/paykit.config.ts b/e2e/paykit.config.ts new file mode 100644 index 00000000..3cb48e37 --- /dev/null +++ b/e2e/paykit.config.ts @@ -0,0 +1,39 @@ +import { polar } from "@paykitjs/polar"; +import { stripe } from "@paykitjs/stripe"; +import { createPayKit } from "paykitjs"; +import { Pool } from "pg"; + +import { env } from "./test-utils/env"; +import { allProducts } from "./test-utils/products"; + +function createProvider() { + if (env.PROVIDER === "polar") { + if (!env.E2E_POLAR_ACCESS_TOKEN || !env.E2E_POLAR_WHSEC) { + throw new Error("E2E_POLAR_ACCESS_TOKEN and E2E_POLAR_WHSEC must be set"); + } + + return polar({ + accessToken: env.E2E_POLAR_ACCESS_TOKEN, + webhookSecret: env.E2E_POLAR_WHSEC, + server: "sandbox", + }); + } + + if (!env.E2E_STRIPE_SK || !env.E2E_STRIPE_WHSEC) { + throw new Error("E2E_STRIPE_SK and E2E_STRIPE_WHSEC must be set"); + } + + return stripe({ + secretKey: env.E2E_STRIPE_SK, + webhookSecret: env.E2E_STRIPE_WHSEC, + }); +} + +export const paykit = createPayKit({ + database: new Pool({ connectionString: env.TEST_DATABASE_URL }), + products: allProducts, + provider: createProvider(), + testing: { enabled: true }, +}); + +export default paykit; diff --git a/e2e/test-utils/env.ts b/e2e/test-utils/env.ts index bf644fec..d768dd53 100644 --- a/e2e/test-utils/env.ts +++ b/e2e/test-utils/env.ts @@ -22,6 +22,7 @@ export const env = createEnv({ // Polar E2E_POLAR_ACCESS_TOKEN: z.string().optional(), + E2E_POLAR_CHECKOUT_MODE: z.enum(["auto", "direct", "manual"]).default("auto"), E2E_POLAR_WHSEC: z.string().optional(), }, runtimeEnv: process.env, diff --git a/e2e/test-utils/harness/polar.ts b/e2e/test-utils/harness/polar.ts index 86185cfa..77f82331 100644 --- a/e2e/test-utils/harness/polar.ts +++ b/e2e/test-utils/harness/polar.ts @@ -1,9 +1,18 @@ +import { mkdir, readFile, writeFile } from "node:fs/promises"; +import { dirname } from "node:path"; + import { polar } from "@paykitjs/polar"; -import { chromium } from "playwright"; +import { chromium, type Frame, type Locator, type Page } from "playwright"; import { env } from "../env"; import type { ProviderHarness } from "./types"; +const POLAR_CHECKOUT_THROTTLE_FILE = "/tmp/paykit-polar-checkout-throttle"; +const POLAR_MIN_CHECKOUT_BEFORE_SUBMIT_MS = 15_000; +const POLAR_MIN_CHECKOUT_API_INTERVAL_MS = 75_000; + +const checkoutDetailsBySecret = new Map(); + export function createPolarHarness(): ProviderHarness { const accessToken = env.E2E_POLAR_ACCESS_TOKEN; const webhookSecret = env.E2E_POLAR_WHSEC; @@ -16,51 +25,50 @@ export function createPolarHarness(): ProviderHarness { capabilities: { testClocks: false, directSubscription: false, + invoiceWebhooks: false, + repeatedHostedCheckout: false, }, - createProviderConfig() { + createProvider() { return polar({ accessToken, webhookSecret, server: "sandbox" }); }, + applyTestingOverrides(ctx) { + const createSubscriptionCheckout = ctx.provider.createSubscriptionCheckout?.bind( + ctx.provider, + ); + if (!createSubscriptionCheckout) return; + + ctx.provider.createSubscriptionCheckout = async (data) => { + const result = await createSubscriptionCheckout(data); + if (env.E2E_POLAR_CHECKOUT_MODE === "direct") { + await rememberCheckoutDetails({ + accessToken, + clientSecret: getPolarCheckoutClientSecret(result.paymentUrl), + providerCheckoutSessionId: result.providerCheckoutSessionId, + }); + } + return result; + }; + }, + async setupCustomerForDirectSubscription(_providerCustomerId: string) { // Polar doesn't support direct subscription — always goes through checkout. // This is a no-op; tests will get a paymentUrl and call completeCheckout. }, async completeCheckout(url: string) { - const browser = await chromium.launch({ headless: true }); - const page = await browser.newPage(); - - try { - await page.goto(url, { waitUntil: "networkidle" }); - - // Polar sandbox checkout — fill test card details - await page.fill( - '[data-testid="card-number"], input[name="cardNumber"], input[placeholder*="card number" i]', - "4242424242424242", - ); - await page.fill( - '[data-testid="card-expiry"], input[name="cardExpiry"], input[placeholder*="MM" i]', - "12/30", - ); - await page.fill( - '[data-testid="card-cvc"], input[name="cardCvc"], input[placeholder*="CVC" i]', - "123", - ); - - // Submit payment - const submitButton = page.locator( - 'button[type="submit"], button:has-text("Pay"), button:has-text("Subscribe")', - ); - await submitButton.click(); - - // Wait for redirect to success URL or confirmation - await page.waitForURL("**/success**", { timeout: 30_000 }).catch(() => { - // Some checkouts show a confirmation page rather than redirecting - }); - } finally { - await browser.close(); + if (env.E2E_POLAR_CHECKOUT_MODE === "manual") { + await waitForManualCheckout(url); + return; + } + + if (env.E2E_POLAR_CHECKOUT_MODE === "direct") { + await completeCheckoutViaClientApi(url); + return; } + + await completeHostedCheckout(url); }, async cleanup(_ctx) { @@ -75,3 +83,679 @@ export function createPolarHarness(): ProviderHarness { }, }; } + +async function completeCheckoutViaClientApi(url: string): Promise { + const clientSecret = getPolarCheckoutClientSecret(url); + const checkout = + checkoutDetailsBySecret.get(clientSecret) ?? (await getPolarClientCheckout(clientSecret)); + const publishableKey = checkout.paymentProcessorMetadata.publishable_key; + if (!publishableKey) throw new Error("Polar checkout is missing Stripe publishable key"); + + const confirmationTokenId = await createStripeConfirmationToken({ + amount: checkout.totalAmount || checkout.amount, + currency: checkout.currency, + email: checkout.customerEmail ?? `checkout-${Date.now()}@e2e.paykit.sh`, + name: checkout.customerName ?? "Test Customer", + publishableKey, + }); + + await confirmPolarCheckout(clientSecret, { + confirmationTokenId, + customerBillingAddress: { + city: "San Francisco", + country: "US", + line1: "1 Test St", + line2: "", + postal_code: "94105", + state: "CA", + }, + customerBillingName: checkout.customerName ?? "Test Customer", + customerEmail: checkout.customerEmail ?? `checkout-${Date.now()}@e2e.paykit.sh`, + customerName: checkout.customerName ?? "Test Customer", + }); +} + +async function rememberCheckoutDetails(input: { + accessToken: string; + clientSecret: string; + providerCheckoutSessionId: string; +}): Promise { + const checkout = await getAuthenticatedPolarCheckout({ + accessToken: input.accessToken, + checkoutId: input.providerCheckoutSessionId, + }); + checkoutDetailsBySecret.set(input.clientSecret, checkout); +} + +interface PolarClientCheckout { + amount: number; + currency: string; + customerEmail: string | null; + customerName: string | null; + paymentProcessorMetadata: Record; + totalAmount: number; +} + +async function getPolarClientCheckout(clientSecret: string): Promise { + const response = await fetchPolarCheckoutClient( + `https://sandbox-api.polar.sh/v1/checkouts/client/${clientSecret}`, + ); + if (!response.ok) { + throw new Error( + `Polar client checkout lookup failed: ${String(response.status)} ${await response.text()}`, + ); + } + return parsePolarCheckout(await response.json()); +} + +async function getAuthenticatedPolarCheckout(input: { + accessToken: string; + checkoutId: string; +}): Promise { + const response = await fetch(`https://sandbox-api.polar.sh/v1/checkouts/${input.checkoutId}`, { + headers: { authorization: `Bearer ${input.accessToken}` }, + }); + if (!response.ok) { + throw new Error( + `Polar checkout lookup failed: ${String(response.status)} ${await response.text()}`, + ); + } + return parsePolarCheckout(await response.json()); +} + +function parsePolarCheckout(data: unknown): PolarClientCheckout { + const checkout = data as { + amount?: unknown; + currency?: unknown; + customerEmail?: unknown; + customerName?: unknown; + customer_email?: unknown; + customer_name?: unknown; + paymentProcessorMetadata?: unknown; + payment_processor_metadata?: unknown; + totalAmount?: unknown; + total_amount?: unknown; + }; + + const paymentProcessorMetadata = + checkout.payment_processor_metadata ?? checkout.paymentProcessorMetadata; + return { + amount: typeof checkout.amount === "number" ? checkout.amount : 0, + currency: typeof checkout.currency === "string" ? checkout.currency : "usd", + customerEmail: + typeof checkout.customer_email === "string" + ? checkout.customer_email + : typeof checkout.customerEmail === "string" + ? checkout.customerEmail + : null, + customerName: + typeof checkout.customer_name === "string" + ? checkout.customer_name + : typeof checkout.customerName === "string" + ? checkout.customerName + : null, + paymentProcessorMetadata: + paymentProcessorMetadata && typeof paymentProcessorMetadata === "object" + ? (paymentProcessorMetadata as Record) + : {}, + totalAmount: + typeof checkout.total_amount === "number" + ? checkout.total_amount + : typeof checkout.totalAmount === "number" + ? checkout.totalAmount + : 0, + }; +} + +async function confirmPolarCheckout( + clientSecret: string, + input: { + confirmationTokenId: string; + customerBillingAddress: { + city: string; + country: string; + line1: string; + line2: string; + postal_code: string; + state: string; + }; + customerBillingName: string; + customerEmail: string; + customerName: string; + }, +): Promise { + const response = await fetchPolarCheckoutClient( + `https://sandbox-api.polar.sh/v1/checkouts/client/${clientSecret}/confirm`, + { + body: JSON.stringify({ + confirmation_token_id: input.confirmationTokenId, + customer_billing_address: input.customerBillingAddress, + customer_billing_name: input.customerBillingName, + customer_email: input.customerEmail, + customer_name: input.customerName, + }), + headers: { "content-type": "application/json" }, + method: "POST", + }, + ); + if (!response.ok) { + throw new Error( + `Polar client checkout confirm failed: ${String(response.status)} ${await response.text()}`, + ); + } +} + +async function fetchPolarCheckoutClient(url: string, init?: RequestInit): Promise { + let lastResponse: Response | null = null; + + for (let attempt = 0; attempt < 4; attempt += 1) { + await waitForPolarCheckoutApiSlot(); + const response = await fetch(url, init); + await recordPolarCheckoutApiAttempt(); + + if (response.status !== 429) return response; + + lastResponse = response; + console.warn( + `Polar checkout client API returned 429 for ${new URL(url).pathname}; waiting before retry ${String(attempt + 1)}/4`, + ); + await new Promise((resolve) => + setTimeout(resolve, getRetryAfterMs(response) ?? POLAR_MIN_CHECKOUT_API_INTERVAL_MS), + ); + } + + return lastResponse ?? fetch(url, init); +} + +async function createStripeConfirmationToken(input: { + amount: number; + currency: string; + email: string; + name: string; + publishableKey: string; +}): Promise { + const browser = await chromium.launch({ headless: true }); + const page = await browser.newPage({ viewport: { width: 640, height: 720 } }); + + try { + await page.route("https://paykit.local/stripe-confirm", async (route) => { + await route.fulfill({ + body: ` + + +
+ `, + contentType: "text/html", + }); + }); + await page.goto("https://paykit.local/stripe-confirm", { waitUntil: "domcontentloaded" }); + await page.waitForFunction(() => typeof window.Stripe === "function", null, { + timeout: 30_000, + }); + + await page.evaluate(({ amount, currency, publishableKey }) => { + const stripe = window.Stripe(publishableKey); + const elements = stripe.elements({ amount, currency, mode: "payment" }); + const payment = elements.create("payment"); + payment.mount("#payment-element"); + window.__paykitStripe = { elements, stripe }; + }, input); + + const paymentFrame = page.frameLocator('iframe[title="Secure payment input frame"]'); + + await fillRequired( + paymentFrame.locator( + 'input[name="number"], input[autocomplete="cc-number"], input[placeholder*="card number" i]', + ), + "4242424242424242", + 60_000, + ); + await fillRequired( + paymentFrame.locator( + 'input[name="expiry"], input[autocomplete="cc-exp"], input[placeholder*="MM" i]', + ), + "1230", + ); + await fillRequired( + paymentFrame.locator( + 'input[name="cvc"], input[autocomplete="cc-csc"], input[placeholder*="CVC" i]', + ), + "123", + ); + + const result = await page.evaluate(async ({ email, name }) => { + const stripeState = window.__paykitStripe; + const submit = await stripeState.elements.submit(); + if (submit.error) throw new Error(submit.error.message); + const token = await stripeState.stripe.createConfirmationToken({ + elements: stripeState.elements, + params: { + payment_method_data: { + billing_details: { + address: { + city: "San Francisco", + country: "US", + line1: "1 Test St", + postal_code: "94105", + state: "CA", + }, + email, + name, + }, + }, + }, + }); + if (token.error) throw new Error(token.error.message); + return token.confirmationToken.id; + }, input); + + return result; + } catch (error) { + await captureCheckoutFailure(page, "direct-confirm"); + throw error; + } finally { + await browser.close(); + } +} + +function getPolarCheckoutClientSecret(url: string): string { + const parsed = new URL(url); + const secret = parsed.pathname.split("/").filter(Boolean).at(-1); + if (!secret) throw new Error(`Unable to read Polar checkout client secret from URL: ${url}`); + return secret; +} + +async function completeHostedCheckout(url: string): Promise { + const browser = await chromium.launch({ headless: true }); + const page = await browser.newPage({ viewport: { width: 1280, height: 900 } }); + const checkoutDiagnostics: string[] = []; + const openedAt = Date.now(); + + page.on("requestfailed", (request) => { + checkoutDiagnostics.push( + `request failed ${request.method()} ${request.url()} ${request.failure()?.errorText ?? ""}`, + ); + }); + page.on("response", (response) => { + if (response.status() >= 400) { + checkoutDiagnostics.push(`response ${String(response.status())} ${response.url()}`); + } + }); + + try { + await waitForPolarCheckoutApiSlot(); + await page.goto(url, { timeout: 90_000, waitUntil: "domcontentloaded" }); + await recordPolarCheckoutApiAttempt(); + + const paymentFrame = page.frameLocator('iframe[src*="elements-inner-accessory-target"]'); + + await fillRequired( + paymentFrame.locator( + '[data-testid="card-number"], input[name="cardNumber"], input[autocomplete="cc-number"], input[placeholder*="card number" i], input[placeholder="1234 1234 1234 1234"]', + ), + "4242424242424242", + 60_000, + ); + await fillRequired( + paymentFrame.locator( + '[data-testid="card-expiry"], input[name="cardExpiry"], input[autocomplete="cc-exp"], input[placeholder*="MM" i], input[placeholder="MM / YY"]', + ), + "12/30", + ); + await fillRequired( + paymentFrame.locator( + '[data-testid="card-cvc"], input[name="cardCvc"], input[autocomplete="cc-csc"], input[placeholder*="CVC" i]', + ), + "123", + ); + await fillOptional( + page.locator( + '[data-testid="cardholder-name"], input[name="cardholderName"], input[autocomplete="cc-name"], input[placeholder*="name" i]', + ), + "Test Customer", + ); + await fillOptional( + page.locator('input[type="email"], input[name="email"], input[autocomplete="email"]'), + `checkout-${Date.now()}@e2e.paykit.sh`, + ); + await selectBillingCountry(page); + await fillCheckoutInput(page, 'input[name="customer_billing_address.line1"]', "1 Test St"); + await fillCheckoutInput(page, 'input[name="customer_billing_address.postal_code"]', "94105"); + await fillCheckoutInput(page, 'input[name="customer_billing_address.city"]', "San Francisco"); + await selectBillingState(page); + await page.keyboard.press("Escape").catch(() => undefined); + await page.keyboard.press("Tab").catch(() => undefined); + await page.waitForTimeout(2_000); + + const submitButton = + (await findVisibleLocator(page, 'button:has-text("Subscribe now")')) ?? + (await findVisibleLocator(page, 'button:has-text("Pay")')) ?? + (await findVisibleLocator(page, 'button:has-text("Subscribe")')); + if (!submitButton) throw new Error("Polar checkout submit button was not found"); + await submitButton.waitFor({ state: "visible", timeout: 30_000 }); + await waitForEnabled(submitButton, 15_000); + checkoutDiagnostics.push(await describeCheckoutSubmit(page, submitButton)); + + await submitCheckoutOnce(page, submitButton, checkoutDiagnostics, openedAt); + } catch (error) { + await captureCheckoutFailure(page); + throw error; + } finally { + await browser.close(); + } +} + +declare global { + interface Window { + Stripe: (publishableKey: string) => { + elements: (options: { amount: number; currency: string; mode: "payment" }) => StripeElements; + createConfirmationToken: (input: { + elements: StripeElements; + params: { + payment_method_data: { + billing_details: { + address: { + city: string; + country: string; + line1: string; + postal_code: string; + state: string; + }; + email: string; + name: string; + }; + }; + }; + }) => Promise< + | { confirmationToken: { id: string }; error?: undefined } + | { confirmationToken?: undefined; error: { message?: string } } + >; + }; + __paykitStripe: { + elements: StripeElements; + stripe: ReturnType; + }; + } +} + +interface StripeElements { + create: (type: "payment") => { mount: (selector: string) => void }; + submit: () => Promise<{ error?: { message?: string } }>; +} + +async function waitForManualCheckout(url: string): Promise { + console.info("\nPolar manual checkout mode enabled."); + console.info(`Open and complete this checkout:\n${url}`); + console.info("Waiting for PayKit to observe the checkout webhooks.\n"); +} + +async function submitCheckoutOnce( + page: Page, + submitButton: Locator, + diagnostics: string[], + openedAt: number, +): Promise { + const remaining = POLAR_MIN_CHECKOUT_BEFORE_SUBMIT_MS - (Date.now() - openedAt); + if (remaining > 0) await page.waitForTimeout(remaining); + await waitForPolarCheckoutApiSlot(); + + await page.keyboard.press("Escape").catch(() => undefined); + await centerLocator(submitButton); + diagnostics.push(`before click: ${JSON.stringify(await describeLocator(submitButton))}`); + + const clickAttempts: Array<() => Promise> = [ + () => submitButton.click({ timeout: 10_000 }), + () => clickLikeUser(page, submitButton), + () => submitButton.evaluate((el) => (el as HTMLElement).click()), + ]; + + for (let attempt = 0; attempt < clickAttempts.length; attempt += 1) { + const confirmResponse = page + .waitForResponse( + (response) => + response.url().includes("/checkouts/client/") && response.url().includes("/confirm"), + { timeout: 15_000 }, + ) + .catch(() => null); + + await clickAttempts[attempt]!(); + const response = await confirmResponse; + diagnostics.push( + `after click ${String(attempt + 1)}: ${await describeCheckoutSubmit(page, submitButton)}`, + ); + + if (!response) { + diagnostics.push(`confirm request not observed after click ${String(attempt + 1)}`); + continue; + } + + await recordPolarCheckoutApiAttempt(); + + const responseText = await response.text().catch(() => ""); + diagnostics.push( + `confirm response ${String(response.status())} ${response.url()} ${responseText.slice(0, 1_000)}`, + ); + if (response.status() === 429) { + await page.waitForTimeout(getRetryAfterMs(response) ?? POLAR_MIN_CHECKOUT_API_INTERVAL_MS); + continue; + } + + if (response.status() >= 400) { + throw new Error(`Polar checkout confirm failed:\n${diagnostics.join("\n")}`); + } + + await Promise.race([ + page.waitForURL((u) => u.toString().startsWith("https://example.com/success"), { + timeout: 15_000, + }), + page + .getByText(/thank you|payment complete|subscription active|subscribed/i) + .first() + .waitFor({ timeout: 15_000 }), + ]).catch(() => undefined); + return; + } + + throw new Error(`Polar checkout confirm request was not observed:\n${diagnostics.join("\n")}`); +} + +async function waitForPolarCheckoutApiSlot(): Promise { + const last = Number(await readFile(POLAR_CHECKOUT_THROTTLE_FILE, "utf8").catch(() => "0")); + const wait = POLAR_MIN_CHECKOUT_API_INTERVAL_MS - (Date.now() - last); + if (Number.isFinite(wait) && wait > 0) { + await new Promise((resolve) => setTimeout(resolve, wait)); + } +} + +async function recordPolarCheckoutApiAttempt(): Promise { + await mkdir(dirname(POLAR_CHECKOUT_THROTTLE_FILE), { recursive: true }).catch(() => undefined); + await writeFile(POLAR_CHECKOUT_THROTTLE_FILE, String(Date.now())).catch(() => undefined); +} + +function getRetryAfterMs( + response: { headers: Headers } | { headers(): Record }, +): number | null { + const retryAfter = + typeof response.headers === "function" + ? response.headers()["retry-after"] + : response.headers.get("retry-after"); + if (!retryAfter) return null; + + const seconds = Number(retryAfter); + if (Number.isFinite(seconds)) + return Math.max(seconds * 1_000, POLAR_MIN_CHECKOUT_API_INTERVAL_MS); + + const date = Date.parse(retryAfter); + if (Number.isNaN(date)) return null; + return Math.max(date - Date.now(), POLAR_MIN_CHECKOUT_API_INTERVAL_MS); +} + +async function selectBillingCountry(page: Page): Promise { + const select = page.locator('select[autocomplete="billing country"]').first(); + await select.waitFor({ state: "visible", timeout: 15_000 }); + await select.selectOption("US"); + await expectInputValue(select, "US", "Polar checkout billing country"); + await page.locator('input[name="customer_billing_address.line1"]').first().waitFor({ + state: "visible", + timeout: 15_000, + }); +} + +async function selectBillingState(page: Page): Promise { + const select = page.locator('select[autocomplete="billing address-level1"]').first(); + await select.waitFor({ state: "visible", timeout: 15_000 }); + await select.selectOption("US-CA"); + await expectInputValue(select, "US-CA", "Polar checkout billing state"); +} + +async function fillCheckoutInput(page: Page, selector: string, value: string): Promise { + const input = page.locator(selector).first(); + await input.waitFor({ state: "visible", timeout: 15_000 }); + await input.click({ force: true }); + await input.fill(value).catch(async () => { + await input.pressSequentially(value); + }); + await expectInputValue(input, value, `Polar checkout field ${selector}`); +} + +async function expectInputValue(locator: Locator, expected: string, label: string): Promise { + const actual = await locator.inputValue().catch(() => ""); + if (actual !== expected) + throw new Error(`${label} was not filled; expected ${expected}, got ${actual}`); +} + +function locatorRoots(page: Page): Array { + return [page, ...page.frames()]; +} + +async function findVisibleLocator(page: Page, selector: string): Promise { + const deadline = Date.now() + 5_000; + + while (Date.now() < deadline) { + for (const root of locatorRoots(page)) { + const locator = root.locator(selector); + const count = await locator.count().catch(() => 0); + for (let i = 0; i < count; i += 1) { + const candidate = locator.nth(i); + if (!(await candidate.isVisible().catch(() => false))) continue; + if (!(await candidate.isEnabled().catch(() => false))) continue; + return candidate; + } + } + await page.waitForTimeout(100); + } + + return null; +} + +async function fillRequired(locator: Locator, value: string, timeout = 5_000): Promise { + const input = locator.first(); + await input.waitFor({ state: "visible", timeout }); + await input.click({ force: true }); + await input.fill("").catch(() => undefined); + await input.pressSequentially(value, { delay: 20 }); +} + +async function fillOptional(locator: Locator, value: string, timeout = 5_000): Promise { + await locator + .first() + .waitFor({ state: "visible", timeout }) + .catch(() => undefined); + if ( + (await locator + .first() + .isVisible() + .catch(() => false)) && + (await locator + .first() + .isEnabled() + .catch(() => false)) + ) { + const input = locator.first(); + const existing = await input.inputValue().catch(() => ""); + if (existing.trim()) return; + await input.click({ force: true }); + await input.fill("").catch(() => undefined); + await input.pressSequentially(value, { delay: 10 }); + } +} + +async function clickLikeUser(page: Page, locator: Locator): Promise { + const first = locator.first(); + await centerLocator(first); + const box = await first.boundingBox(); + if (!box) { + await first.click({ timeout: 10_000 }); + return; + } + + await page.mouse.move(box.x + box.width / 2, box.y + box.height / 2, { steps: 8 }); + await page.waitForTimeout(250); + await page.mouse.down(); + await page.waitForTimeout(80); + await page.mouse.up(); +} + +async function centerLocator(locator: Locator): Promise { + await locator + .evaluate((el) => el.scrollIntoView({ block: "center", inline: "center" })) + .catch(async () => { + await locator.scrollIntoViewIfNeeded(); + }); +} + +async function waitForEnabled(locator: Locator, timeout: number): Promise { + const deadline = Date.now() + timeout; + + while (Date.now() < deadline) { + if ( + await locator + .first() + .isEnabled() + .catch(() => false) + ) + return; + await new Promise((resolve) => setTimeout(resolve, 200)); + } + + throw new Error("Polar checkout submit button did not become enabled"); +} + +async function describeCheckoutSubmit(page: Page, submitButton: Locator): Promise { + const [button, activeElement, visibleAlerts, url] = await Promise.all([ + describeLocator(submitButton), + page + .evaluate(() => { + const el = document.activeElement; + if (!el) return "none"; + return `${el.tagName.toLowerCase()} ${el.getAttribute("name") ?? ""} ${el.getAttribute("placeholder") ?? ""} ${el.textContent?.slice(0, 80) ?? ""}`; + }) + .catch((error: unknown) => `active element unavailable: ${String(error)}`), + page + .locator('[role="alert"], [aria-live], .error, [data-testid*="error" i]') + .evaluateAll((els) => + els.map((el) => el.textContent?.trim()).filter((text): text is string => Boolean(text)), + ) + .catch(() => []), + Promise.resolve(page.url()), + ]); + + return JSON.stringify({ activeElement, button, url, visibleAlerts }); +} + +async function describeLocator(locator: Locator): Promise> { + const first = locator.first(); + const [box, enabled, text, visible] = await Promise.all([ + first.boundingBox().catch(() => null), + first.isEnabled().catch(() => false), + first.innerText({ timeout: 1_000 }).catch(() => ""), + first.isVisible().catch(() => false), + ]); + + return { box, enabled, text, visible }; +} + +async function captureCheckoutFailure(page: Page, label = "failure"): Promise { + const path = `test-results/polar-checkout-${label}-${Date.now()}.png`; + await page.screenshot({ path, fullPage: true }).catch(() => undefined); +} diff --git a/e2e/test-utils/harness/stripe.ts b/e2e/test-utils/harness/stripe.ts index 2c0c5ff6..ebbe818d 100644 --- a/e2e/test-utils/harness/stripe.ts +++ b/e2e/test-utils/harness/stripe.ts @@ -1,5 +1,5 @@ import { stripe } from "@paykitjs/stripe"; -import { chromium } from "playwright"; +import { chromium, type Locator, type Page } from "playwright"; import { default as Stripe } from "stripe"; import type { PaymentProvider } from "../../../packages/paykit/src/providers/provider"; @@ -18,14 +18,16 @@ export function createStripeHarness(): ProviderHarness { capabilities: { testClocks: true, directSubscription: true, + invoiceWebhooks: true, + repeatedHostedCheckout: true, }, - createProviderConfig() { + createProvider() { return stripe({ secretKey, webhookSecret }); }, applyTestingOverrides(ctx) { - // Stripe's real createSubscription uses payment_behavior: "default_incomplete", + // Stripe's real subscription create uses payment_behavior: "default_incomplete", // which requires client-side confirmation via Stripe.js. In tests we want the // subscription to activate straight away from the server after a PM is attached. const provider = ctx.provider as PaymentProvider; @@ -88,31 +90,23 @@ export function createStripeHarness(): ProviderHarness { async completeCheckout(url: string) { const browser = await chromium.launch({ headless: true }); + const page = await browser.newPage(); try { - const page = await browser.newPage(); await page.goto(url, { waitUntil: "domcontentloaded" }); // Stripe's hosted checkout uses custom inputs that require per-key events; // fill() does not dispatch them correctly, so use pressSequentially. - const cardNumber = page.locator("#cardNumber"); - await cardNumber.waitFor({ timeout: 60_000 }); - await cardNumber.pressSequentially("4242424242424242"); - - const cardExpiry = page.locator("#cardExpiry"); - await cardExpiry.waitFor({ timeout: 30_000 }); - await cardExpiry.pressSequentially("1234"); - - const cardCvc = page.locator("#cardCvc"); - await cardCvc.waitFor({ timeout: 30_000 }); - await cardCvc.pressSequentially("123"); - - const billingName = page.locator("#billingName"); - if (await billingName.isVisible().catch(() => false)) { - await billingName.pressSequentially("Test Customer"); - } - - const submitBtn = page.locator(".SubmitButton-TextContainer").first(); + await pressIfVisible(page.locator("#cardNumber"), "4242424242424242", 60_000); + await pressIfVisible(page.locator("#cardExpiry"), "1234"); + await pressIfVisible(page.locator("#cardCvc"), "123"); + await pressIfVisible(page.locator("#billingName"), "Test Customer"); + await pressIfVisible(page.locator("#email"), `checkout-${Date.now()}@e2e.paykit.sh`); + await pressIfVisible(page.locator("#billingPostalCode"), "10001"); + + const submitBtn = page.locator('button[type="submit"]').first(); + await submitBtn.waitFor({ state: "visible", timeout: 30_000 }); + await submitBtn.scrollIntoViewIfNeeded(); await submitBtn.click(); // Wait for Stripe to navigate away from the checkout page (success redirect @@ -123,6 +117,9 @@ export function createStripeHarness(): ProviderHarness { timeout: 60_000, }) .catch(() => {}); + } catch (error) { + await captureCheckoutFailure(page); + throw error; } finally { await browser.close(); } @@ -150,6 +147,18 @@ export function createStripeHarness(): ProviderHarness { }; } +async function pressIfVisible(locator: Locator, value: string, timeout = 5_000): Promise { + await locator.waitFor({ state: "visible", timeout }).catch(() => undefined); + if (await locator.isVisible().catch(() => false)) { + await locator.pressSequentially(value); + } +} + +async function captureCheckoutFailure(page: Page): Promise { + const path = `test-results/stripe-checkout-${Date.now()}.png`; + await page.screenshot({ path, fullPage: true }).catch(() => undefined); +} + function validateStripeEnv(): void { if (!env.E2E_STRIPE_SK || !env.E2E_STRIPE_WHSEC) { throw new Error("E2E_STRIPE_SK and E2E_STRIPE_WHSEC must be set"); diff --git a/e2e/test-utils/harness/types.ts b/e2e/test-utils/harness/types.ts index e249eef2..510efdff 100644 --- a/e2e/test-utils/harness/types.ts +++ b/e2e/test-utils/harness/types.ts @@ -1,17 +1,20 @@ -import type { PayKitProviderConfig } from "paykitjs"; +import type { PayKitProvider } from "paykitjs"; import type { PayKitContext } from "../../../packages/paykit/src/core/context"; export interface ProviderCapabilities { testClocks: boolean; directSubscription: boolean; + invoiceWebhooks: boolean; + /** Hosted checkout can be used repeatedly inside provider-agnostic billing tests. */ + repeatedHostedCheckout: boolean; } export interface ProviderHarness { id: string; capabilities: ProviderCapabilities; - createProviderConfig(): PayKitProviderConfig; + createProvider(): PayKitProvider; /** * Apply testing-only overrides to the PayKit provider (e.g., Stripe's diff --git a/e2e/test-utils/hub.ts b/e2e/test-utils/hub.ts index 05d68c3d..6fc2d4bd 100644 --- a/e2e/test-utils/hub.ts +++ b/e2e/test-utils/hub.ts @@ -15,14 +15,30 @@ interface BufferedEvent { } /** - * Extract the provider customer ID from a Stripe event body. + * Extract the provider customer ID from provider webhook bodies. * Returns null for events that aren't keyed by customer (e.g. product.created). */ -function extractStripeCustomerId(body: string): string | null { +function extractProviderCustomerId(body: string): string | null { try { const parsed = JSON.parse(body) as { - data?: { object?: { id?: string; customer?: string | null; object?: string } }; + data?: { + customer?: { id?: string | null } | null; + customerId?: string | null; + customer_id?: string | null; + object?: { id?: string; customer?: string | null; object?: string }; + }; }; + + if (typeof parsed.data?.customerId === "string") { + return parsed.data.customerId; + } + if (typeof parsed.data?.customer_id === "string") { + return parsed.data.customer_id; + } + if (typeof parsed.data?.customer?.id === "string") { + return parsed.data.customer.id; + } + const obj = parsed.data?.object; if (!obj) return null; if (obj.object === "customer" && typeof obj.id === "string") return obj.id; @@ -87,7 +103,7 @@ export function startHub(): Promise { } // Otherwise: route webhook by customer ID - const customerId = extractStripeCustomerId(body); + const customerId = extractProviderCustomerId(body); if (!customerId) { // No customer → setup artifacts (product.created, price.created). Drop. res.writeHead(204); diff --git a/e2e/test-utils/setup.ts b/e2e/test-utils/setup.ts index 6e210ff5..ade72bc3 100644 --- a/e2e/test-utils/setup.ts +++ b/e2e/test-utils/setup.ts @@ -30,7 +30,7 @@ type TestPayKitInstance = ReturnType< typeof createPayKit<{ database: Pool; products: typeof allProducts; - provider: ReturnType; + provider: ReturnType; testing: { enabled: true }; }> >; @@ -86,11 +86,11 @@ export async function createTestPayKit(): Promise { await migrateDatabase(pool); // 3. Create PayKit instance with the active provider - const providerConfig = harness.createProviderConfig(); + const provider = harness.createProvider(); const paykit = createPayKit({ database: pool, products: allProducts, - provider: providerConfig, + provider, testing: { enabled: true }, }); @@ -236,8 +236,6 @@ export async function subscribeCustomer(input: { customerId: string; planId: Parameters[0]["planId"]; }): Promise { - const beforeSubscribe = new Date(); - const result = await input.t.paykit.subscribe({ customerId: input.customerId, planId: input.planId, @@ -247,15 +245,71 @@ export async function subscribeCustomer(input: { if (result.paymentUrl) { // Checkout-based flow — automate checkout completion await input.t.harness.completeCheckout(result.paymentUrl); + } - // Wait for the subscription to become active via webhook - await waitForWebhook({ - database: input.t.database, - eventType: "subscription.updated", - after: beforeSubscribe, - timeout: 60_000, - }); + await waitForPlanPresent({ + database: input.t.database, + customerId: input.customerId, + planId: input.planId, + timeout: result.paymentUrl ? 120_000 : 30_000, + }); +} + +/** Waits until the requested plan has materialized in PayKit's DB. */ +export async function waitForPlanPresent(input: { + database: PayKitDatabase; + customerId: string; + planId: string; + timeout?: number; +}): Promise { + const timeout = input.timeout ?? 30_000; + const start = Date.now(); + + while (Date.now() - start < timeout) { + const rows = await input.database + .select({ status: subscription.status }) + .from(subscription) + .innerJoin(product, eq(product.internalId, subscription.productInternalId)) + .where( + and( + eq(subscription.customerId, input.customerId), + eq(product.id, input.planId), + inArray(subscription.status, [...presentSubscriptionStatuses]), + ), + ) + .limit(1); + + if (rows.length > 0) return; + + await new Promise((resolve) => setTimeout(resolve, 500)); } + + throw new Error( + `Timed out waiting for plan "${input.planId}" for customer "${input.customerId}"`, + ); +} + +/** Waits until a specific plan is the single active plan in its group. */ +export async function waitForSingleActivePlanInGroup(input: { + database: PayKitDatabase; + customerId: string; + group: string; + planId: string; + timeout?: number; +}): Promise { + const timeout = input.timeout ?? 30_000; + const start = Date.now(); + + while (Date.now() - start < timeout) { + try { + await expectSingleActivePlanInGroup(input); + return; + } catch { + await new Promise((resolve) => setTimeout(resolve, 500)); + } + } + + await expectSingleActivePlanInGroup(input); } export async function expectProduct(input: { @@ -642,8 +696,12 @@ export async function waitForForwardedWebhookRequest(input: { try { const payload = JSON.parse(request.body) as { id?: string; type?: string }; + const webhookIdHeader = Object.entries(request.headers).find( + ([key]) => key.toLowerCase() === "webhook-id", + )?.[1]; const matchesProviderEventId = - input.providerEventId !== undefined && payload.id === input.providerEventId; + input.providerEventId !== undefined && + (payload.id === input.providerEventId || webhookIdHeader === input.providerEventId); const matchesEventType = input.eventType !== undefined && payload.type === input.eventType; if (matchesProviderEventId || matchesEventType) { return request; diff --git a/packages/paykit/src/api/__tests__/define-route.test.ts b/packages/paykit/src/api/__tests__/define-route.test.ts index 0a370ca4..2934cedb 100644 --- a/packages/paykit/src/api/__tests__/define-route.test.ts +++ b/packages/paykit/src/api/__tests__/define-route.test.ts @@ -9,9 +9,6 @@ function createTestContext(trustedOrigins?: string[]) { options: { database: "postgres://paykit:test@localhost:5432/paykit", provider: { - createAdapter: () => { - throw new Error("not used in test"); - }, id: "stripe", name: "Stripe", }, diff --git a/packages/paykit/src/api/__tests__/methods.test.ts b/packages/paykit/src/api/__tests__/methods.test.ts index 73847740..43ae09be 100644 --- a/packages/paykit/src/api/__tests__/methods.test.ts +++ b/packages/paykit/src/api/__tests__/methods.test.ts @@ -31,16 +31,15 @@ function createTestContext() { }, ], provider: { - createAdapter: vi.fn(), id: "stripe", name: "Stripe", }, }, products: { plans: [] }, provider: { - handleWebhook, id: "stripe", name: "Stripe", + parseWebhook: handleWebhook, }, } as unknown as PayKitContext; diff --git a/packages/paykit/src/cli/commands/listen.ts b/packages/paykit/src/cli/commands/listen.ts index 125554e7..3492e72e 100644 --- a/packages/paykit/src/cli/commands/listen.ts +++ b/packages/paykit/src/cli/commands/listen.ts @@ -35,14 +35,14 @@ interface DeliveryResponse { receivedAt: string; } -interface TunnelCapableProvider extends PaymentProvider { - disableTunnelWebhook(data: { endpointId: string }): Promise; - ensureTunnelWebhook(data: { existingEndpointId?: string | null; url: string }): Promise<{ +interface WebhookEndpointCapableProvider extends PaymentProvider { + deleteWebhookEndpoint(data: { endpointId: string }): Promise; + ensureWebhookEndpoint(data: { existingEndpointId?: string | null; url: string }): Promise<{ created: boolean; endpointId: string; webhookSecret?: string; }>; - getTunnelAccount(): Promise<{ + getWebhookEndpointAccount(): Promise<{ displayName?: string; environment: string; providerAccountId: string; @@ -80,7 +80,7 @@ interface RelayRuntimeContext { account: TunnelAccountSummary; config: Awaited>; deviceToken: string; - provider: TunnelCapableProvider; + provider: WebhookEndpointCapableProvider; } function sleep(ms: number): Promise { @@ -230,16 +230,16 @@ function printRetrySummary( devLogger.print(`Retried ${label} ${picocolors.dim(id)}.`); } -function assertTunnelProvider(provider: PaymentProvider): TunnelCapableProvider { +function assertWebhookEndpointProvider(provider: PaymentProvider): WebhookEndpointCapableProvider { if ( - typeof provider.getTunnelAccount !== "function" || - typeof provider.ensureTunnelWebhook !== "function" || - typeof provider.disableTunnelWebhook !== "function" + typeof provider.getWebhookEndpointAccount !== "function" || + typeof provider.ensureWebhookEndpoint !== "function" || + typeof provider.deleteWebhookEndpoint !== "function" ) { throw new Error(`Provider "${provider.name}" does not support paykitjs listen yet.`); } - return provider as TunnelCapableProvider; + return provider as WebhookEndpointCapableProvider; } function sanitizeReplayHeaders(headers: Record): Headers { @@ -614,10 +614,10 @@ async function deliverWebhook(params: { async function syncProviderWebhook(params: { deviceToken: string; - provider: TunnelCapableProvider; + provider: WebhookEndpointCapableProvider; tunnel: TunnelResponse; }): Promise<{ webhookSecret?: string }> { - const providerWebhook = await params.provider.ensureTunnelWebhook({ + const providerWebhook = await params.provider.ensureWebhookEndpoint({ existingEndpointId: params.tunnel.providerWebhookEndpointId, url: params.tunnel.webhookUrl, }); @@ -648,11 +648,11 @@ async function loadRelayRuntimeContext(params: { }): Promise { params.devLogger.start("Loading PayKit config"); const config = await getPayKitConfig({ configPath: params.configPath, cwd: params.cwd }); - const provider = assertTunnelProvider(config.options.provider.createAdapter()); + const provider = assertWebhookEndpointProvider(config.options.provider); const deviceToken = getOrCreateDeviceToken(); params.devLogger.update("Connecting to Stripe"); - const account = await provider.getTunnelAccount(); + const account = await provider.getWebhookEndpointAccount(); params.devLogger.update("Connecting to PayKit"); return { @@ -827,7 +827,7 @@ async function disableAction(options: { config?: string; cwd: string }): Promise if (tunnel.providerWebhookEndpointId) { try { - await provider.disableTunnelWebhook({ endpointId: tunnel.providerWebhookEndpointId }); + await provider.deleteWebhookEndpoint({ endpointId: tunnel.providerWebhookEndpointId }); } catch (error) { if (!isMissingWebhookEndpointError(error)) { const message = error instanceof Error ? error.message : String(error); diff --git a/packages/paykit/src/cli/utils/shared.ts b/packages/paykit/src/cli/utils/shared.ts index 970d1c3c..4eefdb1d 100644 --- a/packages/paykit/src/cli/utils/shared.ts +++ b/packages/paykit/src/cli/utils/shared.ts @@ -3,7 +3,7 @@ import type { Pool } from "pg"; import type { createContext, PayKitContext } from "../../core/context"; import type { getPendingMigrationCount, migrateDatabase } from "../../database/index"; import type { dryRunSyncProducts, syncProducts } from "../../product/product-sync.service"; -import type { PayKitProviderConfig } from "../../providers/provider"; +import type { PayKitProvider } from "../../providers/provider"; import type { PayKitOptions } from "../../types/options"; import type { NormalizedPlan } from "../../types/schema"; import type { detectPackageManager, getInstallCommand, getRunCommand } from "./detect"; @@ -108,16 +108,13 @@ export interface ProviderCheckResult { webhookEndpoints: Array<{ url: string; status: string }> | null; } -export async function checkProvider( - providerConfig: PayKitProviderConfig, -): Promise { +export async function checkProvider(provider: PayKitProvider): Promise { try { - const adapter = providerConfig.createAdapter(); - const result = await adapter.check?.(); + const result = await provider.check?.(); if (!result) { return { - account: { ok: true, displayName: providerConfig.name, mode: "unknown" }, + account: { ok: true, displayName: provider.name, mode: "unknown" }, customerSample: [], errors: [], webhookEndpoints: null, diff --git a/packages/paykit/src/core/__tests__/context.test.ts b/packages/paykit/src/core/__tests__/context.test.ts index 89e99f5d..bb4b43ea 100644 --- a/packages/paykit/src/core/__tests__/context.test.ts +++ b/packages/paykit/src/core/__tests__/context.test.ts @@ -1,7 +1,25 @@ import type { Pool } from "pg"; import { beforeEach, describe, expect, it, vi } from "vitest"; -import type { PayKitProviderConfig, PaymentProvider } from "../../providers/provider"; +import type { PayKitProvider } from "../../providers/provider"; + +const testCapabilities: PayKitProvider["capabilities"] = { + subscriptionProducts: true, + subscriptionCheckout: true, + customerPortal: true, + createInvoices: true, + detachPaymentMethods: true, + setupPaymentMethods: true, + cancelSubscriptionsAtPeriodEnd: true, + createSubscriptions: true, + changeSubscriptionProducts: true, + listActiveSubscriptions: true, + pendingSubscriptionProductChanges: false, + resumeSubscriptionsAtPeriodEnd: true, + subscriptionSchedules: false, + testClocks: false, + manageWebhookEndpoints: false, +}; const mocks = vi.hoisted(() => ({ createDatabase: vi.fn(), @@ -31,13 +49,11 @@ describe("core/context", () => { level: "debug", } as const; const database = {} as Pool; - const adapter = { id: "test", name: "Test" } as unknown as PaymentProvider; - const provider: PayKitProviderConfig = { - capabilities: { testClocks: false }, + const provider = { + capabilities: testCapabilities, id: "test", name: "Test", - createAdapter: () => adapter, - }; + } as unknown as PayKitProvider; const context = await createContext({ database, @@ -48,6 +64,6 @@ describe("core/context", () => { expect(mocks.createDatabase).toHaveBeenCalledWith(database); expect(mocks.createPayKitLogger).toHaveBeenCalledWith(logging); expect(context.logger).toEqual({ kind: "logger" }); - expect(context.provider).toBe(adapter); + expect(context.provider).toBe(provider); }); }); diff --git a/packages/paykit/src/core/context.ts b/packages/paykit/src/core/context.ts index a83fa37e..dd4f2173 100644 --- a/packages/paykit/src/core/context.ts +++ b/packages/paykit/src/core/context.ts @@ -1,7 +1,7 @@ import { Pool } from "pg"; import { createDatabase, type PayKitDatabase } from "../database/index"; -import type { PaymentProvider } from "../providers/provider"; +import type { PayKitProvider } from "../providers/provider"; import type { PayKitOptions } from "../types/options"; import { normalizeSchema, type NormalizedSchema } from "../types/schema"; import { PayKitError, PAYKIT_ERROR_CODES } from "./errors"; @@ -11,7 +11,7 @@ export interface PayKitContext { options: PayKitOptions; basePath: string; database: PayKitDatabase; - provider: PaymentProvider; + provider: PayKitProvider; products: NormalizedSchema; logger: PayKitInternalLogger; } @@ -34,14 +34,13 @@ export async function createContext(options: PayKitOptions): Promise = {}) { + return { + capabilities: testCapabilities, + upsertSubscriptionProduct: vi.fn(), + check: vi.fn(), + createCustomer: vi.fn(), + deleteCustomer: vi.fn(), + updateCustomer: vi.fn(), + id: "stripe", + createInvoice: vi.fn(), + name: "Stripe", + createPaymentMethodSetupSession: vi.fn(), + detachPaymentMethod: vi.fn(), + createCustomerPortalSession: vi.fn(), + cancelSubscriptionAtPeriodEnd: vi.fn(), + changeSubscriptionProduct: vi.fn(), + changeSubscriptionProductAtPeriodEnd: vi.fn(), + createSubscription: vi.fn(), + createSubscriptionCheckout: vi.fn(), + getSubscription: vi.fn(), + listActiveSubscriptions: vi.fn(), + resumeSubscriptionAtPeriodEnd: vi.fn(), + advanceTestClock: vi.fn(), + getTestClock: vi.fn(), + parseWebhook: vi.fn(), + ...overrides, + } as PaymentProvider; +} + const emptyProducts: NormalizedSchema = { features: [], plans: [], @@ -68,32 +116,16 @@ describe("customer/service", () => { .mockResolvedValueOnce(createCustomerRow()) .mockResolvedValueOnce(syncedCustomer) .mockResolvedValueOnce(syncedCustomer); - const stripe = { - advanceTestClock: vi.fn(), - attachPaymentMethod: vi.fn(), - cancelSubscription: vi.fn(), - createInvoice: vi.fn(), - createPortalSession: vi.fn(), - createSubscription: vi.fn(), - createSubscriptionCheckout: vi.fn(), - deleteCustomer: vi.fn(), - detachPaymentMethod: vi.fn(), - getTestClock: vi.fn(), - handleWebhook: vi.fn(), - listActiveSubscriptions: vi.fn(), - resumeSubscription: vi.fn(), - scheduleSubscriptionChange: vi.fn(), - syncProducts: vi.fn(), - updateSubscription: vi.fn(), - createCustomer: vi.fn().mockResolvedValue({ - providerCustomer: { - frozenTime: "2024-01-01T00:00:00.000Z", - id: "cus_123", - testClockId: "clock_123", - }, - }), - updateCustomer: vi.fn(), - }; + const createCustomer = vi.fn().mockResolvedValue({ + providerCustomer: { + frozenTime: "2024-01-01T00:00:00.000Z", + id: "cus_123", + testClockId: "clock_123", + }, + }); + const provider = createProviderMock({ + createCustomer, + }); const ctx = { database: { query: { @@ -113,17 +145,11 @@ describe("customer/service", () => { provider: { id: "stripe", name: "Stripe", - createAdapter: vi.fn(), }, testing: { enabled: true }, }, products: emptyProducts, - provider: { - capabilities: { testClocks: true }, - id: "stripe", - name: "Stripe", - ...stripe, - }, + provider, } as unknown as PayKitContext; const customer = await upsertCustomer(ctx, { @@ -133,7 +159,7 @@ describe("customer/service", () => { }); expect(customer).toEqual(syncedCustomer); - expect(stripe.createCustomer).toHaveBeenCalledWith({ + expect(createCustomer).toHaveBeenCalledWith({ createTestClock: true, email: "test@example.com", id: "customer_123", @@ -167,30 +193,14 @@ describe("customer/service", () => { .mockResolvedValueOnce(createCustomerRow()) .mockResolvedValueOnce(syncedCustomer) .mockResolvedValueOnce(syncedCustomer); - const stripe = { - advanceTestClock: vi.fn(), - attachPaymentMethod: vi.fn(), - cancelSubscription: vi.fn(), - createInvoice: vi.fn(), - createPortalSession: vi.fn(), - createSubscription: vi.fn(), - createSubscriptionCheckout: vi.fn(), - deleteCustomer: vi.fn(), - detachPaymentMethod: vi.fn(), - getTestClock: vi.fn(), - handleWebhook: vi.fn(), - listActiveSubscriptions: vi.fn(), - resumeSubscription: vi.fn(), - scheduleSubscriptionChange: vi.fn(), - syncProducts: vi.fn(), - updateSubscription: vi.fn(), - createCustomer: vi.fn().mockResolvedValue({ - providerCustomer: { - id: "cus_456", - }, - }), - updateCustomer: vi.fn(), - }; + const createCustomer = vi.fn().mockResolvedValue({ + providerCustomer: { + id: "cus_456", + }, + }); + const provider = createProviderMock({ + createCustomer, + }); const ctx = { database: { query: { @@ -210,16 +220,10 @@ describe("customer/service", () => { provider: { id: "stripe", name: "Stripe", - createAdapter: vi.fn(), }, }, products: emptyProducts, - provider: { - capabilities: { testClocks: true }, - id: "stripe", - name: "Stripe", - ...stripe, - }, + provider, } as unknown as PayKitContext; await upsertCustomer(ctx, { @@ -228,7 +232,7 @@ describe("customer/service", () => { upsertProviderCustomer: true, }); - expect(stripe.createCustomer).toHaveBeenCalledWith({ + expect(createCustomer).toHaveBeenCalledWith({ createTestClock: false, email: "prod@example.com", id: "customer_123", @@ -374,12 +378,12 @@ describe("customer/service", () => { .fn() .mockResolvedValueOnce(existingCustomer) .mockResolvedValueOnce(existingCustomer); - const providerMock = { - id: "stripe", - name: "Stripe", - createCustomer: vi.fn(), - updateCustomer: vi.fn(), - }; + const createCustomer = vi.fn(); + const updateCustomer = vi.fn(); + const providerMock = createProviderMock({ + createCustomer, + updateCustomer, + }); const ctx = { database: { query: { customer: { findFirst } }, @@ -387,7 +391,7 @@ describe("customer/service", () => { }, logger: { warn: vi.fn() }, options: { - provider: { id: "stripe", name: "Stripe", createAdapter: vi.fn() }, + provider: { id: "stripe", name: "Stripe" }, }, products: emptyProducts, provider: providerMock, @@ -399,8 +403,8 @@ describe("customer/service", () => { upsertProviderCustomer: true, }); - expect(providerMock.createCustomer).not.toHaveBeenCalled(); - expect(providerMock.updateCustomer).not.toHaveBeenCalled(); + expect(createCustomer).not.toHaveBeenCalled(); + expect(updateCustomer).not.toHaveBeenCalled(); expect(result.provider).toEqual(existingCustomer.provider); }); @@ -424,12 +428,10 @@ describe("customer/service", () => { .mockResolvedValueOnce(createCustomerRow({ email: "old@example.com" })) .mockResolvedValueOnce(existingCustomer) .mockResolvedValueOnce(existingCustomer); - const providerMock = { - id: "stripe", - name: "Stripe", - createCustomer: vi.fn(), - updateCustomer: vi.fn(), - }; + const updateCustomer = vi.fn(); + const providerMock = createProviderMock({ + updateCustomer, + }); const ctx = { database: { query: { customer: { findFirst } }, @@ -440,7 +442,7 @@ describe("customer/service", () => { }, logger: { warn: vi.fn() }, options: { - provider: { id: "stripe", name: "Stripe", createAdapter: vi.fn() }, + provider: { id: "stripe", name: "Stripe" }, }, products: emptyProducts, provider: providerMock, @@ -452,7 +454,7 @@ describe("customer/service", () => { upsertProviderCustomer: true, }); - expect(providerMock.updateCustomer).toHaveBeenCalledWith( + expect(updateCustomer).toHaveBeenCalledWith( expect.objectContaining({ providerCustomerId: "cus_existing", email: "new@example.com" }), ); }); @@ -471,12 +473,10 @@ describe("customer/service", () => { .mockResolvedValueOnce(createCustomerRow()) .mockResolvedValueOnce(existingCustomer) .mockResolvedValueOnce(existingCustomer); - const providerMock = { - id: "stripe", - name: "Stripe", - createCustomer: vi.fn(), - updateCustomer: vi.fn(), - }; + const updateCustomer = vi.fn(); + const providerMock = createProviderMock({ + updateCustomer, + }); const ctx = { database: { query: { customer: { findFirst } }, @@ -487,7 +487,7 @@ describe("customer/service", () => { }, logger: { warn: vi.fn() }, options: { - provider: { id: "stripe", name: "Stripe", createAdapter: vi.fn() }, + provider: { id: "stripe", name: "Stripe" }, }, products: emptyProducts, provider: providerMock, @@ -499,7 +499,7 @@ describe("customer/service", () => { upsertProviderCustomer: true, }); - expect(providerMock.updateCustomer).toHaveBeenCalled(); + expect(updateCustomer).toHaveBeenCalled(); expect(providerUpdate.set).toHaveBeenCalledWith( expect.objectContaining({ provider: expect.objectContaining({ diff --git a/packages/paykit/src/customer/customer.api.ts b/packages/paykit/src/customer/customer.api.ts index d3ef8fcc..3ccbf9fe 100644 --- a/packages/paykit/src/customer/customer.api.ts +++ b/packages/paykit/src/customer/customer.api.ts @@ -1,6 +1,7 @@ import * as z from "zod"; import { definePayKitMethod, returnUrl } from "../api/define-route"; +import { assertProviderHasCapability } from "../providers/capabilities"; import { getCustomerWithDetails, hardDeleteCustomer, @@ -60,11 +61,12 @@ export const customerPortal = definePayKitMethod( }, }, async (ctx) => { + assertProviderHasCapability(ctx.paykit.provider, "customerPortal"); const { providerCustomerId } = await upsertProviderCustomer(ctx.paykit, { customerId: ctx.customer.id, }); - const { url } = await ctx.paykit.provider.createPortalSession({ + const { url } = await ctx.paykit.provider.createCustomerPortalSession({ providerCustomerId, returnUrl: ctx.input.returnUrl, }); diff --git a/packages/paykit/src/customer/customer.service.ts b/packages/paykit/src/customer/customer.service.ts index a0743c15..c997cc37 100644 --- a/packages/paykit/src/customer/customer.service.ts +++ b/packages/paykit/src/customer/customer.service.ts @@ -12,6 +12,7 @@ import { subscription, } from "../database/schema"; import { getProductByHash } from "../product/product.service"; +import { assertProviderHasCapability } from "../providers/capabilities"; import type { ProviderCustomer, ProviderCustomerMap } from "../providers/provider"; import { getActiveSubscriptionInGroup, @@ -457,11 +458,13 @@ export async function hardDeleteCustomer(ctx: PayKitContext, customerId: string) const providerCustomerId = getProviderCustomerId(existingCustomer, ctx.provider.id); if (providerCustomerId) { try { + assertProviderHasCapability(ctx.provider, "listActiveSubscriptions"); + assertProviderHasCapability(ctx.provider, "cancelSubscriptionsAtPeriodEnd"); const activeSubscriptions = await ctx.provider.listActiveSubscriptions({ providerCustomerId, }); for (const sub of activeSubscriptions) { - await ctx.provider.cancelSubscription({ + await ctx.provider.cancelSubscriptionAtPeriodEnd({ providerSubscriptionId: sub.providerSubscriptionId, }); } diff --git a/packages/paykit/src/index.ts b/packages/paykit/src/index.ts index 36130ea6..65748db6 100644 --- a/packages/paykit/src/index.ts +++ b/packages/paykit/src/index.ts @@ -26,12 +26,20 @@ export type { ReportResult, } from "./entitlement/entitlement.service"; export type { - PayKitProviderConfig, + PayKitProvider, PaymentProvider, ProviderCustomer, ProviderCustomerMap, - ProviderTunnelAccount, - ProviderTunnelWebhook, + ProviderHealthResult, + ProviderWebhookEndpointAccount, + ProviderWebhookEndpointResult, + ProviderInvoice, + ProviderProductInput, + ProviderProductResult, + ProviderSubscription, + ProviderSubscriptionResult, + ProviderSubscriptionSchedule, + ProviderSubscriptionSchedulePhase, ProviderTestClock, } from "./providers/provider"; export type { @@ -78,5 +86,11 @@ export { PayKitError, PAYKIT_ERROR_CODES } from "./core/errors"; export type { PayKitErrorCode } from "./core/errors"; export { defineErrorCodes } from "./core/error-codes"; export type { RawError } from "./core/error-codes"; +export { defineProvider } from "./providers/provider"; +export { + assertProviderCapability, + assertProviderHasCapability, + unsupportedProviderCapability, +} from "./providers/capabilities"; export { feature, plan } from "./types/schema"; export { createPayKitEndpoint, definePayKitMethod, returnUrl } from "./api/define-route"; diff --git a/packages/paykit/src/product/product-sync.service.ts b/packages/paykit/src/product/product-sync.service.ts index 7e65bfa5..f6a4d9ad 100644 --- a/packages/paykit/src/product/product-sync.service.ts +++ b/packages/paykit/src/product/product-sync.service.ts @@ -1,5 +1,6 @@ import type { PayKitContext } from "../core/context"; import { PayKitError, PAYKIT_ERROR_CODES } from "../core/errors"; +import { assertProviderHasCapability } from "../providers/capabilities"; import type { StoredProductFeature } from "../types/models"; import type { NormalizedPlan, NormalizedPlanFeature } from "../types/schema"; import { @@ -178,46 +179,27 @@ export async function syncProducts(ctx: PayKitContext): Promise 0) { - const providerResults = await ctx.provider.syncProducts({ - products: paidPlansToSync.map((p) => ({ - existingProviderProduct: p.existingProviderProduct, - id: p.id, - name: p.name, - priceAmount: p.priceAmount, - priceInterval: p.priceInterval, - })), - }); - - const requestedIds = new Set(paidPlansToSync.map((p) => p.id)); - const resultById = new Map(); - for (const r of providerResults.results) { - if (resultById.has(r.id)) { - throw PayKitError.from( - "INTERNAL_SERVER_ERROR", - PAYKIT_ERROR_CODES.PLAN_SYNC_FAILED, - `Provider syncProducts returned duplicate mapping for id: ${r.id}`, - ); + assertProviderHasCapability(ctx.provider, "subscriptionProducts"); + const activeProviderProductIds: string[] = []; + for (const plan of paidPlansToSync) { + const providerResult = await ctx.provider.upsertSubscriptionProduct({ + existingProviderProduct: plan.existingProviderProduct, + id: plan.id, + name: plan.name, + priceAmount: plan.priceAmount, + priceInterval: plan.priceInterval, + }); + const providerProductId = providerResult.providerProduct.productId; + if (providerProductId) { + activeProviderProductIds.push(providerProductId); } - resultById.set(r.id, r); - } - const missingIds = [...requestedIds].filter((id) => !resultById.has(id)); - if (missingIds.length > 0 || resultById.size !== requestedIds.size) { - throw PayKitError.from( - "INTERNAL_SERVER_ERROR", - PAYKIT_ERROR_CODES.PLAN_SYNC_FAILED, - `Provider syncProducts returned invalid mapping: missing=[${missingIds.join(", ")}], expected=${String(requestedIds.size)}, got=${String(resultById.size)}`, - ); - } - - for (const [, providerResult] of resultById) { - const plan = paidPlansToSync.find((p) => p.id === providerResult.id); - if (!plan) continue; await upsertProviderProduct(ctx.database, { productInternalId: plan.storedProductInternalId, providerId, providerProduct: providerResult.providerProduct, }); } + await ctx.provider.cleanupSubscriptionProducts?.({ activeProviderProductIds }); } return results; diff --git a/packages/paykit/src/providers/__tests__/capabilities.test.ts b/packages/paykit/src/providers/__tests__/capabilities.test.ts new file mode 100644 index 00000000..b78ee4a0 --- /dev/null +++ b/packages/paykit/src/providers/__tests__/capabilities.test.ts @@ -0,0 +1,25 @@ +import { describe, expect, it } from "vitest"; + +import { PAYKIT_ERROR_CODES } from "../../core/errors"; +import { assertProviderCapability, unsupportedProviderCapability } from "../capabilities"; + +const provider = { + capabilities: {} as never, + id: "test", + name: "Test Provider", +}; + +describe("provider/capabilities", () => { + it("throws a stable unsupported capability error", () => { + const error = unsupportedProviderCapability(provider, "subscriptions.schedules"); + + expect(error).toMatchObject({ + code: PAYKIT_ERROR_CODES.PROVIDER_CAPABILITY_UNSUPPORTED.code, + message: 'Test Provider does not support provider capability "subscriptions.schedules"', + }); + }); + + it("passes when the capability is supported", () => { + expect(() => assertProviderCapability(provider, "subscriptions.schedules", true)).not.toThrow(); + }); +}); diff --git a/packages/paykit/src/providers/capabilities.ts b/packages/paykit/src/providers/capabilities.ts new file mode 100644 index 00000000..aed3e037 --- /dev/null +++ b/packages/paykit/src/providers/capabilities.ts @@ -0,0 +1,37 @@ +import { PayKitError, PAYKIT_ERROR_CODES } from "../core/errors"; +import type { PayKitProvider, PayKitProviderCapabilities } from "./provider"; + +type ProviderWithCapability = PayKitProvider< + PayKitProviderCapabilities & Record +>; + +/** Raises a stable error when a billing flow needs an unsupported provider feature. */ +export function unsupportedProviderCapability( + provider: Pick, + capability: string, +): PayKitError { + return PayKitError.from( + "BAD_REQUEST", + PAYKIT_ERROR_CODES.PROVIDER_CAPABILITY_UNSUPPORTED, + `${provider.name} does not support provider capability "${capability}"`, + ); +} + +/** Ensures a capability flag is enabled before composing provider primitives. */ +export function assertProviderCapability( + provider: Pick, + capability: string, + supported: boolean, +): void { + if (!supported) { + throw unsupportedProviderCapability(provider, capability); + } +} + +/** Ensures a provider capability flag is enabled and narrows required methods. */ +export function assertProviderHasCapability( + provider: PayKitProvider, + capability: TKey, +): asserts provider is ProviderWithCapability { + assertProviderCapability(provider, capability, provider.capabilities[capability]); +} diff --git a/packages/paykit/src/providers/provider.ts b/packages/paykit/src/providers/provider.ts index 73a5a0c9..386efd5f 100644 --- a/packages/paykit/src/providers/provider.ts +++ b/packages/paykit/src/providers/provider.ts @@ -28,17 +28,94 @@ export interface ProviderPaymentMethod { } export interface PayKitProviderCapabilities { + /** + * Provider-hosted subscription products/prices can be created or updated. + * Requires {@link PayKitProvider.upsertSubscriptionProduct}. + * May also provide {@link PayKitProvider.cleanupSubscriptionProducts}. + */ + subscriptionProducts: boolean; + /** + * Provider can create hosted checkout sessions for subscriptions. + * Requires {@link PayKitProvider.createSubscriptionCheckout}. + */ + subscriptionCheckout: boolean; + /** + * Provider can open a customer self-service billing portal. + * Requires {@link PayKitProvider.createCustomerPortalSession}. + */ + customerPortal: boolean; + /** + * Provider can create one-off invoices through its API. + * Requires {@link PayKitProvider.createInvoice}. + */ + createInvoices: boolean; + /** + * Provider can detach reusable payment methods. + * Requires {@link PayKitProvider.detachPaymentMethod}. + */ + detachPaymentMethods: boolean; + /** + * Provider can collect reusable payment methods without creating a subscription. + * Requires {@link PayKitProvider.createPaymentMethodSetupSession}. + */ + setupPaymentMethods: boolean; + /** + * Provider can mark a subscription to cancel at the current period end. + * Requires {@link PayKitProvider.cancelSubscriptionAtPeriodEnd}. + */ + cancelSubscriptionsAtPeriodEnd: boolean; + /** + * Provider can create a subscription directly without hosted checkout. + * Requires {@link PayKitProvider.createSubscription}. + */ + createSubscriptions: boolean; + /** + * Provider can change a subscription product immediately. + * Requires {@link PayKitProvider.changeSubscriptionProduct}. + */ + changeSubscriptionProducts: boolean; + /** + * Provider can list active subscriptions for a customer. + * Requires {@link PayKitProvider.listActiveSubscriptions}. + */ + listActiveSubscriptions: boolean; + /** + * Provider can schedule a product change at period end without explicit schedule resources. + * Requires {@link PayKitProvider.changeSubscriptionProductAtPeriodEnd}. + */ + pendingSubscriptionProductChanges: boolean; + /** + * Provider can clear a pending cancel-at-period-end state. + * Requires {@link PayKitProvider.resumeSubscriptionAtPeriodEnd}. + */ + resumeSubscriptionsAtPeriodEnd: boolean; + /** + * Provider has explicit subscription schedule resources. + * Requires {@link PayKitProvider.getOrCreateSubscriptionSchedule} and + * {@link PayKitProvider.updateSubscriptionSchedulePhases}. + */ + subscriptionSchedules: boolean; + /** + * Provider supports deterministic test clocks. + * Requires {@link PayKitProvider.getTestClock} and {@link PayKitProvider.advanceTestClock}. + */ testClocks: boolean; + /** + * Provider webhook endpoints can be managed through its API. + * Requires {@link PayKitProvider.getWebhookEndpointAccount}, + * {@link PayKitProvider.ensureWebhookEndpoint}, and {@link PayKitProvider.deleteWebhookEndpoint}. + */ + manageWebhookEndpoints: boolean; } -export interface ProviderTunnelAccount { +export interface ProviderWebhookEndpointAccount { displayName?: string; environment: string; providerAccountId: string; providerId: string; } -export interface ProviderTunnelWebhook { +export interface ProviderWebhookEndpointResult { created: boolean; endpointId: string; webhookSecret?: string; @@ -66,6 +143,8 @@ export interface ProviderSubscription { currentPeriodEndAt?: Date | null; currentPeriodStartAt?: Date | null; endedAt?: Date | null; + /** Current provider product/price reference when the provider exposes it. */ + providerProduct?: Record | null; providerSubscriptionId: string; providerSubscriptionScheduleId?: string | null; status: string; @@ -79,11 +158,96 @@ export interface ProviderSubscriptionResult { subscription?: ProviderSubscription | null; } -export interface PaymentProvider { +export interface ProviderProductInput { + existingProviderProduct?: Record | null; + id: string; + name: string; + priceAmount: number; + priceInterval?: string | null; +} + +export interface ProviderProductResult { + /** Provider-specific product/price reference stored by PayKit. */ + providerProduct: Record; +} + +export interface ProviderSubscriptionSchedulePhase { + endAt?: Date | null; + providerProduct: Record; + startAt: Date; +} + +export interface ProviderSubscriptionSchedule { + /** Start of the currently active provider schedule phase. */ + currentPhaseStartAt: Date; + id: string; +} + +export interface ProviderHealthResult { + ok: boolean; + displayName: string; + mode: string; + webhookEndpoints?: Array<{ url: string; status: string }>; + errors?: string[]; + customerSample?: Array<{ providerEmail: string; paykitCustomerId: string | null }>; + error?: string; +} + +type CapabilityMethodKeys = + | "advanceTestClock" + | "cancelSubscriptionAtPeriodEnd" + | "changeSubscriptionProduct" + | "changeSubscriptionProductAtPeriodEnd" + | "cleanupSubscriptionProducts" + | "createCustomerPortalSession" + | "createInvoice" + | "createPaymentMethodSetupSession" + | "createSubscription" + | "createSubscriptionCheckout" + | "deleteWebhookEndpoint" + | "detachPaymentMethod" + | "ensureWebhookEndpoint" + | "getOrCreateSubscriptionSchedule" + | "getTestClock" + | "getWebhookEndpointAccount" + | "listActiveSubscriptions" + | "resumeSubscriptionAtPeriodEnd" + | "updateSubscriptionSchedulePhases" + | "upsertSubscriptionProduct"; + +type CapabilityRequiredMethodKeys = + | (TCapabilities["subscriptionProducts"] extends true ? "upsertSubscriptionProduct" : never) + | (TCapabilities["subscriptionCheckout"] extends true ? "createSubscriptionCheckout" : never) + | (TCapabilities["customerPortal"] extends true ? "createCustomerPortalSession" : never) + | (TCapabilities["createInvoices"] extends true ? "createInvoice" : never) + | (TCapabilities["detachPaymentMethods"] extends true ? "detachPaymentMethod" : never) + | (TCapabilities["setupPaymentMethods"] extends true ? "createPaymentMethodSetupSession" : never) + | (TCapabilities["cancelSubscriptionsAtPeriodEnd"] extends true + ? "cancelSubscriptionAtPeriodEnd" + : never) + | (TCapabilities["createSubscriptions"] extends true ? "createSubscription" : never) + | (TCapabilities["changeSubscriptionProducts"] extends true ? "changeSubscriptionProduct" : never) + | (TCapabilities["listActiveSubscriptions"] extends true ? "listActiveSubscriptions" : never) + | (TCapabilities["pendingSubscriptionProductChanges"] extends true + ? "changeSubscriptionProductAtPeriodEnd" + : never) + | (TCapabilities["resumeSubscriptionsAtPeriodEnd"] extends true + ? "resumeSubscriptionAtPeriodEnd" + : never) + | (TCapabilities["subscriptionSchedules"] extends true + ? "getOrCreateSubscriptionSchedule" | "updateSubscriptionSchedulePhases" + : never) + | (TCapabilities["testClocks"] extends true ? "advanceTestClock" | "getTestClock" : never) + | (TCapabilities["manageWebhookEndpoints"] extends true + ? "deleteWebhookEndpoint" | "ensureWebhookEndpoint" | "getWebhookEndpointAccount" + : never); + +interface PayKitProviderBase { readonly id: string; readonly name: string; - readonly capabilities: PayKitProviderCapabilities; + readonly capabilities: TCapabilities; + /** Creates a provider customer for a PayKit customer. */ createCustomer(data: { createTestClock?: boolean; id: string; @@ -91,7 +255,9 @@ export interface PaymentProvider { name?: string; metadata?: Record; }): Promise<{ providerCustomer: ProviderCustomer }>; - + /** Deletes a provider customer. */ + deleteCustomer(data: { providerCustomerId: string }): Promise; + /** Updates mutable provider customer profile fields. */ updateCustomer(data: { providerCustomerId: string; email?: string; @@ -99,113 +265,151 @@ export interface PaymentProvider { metadata?: Record; }): Promise; - deleteCustomer(data: { providerCustomerId: string }): Promise; + /** Retrieves the provider's current normalized subscription snapshot. */ + getSubscription(data: { providerSubscriptionId: string }): Promise; - getTestClock(data: { testClockId: string }): Promise; + /** Verifies and normalizes provider webhook events. */ + parseWebhook(data: { + allowStaleSignatures?: boolean; + body: string; + headers: Record; + }): Promise; - advanceTestClock(data: { testClockId: string; frozenTime: Date }): Promise; + /** Checks provider account health and optional diagnostic details. */ + check?(): Promise; +} + +interface PayKitProviderCapabilityMethodMap { + /** Creates or updates one subscription product and returns its provider reference. */ + upsertSubscriptionProduct(data: ProviderProductInput): Promise; + /** Optional provider cleanup after billing core has upserted all active products. */ + cleanupSubscriptionProducts?(data: { activeProviderProductIds: string[] }): Promise; + + /** Creates and returns a provider invoice. */ + createInvoice(data: { + providerCustomerId: string; + lines: Array<{ amount: number; description: string }>; + autoAdvance?: boolean; + }): Promise; - attachPaymentMethod(data: { + /** Creates a hosted setup session for collecting a reusable payment method. */ + createPaymentMethodSetupSession(data: { providerCustomerId: string; returnURL: string; }): Promise<{ url: string }>; - createSubscriptionCheckout(data: { - providerCustomerId: string; - providerProduct: Record; - successUrl: string; - cancelUrl?: string; - metadata?: Record; - }): Promise<{ paymentUrl: string; providerCheckoutSessionId: string }>; + /** Detaches or removes a reusable provider payment method. */ + detachPaymentMethod(data: { providerMethodId: string }): Promise; - createSubscription(data: { + /** Creates a provider-hosted customer portal session. */ + createCustomerPortalSession(data: { providerCustomerId: string; - providerProduct: Record; + returnUrl: string; + }): Promise<{ url: string }>; + + /** Marks the provider subscription to cancel at the current period end. */ + cancelSubscriptionAtPeriodEnd(data: { + providerSubscriptionId: string; + providerSubscriptionScheduleId?: string | null; }): Promise; - updateSubscription(data: { + /** Creates a provider subscription directly without hosted checkout. */ + createSubscription(data: { + providerCustomerId: string; providerProduct: Record; - providerSubscriptionId: string; }): Promise; - createInvoice(data: { + /** Creates a provider-hosted checkout session for a subscription product. */ + createSubscriptionCheckout(data: { providerCustomerId: string; - lines: Array<{ amount: number; description: string }>; - autoAdvance?: boolean; - }): Promise; + providerProduct: Record; + successUrl: string; + cancelUrl?: string; + metadata?: Record; + }): Promise<{ paymentUrl: string; providerCheckoutSessionId: string }>; - scheduleSubscriptionChange(data: { - providerProduct?: Record | null; - providerSubscriptionScheduleId?: string | null; + /** Changes the subscription product immediately. */ + changeSubscriptionProduct(data: { + providerProduct: Record; providerSubscriptionId: string; }): Promise; - cancelSubscription(data: { - currentPeriodEndAt?: Date | null; + /** Schedules or records a provider-native product change at period end. */ + changeSubscriptionProductAtPeriodEnd(data: { + providerProduct: Record; providerSubscriptionId: string; - providerSubscriptionScheduleId?: string | null; }): Promise; + /** Lists active provider subscriptions for a customer. */ listActiveSubscriptions(data: { providerCustomerId: string; }): Promise>; - resumeSubscription(data: { + /** Clears pending cancellation and provider-native scheduled product changes if possible. */ + resumeSubscriptionAtPeriodEnd(data: { providerSubscriptionId: string; providerSubscriptionScheduleId?: string | null; }): Promise; - detachPaymentMethod(data: { providerMethodId: string }): Promise; - - syncProducts(data: { - products: Array<{ - id: string; - name: string; - priceAmount: number; - priceInterval?: string | null; - existingProviderProduct?: Record | null; - }>; - }): Promise<{ - results: Array<{ - id: string; - providerProduct: Record; - }>; - }>; - - handleWebhook(data: { - allowStaleSignatures?: boolean; - body: string; - headers: Record; - }): Promise; + /** Retrieves or creates the provider schedule resource for a subscription. */ + getOrCreateSubscriptionSchedule(data: { + providerSubscriptionId: string; + providerSubscriptionScheduleId?: string | null; + }): Promise; + /** Replaces the provider schedule phases with the provided normalized phases. */ + updateSubscriptionSchedulePhases(data: { + phases: ProviderSubscriptionSchedulePhase[]; + providerSubscriptionScheduleId: string; + }): Promise; - createPortalSession(data: { - providerCustomerId: string; - returnUrl: string; - }): Promise<{ url: string }>; + /** Advances a deterministic provider test clock. */ + advanceTestClock(data: { testClockId: string; frozenTime: Date }): Promise; + /** Retrieves a deterministic provider test clock. */ + getTestClock(data: { testClockId: string }): Promise; - getTunnelAccount?(): Promise; + /** Returns the provider account identity used to scope webhook endpoint management. */ + getWebhookEndpointAccount(): Promise; - ensureTunnelWebhook?(data: { + /** Creates or updates the provider webhook endpoint for a PayKit webhook URL. */ + ensureWebhookEndpoint(data: { existingEndpointId?: string | null; url: string; - }): Promise; - - disableTunnelWebhook?(data: { endpointId: string }): Promise; + }): Promise; - check?(): Promise<{ - ok: boolean; - displayName: string; - mode: string; - webhookEndpoints?: Array<{ url: string; status: string }>; - errors?: string[]; - customerSample?: Array<{ providerEmail: string; paykitCustomerId: string | null }>; - error?: string; - }>; + /** Deletes a provider webhook endpoint by provider endpoint ID. */ + deleteWebhookEndpoint(data: { endpointId: string }): Promise; } -export interface PayKitProviderConfig { - id: string; - name: string; - capabilities: PayKitProviderCapabilities; - createAdapter(): PaymentProvider; +type PayKitProviderCapabilityMethods = Pick< + PayKitProviderCapabilityMethodMap, + CapabilityRequiredMethodKeys & CapabilityMethodKeys +> & + Partial< + Omit< + PayKitProviderCapabilityMethodMap, + CapabilityRequiredMethodKeys & CapabilityMethodKeys + > + >; + +export type PayKitProvider< + TCapabilities extends PayKitProviderCapabilities = PayKitProviderCapabilities, +> = PayKitProviderBase & PayKitProviderCapabilityMethods; + +type ExactProviderCapabilities = TCapabilities & + Record, never>; + +/** + * Defines a PayKit payment provider while preserving the provider's exact type. + * @param provider Provider implementation. + */ +export function defineProvider< + const TCapabilities extends PayKitProviderCapabilities, + const TProvider extends PayKitProvider>, +>( + provider: TProvider & { readonly capabilities: ExactProviderCapabilities }, +): TProvider { + return provider; } + +/** @deprecated Use {@link PayKitProvider}. */ +export type PaymentProvider = PayKitProvider; diff --git a/packages/paykit/src/subscription/subscription.service.ts b/packages/paykit/src/subscription/subscription.service.ts index ba4c28e3..5eb09fe3 100644 --- a/packages/paykit/src/subscription/subscription.service.ts +++ b/packages/paykit/src/subscription/subscription.service.ts @@ -19,7 +19,12 @@ import { getProductFeatures, withProviderInfo, } from "../product/product.service"; -import type { ProviderRequiredAction, ProviderSubscription } from "../providers/provider"; +import { assertProviderCapability, assertProviderHasCapability } from "../providers/capabilities"; +import type { + ProviderRequiredAction, + ProviderSubscription, + ProviderSubscriptionResult, +} from "../providers/provider"; import type { DeleteSubscriptionAction, NormalizedSubscription, @@ -243,8 +248,8 @@ async function cancelExistingProviderSubscriptionForCheckout( ); } - await ctx.provider.cancelSubscription({ - currentPeriodEndAt: completion.subCtx.activeSubscription.currentPeriodEndAt, + assertProviderHasCapability(ctx.provider, "cancelSubscriptionsAtPeriodEnd"); + await ctx.provider.cancelSubscriptionAtPeriodEnd({ providerSubscriptionId: activeSubscriptionRef.subscriptionId, providerSubscriptionScheduleId: activeSubscriptionRef.subscriptionScheduleId, }); @@ -744,6 +749,92 @@ function getProviderSubscriptionRef(subscription: ActiveSubscription): { }; } +async function scheduleProviderProductChange( + ctx: PayKitContext, + input: { + currentPeriodEndAt: Date | null; + currentPeriodStartAt: Date | null; + providerProduct: Record; + providerSubscriptionId: string; + providerSubscriptionScheduleId: string | null; + }, +): Promise { + if (ctx.provider.capabilities.subscriptionSchedules) { + return scheduleWithProviderSchedule(ctx, input); + } + + if (ctx.provider.capabilities.pendingSubscriptionProductChanges) { + return scheduleWithPendingProductChange(ctx, input); + } + + assertProviderCapability(ctx.provider, "scheduledSubscriptionProductChange", false); + throw new Error("unreachable"); +} + +async function scheduleWithProviderSchedule( + ctx: PayKitContext, + input: Parameters[1], +): Promise { + assertProviderCapability( + ctx.provider, + "subscriptionSchedules", + ctx.provider.capabilities.subscriptionSchedules && + ctx.provider.getOrCreateSubscriptionSchedule != null && + ctx.provider.updateSubscriptionSchedulePhases != null, + ); + assertProviderHasCapability(ctx.provider, "subscriptionSchedules"); + const getOrCreateSchedule = ctx.provider.getOrCreateSubscriptionSchedule!; + const updateSchedulePhases = ctx.provider.updateSubscriptionSchedulePhases!; + const currentSubscription = await ctx.provider.getSubscription({ + providerSubscriptionId: input.providerSubscriptionId, + }); + const currentProviderProduct = currentSubscription.providerProduct; + const periodStartAt = input.currentPeriodStartAt ?? currentSubscription.currentPeriodStartAt; + const periodEndAt = input.currentPeriodEndAt ?? currentSubscription.currentPeriodEndAt; + if (!currentProviderProduct || !periodStartAt || !periodEndAt) { + throw PayKitError.from("BAD_REQUEST", PAYKIT_ERROR_CODES.PROVIDER_SUBSCRIPTION_MISSING_PERIOD); + } + + const schedule = await getOrCreateSchedule({ + providerSubscriptionId: input.providerSubscriptionId, + providerSubscriptionScheduleId: input.providerSubscriptionScheduleId, + }); + await updateSchedulePhases({ + providerSubscriptionScheduleId: schedule.id, + phases: [ + { + endAt: periodEndAt, + providerProduct: currentProviderProduct, + startAt: schedule.currentPhaseStartAt ?? periodStartAt, + }, + { + providerProduct: input.providerProduct, + startAt: periodEndAt, + }, + ], + }); + const subscription = await ctx.provider.getSubscription({ + providerSubscriptionId: input.providerSubscriptionId, + }); + return { paymentUrl: null, requiredAction: null, subscription }; +} + +async function scheduleWithPendingProductChange( + ctx: PayKitContext, + input: Parameters[1], +): Promise { + assertProviderCapability( + ctx.provider, + "pendingSubscriptionProductChanges", + ctx.provider.capabilities.pendingSubscriptionProductChanges, + ); + assertProviderHasCapability(ctx.provider, "pendingSubscriptionProductChanges"); + return ctx.provider.changeSubscriptionProductAtPeriodEnd({ + providerProduct: input.providerProduct, + providerSubscriptionId: input.providerSubscriptionId, + }); +} + /** Returns a noop or resumes the current provider subscription. */ async function handleSamePlanSubscribe( ctx: PayKitContext, @@ -760,7 +851,8 @@ async function handleSamePlanSubscribe( } const activeSubscriptionRef = getProviderSubscriptionRef(activeSubscription); - const providerResult = await ctx.provider.resumeSubscription({ + assertProviderHasCapability(ctx.provider, "resumeSubscriptionsAtPeriodEnd"); + const providerResult = await ctx.provider.resumeSubscriptionAtPeriodEnd({ providerSubscriptionId: activeSubscriptionRef.subscriptionId!, providerSubscriptionScheduleId: activeSubscriptionRef.subscriptionScheduleId, }); @@ -822,6 +914,7 @@ async function handleInitialSubscribe( return createCheckoutSubscribe(ctx, subCtx); } + assertProviderHasCapability(ctx.provider, "createSubscriptions"); const providerResult = await ctx.provider.createSubscription({ providerCustomerId: subCtx.providerCustomerId, providerProduct: subCtx.storedPlan.providerProduct!, @@ -879,6 +972,7 @@ async function handleLocalPlanSwitch( return buildSubscribeResult({ paymentUrl: null }); } + assertProviderHasCapability(ctx.provider, "createSubscriptions"); const providerResult = await ctx.provider.createSubscription({ providerCustomerId: subCtx.providerCustomerId, providerProduct: subCtx.storedPlan.providerProduct!, @@ -921,8 +1015,8 @@ async function handleCancelToFree( throw PayKitError.from("INTERNAL_SERVER_ERROR", PAYKIT_ERROR_CODES.SUBSCRIPTION_CREATE_FAILED); } - const providerResult = await ctx.provider.cancelSubscription({ - currentPeriodEndAt: activeSubscription.currentPeriodEndAt, + assertProviderHasCapability(ctx.provider, "cancelSubscriptionsAtPeriodEnd"); + const providerResult = await ctx.provider.cancelSubscriptionAtPeriodEnd({ providerSubscriptionId: activeSubscriptionRef.subscriptionId, providerSubscriptionScheduleId: activeSubscriptionRef.subscriptionScheduleId, }); @@ -975,7 +1069,9 @@ async function handleScheduledDowngrade( throw PayKitError.from("INTERNAL_SERVER_ERROR", PAYKIT_ERROR_CODES.SUBSCRIPTION_CREATE_FAILED); } - const providerResult = await ctx.provider.scheduleSubscriptionChange({ + const providerResult = await scheduleProviderProductChange(ctx, { + currentPeriodEndAt: activeSubscription.currentPeriodEndAt, + currentPeriodStartAt: activeSubscription.currentPeriodStartAt, providerProduct: subCtx.storedPlan.providerProduct!, providerSubscriptionId: activeSubscriptionRef.subscriptionId, providerSubscriptionScheduleId: activeSubscriptionRef.subscriptionScheduleId, @@ -1029,7 +1125,8 @@ async function handleUpgrade( throw PayKitError.from("INTERNAL_SERVER_ERROR", PAYKIT_ERROR_CODES.SUBSCRIPTION_CREATE_FAILED); } - const providerResult = await ctx.provider.updateSubscription({ + assertProviderHasCapability(ctx.provider, "changeSubscriptionProducts"); + const providerResult = await ctx.provider.changeSubscriptionProduct({ providerProduct: subCtx.storedPlan.providerProduct!, providerSubscriptionId: activeSubscriptionRef.subscriptionId, }); @@ -1066,6 +1163,7 @@ async function createCheckoutSubscribe( ctx: PayKitContext, subCtx: SubscribeContext, ): Promise { + assertProviderHasCapability(ctx.provider, "subscriptionCheckout"); const checkoutResult = await ctx.provider.createSubscriptionCheckout({ cancelUrl: subCtx.cancelUrl, metadata: { diff --git a/packages/paykit/src/testing/testing.service.ts b/packages/paykit/src/testing/testing.service.ts index e3435c0b..d9036d0b 100644 --- a/packages/paykit/src/testing/testing.service.ts +++ b/packages/paykit/src/testing/testing.service.ts @@ -5,6 +5,7 @@ import { getProviderCustomer, setProviderCustomer, } from "../customer/customer.service"; +import { assertProviderHasCapability } from "../providers/capabilities"; import type { Customer } from "../types/models"; function assertTestingEnabled(ctx: PayKitContext): void { @@ -15,10 +16,13 @@ function assertTestingEnabled(ctx: PayKitContext): void { if (!ctx.provider.capabilities.testClocks) { throw PayKitError.from("BAD_REQUEST", PAYKIT_ERROR_CODES.TESTING_NOT_ENABLED); } + + assertProviderHasCapability(ctx.provider, "testClocks"); } export async function getCustomerTestClock(ctx: PayKitContext, customerId: string) { assertTestingEnabled(ctx); + assertProviderHasCapability(ctx.provider, "testClocks"); const customer = await getCustomerByIdOrThrow(ctx.database, customerId); const providerCustomer = getProviderCustomer(customer, ctx.provider.id); if (!providerCustomer) { @@ -29,7 +33,9 @@ export async function getCustomerTestClock(ctx: PayKitContext, customerId: strin throw PayKitError.from("NOT_FOUND", PAYKIT_ERROR_CODES.TEST_CLOCK_NOT_FOUND); } - const testClock = await ctx.provider.getTestClock({ testClockId: providerCustomer.testClockId }); + const testClock = await ctx.provider.getTestClock({ + testClockId: providerCustomer.testClockId, + }); await setProviderCustomer(ctx.database, { customerId, providerCustomer: { @@ -59,6 +65,7 @@ export async function advanceCustomerTestClock( input: { customerId: string; frozenTime: Date }, ) { assertTestingEnabled(ctx); + assertProviderHasCapability(ctx.provider, "testClocks"); const customer = await getCustomerByIdOrThrow(ctx.database, input.customerId); const providerCustomer = getProviderCustomer(customer, ctx.provider.id); if (!providerCustomer) { diff --git a/packages/paykit/src/types/events.ts b/packages/paykit/src/types/events.ts index a9289996..79091c23 100644 --- a/packages/paykit/src/types/events.ts +++ b/packages/paykit/src/types/events.ts @@ -153,9 +153,7 @@ export interface NormalizedWebhookEventMap { export type NormalizedWebhookEventName = keyof NormalizedWebhookEventMap; export type AnyNormalizedWebhookEvent = { - [TName in NormalizedWebhookEventName]: EventByName & { - actions?: WebhookApplyAction[]; - }; + [TName in NormalizedWebhookEventName]: EventByName; }[NormalizedWebhookEventName]; export type NormalizedWebhookEvent< diff --git a/packages/paykit/src/types/options.ts b/packages/paykit/src/types/options.ts index c8c25bdb..4a43f810 100644 --- a/packages/paykit/src/types/options.ts +++ b/packages/paykit/src/types/options.ts @@ -1,7 +1,7 @@ import type { Pool } from "pg"; import type { LevelWithSilent, Logger } from "pino"; -import type { PayKitProviderConfig } from "../providers/provider"; +import type { PayKitProvider } from "../providers/provider"; import type { PayKitEventHandlers } from "./events"; import type { PayKitPlugin } from "./plugin"; import type { PayKitProductsModule } from "./schema"; @@ -17,7 +17,7 @@ export interface PayKitTestingOptions { export interface PayKitOptions { database: Pool | string; - provider: PayKitProviderConfig; + provider: PayKitProvider; products?: PayKitProductsModule; /** * PayKit root path, e.g. `/paykit` or `/billing`. diff --git a/packages/paykit/src/webhook/webhook.service.ts b/packages/paykit/src/webhook/webhook.service.ts index b9ec0064..690cd02a 100644 --- a/packages/paykit/src/webhook/webhook.service.ts +++ b/packages/paykit/src/webhook/webhook.service.ts @@ -132,6 +132,80 @@ async function applyAction(ctx: PayKitContext, action: WebhookApplyAction): Prom } } +function getWebhookApplyActions(event: AnyNormalizedWebhookEvent): WebhookApplyAction[] { + switch (event.name) { + case "checkout.completed": { + return []; + } + case "payment_method.attached": { + return [ + { + data: { + paymentMethod: event.payload.paymentMethod, + providerCustomerId: event.payload.providerCustomerId, + }, + type: "payment_method.upsert", + }, + ]; + } + case "payment_method.detached": { + return [ + { + data: { + providerMethodId: event.payload.providerMethodId, + }, + type: "payment_method.delete", + }, + ]; + } + case "payment.succeeded": { + return [ + { + data: { + payment: event.payload.payment, + providerCustomerId: event.payload.providerCustomerId, + }, + type: "payment.upsert", + }, + ]; + } + case "subscription.updated": { + return [ + { + data: { + providerCustomerId: event.payload.providerCustomerId, + subscription: event.payload.subscription, + }, + type: "subscription.upsert", + }, + ]; + } + case "subscription.deleted": { + return [ + { + data: { + providerCustomerId: event.payload.providerCustomerId, + providerSubscriptionId: event.payload.providerSubscriptionId, + }, + type: "subscription.delete", + }, + ]; + } + case "invoice.updated": { + return [ + { + data: { + invoice: event.payload.invoice, + providerCustomerId: event.payload.providerCustomerId, + providerSubscriptionId: event.payload.providerSubscriptionId, + }, + type: "invoice.upsert", + }, + ]; + } + } +} + async function processWebhookEvent( ctx: PayKitContext, event: AnyNormalizedWebhookEvent, @@ -166,7 +240,7 @@ async function processWebhookEvent( } } - for (const action of event.actions ?? []) { + for (const action of getWebhookApplyActions(event)) { ctx.logger.info({ actionType: action.type }, "applying action"); const customerId = await applyAction(txCtx, action); if (customerId) { @@ -205,7 +279,7 @@ export async function handleWebhook( input: HandleWebhookInput, ): Promise<{ received: true }> { return ctx.logger.trace.run("wh", async () => { - const events = await ctx.provider.handleWebhook({ + const events = await ctx.provider.parseWebhook({ allowStaleSignatures: input.allowStaleSignatures, body: input.body, headers: input.headers, diff --git a/packages/polar/src/__tests__/polar-provider.test.ts b/packages/polar/src/__tests__/polar-provider.test.ts new file mode 100644 index 00000000..9f20713b --- /dev/null +++ b/packages/polar/src/__tests__/polar-provider.test.ts @@ -0,0 +1,98 @@ +import { describe, expect, it, vi } from "vitest"; + +const polarSdkMock = vi.hoisted(() => ({ + client: null as unknown, + Polar: vi.fn(function Polar() { + return polarSdkMock.client; + }), +})); + +vi.mock("@polar-sh/sdk", () => ({ + Polar: polarSdkMock.Polar, +})); + +import { polar } from "../polar-provider"; + +function createProvider(client: unknown) { + polarSdkMock.client = client; + return polar({ + accessToken: "polar_test_123", + server: "sandbox", + webhookSecret: "whsec_123", + }); +} + +describe("@paykitjs/polar", () => { + it("manages webhook endpoints", async () => { + const createWebhookEndpoint = vi.fn().mockResolvedValue({ + id: "webhook_123", + secret: "whsec_created", + }); + const client = { + organizations: { + list: vi.fn().mockResolvedValue({ result: { items: [{ id: "org_123", name: "Acme" }] } }), + }, + webhooks: { + createWebhookEndpoint, + }, + }; + const provider = createProvider(client); + + const account = await provider.getWebhookEndpointAccount?.(); + const endpoint = await provider.ensureWebhookEndpoint?.({ url: "https://example.com/webhook" }); + + expect(provider.capabilities.manageWebhookEndpoints).toBe(true); + expect(account).toEqual({ + displayName: "Acme", + environment: "sandbox", + providerAccountId: "org_123", + providerId: "polar", + }); + expect(createWebhookEndpoint).toHaveBeenCalledWith( + expect.objectContaining({ + format: "raw", + name: "PayKit", + url: "https://example.com/webhook", + }), + ); + expect(createWebhookEndpoint).toHaveBeenCalledWith( + expect.not.objectContaining({ organizationId: expect.any(String) }), + ); + expect(endpoint).toEqual({ + created: true, + endpointId: "webhook_123", + webhookSecret: "whsec_created", + }); + }); + + it("updates an existing webhook endpoint", async () => { + const updateWebhookEndpoint = vi.fn().mockResolvedValue({ id: "webhook_existing" }); + const provider = createProvider({ + webhooks: { + updateWebhookEndpoint, + }, + }); + + const endpoint = await provider.ensureWebhookEndpoint?.({ + existingEndpointId: "webhook_existing", + url: "https://example.com/webhook", + }); + + expect(updateWebhookEndpoint).toHaveBeenCalledWith( + expect.objectContaining({ + id: "webhook_existing", + webhookEndpointUpdate: expect.objectContaining({ + enabled: true, + format: "raw", + name: "PayKit", + url: "https://example.com/webhook", + }), + }), + ); + expect(endpoint).toEqual({ + created: false, + endpointId: "webhook_existing", + webhookSecret: "whsec_123", + }); + }); +}); diff --git a/packages/polar/src/polar-provider.ts b/packages/polar/src/polar-provider.ts index ccb668fc..af96b3f4 100644 --- a/packages/polar/src/polar-provider.ts +++ b/packages/polar/src/polar-provider.ts @@ -1,8 +1,10 @@ import { Polar } from "@polar-sh/sdk"; +import { WebhookEventType } from "@polar-sh/sdk/models/components/webhookeventtype"; +import { WebhookFormat } from "@polar-sh/sdk/models/components/webhookformat"; import { SDKValidationError } from "@polar-sh/sdk/models/errors/sdkvalidationerror"; import { validateEvent, WebhookVerificationError } from "@polar-sh/sdk/webhooks"; import { PayKitError, PAYKIT_ERROR_CODES } from "paykitjs"; -import type { NormalizedWebhookEvent, PayKitProviderConfig, PaymentProvider } from "paykitjs"; +import type { NormalizedWebhookEvent, PayKitProvider } from "paykitjs"; export interface PolarOptions { accessToken: string; @@ -10,30 +12,91 @@ export interface PolarOptions { server?: "production" | "sandbox"; } -export type PolarProviderConfig = PayKitProviderConfig & { - capabilities: { testClocks: false }; -}; - type PolarWebhookEvent = ReturnType; type PolarSubscriptionEvent = Extract; type PolarCheckoutEvent = Extract; +type PolarWebhookData = Record; +type PolarClient = InstanceType; + +const polarCapabilities = { + subscriptionProducts: true, + subscriptionCheckout: true, + customerPortal: true, + createInvoices: false, + detachPaymentMethods: false, + setupPaymentMethods: false, + cancelSubscriptionsAtPeriodEnd: true, + createSubscriptions: false, + changeSubscriptionProducts: true, + listActiveSubscriptions: true, + pendingSubscriptionProductChanges: true, + resumeSubscriptionsAtPeriodEnd: true, + subscriptionSchedules: false, + testClocks: false, + manageWebhookEndpoints: true, +} as const satisfies PayKitProvider["capabilities"]; + +export type PolarProvider = PayKitProvider; + +const POLAR_WEBHOOK_EVENTS = [ + WebhookEventType.CheckoutCreated, + WebhookEventType.CheckoutUpdated, + WebhookEventType.SubscriptionCreated, + WebhookEventType.SubscriptionUpdated, + WebhookEventType.SubscriptionActive, + WebhookEventType.SubscriptionUncanceled, + WebhookEventType.SubscriptionCanceled, + WebhookEventType.SubscriptionPastDue, + WebhookEventType.SubscriptionRevoked, +]; function toDate(value: Date | string | null | undefined): Date | null { if (!value) return null; return value instanceof Date ? value : new Date(value); } -function normalizePolarSubscription(sub: PolarSubscriptionEvent["data"]) { +function getString(data: PolarWebhookData, camelKey: string, snakeKey: string): string | null { + const value = data[camelKey] ?? data[snakeKey]; + return typeof value === "string" ? value : null; +} + +function getEntityId(data: PolarWebhookData, key: string): string | null { + const value = data[key]; + if (!value || typeof value !== "object") return null; + const id = (value as { id?: unknown }).id; + return typeof id === "string" ? id : null; +} + +function getBoolean(data: PolarWebhookData, camelKey: string, snakeKey: string): boolean { + return data[camelKey] === true || data[snakeKey] === true; +} + +function getDateValue(data: PolarWebhookData, camelKey: string, snakeKey: string) { + const value = data[camelKey] ?? data[snakeKey]; + return value instanceof Date || typeof value === "string" ? value : null; +} + +function getMetadata(data: PolarWebhookData): Record | undefined { + const metadata = data.metadata; + if (!metadata || typeof metadata !== "object") return undefined; + return Object.fromEntries(Object.entries(metadata).map(([key, value]) => [key, String(value)])); +} + +function normalizePolarSubscription(sub: PolarWebhookData) { + const data = sub as PolarWebhookData; + return { - cancelAtPeriodEnd: sub.cancelAtPeriodEnd, - canceledAt: toDate(sub.canceledAt), - currentPeriodEndAt: toDate(sub.currentPeriodEnd), - currentPeriodStartAt: toDate(sub.currentPeriodStart), - endedAt: toDate(sub.endedAt), - providerProduct: { productId: sub.productId }, - providerSubscriptionId: sub.id, + cancelAtPeriodEnd: getBoolean(data, "cancelAtPeriodEnd", "cancel_at_period_end"), + canceledAt: toDate(getDateValue(data, "canceledAt", "canceled_at")), + currentPeriodEndAt: toDate(getDateValue(data, "currentPeriodEnd", "current_period_end")), + currentPeriodStartAt: toDate(getDateValue(data, "currentPeriodStart", "current_period_start")), + endedAt: toDate(getDateValue(data, "endedAt", "ended_at")), + providerProduct: { + productId: getString(data, "productId", "product_id") ?? getEntityId(data, "product") ?? "", + }, + providerSubscriptionId: getString(data, "id", "id") ?? "", providerSubscriptionScheduleId: null, - status: sub.status, + status: getString(data, "status", "status") ?? "unknown", }; } @@ -42,26 +105,23 @@ function createSubscriptionEvents( webhookId: string, ): NormalizedWebhookEvent[] { const sub = event.data; + const data = sub as PolarWebhookData; + const providerCustomerId = + getString(data, "customerId", "customer_id") ?? getEntityId(data, "customer"); + const providerSubscriptionId = getString(data, "id", "id"); + + if (!providerCustomerId || !providerSubscriptionId) return []; // `subscription.revoked` = immediately terminated (like Stripe delete) // `subscription.canceled` = will cancel at period end (like Stripe cancel_at_period_end) if (event.type === "subscription.revoked") { return [ { - actions: [ - { - data: { - providerCustomerId: sub.customerId, - providerSubscriptionId: sub.id, - }, - type: "subscription.delete", - }, - ], name: "subscription.deleted", payload: { - providerCustomerId: sub.customerId, + providerCustomerId, providerEventId: webhookId, - providerSubscriptionId: sub.id, + providerSubscriptionId, }, }, ]; @@ -70,18 +130,9 @@ function createSubscriptionEvents( const normalized = normalizePolarSubscription(sub); return [ { - actions: [ - { - data: { - providerCustomerId: sub.customerId, - subscription: normalized, - }, - type: "subscription.upsert", - }, - ], name: "subscription.updated", payload: { - providerCustomerId: sub.customerId, + providerCustomerId, providerEventId: webhookId, subscription: normalized, }, @@ -92,45 +143,261 @@ function createSubscriptionEvents( function createCheckoutEvents( event: { type?: string; data: PolarCheckoutEvent["data"] }, webhookId: string, + subscriptionData?: PolarSubscriptionEvent["data"], ): NormalizedWebhookEvent[] { const checkout = event.data; - if (checkout.status !== "succeeded") return []; + const data = checkout as PolarWebhookData; - const providerCustomerId = checkout.customerId; + const providerCustomerId = + getString(data, "customerId", "customer_id") ?? getEntityId(data, "customer"); + const providerSubscriptionId = + getString(data, "subscriptionId", "subscription_id") ?? getEntityId(data, "subscription"); + const status = getString(data, "status", "status") ?? "unknown"; if (!providerCustomerId) return []; + if (!providerSubscriptionId) return []; + + const subscription = subscriptionData + ? normalizePolarSubscription(subscriptionData) + : normalizePolarSubscription({ + cancelAtPeriodEnd: false, + currentPeriodEnd: null, + currentPeriodStart: null, + id: providerSubscriptionId, + productId: getString(data, "productId", "product_id") ?? getEntityId(data, "product"), + status: "active", + }); return [ { name: "checkout.completed", payload: { - checkoutSessionId: checkout.id, + checkoutSessionId: getString(data, "id", "id") ?? "", mode: "subscription", paymentStatus: "paid", providerCustomerId, providerEventId: webhookId, - providerSubscriptionId: checkout.subscriptionId ?? undefined, - status: checkout.status, - metadata: checkout.metadata - ? Object.fromEntries(Object.entries(checkout.metadata).map(([k, v]) => [k, String(v)])) - : undefined, + providerSubscriptionId: providerSubscriptionId ?? undefined, + status, + subscription, + metadata: getMetadata(data), }, }, ]; } -function notSupported(method: string): never { - throw PayKitError.from( - "BAD_REQUEST", - PAYKIT_ERROR_CODES.PROVIDER_WEBHOOK_INVALID, - `${method} is not supported by the Polar provider`, +async function createCheckoutEventsFromSdk( + client: PolarClient, + event: { type?: string; data: PolarCheckoutEvent["data"] }, + webhookId: string, +): Promise { + const events = createCheckoutEvents(event, webhookId); + if (events.length > 0) return events; + + const data = event.data as PolarWebhookData; + const status = getString(data, "status", "status") ?? "unknown"; + if (status !== "confirmed" && status !== "succeeded") return []; + + const checkoutId = getString(data, "id", "id"); + if (!checkoutId) return []; + + const checkout = await client.checkouts.get({ id: checkoutId }); + const checkoutData = checkout as PolarCheckoutEvent["data"]; + const subscriptionId = getString( + checkout as PolarWebhookData, + "subscriptionId", + "subscription_id", ); + if (subscriptionId) { + const subscription = await client.subscriptions.get({ id: subscriptionId }); + return createCheckoutEvents({ type: event.type, data: checkoutData }, webhookId, subscription); + } + + return []; } -export function createPolarProvider(client: Polar, options: PolarOptions): PaymentProvider { +function createRawWebhookEvents(body: string, webhookId: string): NormalizedWebhookEvent[] { + const event = JSON.parse(body) as { type?: string; data?: unknown }; + if (!event.data || typeof event.data !== "object") return []; + + switch (event.type) { + case "subscription.created": + case "subscription.updated": + case "subscription.active": + case "subscription.uncanceled": + case "subscription.canceled": + case "subscription.past_due": + case "subscription.revoked": + return createSubscriptionEvents( + { type: event.type, data: event.data as PolarSubscriptionEvent["data"] }, + webhookId, + ); + case "checkout.created": + case "checkout.updated": + return createCheckoutEvents( + { type: event.type, data: event.data as PolarCheckoutEvent["data"] }, + webhookId, + ); + default: + return []; + } +} + +function normalizeProviderSubscription(sub: { + cancelAtPeriodEnd: boolean; + currentPeriodEnd?: Date | string | null; + currentPeriodStart?: Date | string | null; + id: string; + productId?: string | null; + status: string; +}) { + return { + cancelAtPeriodEnd: sub.cancelAtPeriodEnd, + currentPeriodEndAt: toDate(sub.currentPeriodEnd), + currentPeriodStartAt: toDate(sub.currentPeriodStart), + providerProduct: sub.productId ? { productId: sub.productId } : null, + providerSubscriptionId: sub.id, + status: sub.status, + }; +} + +export function polar(options: PolarOptions): PolarProvider { + let client: Polar | null = null; + const getPolar = () => { + client ??= new Polar({ + accessToken: options.accessToken, + server: options.server ?? "production", + }); + return client; + }; + + let polarProductMapPromise: Promise< + Map + > | null = null; + + const getPolarProductMap = async () => { + polarProductMapPromise ??= getPolar() + .products.list({ isArchived: false, limit: 100 }) + .then((result) => new Map((result.result.items ?? []).map((p) => [p.id, p]))); + return polarProductMapPromise; + }; + const getWebhookEndpointAccount = async () => { + const orgs = await getPolar().organizations.list({ limit: 1 }); + const org = orgs.result.items?.[0]; + + if (!org) { + throw PayKitError.from( + "INTERNAL_SERVER_ERROR", + PAYKIT_ERROR_CODES.PROVIDER_INVALID_CONFIG, + "No Polar organization found for the configured access token", + ); + } + + return { + displayName: org.name, + environment: options.server === "sandbox" ? "sandbox" : "production", + providerAccountId: org.id, + providerId: "polar", + }; + }; + return { id: "polar", name: "Polar", - capabilities: { testClocks: false }, + capabilities: polarCapabilities, + + async upsertSubscriptionProduct(product) { + const polarProductMap = await getPolarProductMap(); + const existingProductId = product.existingProviderProduct?.productId ?? null; + const existingPolarProduct = existingProductId + ? polarProductMap.get(existingProductId) + : null; + + if (existingPolarProduct) { + const intervalMatches = + existingPolarProduct.recurringInterval === (product.priceInterval ?? null); + + if (intervalMatches) { + const updated = await getPolar().products.update({ + id: existingPolarProduct.id, + productUpdate: { + name: product.name, + visibility: "private", + prices: [ + { + amountType: "fixed" as const, + priceAmount: product.priceAmount, + priceCurrency: "usd", + }, + ], + }, + }); + return { providerProduct: { productId: updated.id } }; + } + + await getPolar().products.update({ + id: existingPolarProduct.id, + productUpdate: { isArchived: true }, + }); + } + + const created = await getPolar().products.create({ + name: product.name, + visibility: "private", + recurringInterval: (product.priceInterval as "month" | "year") ?? null, + prices: [ + { + amountType: "fixed" as const, + priceAmount: product.priceAmount, + priceCurrency: "usd", + }, + ], + }); + return { providerProduct: { productId: created.id } }; + }, + + async cleanupSubscriptionProducts(data) { + const [polarProductMap, orgs] = await Promise.all([ + getPolarProductMap(), + getPolar().organizations.list({ limit: 1 }), + ]); + const activeProductIds = new Set(data.activeProviderProductIds); + const cleanup: Promise[] = []; + + for (const [polarId] of polarProductMap) { + if (!activeProductIds.has(polarId)) { + cleanup.push( + getPolar().products.update({ + id: polarId, + productUpdate: { isArchived: true }, + }), + ); + } + } + + const org = orgs.result.items?.[0]; + if (org) { + cleanup.push( + getPolar().organizations.update({ + id: org.id, + organizationUpdate: { + subscriptionSettings: { + allowMultipleSubscriptions: true, + allowCustomerUpdates: false, + prorationBehavior: "invoice", + benefitRevocationGracePeriod: org.subscriptionSettings.benefitRevocationGracePeriod, + preventTrialAbuse: org.subscriptionSettings.preventTrialAbuse, + }, + customerPortalSettings: { + subscription: { updateSeats: false, updatePlan: false }, + usage: org.customerPortalSettings.usage, + }, + }, + }), + ); + } + + await Promise.all(cleanup); + }, async createCustomer(data) { if (!data.email) { @@ -147,7 +414,7 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme }; try { - const customer = await client.customers.create({ + const customer = await getPolar().customers.create({ email: data.email, name: data.name, metadata: customerMetadata, @@ -160,7 +427,7 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme if (!(error instanceof SDKValidationError)) throw error; // Duplicate email — find and re-link the existing customer. - const list = await client.customers.list({ query: data.email, limit: 1 }); + const list = await getPolar().customers.list({ query: data.email, limit: 1 }); const existing = list.result.items[0]; if (!existing) { @@ -171,7 +438,7 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme ); } - await client.customers.update({ + await getPolar().customers.update({ id: existing.id, customerUpdate: { name: data.name, @@ -185,8 +452,12 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme } }, + async deleteCustomer(data) { + await getPolar().customers.delete({ id: data.providerCustomerId }); + }, + async updateCustomer(data) { - await client.customers.update({ + await getPolar().customers.update({ id: data.providerCustomerId, customerUpdate: { email: data.email, @@ -196,24 +467,18 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme }); }, - async deleteCustomer(data) { - await client.customers.delete({ id: data.providerCustomerId }); - }, - - getTestClock() { - return notSupported("getTestClock"); - }, - - advanceTestClock() { - return notSupported("advanceTestClock"); - }, + async createCustomerPortalSession(data) { + const session = await getPolar().customerSessions.create({ + customerId: data.providerCustomerId, + }); - attachPaymentMethod() { - return notSupported("attachPaymentMethod"); + return { + url: session.customerPortalUrl, + }; }, async createSubscriptionCheckout(data) { - const checkout = await client.checkouts.create({ + const checkout = await getPolar().checkouts.create({ products: [data.providerProduct.productId!], customerId: data.providerCustomerId, metadata: data.metadata, @@ -230,12 +495,8 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme }; }, - createSubscription() { - return notSupported("createSubscription (use checkout instead)"); - }, - - async updateSubscription(data) { - const sub = await client.subscriptions.update({ + async changeSubscriptionProduct(data) { + const sub = await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { productId: data.providerProduct.productId!, @@ -245,33 +506,23 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme return { paymentUrl: null, - subscription: { - cancelAtPeriodEnd: sub.cancelAtPeriodEnd, - currentPeriodEndAt: sub.currentPeriodEnd ? new Date(sub.currentPeriodEnd) : null, - currentPeriodStartAt: sub.currentPeriodStart ? new Date(sub.currentPeriodStart) : null, - providerSubscriptionId: sub.id, - status: sub.status, - }, + subscription: normalizeProviderSubscription(sub), }; }, - createInvoice() { - return notSupported("createInvoice"); - }, - - async scheduleSubscriptionChange(data) { - const current = await client.subscriptions.get({ id: data.providerSubscriptionId }); + async changeSubscriptionProductAtPeriodEnd(data) { + const current = await getPolar().subscriptions.get({ id: data.providerSubscriptionId }); const wasCanceled = current.cancelAtPeriodEnd; // Un-cancel to allow product update (Polar rejects updates on canceled subs) if (wasCanceled) { - await client.subscriptions.update({ + await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { cancelAtPeriodEnd: false }, }); } - await client.subscriptions.update({ + await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { productId: data.providerProduct!.productId!, @@ -281,28 +532,22 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme // Re-cancel if it was previously canceled (preserve cancel-at-period-end intent) if (wasCanceled) { - await client.subscriptions.update({ + await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { cancelAtPeriodEnd: true }, }); } - const sub = await client.subscriptions.get({ id: data.providerSubscriptionId }); + const sub = await getPolar().subscriptions.get({ id: data.providerSubscriptionId }); return { paymentUrl: null, - subscription: { - cancelAtPeriodEnd: sub.cancelAtPeriodEnd, - currentPeriodEndAt: sub.currentPeriodEnd ? new Date(sub.currentPeriodEnd) : null, - currentPeriodStartAt: sub.currentPeriodStart ? new Date(sub.currentPeriodStart) : null, - providerSubscriptionId: sub.id, - status: sub.status, - }, + subscription: normalizeProviderSubscription(sub), }; }, - async cancelSubscription(data) { - const sub = await client.subscriptions.update({ + async cancelSubscriptionAtPeriodEnd(data) { + const sub = await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { cancelAtPeriodEnd: true, @@ -311,18 +556,18 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme return { paymentUrl: null, - subscription: { - cancelAtPeriodEnd: sub.cancelAtPeriodEnd, - currentPeriodEndAt: sub.currentPeriodEnd ? new Date(sub.currentPeriodEnd) : null, - currentPeriodStartAt: sub.currentPeriodStart ? new Date(sub.currentPeriodStart) : null, - providerSubscriptionId: sub.id, - status: sub.status, - }, + subscription: normalizeProviderSubscription(sub), }; }, + async getSubscription(data) { + return normalizeProviderSubscription( + await getPolar().subscriptions.get({ id: data.providerSubscriptionId }), + ); + }, + async listActiveSubscriptions(data) { - const result = await client.subscriptions.list({ + const result = await getPolar().subscriptions.list({ customerId: data.providerCustomerId, }); @@ -331,12 +576,12 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme .map((sub) => ({ providerSubscriptionId: sub.id })); }, - async resumeSubscription(data) { - const current = await client.subscriptions.get({ id: data.providerSubscriptionId }); + async resumeSubscriptionAtPeriodEnd(data) { + const current = await getPolar().subscriptions.get({ id: data.providerSubscriptionId }); // Un-cancel first if pending cancellation if (current.cancelAtPeriodEnd) { - await client.subscriptions.update({ + await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { cancelAtPeriodEnd: false }, }); @@ -344,134 +589,19 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme // Clear pending product change if any const sub = current.pendingUpdate - ? await client.subscriptions.update({ + ? await getPolar().subscriptions.update({ id: data.providerSubscriptionId, subscriptionUpdate: { productId: current.productId }, }) - : await client.subscriptions.get({ id: data.providerSubscriptionId }); + : await getPolar().subscriptions.get({ id: data.providerSubscriptionId }); return { paymentUrl: null, - subscription: { - cancelAtPeriodEnd: sub.cancelAtPeriodEnd, - currentPeriodEndAt: sub.currentPeriodEnd ? new Date(sub.currentPeriodEnd) : null, - currentPeriodStartAt: sub.currentPeriodStart ? new Date(sub.currentPeriodStart) : null, - providerSubscriptionId: sub.id, - status: sub.status, - }, + subscription: normalizeProviderSubscription(sub), }; }, - detachPaymentMethod() { - return notSupported("detachPaymentMethod"); - }, - - async syncProducts(data) { - const [allPolarProducts, orgs] = await Promise.all([ - client.products.list({ isArchived: false, limit: 100 }), - client.organizations.list({ limit: 1 }), - ]); - - const org = orgs.result.items?.[0]; - const polarProductMap = new Map((allPolarProducts.result.items ?? []).map((p) => [p.id, p])); - - const activeProductIds = new Set(); - - const results = await Promise.all( - data.products.map(async (product) => { - const existingProductId = product.existingProviderProduct?.productId ?? null; - const existingPolarProduct = existingProductId - ? polarProductMap.get(existingProductId) - : null; - - if (existingPolarProduct) { - const intervalMatches = - existingPolarProduct.recurringInterval === (product.priceInterval ?? null); - - if (intervalMatches) { - const updated = await client.products.update({ - id: existingPolarProduct.id, - productUpdate: { - name: product.name, - visibility: "private", - prices: [ - { - amountType: "fixed" as const, - priceAmount: product.priceAmount, - priceCurrency: "usd", - }, - ], - }, - }); - activeProductIds.add(updated.id); - return { id: product.id, providerProduct: { productId: updated.id } }; - } - - // Interval changed — archive old, create new - await client.products.update({ - id: existingPolarProduct.id, - productUpdate: { isArchived: true }, - }); - } - - const created = await client.products.create({ - name: product.name, - visibility: "private", - recurringInterval: (product.priceInterval as "month" | "year") ?? null, - prices: [ - { - amountType: "fixed" as const, - priceAmount: product.priceAmount, - priceCurrency: "usd", - }, - ], - }); - activeProductIds.add(created.id); - return { id: product.id, providerProduct: { productId: created.id } }; - }), - ); - - // Archive orphans + configure org settings in parallel - const cleanup: Promise[] = []; - - for (const [polarId] of polarProductMap) { - if (!activeProductIds.has(polarId)) { - cleanup.push( - client.products.update({ - id: polarId, - productUpdate: { isArchived: true }, - }), - ); - } - } - - if (org) { - cleanup.push( - client.organizations.update({ - id: org.id, - organizationUpdate: { - subscriptionSettings: { - allowMultipleSubscriptions: true, - allowCustomerUpdates: false, - prorationBehavior: "invoice", - benefitRevocationGracePeriod: org.subscriptionSettings.benefitRevocationGracePeriod, - preventTrialAbuse: org.subscriptionSettings.preventTrialAbuse, - }, - customerPortalSettings: { - subscription: { updateSeats: false, updatePlan: false }, - usage: org.customerPortalSettings.usage, - }, - }, - }), - ); - } - - await Promise.all(cleanup); - - return { results }; - }, - - async handleWebhook(data): Promise { + async parseWebhook(data): Promise { const webhookIdKey = Object.keys(data.headers).find((k) => k.toLowerCase() === "webhook-id"); const webhookId = webhookIdKey ? data.headers[webhookIdKey]! : ""; @@ -486,9 +616,8 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme "Invalid Polar webhook signature", ); } - // Unknown event types (e.g. member.created) — ignore silently if (error instanceof SDKValidationError) { - return []; + return createRawWebhookEvents(data.body, webhookId); } throw error; } @@ -500,31 +629,78 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme case "subscription.uncanceled": case "subscription.canceled": case "subscription.past_due": - case "subscription.revoked": - return createSubscriptionEvents(event, webhookId); + case "subscription.revoked": { + const events = createSubscriptionEvents(event, webhookId); + return events.length > 0 ? events : createRawWebhookEvents(data.body, webhookId); + } case "checkout.created": - case "checkout.updated": - return createCheckoutEvents(event, webhookId); + case "checkout.updated": { + const events = await createCheckoutEventsFromSdk(getPolar(), event, webhookId); + return events.length > 0 ? events : createRawWebhookEvents(data.body, webhookId); + } default: return []; } }, - async createPortalSession(data) { - const session = await client.customerSessions.create({ - customerId: data.providerCustomerId, + async getWebhookEndpointAccount() { + return getWebhookEndpointAccount(); + }, + + async ensureWebhookEndpoint(data) { + if (data.existingEndpointId) { + try { + const endpoint = await getPolar().webhooks.updateWebhookEndpoint({ + id: data.existingEndpointId, + webhookEndpointUpdate: { + enabled: true, + events: POLAR_WEBHOOK_EVENTS, + format: WebhookFormat.Raw, + name: "PayKit", + url: data.url, + }, + }); + return { + created: false, + endpointId: endpoint.id, + webhookSecret: options.webhookSecret, + }; + } catch (error) { + if (!(error instanceof Error) || !/not found|404/i.test(error.message)) { + throw error; + } + } + } + + const endpoint = await getPolar().webhooks.createWebhookEndpoint({ + events: POLAR_WEBHOOK_EVENTS, + format: WebhookFormat.Raw, + name: "PayKit", + url: data.url, }); return { - url: session.customerPortalUrl, + created: true, + endpointId: endpoint.id, + webhookSecret: endpoint.secret, }; }, + async deleteWebhookEndpoint(data) { + await getPolar().webhooks.deleteWebhookEndpoint({ id: data.endpointId }); + }, + async check() { try { - await client.products.list({ limit: 1 }); + await getPolar().products.list({ limit: 1 }); + + const endpoints = await getPolar().webhooks.listWebhookEndpoints({ limit: 100 }); + const webhookEndpoints = (endpoints.result.items ?? []).map((endpoint) => ({ + status: endpoint.enabled ? "enabled" : "disabled", + url: endpoint.url, + })); - const customers = await client.customers.list({ + const customers = await getPolar().customers.list({ limit: 5, sorting: ["created_at"], }); @@ -537,7 +713,7 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme ok: true, displayName: "Polar", mode: options.server === "sandbox" ? "sandbox" : "production", - webhookEndpoints: [], + webhookEndpoints, customerSample, }; } catch (error) { @@ -552,18 +728,3 @@ export function createPolarProvider(client: Polar, options: PolarOptions): Payme }, }; } - -export function polar(polarOptions: PolarOptions): PolarProviderConfig { - return { - id: "polar", - name: "Polar", - capabilities: { testClocks: false }, - createAdapter(): PaymentProvider { - const client = new Polar({ - accessToken: polarOptions.accessToken, - server: polarOptions.server ?? "production", - }); - return createPolarProvider(client, polarOptions); - }, - }; -} diff --git a/packages/stripe/src/__tests__/stripe-provider.test.ts b/packages/stripe/src/__tests__/stripe-provider.test.ts index 0ae26b87..5a5e4339 100644 --- a/packages/stripe/src/__tests__/stripe-provider.test.ts +++ b/packages/stripe/src/__tests__/stripe-provider.test.ts @@ -3,28 +3,16 @@ import { describe, expect, it } from "vitest"; import { stripe } from "../stripe-provider"; describe("@paykitjs/stripe", () => { - it("should return a provider config with createAdapter", () => { - const config = stripe({ + it("should return a provider", () => { + const provider = stripe({ secretKey: "sk_test_123", webhookSecret: "whsec_test_123", }); - expect(config.id).toBe("stripe"); - expect(config.name).toBe("Stripe"); - expect(typeof config.createAdapter).toBe("function"); - }); - - it("should create a PaymentProvider adapter", () => { - const config = stripe({ - secretKey: "sk_test_123", - webhookSecret: "whsec_test_123", - }); - - const adapter = config.createAdapter(); - expect(adapter.id).toBe("stripe"); - expect(adapter.name).toBe("Stripe"); - expect(typeof adapter.createCustomer).toBe("function"); - expect(typeof adapter.updateCustomer).toBe("function"); - expect(typeof adapter.handleWebhook).toBe("function"); + expect(provider.id).toBe("stripe"); + expect(provider.name).toBe("Stripe"); + expect(typeof provider.createCustomer).toBe("function"); + expect(typeof provider.updateCustomer).toBe("function"); + expect(typeof provider.parseWebhook).toBe("function"); }); }); diff --git a/packages/stripe/src/__tests__/stripe.test.ts b/packages/stripe/src/__tests__/stripe.test.ts index 3592659e..cdf2d69c 100644 --- a/packages/stripe/src/__tests__/stripe.test.ts +++ b/packages/stripe/src/__tests__/stripe.test.ts @@ -1,9 +1,34 @@ import { PAYKIT_ERROR_CODES } from "paykitjs"; -import { describe, expect, it, vi } from "vitest"; - -import { createStripeProvider, stripe } from "../stripe-provider"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const stripeSdkMock = vi.hoisted(() => ({ + client: null as unknown, + Stripe: vi.fn(function Stripe() { + return stripeSdkMock.client; + }), +})); + +vi.mock("stripe", () => ({ + default: stripeSdkMock.Stripe, +})); + +import { stripe, type StripeOptions } from "../stripe-provider"; + +function createRuntime(client: unknown, options: Partial = {}) { + stripeSdkMock.client = client; + return stripe({ + secretKey: "sk_test_123", + webhookSecret: "whsec_123", + ...options, + }); +} describe("providers/stripe", () => { + beforeEach(() => { + stripeSdkMock.Stripe.mockClear(); + stripeSdkMock.client = null; + }); + it("creates a test clock and stores its id on the provider customer", async () => { const createClock = vi.fn().mockResolvedValue({ frozen_time: 1_700_000_000, @@ -12,22 +37,16 @@ describe("providers/stripe", () => { status: "ready", }); const createCustomer = vi.fn().mockResolvedValue({ id: "cus_123" }); - const runtime = createStripeProvider( - { - customers: { create: createCustomer }, - testHelpers: { - testClocks: { - advance: vi.fn(), - create: createClock, - retrieve: vi.fn(), - }, + const runtime = createRuntime({ + customers: { create: createCustomer }, + testHelpers: { + testClocks: { + advance: vi.fn(), + create: createClock, + retrieve: vi.fn(), }, - } as never, - { - secretKey: "sk_test_123", - webhookSecret: "whsec_123", }, - ); + }); const result = await runtime.createCustomer({ createTestClock: true, @@ -60,7 +79,7 @@ describe("providers/stripe", () => { }); it("throws a clear error when testing mode uses a live Stripe key", async () => { - const runtime = createStripeProvider( + const runtime = createRuntime( { customers: { create: vi.fn() }, testHelpers: { @@ -70,10 +89,9 @@ describe("providers/stripe", () => { retrieve: vi.fn(), }, }, - } as never, + }, { secretKey: "sk_live_123", - webhookSecret: "whsec_123", }, ); @@ -95,22 +113,16 @@ describe("providers/stripe", () => { name: "customer_123", status: "ready", }); - const runtime = createStripeProvider( - { - customers: { create: vi.fn() }, - testHelpers: { - testClocks: { - advance: advanceClock, - create: vi.fn(), - retrieve: retrieveClock, - }, + const runtime = createRuntime({ + customers: { create: vi.fn() }, + testHelpers: { + testClocks: { + advance: advanceClock, + create: vi.fn(), + retrieve: retrieveClock, }, - } as never, - { - secretKey: "sk_test_123", - webhookSecret: "whsec_123", }, - ); + }); const frozenTime = new Date("2024-01-02T00:00:00.000Z"); const result = await runtime.advanceTestClock({ @@ -135,14 +147,13 @@ describe("providers/stripe", () => { createSession: ReturnType, managedPayments: boolean, ) { - return createStripeProvider( + return createRuntime( { checkout: { sessions: { create: createSession } }, - } as never, + }, { + apiVersion: managedPayments ? "2026-03-04.preview" : undefined, managedPayments, - secretKey: "sk_test_123", - webhookSecret: "whsec_123", }, ); } diff --git a/packages/stripe/src/stripe-provider.ts b/packages/stripe/src/stripe-provider.ts index 0375f5c2..26543709 100644 --- a/packages/stripe/src/stripe-provider.ts +++ b/packages/stripe/src/stripe-provider.ts @@ -1,10 +1,5 @@ -import { PayKitError, PAYKIT_ERROR_CODES } from "paykitjs"; -import type { - NormalizedWebhookEvent, - PayKitProviderConfig, - PaymentProvider, - ProviderTestClock, -} from "paykitjs"; +import { PayKitError, PAYKIT_ERROR_CODES, unsupportedProviderCapability } from "paykitjs"; +import type { NormalizedWebhookEvent, PayKitProvider, ProviderTestClock } from "paykitjs"; import StripeSdk from "stripe"; /** @@ -27,6 +22,24 @@ const STRIPE_WEBHOOK_EVENTS: StripeSdk.WebhookEndpointCreateParams.EnabledEvent[ "payment_method.detached", ]; +const stripeCapabilities = { + subscriptionProducts: true, + subscriptionCheckout: true, + customerPortal: true, + createInvoices: true, + detachPaymentMethods: true, + setupPaymentMethods: true, + cancelSubscriptionsAtPeriodEnd: true, + createSubscriptions: true, + changeSubscriptionProducts: true, + listActiveSubscriptions: true, + pendingSubscriptionProductChanges: false, + resumeSubscriptionsAtPeriodEnd: true, + subscriptionSchedules: true, + testClocks: true, + manageWebhookEndpoints: true, +} as const satisfies PayKitProvider["capabilities"]; + export interface StripeOptions { secretKey: string; webhookSecret: string; @@ -36,9 +49,7 @@ export interface StripeOptions { managedPayments?: boolean; } -export type StripeProviderConfig = PayKitProviderConfig & { - capabilities: { testClocks: true }; -}; +export type StripeProvider = PayKitProvider; type StripeInvoiceWithExtras = StripeSdk.Invoice & { payment_intent?: StripeSdk.PaymentIntent | string | null; @@ -370,15 +381,6 @@ async function createCheckoutCompletedEvents( isDefault: session.mode === "subscription", }; events.push({ - actions: [ - { - data: { - paymentMethod: normalizedPaymentMethod, - providerCustomerId, - }, - type: "payment_method.upsert", - }, - ], name: "payment_method.attached", payload: { paymentMethod: normalizedPaymentMethod, @@ -390,15 +392,6 @@ async function createCheckoutCompletedEvents( if (session.mode === "payment" && paymentIntent?.status === "succeeded") { const normalizedPayment = normalizeStripePaymentIntent(paymentIntent); events.push({ - actions: [ - { - data: { - payment: normalizedPayment, - providerCustomerId, - }, - type: "payment.upsert", - }, - ], name: "payment.succeeded", payload: { payment: normalizedPayment, @@ -454,15 +447,6 @@ async function createSubscriptionEvents(event: StripeSdk.Event): Promise = { - actions: [ - { - data: { - providerCustomerId, - subscription: normalizedSubscription, - }, - type: "subscription.upsert", - }, - ], name: "subscription.updated", payload: { providerCustomerId, @@ -518,16 +493,6 @@ function createInvoiceEvents(event: StripeSdk.Event): NormalizedWebhookEvent[] { const normalizedInvoice = normalizeStripeInvoice(invoice); const normalizedEvent: NormalizedWebhookEvent<"invoice.updated"> = { - actions: [ - { - data: { - invoice: normalizedInvoice, - providerCustomerId, - providerSubscriptionId, - }, - type: "invoice.upsert", - }, - ], name: "invoice.updated", payload: { invoice: normalizedInvoice, @@ -548,14 +513,6 @@ function createDetachedPaymentMethodEvents(event: StripeSdk.Event): NormalizedWe return [ { - actions: [ - { - data: { - providerMethodId: paymentMethod.id, - }, - type: "payment_method.delete", - }, - ], name: "payment_method.detached", payload: { providerEventId: event.id, @@ -565,26 +522,76 @@ function createDetachedPaymentMethodEvents(event: StripeSdk.Event): NormalizedWe ]; } -export function createStripeProvider(client: StripeSdk, options: StripeOptions): PaymentProvider { +export function stripe(options: StripeOptions): StripeProvider { + const apiVersion = options.apiVersion ?? PAYKIT_STRIPE_API_VERSION; + if (options.managedPayments) { + if (!apiVersion.endsWith(".preview") || apiVersion < STRIPE_MANAGED_PAYMENTS_MIN_VERSION) { + throw PayKitError.from( + "BAD_REQUEST", + PAYKIT_ERROR_CODES.PROVIDER_INVALID_CONFIG, + `managedPayments requires apiVersion >= ${STRIPE_MANAGED_PAYMENTS_MIN_VERSION} (got "${apiVersion}")`, + ); + } + } + + let client: StripeSdk | null = null; + const getStripe = () => { + client ??= new StripeSdk(options.secretKey, { + apiVersion: apiVersion as StripeSdk.LatestApiVersion, + maxNetworkRetries: 3, + }); + return client; + }; + const currency = "usd"; return { id: "stripe", name: "Stripe", - capabilities: { testClocks: true }, + capabilities: stripeCapabilities, + + async upsertSubscriptionProduct(product) { + let productId = product.existingProviderProduct?.productId ?? null; + if (!productId) { + const stripeProduct = await getStripe().products.create({ + metadata: { paykit_product_id: product.id }, + name: product.name, + }); + productId = stripeProduct.id; + } else { + await getStripe().products.update(productId, { name: product.name }); + } + + const existingPriceId = product.existingProviderProduct?.priceId ?? null; + if (existingPriceId) { + return { providerProduct: { productId, priceId: existingPriceId } }; + } + + const priceParams: StripeSdk.PriceCreateParams = { + currency, + product: productId, + unit_amount: product.priceAmount, + }; + if (product.priceInterval) { + priceParams.recurring = { interval: product.priceInterval as "month" | "year" }; + } + const stripePrice = await getStripe().prices.create(priceParams); + + return { providerProduct: { productId, priceId: stripePrice.id } }; + }, async createCustomer(data) { let testClock: ProviderTestClock | undefined; if (data.createTestClock) { assertStripeTestKey(options); - const clock = await client.testHelpers.testClocks.create({ + const clock = await getStripe().testHelpers.testClocks.create({ frozen_time: Math.floor(Date.now() / 1000), name: data.id, }); testClock = normalizeStripeTestClock(clock); } - const customer = await client.customers.create({ + const customer = await getStripe().customers.create({ email: data.email, metadata: { customerId: data.id, @@ -603,43 +610,42 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }; }, + async deleteCustomer(data) { + await getStripe().customers.del(data.providerCustomerId); + }, + async updateCustomer(data) { - await client.customers.update(data.providerCustomerId, { + await getStripe().customers.update(data.providerCustomerId, { email: data.email, metadata: data.metadata, name: data.name, }); }, - async deleteCustomer(data) { - await client.customers.del(data.providerCustomerId); - }, - - async getTestClock(data) { - const clock = await client.testHelpers.testClocks.retrieve(data.testClockId); - return normalizeStripeTestClock(clock); - }, - - async advanceTestClock(data) { - assertStripeTestKey(options); - - await client.testHelpers.testClocks.advance(data.testClockId, { - frozen_time: Math.floor(data.frozenTime.getTime() / 1000), + async createInvoice(data) { + const stripeInvoice = await getStripe().invoices.create({ + auto_advance: data.autoAdvance ?? true, + collection_method: "charge_automatically", + customer: data.providerCustomerId, + currency, }); - for (let i = 0; i < 60; i++) { - const clock = await client.testHelpers.testClocks.retrieve(data.testClockId); - if (clock.status === "ready") { - return normalizeStripeTestClock(clock); - } - await new Promise((resolve) => setTimeout(resolve, 2000)); + if (data.lines.length > 0) { + await getStripe().invoices.addLines(stripeInvoice.id, { + lines: data.lines.map((line) => ({ + amount: line.amount, + description: line.description, + })), + }); } - throw new Error(`Test clock ${data.testClockId} did not reach 'ready' status`); + const finalizedInvoice = await getStripe().invoices.finalizeInvoice(stripeInvoice.id); + + return normalizeStripeInvoice(finalizedInvoice); }, - async attachPaymentMethod(data) { - const session = await client.checkout.sessions.create({ + async createPaymentMethodSetupSession(data) { + const session = await getStripe().checkout.sessions.create({ cancel_url: data.returnURL, client_reference_id: data.providerCustomerId, customer: data.providerCustomerId, @@ -654,30 +660,47 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): return { url: session.url }; }, - async createSubscriptionCheckout(data) { - const sessionParams: StripeSdk.Checkout.SessionCreateParams & { - managed_payments?: { enabled: boolean }; - } = { - cancel_url: data.cancelUrl ?? data.successUrl, - client_reference_id: data.providerCustomerId, + async detachPaymentMethod(data) { + await getStripe().paymentMethods.detach(data.providerMethodId); + }, + + async createCustomerPortalSession(data) { + const session = await getStripe().billingPortal.sessions.create({ customer: data.providerCustomerId, - line_items: [{ price: data.providerProduct.priceId, quantity: 1 }], - metadata: data.metadata, - mode: "subscription", - success_url: data.successUrl, - }; - if (options.managedPayments) { - sessionParams.managed_payments = { enabled: true }; - } - const session = await client.checkout.sessions.create(sessionParams); + return_url: data.returnUrl, + }); + return { url: session.url }; + }, - if (!session.url) { - throw PayKitError.from("BAD_REQUEST", PAYKIT_ERROR_CODES.PROVIDER_SESSION_INVALID); + async cancelSubscriptionAtPeriodEnd(data) { + let scheduleId = data.providerSubscriptionScheduleId ?? null; + if (!scheduleId) { + const currentSubscription = (await getStripe().subscriptions.retrieve( + data.providerSubscriptionId, + )) as StripeSubscriptionWithExtras; + scheduleId = + typeof currentSubscription.schedule === "string" + ? currentSubscription.schedule + : (currentSubscription.schedule?.id ?? null); + } + if (scheduleId) { + const schedule = await getStripe().subscriptionSchedules.retrieve(scheduleId); + if (schedule.status !== "released" && schedule.status !== "canceled") { + await getStripe().subscriptionSchedules.release(scheduleId); + } } + const updatedSubscription = (await getStripe().subscriptions.update( + data.providerSubscriptionId, + { + cancel_at_period_end: true, + }, + )) as StripeSubscriptionWithExtras; + return { - paymentUrl: session.url, - providerCheckoutSessionId: session.id, + paymentUrl: null, + requiredAction: null, + subscription: normalizeStripeSubscription(updatedSubscription), }; }, @@ -688,7 +711,7 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): payment_behavior: "default_incomplete", expand: ["latest_invoice.payment_intent"], }; - const createdSubscription = (await client.subscriptions.create( + const createdSubscription = (await getStripe().subscriptions.create( createParams, )) as StripeSubscriptionWithExtras; @@ -710,9 +733,36 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }; }, - async updateSubscription(data) { + async createSubscriptionCheckout(data) { + const sessionParams: StripeSdk.Checkout.SessionCreateParams & { + managed_payments?: { enabled: boolean }; + } = { + cancel_url: data.cancelUrl ?? data.successUrl, + client_reference_id: data.providerCustomerId, + customer: data.providerCustomerId, + line_items: [{ price: data.providerProduct.priceId, quantity: 1 }], + metadata: data.metadata, + mode: "subscription", + success_url: data.successUrl, + }; + if (options.managedPayments) { + sessionParams.managed_payments = { enabled: true }; + } + const session = await getStripe().checkout.sessions.create(sessionParams); + + if (!session.url) { + throw PayKitError.from("BAD_REQUEST", PAYKIT_ERROR_CODES.PROVIDER_SESSION_INVALID); + } + + return { + paymentUrl: session.url, + providerCheckoutSessionId: session.id, + }; + }, + + async changeSubscriptionProduct(data) { const currentSubscription = await retrieveExpandedSubscription( - client, + getStripe(), data.providerSubscriptionId, ); const currentItem = currentSubscription.items.data[0]; @@ -723,17 +773,20 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): ); } - const updatedSubscription = (await client.subscriptions.update(data.providerSubscriptionId, { - items: [ - { - id: currentItem.id, - price: data.providerProduct.priceId, - }, - ], - payment_behavior: "pending_if_incomplete", - proration_behavior: "always_invoice", - expand: ["latest_invoice.payment_intent"], - })) as StripeSubscriptionWithExtras; + const updatedSubscription = (await getStripe().subscriptions.update( + data.providerSubscriptionId, + { + items: [ + { + id: currentItem.id, + price: data.providerProduct.priceId, + }, + ], + payment_behavior: "pending_if_incomplete", + proration_behavior: "always_invoice", + expand: ["latest_invoice.payment_intent"], + }, + )) as StripeSubscriptionWithExtras; const latestInvoice = updatedSubscription.latest_invoice; const invoice = @@ -753,105 +806,21 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }; }, - async scheduleSubscriptionChange(data) { - if (!data.providerProduct?.priceId) { - throw PayKitError.from("BAD_REQUEST", PAYKIT_ERROR_CODES.PROVIDER_PRICE_REQUIRED); - } - - const currentSub = (await client.subscriptions.retrieve(data.providerSubscriptionId, { - expand: ["items"], - })) as StripeSubscriptionWithExtras; - const periodEndSeconds = getLatestPeriodEnd(currentSub); - if (typeof periodEndSeconds !== "number") { - throw PayKitError.from( - "BAD_REQUEST", - PAYKIT_ERROR_CODES.PROVIDER_SUBSCRIPTION_MISSING_PERIOD, - ); - } - - const currentItems = currentSub.items.data.map((item: { price: { id: string } }) => ({ - price: item.price.id, - quantity: 1, - })); - - let schedule: StripeSdk.SubscriptionSchedule; - if (data.providerSubscriptionScheduleId) { - schedule = await client.subscriptionSchedules.retrieve(data.providerSubscriptionScheduleId); - } else { - const existingScheduleId = - typeof currentSub.schedule === "string" - ? currentSub.schedule - : (currentSub.schedule?.id ?? null); - schedule = existingScheduleId - ? await client.subscriptionSchedules.retrieve(existingScheduleId) - : await client.subscriptionSchedules.create({ - from_subscription: data.providerSubscriptionId, - }); - } - const scheduleId = schedule.id; - - const currentPhase = schedule.phases[0]; - const currentPhaseStart = currentPhase?.start_date ?? Math.floor(Date.now() / 1000); - - await client.subscriptionSchedules.update(scheduleId, { - end_behavior: "release", - phases: [ - { - items: currentItems, - start_date: currentPhaseStart, - end_date: periodEndSeconds, - }, - { - items: [{ price: data.providerProduct.priceId, quantity: 1 }], - start_date: periodEndSeconds, - }, - ], - }); - - const updatedSubscription = await retrieveExpandedSubscription( - client, - data.providerSubscriptionId, + async changeSubscriptionProductAtPeriodEnd() { + throw unsupportedProviderCapability( + { id: "stripe", name: "Stripe" }, + "subscriptions.pendingProductChange", ); - - return { - paymentUrl: null, - requiredAction: null, - subscription: normalizeStripeSubscription(updatedSubscription), - }; }, - async cancelSubscription(data) { - const currentSubscription = (await client.subscriptions.retrieve( - data.providerSubscriptionId, - )) as StripeSubscriptionWithExtras; - - let scheduleId = data.providerSubscriptionScheduleId ?? null; - if (!scheduleId) { - scheduleId = - typeof currentSubscription.schedule === "string" - ? currentSubscription.schedule - : (currentSubscription.schedule?.id ?? null); - } - if (scheduleId) { - const schedule = await client.subscriptionSchedules.retrieve(scheduleId); - if (schedule.status !== "released" && schedule.status !== "canceled") { - await client.subscriptionSchedules.release(scheduleId); - } - } - - const updatedSubscription = (await client.subscriptions.update(data.providerSubscriptionId, { - cancel_at_period_end: true, - })) as StripeSubscriptionWithExtras; - - return { - paymentUrl: null, - requiredAction: null, - subscription: normalizeStripeSubscription(updatedSubscription), - }; + async getSubscription(data) { + return normalizeStripeSubscription( + await retrieveExpandedSubscription(getStripe(), data.providerSubscriptionId), + ); }, async listActiveSubscriptions(data) { - const subscriptions = await client.subscriptions.list({ + const subscriptions = await getStripe().subscriptions.list({ customer: data.providerCustomerId, status: "active", }); @@ -860,22 +829,25 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): })); }, - async resumeSubscription(data) { + async resumeSubscriptionAtPeriodEnd(data) { let scheduleId = data.providerSubscriptionScheduleId ?? null; if (!scheduleId) { - const sub = await client.subscriptions.retrieve(data.providerSubscriptionId); + const sub = await getStripe().subscriptions.retrieve(data.providerSubscriptionId); scheduleId = typeof sub.schedule === "string" ? sub.schedule : (sub.schedule?.id ?? null); } if (scheduleId) { - const schedule = await client.subscriptionSchedules.retrieve(scheduleId); + const schedule = await getStripe().subscriptionSchedules.retrieve(scheduleId); if (schedule.status !== "released" && schedule.status !== "canceled") { - await client.subscriptionSchedules.release(scheduleId); + await getStripe().subscriptionSchedules.release(scheduleId); } } - const updatedSubscription = (await client.subscriptions.update(data.providerSubscriptionId, { - cancel_at_period_end: false, - })) as StripeSubscriptionWithExtras; + const updatedSubscription = (await getStripe().subscriptions.update( + data.providerSubscriptionId, + { + cancel_at_period_end: false, + }, + )) as StripeSubscriptionWithExtras; return { paymentUrl: null, @@ -884,71 +856,64 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }; }, - async detachPaymentMethod(data) { - await client.paymentMethods.detach(data.providerMethodId); + async getOrCreateSubscriptionSchedule(data) { + const currentSub = (await getStripe().subscriptions.retrieve(data.providerSubscriptionId, { + expand: ["items", "schedule"], + })) as StripeSubscriptionWithExtras; + const existingScheduleId = + data.providerSubscriptionScheduleId ?? + (typeof currentSub.schedule === "string" + ? currentSub.schedule + : (currentSub.schedule?.id ?? null)); + const schedule = existingScheduleId + ? await getStripe().subscriptionSchedules.retrieve(existingScheduleId) + : await getStripe().subscriptionSchedules.create({ + from_subscription: data.providerSubscriptionId, + }); + const currentPhase = schedule.phases[0]; + return { + currentPhaseStartAt: new Date( + (currentPhase?.start_date ?? Math.floor(Date.now() / 1000)) * 1000, + ), + id: schedule.id, + }; }, - async syncProducts(data) { - const results = await Promise.all( - data.products.map(async (product) => { - let productId = product.existingProviderProduct?.productId ?? null; - if (!productId) { - const stripeProduct = await client.products.create({ - metadata: { paykit_product_id: product.id }, - name: product.name, - }); - productId = stripeProduct.id; - } else { - await client.products.update(productId, { name: product.name }); - } - - const existingPriceId = product.existingProviderProduct?.priceId ?? null; - if (existingPriceId) { - return { id: product.id, providerProduct: { productId, priceId: existingPriceId } }; - } - - const priceParams: StripeSdk.PriceCreateParams = { - currency, - product: productId, - unit_amount: product.priceAmount, - }; - if (product.priceInterval) { - priceParams.recurring = { - interval: product.priceInterval as "month" | "year", - }; - } - const stripePrice = await client.prices.create(priceParams); - - return { id: product.id, providerProduct: { productId, priceId: stripePrice.id } }; - }), - ); - - return { results }; + async updateSubscriptionSchedulePhases(data) { + await getStripe().subscriptionSchedules.update(data.providerSubscriptionScheduleId, { + end_behavior: "release", + phases: data.phases.map((phase) => ({ + items: [{ price: phase.providerProduct.priceId, quantity: 1 }], + start_date: Math.floor(phase.startAt.getTime() / 1000), + ...(phase.endAt ? { end_date: Math.floor(phase.endAt.getTime() / 1000) } : {}), + })), + }); }, - async createInvoice(data) { - const stripeInvoice = await client.invoices.create({ - auto_advance: data.autoAdvance ?? true, - collection_method: "charge_automatically", - customer: data.providerCustomerId, - currency, + async advanceTestClock(data) { + assertStripeTestKey(options); + + await getStripe().testHelpers.testClocks.advance(data.testClockId, { + frozen_time: Math.floor(data.frozenTime.getTime() / 1000), }); - if (data.lines.length > 0) { - await client.invoices.addLines(stripeInvoice.id, { - lines: data.lines.map((line) => ({ - amount: line.amount, - description: line.description, - })), - }); + for (let i = 0; i < 60; i++) { + const clock = await getStripe().testHelpers.testClocks.retrieve(data.testClockId); + if (clock.status === "ready") { + return normalizeStripeTestClock(clock); + } + await new Promise((resolve) => setTimeout(resolve, 2000)); } - const finalizedInvoice = await client.invoices.finalizeInvoice(stripeInvoice.id); + throw new Error(`Test clock ${data.testClockId} did not reach 'ready' status`); + }, - return normalizeStripeInvoice(finalizedInvoice); + async getTestClock(data) { + const clock = await getStripe().testHelpers.testClocks.retrieve(data.testClockId); + return normalizeStripeTestClock(clock); }, - async handleWebhook(data) { + async parseWebhook(data) { const headerKey = Object.keys(data.headers).find( (k) => k.toLowerCase() === "stripe-signature", ); @@ -958,22 +923,22 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): } const tolerance = data.allowStaleSignatures ? Number.POSITIVE_INFINITY : undefined; - const event = await client.webhooks.constructEventAsync( + const event = await getStripe().webhooks.constructEventAsync( data.body, signature, options.webhookSecret, tolerance, ); return [ - ...(await createCheckoutCompletedEvents(client, event)), + ...(await createCheckoutCompletedEvents(getStripe(), event)), ...(await createSubscriptionEvents(event)), ...createInvoiceEvents(event), ...createDetachedPaymentMethodEvents(event), ]; }, - async getTunnelAccount() { - const account = await client.accounts.retrieve(); + async getWebhookEndpointAccount() { + const account = await getStripe().accounts.retrieve(); const displayName = getStripeDisplayName(account); return { displayName, @@ -983,10 +948,10 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }; }, - async ensureTunnelWebhook(data) { + async ensureWebhookEndpoint(data) { if (data.existingEndpointId) { try { - const endpoint = await client.webhookEndpoints.update(data.existingEndpointId, { + const endpoint = await getStripe().webhookEndpoints.update(data.existingEndpointId, { enabled_events: STRIPE_WEBHOOK_EVENTS, url: data.url, }); @@ -1004,7 +969,7 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): } } - const endpoint = await client.webhookEndpoints.create({ + const endpoint = await getStripe().webhookEndpoints.create({ enabled_events: STRIPE_WEBHOOK_EVENTS, url: data.url, }); @@ -1016,27 +981,19 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }; }, - async disableTunnelWebhook(data) { - await client.webhookEndpoints.del(data.endpointId); - }, - - async createPortalSession(data) { - const session = await client.billingPortal.sessions.create({ - customer: data.providerCustomerId, - return_url: data.returnUrl, - }); - return { url: session.url }; + async deleteWebhookEndpoint(data) { + await getStripe().webhookEndpoints.del(data.endpointId); }, async check() { const mode = getStripeEnvironment(options.secretKey) === "test" ? "test mode" : "live mode"; try { - const account = await client.accounts.retrieve(); + const account = await getStripe().accounts.retrieve(); const displayName = getStripeDisplayName(account); let webhookEndpoints: Array<{ url: string; status: string }> = []; try { - const endpoints = await client.webhookEndpoints.list({ limit: 100 }); + const endpoints = await getStripe().webhookEndpoints.list({ limit: 100 }); webhookEndpoints = endpoints.data .filter((ep) => ep.status === "enabled") .map((ep) => ({ url: ep.url, status: ep.status })); @@ -1052,29 +1009,3 @@ export function createStripeProvider(client: StripeSdk, options: StripeOptions): }, }; } - -export function stripe(options: StripeOptions): StripeProviderConfig { - const apiVersion = options.apiVersion ?? PAYKIT_STRIPE_API_VERSION; - if (options.managedPayments) { - if (!apiVersion.endsWith(".preview") || apiVersion < STRIPE_MANAGED_PAYMENTS_MIN_VERSION) { - throw PayKitError.from( - "BAD_REQUEST", - PAYKIT_ERROR_CODES.PROVIDER_INVALID_CONFIG, - `managedPayments requires apiVersion >= ${STRIPE_MANAGED_PAYMENTS_MIN_VERSION} (got "${apiVersion}")`, - ); - } - } - const client = new StripeSdk(options.secretKey, { - apiVersion: apiVersion as StripeSdk.LatestApiVersion, - maxNetworkRetries: 3, - }); - - return { - id: "stripe", - name: "Stripe", - capabilities: { testClocks: true }, - createAdapter(): PaymentProvider { - return createStripeProvider(client, options); - }, - }; -}