Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions app/api/routes/statistic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc, func
from sqlalchemy import select, desc, func, case
from sqlalchemy.sql.functions import coalesce
from sqlalchemy import or_
from datetime import datetime, timedelta, UTC
Expand All @@ -27,12 +27,12 @@


# I want a query parameter called "offset: <int>" and "limit: <int>"
@router.get("/usage/realtime", response_model=list[UsageRealtimeResponse])
@router.get("/usage/realtime", response_model=UsageRealtimeResponse)
async def get_usage_realtime(
current_user: User = Depends(get_user_by_api_key),
db: AsyncSession = Depends(get_async_db),
offset: int = Query(0, ge=0),
limit: int = Query(10, ge=1),
page_index: int = Query(0, ge=0),
page_size: int = Query(10, ge=1),
forge_key: str = Query(None, min_length=1),
provider_name: str = Query(None, min_length=1),
model_name: str = Query(None, min_length=1),
Expand Down Expand Up @@ -66,9 +66,11 @@ async def get_usage_realtime(
UsageTracker.output_tokens.label("output_tokens"),
UsageTracker.cached_tokens.label("cached_tokens"),
UsageTracker.cost.label("cost"),
UsageTracker.billable.label("billable"),
func.extract(
"epoch", UsageTracker.updated_at - UsageTracker.created_at
).label("duration"),
func.count().over().label("total"),
)
.join(ProviderKey, UsageTracker.provider_key_id == ProviderKey.id)
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
Expand All @@ -87,18 +89,19 @@ async def get_usage_realtime(
UsageTracker.updated_at.is_not(None),
)
.order_by(desc(UsageTracker.created_at))
.offset(offset)
.limit(limit)
.offset(page_index * page_size)
.limit(page_size)
)

# Execute the query
result = await db.execute(query)
rows = result.fetchall()

