Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 149 additions & 47 deletions src/actions/my-usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,68 @@ export interface MyStatsSummary extends UsageLogSummary {
currencyCode: CurrencyCode;
}

type StatsSummaryRow = {
model: string | null;
userRequests: number;
userCost: string | number | null;
userInputTokens: number | null;
userOutputTokens: number | null;
userCacheCreationTokens: number | null;
userCacheReadTokens: number | null;
userCacheCreation5mTokens: number | null;
userCacheCreation1hTokens: number | null;
keyRequests: number;
keyCost: string | number | null;
keyInputTokens: number | null;
keyOutputTokens: number | null;
keyCacheCreationTokens: number | null;
keyCacheReadTokens: number | null;
keyCacheCreation5mTokens: number | null;
keyCacheCreation1hTokens: number | null;
};

function mergeStatsSummaryRows(rows: StatsSummaryRow[]): StatsSummaryRow[] {
const merged = new Map<string, StatsSummaryRow>();

for (const row of rows) {
const key = row.model ?? "\0";
const current = merged.get(key);
if (!current) {
merged.set(key, { ...row });
continue;
}

current.userRequests += row.userRequests ?? 0;
current.userCost = Number(current.userCost ?? 0) + Number(row.userCost ?? 0);
current.userInputTokens = (current.userInputTokens ?? 0) + (row.userInputTokens ?? 0);
current.userOutputTokens = (current.userOutputTokens ?? 0) + (row.userOutputTokens ?? 0);
current.userCacheCreationTokens =
(current.userCacheCreationTokens ?? 0) + (row.userCacheCreationTokens ?? 0);
current.userCacheReadTokens =
(current.userCacheReadTokens ?? 0) + (row.userCacheReadTokens ?? 0);
current.userCacheCreation5mTokens =
(current.userCacheCreation5mTokens ?? 0) + (row.userCacheCreation5mTokens ?? 0);
current.userCacheCreation1hTokens =
(current.userCacheCreation1hTokens ?? 0) + (row.userCacheCreation1hTokens ?? 0);

current.keyRequests += row.keyRequests ?? 0;
current.keyCost = Number(current.keyCost ?? 0) + Number(row.keyCost ?? 0);
current.keyInputTokens = (current.keyInputTokens ?? 0) + (row.keyInputTokens ?? 0);
current.keyOutputTokens = (current.keyOutputTokens ?? 0) + (row.keyOutputTokens ?? 0);
current.keyCacheCreationTokens =
(current.keyCacheCreationTokens ?? 0) + (row.keyCacheCreationTokens ?? 0);
current.keyCacheReadTokens = (current.keyCacheReadTokens ?? 0) + (row.keyCacheReadTokens ?? 0);
current.keyCacheCreation5mTokens =
(current.keyCacheCreation5mTokens ?? 0) + (row.keyCacheCreation5mTokens ?? 0);
current.keyCacheCreation1hTokens =
(current.keyCacheCreation1hTokens ?? 0) + (row.keyCacheCreation1hTokens ?? 0);
}

return Array.from(merged.values()).sort(
(a, b) => Number(b.userCost ?? 0) - Number(a.userCost ?? 0)
);
}

