From fc5caf9251b09ee28565dea53e2a4588804f88cf Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Fri, 12 Sep 2025 13:36:42 -0700 Subject: [PATCH] Add admin user functionality (#228) --- .../versions/f45e46b231f3_add_admin_users.py | 29 +++ app/api/dependencies.py | 12 +- app/api/routes/admin.py | 64 +++++ app/api/routes/users.py | 6 + app/api/routes/webhooks.py | 223 ++++-------------- app/api/schemas/admin.py | 12 + app/api/schemas/user.py | 1 + app/main.py | 2 + app/models/admin_users.py | 10 + app/models/user.py | 1 + 10 files changed, 188 insertions(+), 172 deletions(-) create mode 100644 alembic/versions/f45e46b231f3_add_admin_users.py create mode 100644 app/api/routes/admin.py create mode 100644 app/api/schemas/admin.py create mode 100644 app/models/admin_users.py diff --git a/alembic/versions/f45e46b231f3_add_admin_users.py b/alembic/versions/f45e46b231f3_add_admin_users.py new file mode 100644 index 0000000..1f964f7 --- /dev/null +++ b/alembic/versions/f45e46b231f3_add_admin_users.py @@ -0,0 +1,29 @@ +"""add admin users + +Revision ID: f45e46b231f3 +Revises: 683fc811a969 +Create Date: 2025-09-11 13:14:17.066592 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f45e46b231f3' +down_revision = '683fc811a969' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "admin_users", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("user_id"), + ) + + +def downgrade() -> None: + op.drop_table("admin_users") diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 80d3c0b..6cc5981 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -9,7 +9,6 @@ from datetime import datetime import aiohttp -import requests from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jose import JWTError, jwt @@ -113,6 +112,7 @@ async def get_current_user( result = await db.execute( select(User) .options(selectinload(User.api_keys)) # Eager load Forge API keys + .options(selectinload(User.admin_users)) # Eager load admin users .filter(User.username == token_data.username) ) user = result.scalar_one_or_none() @@ -393,6 +393,7 @@ async def get_current_user_from_clerk( result = await db.execute( select(User) .options(selectinload(User.api_keys)) # Eager load Forge API keys + .options(selectinload(User.admin_users)) # Eager load admin users .filter(User.clerk_user_id == clerk_user_id) ) user = result.scalar_one_or_none() @@ -512,3 +513,12 @@ async def get_current_active_user_from_clerk( if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user") return current_user + + +async def get_current_active_admin_user_from_clerk( + current_user: User = Depends(get_current_active_user_from_clerk), +): + """Ensure the user from Clerk is an admin""" + if not current_user.admin_users: + raise HTTPException(status_code=401, detail="User is not an admin") + return current_user diff --git a/app/api/routes/admin.py b/app/api/routes/admin.py new file mode 100644 index 0000000..3643681 --- /dev/null +++ b/app/api/routes/admin.py @@ -0,0 +1,64 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from decimal import Decimal +from pydantic import BaseModel +import uuid + +from app.api.dependencies import get_current_active_admin_user_from_clerk +from app.core.database import get_async_db +from sqlalchemy.ext.asyncio import AsyncSession +from app.models.user import User +from app.models.stripe import StripePayment +from app.core.logger import get_logger +from app.api.schemas.admin import AddBalanceRequest +from app.services.wallet_service import WalletService + +logger = get_logger(name="admin") +router = APIRouter() + +class AddBalanceResponse(BaseModel): + balance: Decimal + blocked: bool + + +@router.post("/add-balance") +async def add_balance( + add_balance_request: AddBalanceRequest, + current_user: User = Depends(get_current_active_admin_user_from_clerk), + db: AsyncSession = Depends(get_async_db), +): + """Add balance to a user""" + user_id = add_balance_request.user_id + email = add_balance_request.email + amount = add_balance_request.amount + + result = await db.execute( + select(User) + .where( + user_id is None or User.id == user_id, + email is None or User.email == email, + ) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + amount_decimal = Decimal(amount / 100.0) + result = await WalletService.adjust(db, user.id, amount_decimal, f"Admin {current_user.id} added balance for user {user.id}") + if not result.get("success"): + raise HTTPException(status_code=400, detail=f"Failed to add balance for user {user.id}: {result.get('reason')}") + + # add the amount to the user's stripe payment + stripe_payment = StripePayment( + id=f"tb_admin_{uuid.uuid4().hex}", + user_id=user.id, + amount=amount, + currency="USD", + status="completed", + raw_data={"reason": f"Admin {current_user.id} added balance for user {user.id}"}, + ) + db.add(stripe_payment) + await db.commit() + logger.info(f"Added balance {amount_decimal} for user {user.id} by admin {current_user.id}") + + return AddBalanceResponse(balance=result.get("balance"), blocked=result.get("blocked")) diff --git a/app/api/routes/users.py b/app/api/routes/users.py index 6f918d1..8ae56cb 100644 --- a/app/api/routes/users.py +++ b/app/api/routes/users.py @@ -86,6 +86,9 @@ async def read_user_me( else: user_data["forge_api_keys"] = [] + if current_user.admin_users: + user_data["is_admin"] = True + return MaskedUser(**user_data) @@ -104,6 +107,9 @@ async def read_user_me_clerk( else: user_data["forge_api_keys"] = [] + if current_user.admin_users: + user_data["is_admin"] = True + return MaskedUser(**user_data) diff --git a/app/api/routes/webhooks.py b/app/api/routes/webhooks.py index cd54c81..6cd7068 100644 --- a/app/api/routes/webhooks.py +++ b/app/api/routes/webhooks.py @@ -1,18 +1,18 @@ import json import os -from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request, status import stripe -from sqlalchemy import select, update +from sqlalchemy import update, delete, select from sqlalchemy.ext.asyncio import AsyncSession from svix import Webhook, WebhookVerificationError +from sqlalchemy.dialects.postgresql import insert from app.core.database import get_async_db from app.core.logger import get_logger -from app.models.user import User from app.models.stripe import StripePayment -from app.services.provider_service import create_default_tensorblock_provider_for_user +from app.models.user import User +from app.models.admin_users import AdminUsers from app.services.wallet_service import WalletService logger = get_logger(name="webhooks") @@ -22,7 +22,7 @@ # Webhook signing secrets for verifying webhook authenticity CLERK_WEBHOOK_SECRET = os.getenv("CLERK_WEBHOOK_SECRET", "") STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") - +CLERK_TENSORBLOCK_ORGANIZATION_ID = os.getenv("CLERK_TENSORBLOCK_ORGANIZATION_ID", "") @router.post("/clerk") async def clerk_webhook_handler(request: Request, db: AsyncSession = Depends(get_async_db)): @@ -30,9 +30,9 @@ async def clerk_webhook_handler(request: Request, db: AsyncSession = Depends(get Handle Clerk webhooks for user events. Key events to handle: - - user.created: Create a new user in our database - - user.updated: Update user details - - user.deleted: Optionally deactivate the user + - organizationMembership.created: Add user to admin users table + - organizationMembership.updated: Update user in admin users table + - organizationMembership.deleted: Remove user from admin users table """ # Get the request body payload = await request.body() @@ -56,13 +56,9 @@ async def clerk_webhook_handler(request: Request, db: AsyncSession = Depends(get # Verify webhook signature with Svix try: - if not CLERK_WEBHOOK_SECRET: - # For development only - should be removed in production - logger.warning("CLERK_WEBHOOK_SECRET is not set") - else: - wh = Webhook(CLERK_WEBHOOK_SECRET) - # This will throw an error if verification fails - wh.verify(payload.decode(), svix_headers) + wh = Webhook(CLERK_WEBHOOK_SECRET) + # This will throw an error if verification fails + wh.verify(payload.decode(), svix_headers) except WebhookVerificationError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -73,173 +69,57 @@ async def clerk_webhook_handler(request: Request, db: AsyncSession = Depends(get try: event_data = json.loads(payload) event_type = event_data.get("type") + logger.info(f"Received Clerk webhook: {event_type}") - # Extract user data - user_data = event_data.get("data", {}) - clerk_user_id = user_data.get("id") - - # Extract email from email_addresses array - email_addresses = user_data.get("email_addresses", []) - primary_email_id = user_data.get("primary_email_address_id") - - email = None - # Find primary email - for email_obj in email_addresses: - if email_obj.get("id") == primary_email_id: - email = email_obj.get("email_address") - break - - # If no primary email, use the first one - if not email and email_addresses: - email = email_addresses[0].get("email_address", "") - - # Get username or fallback to email prefix - username = user_data.get("username") - if not username and email: - username = email.split("@")[0] - - if not clerk_user_id or not email: - return {"status": "error", "message": "Missing required user data"} - - # Handle different event types - if event_type == "user.created": - await handle_user_created(event_data, db) - - elif event_type == "user.updated": - await handle_user_updated(event_data, db) - - elif event_type == "user.deleted": - await handle_user_deleted(event_data, db) - - return {"status": "success", "message": f"Event {event_type} processed"} - + if event_type == "organizationMembership.created" or event_type == "organizationMembership.updated": + await handle_organization_membership_created(event_data, db) + elif event_type == "organizationMembership.deleted": + await handle_organization_membership_deleted(event_data, db) + else: + logger.warning(f"Unhandled Clerk event type: {event_type}") except json.JSONDecodeError: + logger.error(f"Invalid JSON payload", exc_info=True) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON payload" ) except Exception as e: + logger.exception(f"Error processing Clerk webhook: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing webhook: {str(e)}", ) + + logger.info(f"Processed Clerk webhook: {event_type}") + return {"status": "success", "message": f"Event {event_type} processed"} -async def handle_user_created(event_data: dict, db: AsyncSession): - """Handle user.created event from Clerk""" - try: - clerk_user_id = event_data.get("id") - email = event_data.get("email_addresses", [{}])[0].get("email_address", "") - username = ( - event_data.get("username") - or event_data.get("first_name", "") - or email.split("@")[0] - ) - - logger.info(f"Creating user from Clerk webhook: {username} ({email})") - - # Check if user already exists by clerk_user_id - result = await db.execute( - select(User).filter(User.clerk_user_id == clerk_user_id) - ) - user = result.scalar_one_or_none() - if user: - logger.info(f"User {username} already exists with Clerk ID") - return - - # Check if user exists with this email - result = await db.execute( - select(User).filter(User.email == email) - ) - existing_user = result.scalar_one_or_none() - if existing_user: - # Link existing user to Clerk ID - existing_user.clerk_user_id = clerk_user_id - await db.commit() - logger.info(f"Linked existing user {existing_user.username} to Clerk ID") - return - - # Create new user - user = User( - username=username, - email=email, - clerk_user_id=clerk_user_id, - is_active=True, - hashed_password="", # Clerk handles authentication - ) - db.add(user) - await db.commit() - await db.refresh(user) - - # Create default provider for the user - create_default_tensorblock_provider_for_user(user.id, db) - - logger.info(f"Successfully created user {username} with ID {user.id}") - - except Exception as e: - await db.rollback() - logger.error(f"Failed to create user from webhook: {e}", exc_info=True) - raise - - -async def handle_user_updated(event_data: dict, db: AsyncSession): - """Handle user.updated event from Clerk""" - try: - clerk_user_id = event_data.get("id") - email = event_data.get("email_addresses", [{}])[0].get("email_address", "") - username = ( - event_data.get("username") - or event_data.get("first_name", "") - or email.split("@")[0] - ) - - logger.info(f"Updating user from Clerk webhook: {username} ({email})") - - result = await db.execute( - select(User).filter(User.clerk_user_id == clerk_user_id) - ) - user = result.scalar_one_or_none() - if not user: - logger.warning(f"User with Clerk ID {clerk_user_id} not found for update") - return - - # Update user information - user.username = username - user.email = email - await db.commit() - - logger.info(f"Successfully updated user {username}") - - except Exception as e: - await db.rollback() - logger.error(f"Failed to update user from webhook: {e}", exc_info=True) - raise - - -async def handle_user_deleted(event_data: dict, db: AsyncSession): - """Handle user.deleted event from Clerk""" - try: - clerk_user_id = event_data.get("id") - - logger.info(f"Deleting user from Clerk webhook: {clerk_user_id}") - - result = await db.execute( - select(User).filter(User.clerk_user_id == clerk_user_id) - ) - user = result.scalar_one_or_none() - if not user: - logger.warning(f"User with Clerk ID {clerk_user_id} not found for deletion") - return - - # Deactivate user instead of deleting to preserve data integrity - user.is_active = False - await db.commit() - - logger.info(f"Successfully deactivated user {user.username}") - - except Exception as e: - await db.rollback() - logger.error(f"Failed to delete user from webhook: {e}", exc_info=True) - raise +async def handle_organization_membership_created(event_data: dict, db: AsyncSession): + data = event_data['data'] + if data['organization']['id'] != CLERK_TENSORBLOCK_ORGANIZATION_ID: + logger.warning(f"Received organization membership created event for non-TensorBlock organization: {data['organization']['id']}") + return + + clerk_user_id = data['public_user_data']['user_id'] + role = data['role'] + if role != "org:admin": + # delete from admin users table, if present + await db.execute(delete(AdminUsers).where(AdminUsers.user_id.in_(select(User.id).where(User.clerk_user_id == clerk_user_id)))) + else: + # insert into admin users table, if not already present + await db.execute(insert(AdminUsers).from_select(['user_id'], select(User.id).where(User.clerk_user_id == clerk_user_id)).on_conflict_do_nothing()) + await db.commit() + + +async def handle_organization_membership_deleted(event_data: dict, db: AsyncSession): + data = event_data['data'] + if data['organization']['id'] != CLERK_TENSORBLOCK_ORGANIZATION_ID: + logger.warning(f"Received organization membership deleted event for non-TensorBlock organization: {data['organization']['id']}") + return + + clerk_user_id = data['public_user_data']['user_id'] + # delete from admin users table, if present + await db.execute(delete(AdminUsers).where(AdminUsers.user_id.in_(select(User.id).where(User.clerk_user_id == clerk_user_id)))) + await db.commit() @router.post("/stripe") @@ -347,13 +227,14 @@ async def handle_payment_succeeded(event: dict, db: AsyncSession): logger.info(f"Payment succeeded: {amount_decimal} {currency} for customer {user_id}") - await WalletService.adjust( + result = await WalletService.adjust( db, user_id, amount_decimal, f"deposit:stripe:{session_id}", currency ) + assert result.get("success"), f"Failed to adjust wallet balance for user {user_id}: {result.get('reason')}" except Exception as e: logger.error(f"Failed to process payment success: {e}", exc_info=True) diff --git a/app/api/schemas/admin.py b/app/api/schemas/admin.py new file mode 100644 index 0000000..ce27dd3 --- /dev/null +++ b/app/api/schemas/admin.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, field_validator + +class AddBalanceRequest(BaseModel): + user_id: int | None = None + email: str | None = None + amount: int # in cents + + @field_validator("amount") + def validate_amount(cls, value: float): + if value < 100: + raise ValueError("Amount must be greater than 100 cents") + return value diff --git a/app/api/schemas/user.py b/app/api/schemas/user.py index 0b5b866..6007fd3 100644 --- a/app/api/schemas/user.py +++ b/app/api/schemas/user.py @@ -36,6 +36,7 @@ class User(UserInDB): class MaskedUser(UserInDB): + is_admin: bool = False forge_api_keys: list[str] | None = Field( description="List of all API keys with all but last 4 digits masked", default=None, diff --git a/app/main.py b/app/main.py index f63da19..59b1d12 100644 --- a/app/main.py +++ b/app/main.py @@ -21,6 +21,7 @@ wallet, webhooks, stripe, + admin, ) from app.core.database import engine from app.core.logger import get_logger @@ -173,6 +174,7 @@ def create_app() -> FastAPI: v1_router.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"]) v1_router.include_router(statistic.router, prefix='/statistic', tags=["statistic"]) v1_router.include_router(stripe.router, prefix='/stripe', tags=["stripe"]) + v1_router.include_router(admin.router, prefix='/admin', tags=["admin"]) # Claude Code compatible API endpoints v1_router.include_router(claude_code.router, tags=["Claude Code API"]) diff --git a/app/models/admin_users.py b/app/models/admin_users.py new file mode 100644 index 0000000..61ef197 --- /dev/null +++ b/app/models/admin_users.py @@ -0,0 +1,10 @@ +from sqlalchemy import Column, Integer, ForeignKey +from sqlalchemy.orm import relationship + +from app.models.base import Base + +class AdminUsers(Base): + __tablename__ = "admin_users" + + user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) + user = relationship("User", back_populates="admin_users") diff --git a/app/models/user.py b/app/models/user.py index 4d2d1bc..687613d 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -35,3 +35,4 @@ class User(Base): wallet = relationship("Wallet", back_populates="user", uselist=False) # Optional: Add relationship to ApiRequestLog if needed # api_logs = relationship("ApiRequestLog") + admin_users = relationship("AdminUsers", back_populates="user", uselist=False) \ No newline at end of file