Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""add billable column to usage_tracker

Revision ID: 683fc811a969
Revises: 40e4b59f754d
Create Date: 2025-09-05 10:48:09.623668

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '683fc811a969'
down_revision = '40e4b59f754d'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('usage_tracker', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE'))


def downgrade() -> None:
op.drop_column('usage_tracker', 'billable')
2 changes: 1 addition & 1 deletion app/api/routes/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def get_wallet_balance(
await WalletService.ensure_wallet(db, user.id)
return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD", total_spent=Decimal("0"), total_earned=Decimal("0"))

result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None)))
result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None), UsageTracker.billable))
total_spent = result.scalar_one_or_none() or "0"
result = await db.execute(select(func.sum(StripePayment.amount)).where(StripePayment.user_id == user.id, StripePayment.status == "completed"))
total_earned = result.scalar_one_or_none() or "0"
Expand Down
13 changes: 11 additions & 2 deletions app/api/schemas/stripe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import List, Literal

# https://docs.stripe.com/api/checkout/sessions/create
Expand All @@ -21,7 +21,16 @@ class CreateCheckoutSessionRequest(BaseModel):
line_items: List[StripeCheckoutSessionLineItem]
# Only allow payment mode for now
mode: Literal["payment"] = "payment"
# Attach the session_id to the success_url
# https://docs.stripe.com/payments/checkout/custom-success-page?payment-ui=stripe-hosted&utm_source=chatgpt.com#success-url
success_url: str | None = None
return_url: str | None = None
cancel_url: str | None = None
ui_mode: str = "hosted"
ui_mode: str = "hosted"

@field_validator("success_url")
@classmethod
def append_session_id_to_success_url(cls, value: str):
if value is None:
return None
return value.rstrip("/") + "?session_id={CHECKOUT_SESSION_ID}"
3 changes: 2 additions & 1 deletion app/models/usage_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import UTC
import uuid

from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL, Boolean
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID
from .base import Base
Expand All @@ -25,5 +25,6 @@ class UsageTracker(Base):
cost = Column(DECIMAL(12, 8), nullable=True)
currency = Column(String(3), nullable=True)
pricing_source = Column(String(255), nullable=True)
billable = Column(Boolean, nullable=False, default=False)

provider_key = relationship("ProviderKey", back_populates="usage_tracker")
13 changes: 6 additions & 7 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
InvalidForgeKeyException,
)
from app.models.user import User
from app.models.provider_key import ProviderKey
from app.core.database import get_db_session
from app.services.wallet_service import WalletService

Expand Down Expand Up @@ -193,9 +194,6 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:
f"Loading provider keys from database for user {self.user_id} (sync)"
)

# Query ProviderKey directly by user_id
from app.models.provider_key import ProviderKey

result = await self.db.execute(
select(ProviderKey).filter(
ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None
Expand Down Expand Up @@ -253,9 +251,6 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]:
f"Loading provider keys from database for user {self.user_id} (async)"
)

# Query ProviderKey directly by user_id
from app.models.provider_key import ProviderKey

result = await self.db.execute(
select(ProviderKey).filter(
ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None
Expand Down Expand Up @@ -597,14 +592,18 @@ async def process_request(
# Process the request through the adapter
usage_tracker_id = None
if self.api_key_id is not None and provider_key_id is not None:
await WalletService.wallet_precheck(self.user_id, self.db, provider_key_id)
result = await self.db.execute(select(ProviderKey.billable).where(ProviderKey.id == provider_key_id))
billable = result.scalar_one_or_none() or False
if billable:
await WalletService.wallet_precheck(self.user_id, self.db)
usage_tracker_id = await UsageTrackerService.start_tracking_usage(
db=self.db,
user_id=self.user_id,
provider_key_id=provider_key_id,
forge_key_id=self.api_key_id,
model=actual_model,
endpoint=endpoint,
billable=billable,
)
else:
# For api like list models, we don't have usage tracking
Expand Down
4 changes: 3 additions & 1 deletion app/services/providers/usage_tracker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async def start_tracking_usage(
forge_key_id: int,
model: str,
endpoint: str,
billable: bool = False,
) -> int:
try:
usage_tracker = UsageTracker(
Expand All @@ -33,6 +34,7 @@ async def start_tracking_usage(
model=model,
endpoint=endpoint,
created_at=datetime.now(UTC),
billable=billable,
)
db.add(usage_tracker)
await db.commit()
Expand Down Expand Up @@ -83,7 +85,7 @@ async def update_usage_tracker(
usage_tracker.pricing_source = price_info['pricing_source']

# Deduct from wallet balance if the provider is not free
if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.provider_key.billable:
if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.billable:
try:
result = await WalletService.adjust(
db,
Expand Down
13 changes: 5 additions & 8 deletions app/services/wallet_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from app.core.logger import get_logger
from app.models.wallet import Wallet
from app.models.provider_key import ProviderKey

logger = get_logger(name="wallet_service")

Expand Down Expand Up @@ -62,6 +61,10 @@ async def adjust(
currency: str = "USD"
) -> Dict[str, any]:
"""Adjust wallet balance with optimistic locking and retry"""

# enforce delta to be Decimal
delta = Decimal(delta)

for attempt in range(MAX_RETRIES):
try:
# Read current wallet state including version
Expand Down Expand Up @@ -191,14 +194,8 @@ async def get(db: AsyncSession, account_id: int) -> Optional[Dict[str, any]]:
# Helper: perform wallet precheck
# -------------------------------------------------------------
@staticmethod
async def wallet_precheck(user_id: int, db: AsyncSession, provider_key_id: int) -> None:
async def wallet_precheck(user_id: int, db: AsyncSession) -> None:
"""Check wallet balance and ensure user can make requests"""
provider_key = await db.execute(select(ProviderKey).filter(ProviderKey.id == provider_key_id, ProviderKey.billable))
provider_key = provider_key.scalar_one_or_none()
# If the provider key is not billable, we don't need to check the wallet
if not provider_key:
return

await WalletService.ensure_wallet(db, user_id)
check_result = await WalletService.precheck(db, user_id)

Expand Down
Loading