diff --git a/alembic/versions/683fc811a969_add_billable_column_to_usage_tracker.py b/alembic/versions/683fc811a969_add_billable_column_to_usage_tracker.py new file mode 100644 index 0000000..fc94f16 --- /dev/null +++ b/alembic/versions/683fc811a969_add_billable_column_to_usage_tracker.py @@ -0,0 +1,24 @@ +"""add billable column to usage_tracker + +Revision ID: 683fc811a969 +Revises: 40e4b59f754d +Create Date: 2025-09-05 10:48:09.623668 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '683fc811a969' +down_revision = '40e4b59f754d' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('usage_tracker', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE')) + + +def downgrade() -> None: + op.drop_column('usage_tracker', 'billable') diff --git a/app/api/routes/wallet.py b/app/api/routes/wallet.py index 6886aca..3b49c22 100644 --- a/app/api/routes/wallet.py +++ b/app/api/routes/wallet.py @@ -34,7 +34,7 @@ async def get_wallet_balance( await WalletService.ensure_wallet(db, user.id) return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD", total_spent=Decimal("0"), total_earned=Decimal("0")) - result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None))) + result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None), UsageTracker.billable)) total_spent = result.scalar_one_or_none() or "0" result = await db.execute(select(func.sum(StripePayment.amount)).where(StripePayment.user_id == user.id, StripePayment.status == "completed")) total_earned = result.scalar_one_or_none() or "0" diff --git a/app/api/schemas/stripe.py b/app/api/schemas/stripe.py index 4f30d61..abe77b9 100644 --- a/app/api/schemas/stripe.py +++ b/app/api/schemas/stripe.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from typing import List, Literal # https://docs.stripe.com/api/checkout/sessions/create @@ -21,7 +21,16 @@ class CreateCheckoutSessionRequest(BaseModel): line_items: List[StripeCheckoutSessionLineItem] # Only allow payment mode for now mode: Literal["payment"] = "payment" + # Attach the session_id to the success_url + # https://docs.stripe.com/payments/checkout/custom-success-page?payment-ui=stripe-hosted&utm_source=chatgpt.com#success-url success_url: str | None = None return_url: str | None = None cancel_url: str | None = None - ui_mode: str = "hosted" \ No newline at end of file + ui_mode: str = "hosted" + + @field_validator("success_url") + @classmethod + def append_session_id_to_success_url(cls, value: str): + if value is None: + return None + return value.rstrip("/") + "?session_id={CHECKOUT_SESSION_ID}" \ No newline at end of file diff --git a/app/models/usage_tracker.py b/app/models/usage_tracker.py index c1798a3..40d4ab8 100644 --- a/app/models/usage_tracker.py +++ b/app/models/usage_tracker.py @@ -2,7 +2,7 @@ from datetime import UTC import uuid -from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL, Boolean from sqlalchemy.orm import relationship from sqlalchemy.dialects.postgresql import UUID from .base import Base @@ -25,5 +25,6 @@ class UsageTracker(Base): cost = Column(DECIMAL(12, 8), nullable=True) currency = Column(String(3), nullable=True) pricing_source = Column(String(255), nullable=True) + billable = Column(Boolean, nullable=False, default=False) provider_key = relationship("ProviderKey", back_populates="usage_tracker") diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 1ef8af3..7c9e763 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -18,6 +18,7 @@ InvalidForgeKeyException, ) from app.models.user import User +from app.models.provider_key import ProviderKey from app.core.database import get_db_session from app.services.wallet_service import WalletService @@ -193,9 +194,6 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: f"Loading provider keys from database for user {self.user_id} (sync)" ) - # Query ProviderKey directly by user_id - from app.models.provider_key import ProviderKey - result = await self.db.execute( select(ProviderKey).filter( ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None @@ -253,9 +251,6 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: f"Loading provider keys from database for user {self.user_id} (async)" ) - # Query ProviderKey directly by user_id - from app.models.provider_key import ProviderKey - result = await self.db.execute( select(ProviderKey).filter( ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None @@ -597,7 +592,10 @@ async def process_request( # Process the request through the adapter usage_tracker_id = None if self.api_key_id is not None and provider_key_id is not None: - await WalletService.wallet_precheck(self.user_id, self.db, provider_key_id) + result = await self.db.execute(select(ProviderKey.billable).where(ProviderKey.id == provider_key_id)) + billable = result.scalar_one_or_none() or False + if billable: + await WalletService.wallet_precheck(self.user_id, self.db) usage_tracker_id = await UsageTrackerService.start_tracking_usage( db=self.db, user_id=self.user_id, @@ -605,6 +603,7 @@ async def process_request( forge_key_id=self.api_key_id, model=actual_model, endpoint=endpoint, + billable=billable, ) else: # For api like list models, we don't have usage tracking diff --git a/app/services/providers/usage_tracker_service.py b/app/services/providers/usage_tracker_service.py index 187f0e5..cf9e6a7 100644 --- a/app/services/providers/usage_tracker_service.py +++ b/app/services/providers/usage_tracker_service.py @@ -24,6 +24,7 @@ async def start_tracking_usage( forge_key_id: int, model: str, endpoint: str, + billable: bool = False, ) -> int: try: usage_tracker = UsageTracker( @@ -33,6 +34,7 @@ async def start_tracking_usage( model=model, endpoint=endpoint, created_at=datetime.now(UTC), + billable=billable, ) db.add(usage_tracker) await db.commit() @@ -83,7 +85,7 @@ async def update_usage_tracker( usage_tracker.pricing_source = price_info['pricing_source'] # Deduct from wallet balance if the provider is not free - if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.provider_key.billable: + if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.billable: try: result = await WalletService.adjust( db, diff --git a/app/services/wallet_service.py b/app/services/wallet_service.py index 1f96e8c..3ed92fe 100644 --- a/app/services/wallet_service.py +++ b/app/services/wallet_service.py @@ -8,7 +8,6 @@ from app.core.logger import get_logger from app.models.wallet import Wallet -from app.models.provider_key import ProviderKey logger = get_logger(name="wallet_service") @@ -62,6 +61,10 @@ async def adjust( currency: str = "USD" ) -> Dict[str, any]: """Adjust wallet balance with optimistic locking and retry""" + + # enforce delta to be Decimal + delta = Decimal(delta) + for attempt in range(MAX_RETRIES): try: # Read current wallet state including version @@ -191,14 +194,8 @@ async def get(db: AsyncSession, account_id: int) -> Optional[Dict[str, any]]: # Helper: perform wallet precheck # ------------------------------------------------------------- @staticmethod - async def wallet_precheck(user_id: int, db: AsyncSession, provider_key_id: int) -> None: + async def wallet_precheck(user_id: int, db: AsyncSession) -> None: """Check wallet balance and ensure user can make requests""" - provider_key = await db.execute(select(ProviderKey).filter(ProviderKey.id == provider_key_id, ProviderKey.billable)) - provider_key = provider_key.scalar_one_or_none() - # If the provider key is not billable, we don't need to check the wallet - if not provider_key: - return - await WalletService.ensure_wallet(db, user_id) check_result = await WalletService.precheck(db, user_id)