# Convert to list of dictionaries
usage_stats = []
items = []
total = 0
for row in rows:
usage_stats.append(
items.append(
{
"timestamp": row.timestamp,
"forge_key": row.forge_key,
Expand All @@ -109,20 +112,27 @@ async def get_usage_realtime(
"output_tokens": row.output_tokens,
"cached_tokens": row.cached_tokens,
"cost": decimal.Decimal(row.cost).normalize(),
"billable": row.billable,
"duration": round(float(row.duration), 2)
if row.duration is not None
else 0.0,
}
)
return [UsageRealtimeResponse(**usage_stat) for usage_stat in usage_stats]
total = row.total
return UsageRealtimeResponse(
items=items,
total=total,
page_size=page_size,
page_index=page_index,
)


@router.get("/usage/realtime/clerk", response_model=list[UsageRealtimeResponse])
@router.get("/usage/realtime/clerk", response_model=UsageRealtimeResponse)
async def get_usage_realtime_clerk(
current_user: User = Depends(get_current_active_user_from_clerk),
db: AsyncSession = Depends(get_async_db),
offset: int = Query(0, ge=0),
limit: int = Query(10, ge=1),
page_index: int = Query(0, ge=0),
page_size: int = Query(10, ge=1),
forge_key: str = Query(None, min_length=1),
provider_name: str = Query(None, min_length=1),
model_name: str = Query(None, min_length=1),
Expand All @@ -132,8 +142,8 @@ async def get_usage_realtime_clerk(
return await get_usage_realtime(
current_user,
db,
offset,
limit,
page_index,
page_size,
forge_key,
provider_name,
model_name,
Expand Down Expand Up @@ -186,6 +196,7 @@ async def get_usage_summary(
func.sum(UsageTracker.output_tokens).label("output_tokens"),
func.sum(UsageTracker.cached_tokens).label("cached_tokens"),
func.sum(UsageTracker.cost).label("cost"),
func.sum(case((UsageTracker.billable, UsageTracker.cost), else_=0)).label("charged_cost"),
)
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
.where(
Expand All @@ -208,6 +219,7 @@ async def get_usage_summary(
"breakdown": [],
"total_tokens": 0,
"total_cost": 0,
"total_charged_cost": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cached_tokens": 0,
Expand All @@ -217,6 +229,7 @@ async def get_usage_summary(
"forge_key": row.forge_key,
"tokens": row.tokens,
"cost": decimal.Decimal(row.cost).normalize(),
"charged_cost": decimal.Decimal(row.charged_cost).normalize(),
"input_tokens": row.input_tokens,
"output_tokens": row.output_tokens,
"cached_tokens": row.cached_tokens,
Expand All @@ -226,6 +239,9 @@ async def get_usage_summary(
data_points[row.time_point]["total_cost"] += decimal.Decimal(
row.cost
).normalize()
data_points[row.time_point]["total_charged_cost"] += decimal.Decimal(
row.charged_cost
).normalize()
data_points[row.time_point]["total_input_tokens"] += row.input_tokens
data_points[row.time_point]["total_output_tokens"] += row.output_tokens
data_points[row.time_point]["total_cached_tokens"] += row.cached_tokens
Expand All @@ -236,6 +252,7 @@ async def get_usage_summary(
breakdown=data_point["breakdown"],
total_tokens=data_point["total_tokens"],
total_cost=data_point["total_cost"],
total_charged_cost=data_point["total_charged_cost"],
total_input_tokens=data_point["total_input_tokens"],
total_output_tokens=data_point["total_output_tokens"],
total_cached_tokens=data_point["total_cached_tokens"],
Expand Down Expand Up @@ -292,6 +309,7 @@ async def get_forge_keys_usage(
func.sum(UsageTracker.output_tokens).label("output_tokens"),
func.sum(UsageTracker.cached_tokens).label("cached_tokens"),
func.sum(UsageTracker.cost).label("cost"),
func.sum(case((UsageTracker.billable, UsageTracker.cost), else_=0)).label("charged_cost"),
)
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
.where(
Expand All @@ -311,6 +329,7 @@ async def get_forge_keys_usage(
forge_key=row.forge_key,
tokens=row.tokens,
cost=decimal.Decimal(row.cost).normalize(),
charged_cost=decimal.Decimal(row.charged_cost).normalize(),
input_tokens=row.input_tokens,
output_tokens=row.output_tokens,
cached_tokens=row.cached_tokens,
Expand Down
3 changes: 3 additions & 0 deletions app/api/routes/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def get_wallet_balance_clerk(
return await get_wallet_balance(user, db)

class TransactionHistoryItem(BaseModel):
transaction_id: str
currency: str
amount: Decimal
status: str
Expand All @@ -75,6 +76,7 @@ async def get_wallet_transactions_history(
# I would also want to get the total count of the transactions within one sql query
query = (
select(
StripePayment.id,
StripePayment.currency,
StripePayment.amount,
StripePayment.status,
Expand All @@ -92,6 +94,7 @@ async def get_wallet_transactions_history(
return TransactionHistoryResponse(
items=[
TransactionHistoryItem(
transaction_id=transaction.id,
currency=transaction.currency,
# Convert cents to dollars for USD
amount=transaction.amount / 100.0 if transaction.currency == "USD" else transaction.amount,
Expand Down
11 changes: 10 additions & 1 deletion app/api/schemas/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def mask_forge_name_or_key(v: str) -> str:
# Otherwise, return the original value (user customized name)
return v

class UsageRealtimeResponse(BaseModel):
class UsageRealtimeItem(BaseModel):
timestamp: datetime | str
forge_key: str
provider_name: str
Expand All @@ -23,6 +23,7 @@ class UsageRealtimeResponse(BaseModel):
cached_tokens: int
duration: float
cost: decimal.Decimal
billable: bool

@field_validator('forge_key')
@classmethod
Expand All @@ -36,11 +37,17 @@ def convert_timestamp_to_iso(cls, v: datetime | str) -> str:
return v
return v.isoformat()

class UsageRealtimeResponse(BaseModel):
total: int
items: list[UsageRealtimeItem]
page_size: int
page_index: int

class UsageSummaryBreakdown(BaseModel):
forge_key: str
tokens: int
cost: decimal.Decimal
charged_cost: decimal.Decimal
input_tokens: int
output_tokens: int
cached_tokens: int
Expand All @@ -56,6 +63,7 @@ class UsageSummaryResponse(BaseModel):
breakdown: list[UsageSummaryBreakdown]
total_tokens: int
total_cost: decimal.Decimal
total_charged_cost: decimal.Decimal
total_input_tokens: int
total_output_tokens: int
total_cached_tokens: int
Expand All @@ -72,6 +80,7 @@ class ForgeKeysUsageSummaryResponse(BaseModel):
forge_key: str
tokens: int
cost: decimal.Decimal
charged_cost: decimal.Decimal
input_tokens: int
output_tokens: int
cached_tokens: int
Expand Down
Loading