diff --git a/app/api/routes/statistic.py b/app/api/routes/statistic.py index 7a2e15c..e5f0698 100644 --- a/app/api/routes/statistic.py +++ b/app/api/routes/statistic.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, Query from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, desc, func +from sqlalchemy import select, desc, func, case from sqlalchemy.sql.functions import coalesce from sqlalchemy import or_ from datetime import datetime, timedelta, UTC @@ -27,12 +27,12 @@ # I want a query parameter called "offset: " and "limit: " -@router.get("/usage/realtime", response_model=list[UsageRealtimeResponse]) +@router.get("/usage/realtime", response_model=UsageRealtimeResponse) async def get_usage_realtime( current_user: User = Depends(get_user_by_api_key), db: AsyncSession = Depends(get_async_db), - offset: int = Query(0, ge=0), - limit: int = Query(10, ge=1), + page_index: int = Query(0, ge=0), + page_size: int = Query(10, ge=1), forge_key: str = Query(None, min_length=1), provider_name: str = Query(None, min_length=1), model_name: str = Query(None, min_length=1), @@ -66,9 +66,11 @@ async def get_usage_realtime( UsageTracker.output_tokens.label("output_tokens"), UsageTracker.cached_tokens.label("cached_tokens"), UsageTracker.cost.label("cost"), + UsageTracker.billable.label("billable"), func.extract( "epoch", UsageTracker.updated_at - UsageTracker.created_at ).label("duration"), + func.count().over().label("total"), ) .join(ProviderKey, UsageTracker.provider_key_id == ProviderKey.id) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) @@ -87,8 +89,8 @@ async def get_usage_realtime( UsageTracker.updated_at.is_not(None), ) .order_by(desc(UsageTracker.created_at)) - .offset(offset) - .limit(limit) + .offset(page_index * page_size) + .limit(page_size) ) # Execute the query @@ -96,9 +98,10 @@ async def get_usage_realtime( rows = result.fetchall() # Convert to list of dictionaries - usage_stats = [] + items = [] + total = 0 for row in rows: - usage_stats.append( + items.append( { "timestamp": row.timestamp, "forge_key": row.forge_key, @@ -109,20 +112,27 @@ async def get_usage_realtime( "output_tokens": row.output_tokens, "cached_tokens": row.cached_tokens, "cost": decimal.Decimal(row.cost).normalize(), + "billable": row.billable, "duration": round(float(row.duration), 2) if row.duration is not None else 0.0, } ) - return [UsageRealtimeResponse(**usage_stat) for usage_stat in usage_stats] + total = row.total + return UsageRealtimeResponse( + items=items, + total=total, + page_size=page_size, + page_index=page_index, + ) -@router.get("/usage/realtime/clerk", response_model=list[UsageRealtimeResponse]) +@router.get("/usage/realtime/clerk", response_model=UsageRealtimeResponse) async def get_usage_realtime_clerk( current_user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db), - offset: int = Query(0, ge=0), - limit: int = Query(10, ge=1), + page_index: int = Query(0, ge=0), + page_size: int = Query(10, ge=1), forge_key: str = Query(None, min_length=1), provider_name: str = Query(None, min_length=1), model_name: str = Query(None, min_length=1), @@ -132,8 +142,8 @@ async def get_usage_realtime_clerk( return await get_usage_realtime( current_user, db, - offset, - limit, + page_index, + page_size, forge_key, provider_name, model_name, @@ -186,6 +196,7 @@ async def get_usage_summary( func.sum(UsageTracker.output_tokens).label("output_tokens"), func.sum(UsageTracker.cached_tokens).label("cached_tokens"), func.sum(UsageTracker.cost).label("cost"), + func.sum(case((UsageTracker.billable, UsageTracker.cost), else_=0)).label("charged_cost"), ) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) .where( @@ -208,6 +219,7 @@ async def get_usage_summary( "breakdown": [], "total_tokens": 0, "total_cost": 0, + "total_charged_cost": 0, "total_input_tokens": 0, "total_output_tokens": 0, "total_cached_tokens": 0, @@ -217,6 +229,7 @@ async def get_usage_summary( "forge_key": row.forge_key, "tokens": row.tokens, "cost": decimal.Decimal(row.cost).normalize(), + "charged_cost": decimal.Decimal(row.charged_cost).normalize(), "input_tokens": row.input_tokens, "output_tokens": row.output_tokens, "cached_tokens": row.cached_tokens, @@ -226,6 +239,9 @@ async def get_usage_summary( data_points[row.time_point]["total_cost"] += decimal.Decimal( row.cost ).normalize() + data_points[row.time_point]["total_charged_cost"] += decimal.Decimal( + row.charged_cost + ).normalize() data_points[row.time_point]["total_input_tokens"] += row.input_tokens data_points[row.time_point]["total_output_tokens"] += row.output_tokens data_points[row.time_point]["total_cached_tokens"] += row.cached_tokens @@ -236,6 +252,7 @@ async def get_usage_summary( breakdown=data_point["breakdown"], total_tokens=data_point["total_tokens"], total_cost=data_point["total_cost"], + total_charged_cost=data_point["total_charged_cost"], total_input_tokens=data_point["total_input_tokens"], total_output_tokens=data_point["total_output_tokens"], total_cached_tokens=data_point["total_cached_tokens"], @@ -292,6 +309,7 @@ async def get_forge_keys_usage( func.sum(UsageTracker.output_tokens).label("output_tokens"), func.sum(UsageTracker.cached_tokens).label("cached_tokens"), func.sum(UsageTracker.cost).label("cost"), + func.sum(case((UsageTracker.billable, UsageTracker.cost), else_=0)).label("charged_cost"), ) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) .where( @@ -311,6 +329,7 @@ async def get_forge_keys_usage( forge_key=row.forge_key, tokens=row.tokens, cost=decimal.Decimal(row.cost).normalize(), + charged_cost=decimal.Decimal(row.charged_cost).normalize(), input_tokens=row.input_tokens, output_tokens=row.output_tokens, cached_tokens=row.cached_tokens, diff --git a/app/api/routes/wallet.py b/app/api/routes/wallet.py index 7bb89f5..c9296be 100644 --- a/app/api/routes/wallet.py +++ b/app/api/routes/wallet.py @@ -51,6 +51,7 @@ async def get_wallet_balance_clerk( return await get_wallet_balance(user, db) class TransactionHistoryItem(BaseModel): + transaction_id: str currency: str amount: Decimal status: str @@ -75,6 +76,7 @@ async def get_wallet_transactions_history( # I would also want to get the total count of the transactions within one sql query query = ( select( + StripePayment.id, StripePayment.currency, StripePayment.amount, StripePayment.status, @@ -92,6 +94,7 @@ async def get_wallet_transactions_history( return TransactionHistoryResponse( items=[ TransactionHistoryItem( + transaction_id=transaction.id, currency=transaction.currency, # Convert cents to dollars for USD amount=transaction.amount / 100.0 if transaction.currency == "USD" else transaction.amount, diff --git a/app/api/schemas/statistic.py b/app/api/schemas/statistic.py index c39d733..2eaacb6 100644 --- a/app/api/schemas/statistic.py +++ b/app/api/schemas/statistic.py @@ -12,7 +12,7 @@ def mask_forge_name_or_key(v: str) -> str: # Otherwise, return the original value (user customized name) return v -class UsageRealtimeResponse(BaseModel): +class UsageRealtimeItem(BaseModel): timestamp: datetime | str forge_key: str provider_name: str @@ -23,6 +23,7 @@ class UsageRealtimeResponse(BaseModel): cached_tokens: int duration: float cost: decimal.Decimal + billable: bool @field_validator('forge_key') @classmethod @@ -36,11 +37,17 @@ def convert_timestamp_to_iso(cls, v: datetime | str) -> str: return v return v.isoformat() +class UsageRealtimeResponse(BaseModel): + total: int + items: list[UsageRealtimeItem] + page_size: int + page_index: int class UsageSummaryBreakdown(BaseModel): forge_key: str tokens: int cost: decimal.Decimal + charged_cost: decimal.Decimal input_tokens: int output_tokens: int cached_tokens: int @@ -56,6 +63,7 @@ class UsageSummaryResponse(BaseModel): breakdown: list[UsageSummaryBreakdown] total_tokens: int total_cost: decimal.Decimal + total_charged_cost: decimal.Decimal total_input_tokens: int total_output_tokens: int total_cached_tokens: int @@ -72,6 +80,7 @@ class ForgeKeysUsageSummaryResponse(BaseModel): forge_key: str tokens: int cost: decimal.Decimal + charged_cost: decimal.Decimal input_tokens: int output_tokens: int cached_tokens: int