/**
* Get aggregated statistics for a date range
* 通过 model breakdown 聚合,避免额外的 summary 聚合查询
Expand Down Expand Up @@ -992,40 +1054,80 @@ export async function getMyStatsSummary(
const userId = session.user.id;
const keyString = session.key.key;

// Key 维度是 User 维度的子集:用一条聚合 SQL 扫描 userId 范围即可同时算出两套 breakdown。
const modelBreakdown = await db
.select({
model: usageLedger.model,
// User breakdown(跨所有 Key)
userRequests: sql<number>`count(*)::int`,
userCost: sql<string>`COALESCE(sum(${usageLedger.costUsd}), 0)`,
userInputTokens: sql<number>`COALESCE(sum(${usageLedger.inputTokens}), 0)::double precision`,
userOutputTokens: sql<number>`COALESCE(sum(${usageLedger.outputTokens}), 0)::double precision`,
userCacheCreationTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreationInputTokens}), 0)::double precision`,
userCacheReadTokens: sql<number>`COALESCE(sum(${usageLedger.cacheReadInputTokens}), 0)::double precision`,
userCacheCreation5mTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation5mInputTokens}), 0)::double precision`,
userCacheCreation1hTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation1hInputTokens}), 0)::double precision`,
// Key breakdown(FILTER 聚合)
keyRequests: sql<number>`count(*) FILTER (WHERE ${usageLedger.key} = ${keyString})::int`,
keyCost: sql<string>`COALESCE(sum(${usageLedger.costUsd}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)`,
keyInputTokens: sql<number>`COALESCE(sum(${usageLedger.inputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyOutputTokens: sql<number>`COALESCE(sum(${usageLedger.outputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheCreationTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreationInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheReadTokens: sql<number>`COALESCE(sum(${usageLedger.cacheReadInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheCreation5mTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation5mInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheCreation1hTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation1hInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
})
.from(usageLedger)
.where(
and(
eq(usageLedger.userId, userId),
LEDGER_BILLING_CONDITION,
startDate ? gte(usageLedger.createdAt, startDate) : undefined,
endDate ? lt(usageLedger.createdAt, endDate) : undefined
// Key 维度是 User 维度的子集。迁移期同时存在 message_request 与 usage_ledger:
// 活跃 message_request 是权威实时来源,ledger-only 行补充导入/归档数据,并用 not exists 去重。
const [messageBreakdown, ledgerBreakdown] = await Promise.all([
db
.select({
model: messageRequest.model,
userRequests: sql<number>`count(*)::int`,
userCost: sql<string>`COALESCE(sum(${messageRequest.costUsd}), 0)`,
userInputTokens: sql<number>`COALESCE(sum(${messageRequest.inputTokens}), 0)::double precision`,
userOutputTokens: sql<number>`COALESCE(sum(${messageRequest.outputTokens}), 0)::double precision`,
userCacheCreationTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreationInputTokens}), 0)::double precision`,
userCacheReadTokens: sql<number>`COALESCE(sum(${messageRequest.cacheReadInputTokens}), 0)::double precision`,
userCacheCreation5mTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreation5mInputTokens}), 0)::double precision`,
userCacheCreation1hTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreation1hInputTokens}), 0)::double precision`,
keyRequests: sql<number>`count(*) FILTER (WHERE ${messageRequest.key} = ${keyString})::int`,
keyCost: sql<string>`COALESCE(sum(${messageRequest.costUsd}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)`,
keyInputTokens: sql<number>`COALESCE(sum(${messageRequest.inputTokens}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)::double precision`,
keyOutputTokens: sql<number>`COALESCE(sum(${messageRequest.outputTokens}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)::double precision`,
keyCacheCreationTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreationInputTokens}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)::double precision`,
keyCacheReadTokens: sql<number>`COALESCE(sum(${messageRequest.cacheReadInputTokens}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)::double precision`,
keyCacheCreation5mTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreation5mInputTokens}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)::double precision`,
keyCacheCreation1hTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreation1hInputTokens}) FILTER (WHERE ${messageRequest.key} = ${keyString}), 0)::double precision`,
})
.from(messageRequest)
.where(
and(
eq(messageRequest.userId, userId),
isNull(messageRequest.deletedAt),
EXCLUDE_WARMUP_CONDITION,
startDate ? gte(messageRequest.createdAt, startDate) : undefined,
endDate ? lt(messageRequest.createdAt, endDate) : undefined
)
)
)
.groupBy(usageLedger.model)
.orderBy(sql`sum(${usageLedger.costUsd}) DESC`);
.groupBy(messageRequest.model),
db
.select({
model: usageLedger.model,
userRequests: sql<number>`count(*)::int`,
userCost: sql<string>`COALESCE(sum(${usageLedger.costUsd}), 0)`,
userInputTokens: sql<number>`COALESCE(sum(${usageLedger.inputTokens}), 0)::double precision`,
userOutputTokens: sql<number>`COALESCE(sum(${usageLedger.outputTokens}), 0)::double precision`,
userCacheCreationTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreationInputTokens}), 0)::double precision`,
userCacheReadTokens: sql<number>`COALESCE(sum(${usageLedger.cacheReadInputTokens}), 0)::double precision`,
userCacheCreation5mTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation5mInputTokens}), 0)::double precision`,
userCacheCreation1hTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation1hInputTokens}), 0)::double precision`,
keyRequests: sql<number>`count(*) FILTER (WHERE ${usageLedger.key} = ${keyString})::int`,
keyCost: sql<string>`COALESCE(sum(${usageLedger.costUsd}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)`,
keyInputTokens: sql<number>`COALESCE(sum(${usageLedger.inputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyOutputTokens: sql<number>`COALESCE(sum(${usageLedger.outputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheCreationTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreationInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheReadTokens: sql<number>`COALESCE(sum(${usageLedger.cacheReadInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheCreation5mTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation5mInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
keyCacheCreation1hTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreation1hInputTokens}) FILTER (WHERE ${usageLedger.key} = ${keyString}), 0)::double precision`,
})
.from(usageLedger)
.where(
and(
eq(usageLedger.userId, userId),
LEDGER_BILLING_CONDITION,
sql`not exists (
select 1
from "message_request" as mr_active
where mr_active.id = ${usageLedger.requestId}
and mr_active.deleted_at is null
and mr_active.key = ${usageLedger.key}
)`,
startDate ? gte(usageLedger.createdAt, startDate) : undefined,
endDate ? lt(usageLedger.createdAt, endDate) : undefined
)
)
.groupBy(usageLedger.model),
]);

const modelBreakdown = mergeStatsSummaryRows([...messageBreakdown, ...ledgerBreakdown]);

const keyOnlyBreakdown = modelBreakdown.filter((row) => (row.keyRequests ?? 0) > 0);

Expand Down Expand Up @@ -1077,26 +1179,26 @@ export async function getMyStatsSummary(
keyModelBreakdown: keyOnlyBreakdown
.map((row) => ({
model: row.model,
requests: row.keyRequests,
requests: row.keyRequests ?? 0,
cost: Number(row.keyCost ?? 0),
inputTokens: row.keyInputTokens,
outputTokens: row.keyOutputTokens,
cacheCreationTokens: row.keyCacheCreationTokens,
cacheReadTokens: row.keyCacheReadTokens,
cacheCreation5mTokens: row.keyCacheCreation5mTokens,
cacheCreation1hTokens: row.keyCacheCreation1hTokens,
inputTokens: row.keyInputTokens ?? 0,
outputTokens: row.keyOutputTokens ?? 0,
cacheCreationTokens: row.keyCacheCreationTokens ?? 0,
cacheReadTokens: row.keyCacheReadTokens ?? 0,
cacheCreation5mTokens: row.keyCacheCreation5mTokens ?? 0,
cacheCreation1hTokens: row.keyCacheCreation1hTokens ?? 0,
}))
.sort((a, b) => b.cost - a.cost),
userModelBreakdown: modelBreakdown.map((row) => ({
model: row.model,
requests: row.userRequests,
requests: row.userRequests ?? 0,
cost: Number(row.userCost ?? 0),
inputTokens: row.userInputTokens,
outputTokens: row.userOutputTokens,
cacheCreationTokens: row.userCacheCreationTokens,
cacheReadTokens: row.userCacheReadTokens,
cacheCreation5mTokens: row.userCacheCreation5mTokens,
cacheCreation1hTokens: row.userCacheCreation1hTokens,
inputTokens: row.userInputTokens ?? 0,
outputTokens: row.userOutputTokens ?? 0,
cacheCreationTokens: row.userCacheCreationTokens ?? 0,
cacheReadTokens: row.userCacheReadTokens ?? 0,
cacheCreation5mTokens: row.userCacheCreation5mTokens ?? 0,
cacheCreation1hTokens: row.userCacheCreation1hTokens ?? 0,
})),
currencyCode,
};
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/actions/my-usage-token-aggregation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,8 @@ describe("my-usage token aggregation", () => {
const res = await getMyStatsSummary({ startDate: "2024-01-01", endDate: "2024-01-01" });
expect(res.ok).toBe(true);

expect(capturedSelections).toHaveLength(1);
expect(capturedSelections).toHaveLength(2);

const selection = capturedSelections[0];
const tokenFields = [
"userInputTokens",
"userOutputTokens",
Expand All @@ -211,8 +210,10 @@ describe("my-usage token aggregation", () => {
"keyCacheCreation1hTokens",
];

for (const field of tokenFields) {
expectNoIntTokenSum(selection, field);
for (const selection of capturedSelections) {
for (const field of tokenFields) {
expectNoIntTokenSum(selection, field);
}
}
});
});
Loading