From e072a12b0a249ecc1284a936d5270e667d59e72a Mon Sep 17 00:00:00 2001 From: ali asaria Date: Thu, 13 Nov 2025 14:45:46 -0500 Subject: [PATCH 01/26] This is a first example of using fastapi-auth. Test instructions are in the users router --- api.py | 31 ++++++++++ transformerlab/models/users.py | 70 ++++++++++++++++++++++ transformerlab/routers/test_users.py | 31 ++++++++++ transformerlab/shared/models/user_model.py | 33 ++++++++++ 4 files changed, 165 insertions(+) create mode 100644 transformerlab/models/users.py create mode 100644 transformerlab/routers/test_users.py create mode 100644 transformerlab/shared/models/user_model.py diff --git a/api.py b/api.py index 0271c8baa..9fe04d436 100644 --- a/api.py +++ b/api.py @@ -81,6 +81,17 @@ from dotenv import load_dotenv +from transformerlab.models.users import ( + fastapi_users, + auth_backend, + current_active_user, + UserRead, + UserCreate, + UserUpdate, +) +from transformerlab.routers.test_users import router as users_router +from transformerlab.shared.models.user_model import create_db_and_tables, User + load_dotenv() # The following environment variable can be used by other scripts @@ -109,6 +120,7 @@ async def lifespan(app: FastAPI): galleries.update_gallery_cache() spawn_fastchat_controller_subprocess() await db.init() + await create_db_and_tables() print("✅ SEED DATA") # Initialize experiments and cancel any running jobs seed_default_experiments() @@ -230,6 +242,25 @@ async def validation_exception_handler(request, exc): app.include_router(remote.router) app.include_router(fastchat_openai_api.router) +# Include Auth and Registration Routers +app.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth/jwt", + tags=["auth"], +) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +# Include User Management Router (allows authenticated users to view/update their profile) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) +app.include_router(users_router) + # Authentication and session management routes if os.getenv("TFL_MULTITENANT") == "true": from transformerlab.routers import auth # noqa: E402 diff --git a/transformerlab/models/users.py b/transformerlab/models/users.py new file mode 100644 index 000000000..9534e95eb --- /dev/null +++ b/transformerlab/models/users.py @@ -0,0 +1,70 @@ +# users.py +import uuid +from typing import Optional, AsyncGenerator +from fastapi import Depends, Request +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas +from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy +from fastapi_users.db import SQLAlchemyUserDatabase +from transformerlab.shared.models.user_model import User, get_async_session +from sqlalchemy.ext.asyncio import AsyncSession + + +# --- Pydantic Schemas for API interactions --- +class UserRead(schemas.BaseUser[uuid.UUID]): + # Add your custom fields here if you added them to the User model + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass + + +# --- User Manager (Handles registration, password reset, etc.) --- +SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !! + + +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + reset_password_token_secret = SECRET + verification_token_secret = SECRET + + # Optional: Define custom logic after registration + async def on_after_register(self, user: User, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + + +async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(session, User) + + +async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): + yield UserManager(user_db) + + +# --- Authentication Backend (JWT/Bearer Token) --- +bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") + + +def get_jwt_strategy() -> JWTStrategy: + # Token lasts for 3600 seconds (1 hour) + return JWTStrategy(secret=SECRET, lifetime_seconds=3600) + + +auth_backend = AuthenticationBackend( + name="jwt", + transport=bearer_transport, + get_strategy=get_jwt_strategy, +) + +# --- FastAPIUsers Instance (The main utility) --- +fastapi_users = FastAPIUsers[User, uuid.UUID]( + get_user_manager, + [auth_backend], # Add more backends (like Google OAuth) here +) + +# --- Dependency for Protected Routes --- +# This is what you'll use in your route decorators +current_active_user = fastapi_users.current_user(active=True) diff --git a/transformerlab/routers/test_users.py b/transformerlab/routers/test_users.py new file mode 100644 index 000000000..5901e87d0 --- /dev/null +++ b/transformerlab/routers/test_users.py @@ -0,0 +1,31 @@ +from fastapi import APIRouter, Depends +from transformerlab.shared.models.user_model import User +from transformerlab.models.users import ( + current_active_user, +) + + +router = APIRouter(prefix="/test_users", tags=["users"]) + + +@router.get("/authenticated-route") +async def authenticated_route(user: User = Depends(current_active_user)): + return {"message": f"Hello, {user.email}! You are authenticated."} + + +# To test this, register a new user via /auth/register +# curl -X POST 'http://127.0.0.1:8338/auth/register' \ +# -H 'Content-Type: application/json' \ +# -d '{ +# "email": "test@example.com", +# "password": "password123" +# }' + +# Then +# curl -X POST 'http://127.0.0.1:8338/auth/jwt/login' \ +# -H 'Content-Type: application/x-www-form-urlencoded' \ +# -d 'username=test@example.com&password=password123' + +# This will return a token, which you can use to access the authenticated route: +# curl -X GET 'http://127.0.0.1:8338/authenticated-route' \ +# -H 'Authorization: Bearer ' diff --git a/transformerlab/shared/models/user_model.py b/transformerlab/shared/models/user_model.py new file mode 100644 index 000000000..75e6c2cc7 --- /dev/null +++ b/transformerlab/shared/models/user_model.py @@ -0,0 +1,33 @@ +# database.py +from typing import AsyncGenerator +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base +from sqlalchemy.orm import sessionmaker +from fastapi_users.db import SQLAlchemyBaseUserTableUUID + +# Replace with your actual database URL (e.g., PostgreSQL, SQLite) +from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL + +Base: DeclarativeMeta = declarative_base() + + +# 1. Define your User Model (inherits from a FastAPI Users base class) +class User(SQLAlchemyBaseUserTableUUID, Base): + pass # You can add custom fields here later, like 'first_name: str' + + +# 2. Setup the Async Engine and Session +engine = create_async_engine(DATABASE_URL) +AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +# 3. Utility to create tables (run this on app startup) +async def create_db_and_tables(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +# 4. Database session dependency +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSessionLocal() as session: + yield session From d634d83cc8969406977f242c7195f71f83a26b2d Mon Sep 17 00:00:00 2001 From: ali asaria Date: Fri, 14 Nov 2025 11:30:54 -0500 Subject: [PATCH 02/26] first auth support --- api.py | 31 +--------- transformerlab/models/users.py | 70 +++++++++++++++++++++- transformerlab/routers/auth2.py | 88 ++++++++++++++++++++++++++++ transformerlab/routers/test_users.py | 31 ---------- 4 files changed, 160 insertions(+), 60 deletions(-) create mode 100644 transformerlab/routers/auth2.py delete mode 100644 transformerlab/routers/test_users.py diff --git a/api.py b/api.py index 9fe04d436..dc96caa7a 100644 --- a/api.py +++ b/api.py @@ -50,6 +50,7 @@ batched_prompts, recipes, remote, + auth2, ) import torch @@ -81,15 +82,7 @@ from dotenv import load_dotenv -from transformerlab.models.users import ( - fastapi_users, - auth_backend, - current_active_user, - UserRead, - UserCreate, - UserUpdate, -) -from transformerlab.routers.test_users import router as users_router + from transformerlab.shared.models.user_model import create_db_and_tables, User load_dotenv() @@ -241,25 +234,7 @@ async def validation_exception_handler(request, exc): app.include_router(batched_prompts.router) app.include_router(remote.router) app.include_router(fastchat_openai_api.router) - -# Include Auth and Registration Routers -app.include_router( - fastapi_users.get_auth_router(auth_backend), - prefix="/auth/jwt", - tags=["auth"], -) -app.include_router( - fastapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - tags=["auth"], -) -# Include User Management Router (allows authenticated users to view/update their profile) -app.include_router( - fastapi_users.get_users_router(UserRead, UserUpdate), - prefix="/users", - tags=["users"], -) -app.include_router(users_router) +app.include_router(auth2.router) # Authentication and session management routes if os.getenv("TFL_MULTITENANT") == "true": diff --git a/transformerlab/models/users.py b/transformerlab/models/users.py index 9534e95eb..5dee51a77 100644 --- a/transformerlab/models/users.py +++ b/transformerlab/models/users.py @@ -1,12 +1,14 @@ # users.py import uuid -from typing import Optional, AsyncGenerator +from typing import Optional from fastapi import Depends, Request from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy from fastapi_users.db import SQLAlchemyUserDatabase from transformerlab.shared.models.user_model import User, get_async_session from sqlalchemy.ext.asyncio import AsyncSession +from jose import jwt as _jose_jwt +from datetime import datetime, timedelta # --- Pydantic Schemas for API interactions --- @@ -25,6 +27,8 @@ class UserUpdate(schemas.BaseUserUpdate): # --- User Manager (Handles registration, password reset, etc.) --- SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !! +REFRESH_SECRET = "YOUR_REFRESH_TOKEN_SECRET" # !! USE A DIFFERENT SECRET !! +REFRESH_LIFETIME = 60 * 60 * 24 * 7 # 7 days class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): @@ -68,3 +72,67 @@ def get_jwt_strategy() -> JWTStrategy: # --- Dependency for Protected Routes --- # This is what you'll use in your route decorators current_active_user = fastapi_users.current_user(active=True) + + +def get_refresh_strategy() -> JWTStrategy: + return JWTStrategy(secret=REFRESH_SECRET, lifetime_seconds=REFRESH_LIFETIME) + + +# --- Small helper to create access + refresh tokens for manual flows (e.g. refresh endpoint) --- + + +class _JWTAuthenticationHelper: + """Minimal helper that mirrors a login response (access + refresh token). + + We keep this small and explicit so callers (like the `refresh` endpoint in + `routers/auth.py`) can create new access tokens when given a valid + refresh token. + """ + + def __init__( + self, + access_secret: str, + refresh_secret: str, + access_lifetime: int = 3600, + refresh_lifetime: int = REFRESH_LIFETIME, + ): + self.access_secret = access_secret + self.refresh_secret = refresh_secret + self.access_lifetime = access_lifetime + self.refresh_lifetime = refresh_lifetime + + def _create_token(self, user, secret: str, lifetime_seconds: int) -> str: + now = datetime.utcnow() + exp = now + timedelta(seconds=lifetime_seconds) + payload = { + "sub": str(user.id), + "email": getattr(user, "email", None), + "exp": int(exp.timestamp()), + } + return _jose_jwt.encode(payload, secret, algorithm="HS256") + + def get_login_response(self, user) -> dict: + """Return a dict similar to what FastAPI-Users returns on login. + + Keys: + - access_token: short-lived JWT + - refresh_token: long-lived JWT (can be validated with refresh strategy) + - token_type: 'bearer' + - expires_in: seconds until access token expiry + """ + access = self._create_token(user, self.access_secret, self.access_lifetime) + refresh = self._create_token(user, self.refresh_secret, self.refresh_lifetime) + return { + "access_token": access, + "refresh_token": refresh, + "token_type": "bearer", + "expires_in": self.access_lifetime, + } + + +# Module-level helpers for imports elsewhere +jwt_authentication = _JWTAuthenticationHelper( + SECRET, REFRESH_SECRET, access_lifetime=3600, refresh_lifetime=REFRESH_LIFETIME +) +access_strategy = get_jwt_strategy() +refresh_strategy = get_refresh_strategy() diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py new file mode 100644 index 000000000..755fb1d55 --- /dev/null +++ b/transformerlab/routers/auth2.py @@ -0,0 +1,88 @@ +from fastapi import APIRouter, Depends, HTTPException +from transformerlab.shared.models.user_model import User +from transformerlab.models.users import ( + fastapi_users, + auth_backend, + current_active_user, + UserRead, + UserCreate, + UserUpdate, + get_user_manager, + get_refresh_strategy, + jwt_authentication, +) + +from jose import jwt, JWTError + +router = APIRouter(tags=["users"]) + + +# Include Auth and Registration Routers +router.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth/jwt", + tags=["auth"], +) +router.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +# Include User Management Router (allows authenticated users to view/update their profile) +router.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) + + +@router.get("/test-users/authenticated-route") +async def authenticated_route(user: User = Depends(current_active_user)): + return {"message": f"Hello, {user.email}! You are authenticated."} + + +# To test this, register a new user via /auth/register +# curl -X POST 'http://127.0.0.1:8338/auth/register' \ +# -H 'Content-Type: application/json' \ +# -d '{ +# "email": "test@example.com", +# "password": "password123" +# }' + +# Then +# curl -X POST 'http://127.0.0.1:8338/auth/jwt/login' \ +# -H 'Content-Type: application/x-www-form-urlencoded' \ +# -d 'username=test@example.com&password=password123' + +# This will return a token, which you can use to access the authenticated route: +# curl -X GET 'http://127.0.0.1:8338/authenticated-route' \ +# -H 'Authorization: Bearer ' + + +@router.post("/auth/refresh") +async def refresh_access_token( + refresh_token: str, # Sent by the client in the request body + user_manager=Depends(get_user_manager), +): + try: + # 1. Decode and Validate the Refresh Token + # Get a fresh refresh strategy instance and use its secret to decode + refresh_strategy = get_refresh_strategy() + payload = jwt.decode(refresh_token, str(refresh_strategy.secret), algorithms=["HS256"]) + user_id = payload.get("sub") + + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid refresh token payload") + + # 2. Get the user object from the database + user = await user_manager.get(user_id) + if user is None or not user.is_active: + raise HTTPException(status_code=401, detail="User inactive or not found") + + # 3. Create a NEW Access Token (using the short-lived strategy from the main JWT) + new_access_token = jwt_authentication.get_login_response(user) # Needs custom helper + + return {"access_token": new_access_token["access_token"], "token_type": "bearer"} + + except JWTError: + raise HTTPException(status_code=401, detail="Expired or invalid refresh token") diff --git a/transformerlab/routers/test_users.py b/transformerlab/routers/test_users.py deleted file mode 100644 index 5901e87d0..000000000 --- a/transformerlab/routers/test_users.py +++ /dev/null @@ -1,31 +0,0 @@ -from fastapi import APIRouter, Depends -from transformerlab.shared.models.user_model import User -from transformerlab.models.users import ( - current_active_user, -) - - -router = APIRouter(prefix="/test_users", tags=["users"]) - - -@router.get("/authenticated-route") -async def authenticated_route(user: User = Depends(current_active_user)): - return {"message": f"Hello, {user.email}! You are authenticated."} - - -# To test this, register a new user via /auth/register -# curl -X POST 'http://127.0.0.1:8338/auth/register' \ -# -H 'Content-Type: application/json' \ -# -d '{ -# "email": "test@example.com", -# "password": "password123" -# }' - -# Then -# curl -X POST 'http://127.0.0.1:8338/auth/jwt/login' \ -# -H 'Content-Type: application/x-www-form-urlencoded' \ -# -d 'username=test@example.com&password=password123' - -# This will return a token, which you can use to access the authenticated route: -# curl -X GET 'http://127.0.0.1:8338/authenticated-route' \ -# -H 'Authorization: Bearer ' From 4481d2d3a48988ca18d0cbd63a21c7aef1b88e93 Mon Sep 17 00:00:00 2001 From: ali asaria Date: Mon, 17 Nov 2025 07:05:26 -0500 Subject: [PATCH 03/26] add more logging --- transformerlab/models/users.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformerlab/models/users.py b/transformerlab/models/users.py index 5dee51a77..ba7404e8c 100644 --- a/transformerlab/models/users.py +++ b/transformerlab/models/users.py @@ -39,6 +39,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") + async def on_after_forgot_password(self, user: User, token: str, request: Request | None = None): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify(self, user: User, token: str, request: Request | None = None): + print(f"Verification requested for user {user.id}. Verification token: {token}") + async def get_user_db(session: AsyncSession = Depends(get_async_session)): yield SQLAlchemyUserDatabase(session, User) From 29f6567e09363ed4d1843c5eee4b07ff4b53e0c1 Mon Sep 17 00:00:00 2001 From: ali asaria Date: Mon, 17 Nov 2025 10:40:45 -0500 Subject: [PATCH 04/26] placeholder to get a user's teams --- transformerlab/routers/auth2.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index 755fb1d55..7756dcb69 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -86,3 +86,16 @@ async def refresh_access_token( except JWTError: raise HTTPException(status_code=401, detail="Expired or invalid refresh token") + + +@router.get("/users/me/teams") +async def get_user_teams(user: User = Depends(current_active_user)): + # Placeholder implementation + # In a real application, fetch teams from the database + return { + "user_id": user.id, + "teams": [ + {"id": "550e8400-e29b-41d4-a716-446655440000", "name": "Transformer Lab"}, + {"id": "6ba7b810-9dad-11d1-80b4-00c04fd430c8", "name": "Team 2"}, + ], + } From 2b5827610ed6723a1627642940de4a982629286d Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 11:04:07 -0500 Subject: [PATCH 05/26] Add fastapi-users[sqlalchemy] --- requirements-no-gpu-uv.txt | 31 +++++++++++++++-- requirements-rocm-uv.txt | 31 +++++++++++++++-- requirements-rocm.in | 1 + requirements-uv.txt | 69 ++++++++++++++++---------------------- requirements.in | 1 + 5 files changed, 88 insertions(+), 45 deletions(-) diff --git a/requirements-no-gpu-uv.txt b/requirements-no-gpu-uv.txt index 507218292..205407127 100644 --- a/requirements-no-gpu-uv.txt +++ b/requirements-no-gpu-uv.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements.in -o requirements-no-gpu-uv.txt --index-strategy unsafe-best-match +# uv pip compile requirements.in -o requirements-no-gpu-uv.txt absl-py==2.1.0 # via tensorboard accelerate==1.3.0 @@ -27,6 +27,10 @@ anyio==4.8.0 # sse-starlette # starlette # watchfiles +argon2-cffi==23.1.0 + # via pwdlib +argon2-cffi-bindings==25.1.0 + # via argon2-cffi attrs==25.1.0 # via aiohttp audioread==3.0.1 @@ -39,6 +43,8 @@ azure-core==1.33.0 # azure-identity azure-identity==1.21.0 # via markitdown +bcrypt==4.3.0 + # via pwdlib beautifulsoup4==4.13.3 # via # markdownify @@ -51,6 +57,7 @@ certifi==2022.12.7 # sentry-sdk cffi==1.17.1 # via + # argon2-cffi-bindings # cryptography # soundfile charset-normalizer==2.1.1 @@ -93,12 +100,16 @@ dill==0.3.8 # datasets # evaluate # multiprocess +dnspython==2.8.0 + # via email-validator docker-pycreds==0.4.0 # via wandb einops==0.8.0 # via # -r requirements.in # controlnet-aux +email-validator==2.3.0 + # via fastapi-users et-xmlfile==2.0.0 # via openpyxl evaluate==0.4.3 @@ -106,7 +117,14 @@ evaluate==0.4.3 fastapi==0.115.7 # via # -r requirements.in + # fastapi-users # transformerlab-inference +fastapi-users==15.0.1 + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy +fastapi-users-db-sqlalchemy==7.0.0 + # via fastapi-users filelock==3.13.1 # via # controlnet-aux @@ -169,6 +187,7 @@ humanfriendly==10.0 idna==3.4 # via # anyio + # email-validator # httpx # requests # yarl @@ -205,6 +224,8 @@ macmon-python==0.1.2 # via -r requirements.in magika==0.6.1 # via markitdown +makefun==1.16.0 + # via fastapi-users mammoth==1.9.0 # via markitdown markdown==3.7 @@ -352,6 +373,8 @@ psutil==6.1.1 # peft # transformerlab-inference # wandb +pwdlib==0.2.1 + # via fastapi-users pyarrow==19.0.0 # via datasets pycparser==2.22 @@ -381,6 +404,7 @@ pygments==2.19.1 # rich pyjwt==2.10.1 # via + # fastapi-users # msal # workos pytest==8.4.2 @@ -396,6 +420,7 @@ python-dotenv==1.0.1 python-multipart==0.0.20 # via # -r requirements.in + # fastapi-users # mcp python-pptx==1.0.2 # via markitdown @@ -496,7 +521,9 @@ soxr==0.5.0.post1 speechrecognition==3.14.2 # via markitdown sqlalchemy==2.0.38 - # via -r requirements.in + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp starlette==0.45.3 diff --git a/requirements-rocm-uv.txt b/requirements-rocm-uv.txt index 84ff54a64..5bbec84e4 100644 --- a/requirements-rocm-uv.txt +++ b/requirements-rocm-uv.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements-rocm.in -o requirements-rocm-uv.txt --index-strategy unsafe-best-match +# uv pip compile requirements-rocm.in -o requirements-rocm-uv.txt --index-strategy unsafe-best-match --python-platform linux absl-py==2.2.2 # via tensorboard accelerate==1.6.0 @@ -27,6 +27,10 @@ anyio==4.9.0 # sse-starlette # starlette # watchfiles +argon2-cffi==23.1.0 + # via pwdlib +argon2-cffi-bindings==25.1.0 + # via argon2-cffi attrs==25.3.0 # via aiohttp audioread==3.0.1 @@ -39,6 +43,8 @@ azure-core==1.33.0 # azure-identity azure-identity==1.21.0 # via markitdown +bcrypt==4.3.0 + # via pwdlib beautifulsoup4==4.13.4 # via # markdownify @@ -51,6 +57,7 @@ certifi==2022.12.7 # sentry-sdk cffi==1.17.1 # via + # argon2-cffi-bindings # cryptography # soundfile charset-normalizer==2.1.1 @@ -93,12 +100,16 @@ dill==0.3.8 # datasets # evaluate # multiprocess +dnspython==2.8.0 + # via email-validator docker-pycreds==0.4.0 # via wandb einops==0.8.1 # via # -r requirements-rocm.in # controlnet-aux +email-validator==2.3.0 + # via fastapi-users et-xmlfile==2.0.0 # via openpyxl evaluate==0.4.3 @@ -106,7 +117,14 @@ evaluate==0.4.3 fastapi==0.115.12 # via # -r requirements-rocm.in + # fastapi-users # transformerlab-inference +fastapi-users==15.0.1 + # via + # -r requirements-rocm.in + # fastapi-users-db-sqlalchemy +fastapi-users-db-sqlalchemy==7.0.0 + # via fastapi-users filelock==3.13.1 # via # controlnet-aux @@ -169,6 +187,7 @@ humanfriendly==10.0 idna==3.4 # via # anyio + # email-validator # httpx # requests # yarl @@ -205,6 +224,8 @@ macmon-python==0.1.2 # via -r requirements-rocm.in magika==0.6.1 # via markitdown +makefun==1.16.0 + # via fastapi-users mammoth==1.9.0 # via markitdown markdown==3.8 @@ -350,6 +371,8 @@ psutil==7.0.0 # peft # transformerlab-inference # wandb +pwdlib==0.2.1 + # via fastapi-users pyarrow==20.0.0 # via datasets pycparser==2.22 @@ -379,6 +402,7 @@ pygments==2.19.1 # rich pyjwt==2.10.1 # via + # fastapi-users # msal # workos pyrsmi==0.2.0 @@ -396,6 +420,7 @@ python-dotenv==1.1.0 python-multipart==0.0.20 # via # -r requirements-rocm.in + # fastapi-users # mcp python-pptx==1.0.2 # via markitdown @@ -499,7 +524,9 @@ soxr==0.5.0.post1 speechrecognition==3.14.2 # via markitdown sqlalchemy==2.0.40 - # via -r requirements-rocm.in + # via + # -r requirements-rocm.in + # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp starlette==0.46.2 diff --git a/requirements-rocm.in b/requirements-rocm.in index 3b567d787..532ecc0d4 100644 --- a/requirements-rocm.in +++ b/requirements-rocm.in @@ -5,6 +5,7 @@ datasets==3.6.0 einops evaluate fastapi +fastapi-users[sqlalchemy] packaging psutil python-multipart diff --git a/requirements-uv.txt b/requirements-uv.txt index 875fa6130..be67b1e91 100644 --- a/requirements-uv.txt +++ b/requirements-uv.txt @@ -27,6 +27,10 @@ anyio==4.8.0 # sse-starlette # starlette # watchfiles +argon2-cffi==23.1.0 + # via pwdlib +argon2-cffi-bindings==25.1.0 + # via argon2-cffi attrs==25.1.0 # via aiohttp audioread==3.0.1 @@ -39,6 +43,8 @@ azure-core==1.33.0 # azure-identity azure-identity==1.21.0 # via markitdown +bcrypt==4.3.0 + # via pwdlib beautifulsoup4==4.13.3 # via # markdownify @@ -51,6 +57,7 @@ certifi==2022.12.7 # sentry-sdk cffi==1.17.1 # via + # argon2-cffi-bindings # cryptography # soundfile charset-normalizer==2.1.1 @@ -93,12 +100,16 @@ dill==0.3.8 # datasets # evaluate # multiprocess +dnspython==2.8.0 + # via email-validator docker-pycreds==0.4.0 # via wandb einops==0.8.0 # via # -r requirements.in # controlnet-aux +email-validator==2.3.0 + # via fastapi-users et-xmlfile==2.0.0 # via openpyxl evaluate==0.4.3 @@ -106,7 +117,14 @@ evaluate==0.4.3 fastapi==0.115.7 # via # -r requirements.in + # fastapi-users # transformerlab-inference +fastapi-users==15.0.1 + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy +fastapi-users-db-sqlalchemy==7.0.0 + # via fastapi-users filelock==3.13.1 # via # controlnet-aux @@ -169,6 +187,7 @@ humanfriendly==10.0 idna==3.4 # via # anyio + # email-validator # httpx # requests # yarl @@ -205,6 +224,8 @@ macmon-python==0.1.2 # via -r requirements.in magika==0.6.1 # via markitdown +makefun==1.16.0 + # via fastapi-users mammoth==1.9.0 # via markitdown markdown==3.7 @@ -279,45 +300,8 @@ numpy==2.1.2 # torchvision # transformerlab-inference # transformers -nvidia-cublas-cu12==12.8.4.1 - # via - # nvidia-cudnn-cu12 - # nvidia-cusolver-cu12 - # torch -nvidia-cuda-cupti-cu12==12.8.90 - # via torch -nvidia-cuda-nvrtc-cu12==12.8.93 - # via torch -nvidia-cuda-runtime-cu12==12.8.90 - # via torch -nvidia-cudnn-cu12==9.10.2.21 - # via torch -nvidia-cufft-cu12==11.3.3.83 - # via torch -nvidia-cufile-cu12==1.13.1.3 - # via torch -nvidia-curand-cu12==10.3.9.90 - # via torch -nvidia-cusolver-cu12==11.7.3.90 - # via torch -nvidia-cusparse-cu12==12.5.8.93 - # via - # nvidia-cusolver-cu12 - # torch -nvidia-cusparselt-cu12==0.7.1 - # via torch nvidia-ml-py==12.575.51 # via -r requirements.in -nvidia-nccl-cu12==2.27.3 - # via torch -nvidia-nvjitlink-cu12==12.8.93 - # via - # nvidia-cufft-cu12 - # nvidia-cusolver-cu12 - # nvidia-cusparse-cu12 - # torch -nvidia-nvtx-cu12==12.8.90 - # via torch olefile==0.47 # via markitdown onnxruntime==1.21.0 @@ -389,6 +373,8 @@ psutil==6.1.1 # peft # transformerlab-inference # wandb +pwdlib==0.2.1 + # via fastapi-users pyarrow==19.0.0 # via datasets pycparser==2.22 @@ -418,6 +404,7 @@ pygments==2.19.1 # rich pyjwt==2.10.1 # via + # fastapi-users # msal # workos pytest==8.4.2 @@ -433,6 +420,7 @@ python-dotenv==1.0.1 python-multipart==0.0.20 # via # -r requirements.in + # fastapi-users # mcp python-pptx==1.0.2 # via markitdown @@ -505,7 +493,6 @@ setproctitle==1.3.5 setuptools==70.2.0 # via # tensorboard - # triton # wandb shellingham==1.5.4 # via typer @@ -534,7 +521,9 @@ soxr==0.5.0.post1 speechrecognition==3.14.2 # via markitdown sqlalchemy==2.0.38 - # via -r requirements.in + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp starlette==0.45.3 @@ -605,8 +594,6 @@ transformers==4.57.1 # -r requirements.in # peft # sentence-transformers -triton==3.4.0 - # via torch typer==0.15.4 # via mcp typing-extensions==4.12.2 diff --git a/requirements.in b/requirements.in index bb0b00989..d5247ef4f 100644 --- a/requirements.in +++ b/requirements.in @@ -5,6 +5,7 @@ datasets==3.6.0 einops evaluate fastapi +fastapi-users[sqlalchemy] packaging psutil python-dotenv From f29275e77b6db18e7e8297b2cfda640a7fca73ce Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 11:08:40 -0500 Subject: [PATCH 06/26] Add python-jose --- requirements-no-gpu-uv.txt | 12 ++++++++++++ requirements-rocm-uv.txt | 12 ++++++++++++ requirements-rocm.in | 1 + requirements-uv.txt | 12 ++++++++++++ requirements.in | 1 + 5 files changed, 38 insertions(+) diff --git a/requirements-no-gpu-uv.txt b/requirements-no-gpu-uv.txt index 205407127..815fbc11d 100644 --- a/requirements-no-gpu-uv.txt +++ b/requirements-no-gpu-uv.txt @@ -84,6 +84,7 @@ cryptography==44.0.2 # msal # pdfminer-six # pyjwt + # python-jose # workos datasets==3.6.0 # via @@ -104,6 +105,8 @@ dnspython==2.8.0 # via email-validator docker-pycreds==0.4.0 # via wandb +ecdsa==0.19.1 + # via python-jose einops==0.8.0 # via # -r requirements.in @@ -377,6 +380,10 @@ pwdlib==0.2.1 # via fastapi-users pyarrow==19.0.0 # via datasets +pyasn1==0.6.1 + # via + # python-jose + # rsa pycparser==2.22 # via cffi pydantic==2.11.7 @@ -417,6 +424,8 @@ python-dotenv==1.0.1 # magika # mcp # pydantic-settings +python-jose==3.5.0 + # via -r requirements.in python-multipart==0.0.20 # via # -r requirements.in @@ -461,6 +470,8 @@ rich==13.9.4 # via # transformerlab-inference # typer +rsa==4.9.1 + # via python-jose safetensors==0.5.3 # via # accelerate @@ -502,6 +513,7 @@ six==1.17.0 # via # azure-core # docker-pycreds + # ecdsa # markdownify # python-dateutil # tensorboard diff --git a/requirements-rocm-uv.txt b/requirements-rocm-uv.txt index 5bbec84e4..054d372b7 100644 --- a/requirements-rocm-uv.txt +++ b/requirements-rocm-uv.txt @@ -84,6 +84,7 @@ cryptography==44.0.2 # msal # pdfminer-six # pyjwt + # python-jose # workos datasets==3.6.0 # via @@ -104,6 +105,8 @@ dnspython==2.8.0 # via email-validator docker-pycreds==0.4.0 # via wandb +ecdsa==0.19.1 + # via python-jose einops==0.8.1 # via # -r requirements-rocm.in @@ -375,6 +378,10 @@ pwdlib==0.2.1 # via fastapi-users pyarrow==20.0.0 # via datasets +pyasn1==0.6.1 + # via + # python-jose + # rsa pycparser==2.22 # via cffi pydantic==2.11.7 @@ -417,6 +424,8 @@ python-dotenv==1.1.0 # magika # mcp # pydantic-settings +python-jose==3.5.0 + # via -r requirements-rocm.in python-multipart==0.0.20 # via # -r requirements-rocm.in @@ -463,6 +472,8 @@ rich==14.0.0 # via # transformerlab-inference # typer +rsa==4.9.1 + # via python-jose safetensors==0.5.3 # via # accelerate @@ -505,6 +516,7 @@ six==1.17.0 # via # azure-core # docker-pycreds + # ecdsa # markdownify # python-dateutil # tensorboard diff --git a/requirements-rocm.in b/requirements-rocm.in index 532ecc0d4..ab3ab9372 100644 --- a/requirements-rocm.in +++ b/requirements-rocm.in @@ -10,6 +10,7 @@ packaging psutil python-multipart python-dotenv +python-jose[cryptography] pydantic>= 2.0 nltk==3.9.1 scipy diff --git a/requirements-uv.txt b/requirements-uv.txt index be67b1e91..8df2b5d3c 100644 --- a/requirements-uv.txt +++ b/requirements-uv.txt @@ -84,6 +84,7 @@ cryptography==44.0.2 # msal # pdfminer-six # pyjwt + # python-jose # workos datasets==3.6.0 # via @@ -104,6 +105,8 @@ dnspython==2.8.0 # via email-validator docker-pycreds==0.4.0 # via wandb +ecdsa==0.19.1 + # via python-jose einops==0.8.0 # via # -r requirements.in @@ -377,6 +380,10 @@ pwdlib==0.2.1 # via fastapi-users pyarrow==19.0.0 # via datasets +pyasn1==0.6.1 + # via + # python-jose + # rsa pycparser==2.22 # via cffi pydantic==2.11.7 @@ -417,6 +424,8 @@ python-dotenv==1.0.1 # magika # mcp # pydantic-settings +python-jose==3.5.0 + # via -r requirements.in python-multipart==0.0.20 # via # -r requirements.in @@ -461,6 +470,8 @@ rich==13.9.4 # via # transformerlab-inference # typer +rsa==4.9.1 + # via python-jose safetensors==0.5.3 # via # accelerate @@ -502,6 +513,7 @@ six==1.17.0 # via # azure-core # docker-pycreds + # ecdsa # markdownify # python-dateutil # tensorboard diff --git a/requirements.in b/requirements.in index d5247ef4f..0fd3760da 100644 --- a/requirements.in +++ b/requirements.in @@ -9,6 +9,7 @@ fastapi-users[sqlalchemy] packaging psutil python-dotenv +python-jose[cryptography] python-multipart pydantic>= 2.0 nltk==3.9.1 From fe58699c10b56ad343fe7d9747c01996e36df5be Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 12:10:56 -0500 Subject: [PATCH 07/26] Create team and user_team table --- transformerlab/shared/models/user_model.py | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/transformerlab/shared/models/user_model.py b/transformerlab/shared/models/user_model.py index 75e6c2cc7..063544d13 100644 --- a/transformerlab/shared/models/user_model.py +++ b/transformerlab/shared/models/user_model.py @@ -4,6 +4,8 @@ from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker from fastapi_users.db import SQLAlchemyBaseUserTableUUID +from sqlalchemy import Column, String, ForeignKey +import uuid # Replace with your actual database URL (e.g., PostgreSQL, SQLite) from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL @@ -16,6 +18,22 @@ class User(SQLAlchemyBaseUserTableUUID, Base): pass # You can add custom fields here later, like 'first_name: str' +# 2. Define Team Model +class Team(Base): + __tablename__ = "teams" + + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + name = Column(String, nullable=False) + + +# 3. Define User-Team Association Model +class UserTeam(Base): + __tablename__ = "users_teams" + + user_id = Column(String, ForeignKey("users.id"), primary_key=True) + team_id = Column(String, ForeignKey("teams.id"), primary_key=True) + + # 2. Setup the Async Engine and Session engine = create_async_engine(DATABASE_URL) AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @@ -31,3 +49,16 @@ async def create_db_and_tables(): async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as session: yield session + + +# 5. Create default team if not exists +async def create_default_team(session: AsyncSession) -> Team: + stmt = select(Team).where(Team.name == "Default Team") + result = await session.execute(stmt) + team = result.scalar_one_or_none() + if not team: + team = Team(name="Default Team") + session.add(team) + await session.commit() + await session.refresh(team) + return team From ae487f67dafa6dd01ed2c999f11eef15425dc72e Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 12:20:40 -0500 Subject: [PATCH 08/26] Add endpoint for team create update and delete --- transformerlab/routers/teams.py | 104 ++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 transformerlab/routers/teams.py diff --git a/transformerlab/routers/teams.py b/transformerlab/routers/teams.py new file mode 100644 index 000000000..ef67c2746 --- /dev/null +++ b/transformerlab/routers/teams.py @@ -0,0 +1,104 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session +from transformerlab.models.users import current_active_user +from pydantic import BaseModel +from sqlalchemy import select, delete, update + + +class TeamCreate(BaseModel): + name: str + + +class TeamUpdate(BaseModel): + name: str + + +class TeamResponse(BaseModel): + id: str + name: str + + +router = APIRouter(tags=["teams"]) + + +@router.post("/teams", response_model=TeamResponse) +async def create_team( + team_data: TeamCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + # Create team + team = Team(name=team_data.name) + session.add(team) + await session.commit() + await session.refresh(team) + + # Add user to the team + user_team = UserTeam(user_id=str(user.id), team_id=team.id) + session.add(user_team) + await session.commit() + + return TeamResponse(id=team.id, name=team.name) + + +@router.put("/teams/{team_id}", response_model=TeamResponse) +async def update_team( + team_id: str, + team_data: TeamUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + # Check if user is in the team + stmt = select(UserTeam).where(UserTeam.user_id == str(user.id), UserTeam.team_id == team_id) + result = await session.execute(stmt) + if not result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Not authorized to update this team") + + # Update + stmt = update(Team).where(Team.id == team_id).values(name=team_data.name) + await session.execute(stmt) + await session.commit() + + # Fetch updated + stmt = select(Team).where(Team.id == team_id) + result = await session.execute(stmt) + team = result.scalar_one() + + return TeamResponse(id=team.id, name=team.name) + + +@router.delete("/teams/{team_id}") +async def delete_team( + team_id: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + # Check if user is in the team + stmt = select(UserTeam).where(UserTeam.user_id == str(user.id), UserTeam.team_id == team_id) + result = await session.execute(stmt) + if not result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Not authorized to delete this team") + + # Check if user has other teams + stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) + result = await session.execute(stmt) + user_teams = result.scalars().all() + if len(user_teams) <= 1: + raise HTTPException(status_code=400, detail="Cannot delete the last team") + + # Check if team has only this user + stmt = select(UserTeam).where(UserTeam.team_id == team_id) + result = await session.execute(stmt) + team_users = result.scalars().all() + if len(team_users) > 1: + raise HTTPException(status_code=400, detail="Cannot delete team with multiple users") + + # Delete associations and team + stmt = delete(UserTeam).where(UserTeam.team_id == team_id) + await session.execute(stmt) + stmt = delete(Team).where(Team.id == team_id) + await session.execute(stmt) + await session.commit() + + return {"message": "Team deleted"} \ No newline at end of file From 42b773bc5dadaea454e3ccc3c0a328b3cbfe080d Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 12:28:55 -0500 Subject: [PATCH 09/26] Update get_user_teams --- transformerlab/routers/auth2.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index 7756dcb69..d0096ad03 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, Depends, HTTPException -from transformerlab.shared.models.user_model import User +from transformerlab.shared.models.user_model import User, Team, UserTeam from transformerlab.models.users import ( fastapi_users, auth_backend, @@ -11,6 +11,8 @@ get_refresh_strategy, jwt_authentication, ) +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from jose import jwt, JWTError @@ -89,13 +91,8 @@ async def refresh_access_token( @router.get("/users/me/teams") -async def get_user_teams(user: User = Depends(current_active_user)): - # Placeholder implementation - # In a real application, fetch teams from the database - return { - "user_id": user.id, - "teams": [ - {"id": "550e8400-e29b-41d4-a716-446655440000", "name": "Transformer Lab"}, - {"id": "6ba7b810-9dad-11d1-80b4-00c04fd430c8", "name": "Team 2"}, - ], - } +async def get_user_teams(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)): + stmt = select(Team).join(UserTeam).where(UserTeam.user_id == str(user.id)) + result = await session.execute(stmt) + teams = result.scalars().all() + return {"user_id": str(user.id), "teams": [{"id": team.id, "name": team.name} for team in teams]} From c7105218fce6383cd11d7ec35abcfa0559727f6c Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 12:30:27 -0500 Subject: [PATCH 10/26] Add default team creation and auto-assignment for new user registrations --- transformerlab/models/users.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformerlab/models/users.py b/transformerlab/models/users.py index ba7404e8c..998dbcb1f 100644 --- a/transformerlab/models/users.py +++ b/transformerlab/models/users.py @@ -5,7 +5,7 @@ from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy from fastapi_users.db import SQLAlchemyUserDatabase -from transformerlab.shared.models.user_model import User, get_async_session +from transformerlab.shared.models.user_model import User, get_async_session, create_default_team, UserTeam from sqlalchemy.ext.asyncio import AsyncSession from jose import jwt as _jose_jwt from datetime import datetime, timedelta @@ -38,6 +38,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # Optional: Define custom logic after registration async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") + # Add to default team + async with self.user_db.session as session: + team = await create_default_team(session) + user_team = UserTeam(user_id=str(user.id), team_id=team.id) + session.add(user_team) + await session.commit() async def on_after_forgot_password(self, user: User, token: str, request: Request | None = None): print(f"User {user.id} has forgot their password. Reset token: {token}") From 5c6bd9f43a68a8fd10f30f0e953a3eff4e847ab1 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 12:31:06 -0500 Subject: [PATCH 11/26] Add teams router to main API router --- api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api.py b/api.py index dc96caa7a..fa3600694 100644 --- a/api.py +++ b/api.py @@ -51,6 +51,7 @@ recipes, remote, auth2, + teams, ) import torch @@ -234,6 +235,7 @@ async def validation_exception_handler(request, exc): app.include_router(batched_prompts.router) app.include_router(remote.router) app.include_router(fastchat_openai_api.router) +app.include_router(teams.router) app.include_router(auth2.router) # Authentication and session management routes From cffeec59c2379c371465600d21994386c604e070 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 12:46:39 -0500 Subject: [PATCH 12/26] Ruff --- api.py | 2 +- transformerlab/routers/auth2.py | 2 +- transformerlab/shared/models/user_model.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api.py b/api.py index fa3600694..a5f093d6f 100644 --- a/api.py +++ b/api.py @@ -84,7 +84,7 @@ from dotenv import load_dotenv -from transformerlab.shared.models.user_model import create_db_and_tables, User +from transformerlab.shared.models.user_model import create_db_and_tables load_dotenv() diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index d0096ad03..b35be0198 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, Depends, HTTPException -from transformerlab.shared.models.user_model import User, Team, UserTeam +from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session from transformerlab.models.users import ( fastapi_users, auth_backend, diff --git a/transformerlab/shared/models/user_model.py b/transformerlab/shared/models/user_model.py index 063544d13..3998799ef 100644 --- a/transformerlab/shared/models/user_model.py +++ b/transformerlab/shared/models/user_model.py @@ -4,11 +4,11 @@ from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker from fastapi_users.db import SQLAlchemyBaseUserTableUUID -from sqlalchemy import Column, String, ForeignKey +from sqlalchemy import Column, String, ForeignKey, select import uuid # Replace with your actual database URL (e.g., PostgreSQL, SQLite) -from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL +from transformerlab.db.constants import DATABASE_URL Base: DeclarativeMeta = declarative_base() From ae755dc0af97abcc85d9fff807d363b068fe89c1 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 13:20:00 -0500 Subject: [PATCH 13/26] Bug --- api.py | 2 +- transformerlab/shared/models/user_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api.py b/api.py index a5f093d6f..fa3600694 100644 --- a/api.py +++ b/api.py @@ -84,7 +84,7 @@ from dotenv import load_dotenv -from transformerlab.shared.models.user_model import create_db_and_tables +from transformerlab.shared.models.user_model import create_db_and_tables, User load_dotenv() diff --git a/transformerlab/shared/models/user_model.py b/transformerlab/shared/models/user_model.py index 3998799ef..116f9b3ef 100644 --- a/transformerlab/shared/models/user_model.py +++ b/transformerlab/shared/models/user_model.py @@ -30,7 +30,7 @@ class Team(Base): class UserTeam(Base): __tablename__ = "users_teams" - user_id = Column(String, ForeignKey("users.id"), primary_key=True) + user_id = Column(String, ForeignKey("user.id"), primary_key=True) team_id = Column(String, ForeignKey("teams.id"), primary_key=True) From 9179078717a2bd19c73ffc90f26d7436ab4454db Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 13:20:17 -0500 Subject: [PATCH 14/26] Ruff --- api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api.py b/api.py index fa3600694..a5f093d6f 100644 --- a/api.py +++ b/api.py @@ -84,7 +84,7 @@ from dotenv import load_dotenv -from transformerlab.shared.models.user_model import create_db_and_tables, User +from transformerlab.shared.models.user_model import create_db_and_tables load_dotenv() From 6594be624d11fad65a16d006d297352b7e324d74 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 14:02:05 -0500 Subject: [PATCH 15/26] Fix the bug in get_user_teams --- transformerlab/routers/auth2.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index b35be0198..34d7261a7 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, Depends, HTTPException -from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session +from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session, create_default_team from transformerlab.models.users import ( fastapi_users, auth_backend, @@ -92,7 +92,24 @@ async def refresh_access_token( @router.get("/users/me/teams") async def get_user_teams(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)): - stmt = select(Team).join(UserTeam).where(UserTeam.user_id == str(user.id)) + # Check if user has any team associations + stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) + result = await session.execute(stmt) + user_teams = result.scalars().all() + + # If user has no team associations, assign them to default team + if not user_teams: + default_team = await create_default_team(session) + user_team = UserTeam(user_id=str(user.id), team_id=default_team.id) + session.add(user_team) + await session.commit() + await session.refresh(user_team) + return {"user_id": str(user.id), "teams": [{"id": default_team.id, "name": default_team.name}]} + + # User has team associations, get the actual team objects + team_ids = [ut.team_id for ut in user_teams] + stmt = select(Team).where(Team.id.in_(team_ids)) result = await session.execute(stmt) teams = result.scalars().all() + return {"user_id": str(user.id), "teams": [{"id": team.id, "name": team.name} for team in teams]} From 3a448e95b08c3fddf6ca7c1105bddec4b380fc53 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 14:43:07 -0500 Subject: [PATCH 16/26] Move teams and user_teams into shared/models --- transformerlab/models/users.py | 3 ++- transformerlab/routers/auth2.py | 3 ++- transformerlab/routers/teams.py | 3 ++- transformerlab/shared/models/models.py | 20 ++++++++++++++++++- transformerlab/shared/models/user_model.py | 23 ++-------------------- 5 files changed, 27 insertions(+), 25 deletions(-) diff --git a/transformerlab/models/users.py b/transformerlab/models/users.py index 998dbcb1f..4170e8dcd 100644 --- a/transformerlab/models/users.py +++ b/transformerlab/models/users.py @@ -5,7 +5,8 @@ from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy from fastapi_users.db import SQLAlchemyUserDatabase -from transformerlab.shared.models.user_model import User, get_async_session, create_default_team, UserTeam +from transformerlab.shared.models.user_model import User, get_async_session, create_default_team +from transformerlab.shared.models.models import UserTeam from sqlalchemy.ext.asyncio import AsyncSession from jose import jwt as _jose_jwt from datetime import datetime, timedelta diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index 34d7261a7..37cd27d19 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -1,5 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException -from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session, create_default_team +from transformerlab.shared.models.user_model import User, get_async_session, create_default_team +from transformerlab.shared.models.models import Team, UserTeam from transformerlab.models.users import ( fastapi_users, auth_backend, diff --git a/transformerlab/routers/teams.py b/transformerlab/routers/teams.py index ef67c2746..88cc1d3b1 100644 --- a/transformerlab/routers/teams.py +++ b/transformerlab/routers/teams.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from transformerlab.shared.models.user_model import User, Team, UserTeam, get_async_session +from transformerlab.shared.models.user_model import User, get_async_session +from transformerlab.shared.models.models import Team, UserTeam from transformerlab.models.users import current_active_user from pydantic import BaseModel from sqlalchemy import select, delete, update diff --git a/transformerlab/shared/models/models.py b/transformerlab/shared/models/models.py index 70584c775..bfe98dd66 100644 --- a/transformerlab/shared/models/models.py +++ b/transformerlab/shared/models/models.py @@ -1,6 +1,7 @@ from typing import Optional -from sqlalchemy import String, JSON, DateTime, func, Integer, Index +from sqlalchemy import String, JSON, DateTime, func, Integer, Index, ForeignKey from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +import uuid class Base(DeclarativeBase): @@ -82,3 +83,20 @@ class WorkflowRun(Base): updated_at: Mapped[DateTime] = mapped_column( DateTime, server_default=func.now(), onupdate=func.now(), nullable=False ) + +class Team(Base): + """Team model.""" + + __tablename__ = "teams" + + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + name: Mapped[str] = mapped_column(String, nullable=False) + + +class UserTeam(Base): + """User-Team association model.""" + + __tablename__ = "users_teams" + + user_id: Mapped[str] = mapped_column(String, ForeignKey("user.id"), primary_key=True) + team_id: Mapped[str] = mapped_column(String, ForeignKey("teams.id"), primary_key=True) \ No newline at end of file diff --git a/transformerlab/shared/models/user_model.py b/transformerlab/shared/models/user_model.py index 116f9b3ef..8ff244a94 100644 --- a/transformerlab/shared/models/user_model.py +++ b/transformerlab/shared/models/user_model.py @@ -1,16 +1,13 @@ # database.py from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker from fastapi_users.db import SQLAlchemyBaseUserTableUUID -from sqlalchemy import Column, String, ForeignKey, select -import uuid +from sqlalchemy import select # Replace with your actual database URL (e.g., PostgreSQL, SQLite) from transformerlab.db.constants import DATABASE_URL - -Base: DeclarativeMeta = declarative_base() +from .models import Base, Team # 1. Define your User Model (inherits from a FastAPI Users base class) @@ -18,22 +15,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base): pass # You can add custom fields here later, like 'first_name: str' -# 2. Define Team Model -class Team(Base): - __tablename__ = "teams" - - id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - name = Column(String, nullable=False) - - -# 3. Define User-Team Association Model -class UserTeam(Base): - __tablename__ = "users_teams" - - user_id = Column(String, ForeignKey("user.id"), primary_key=True) - team_id = Column(String, ForeignKey("teams.id"), primary_key=True) - - # 2. Setup the Async Engine and Session engine = create_async_engine(DATABASE_URL) AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) From 5ecf50e7449a75bbc1ac378f52006e1c60aabd91 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 15:08:36 -0500 Subject: [PATCH 17/26] Update authenticated_route --- transformerlab/routers/auth2.py | 34 ++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index 37cd27d19..89e3a86c3 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Header from transformerlab.shared.models.user_model import User, get_async_session, create_default_team from transformerlab.shared.models.models import Team, UserTeam from transformerlab.models.users import ( @@ -39,9 +39,37 @@ ) +async def get_user_and_team( + user: User = Depends(current_active_user), + x_team: str | None = Header(None, alias="X-Team"), + session: AsyncSession = Depends(get_async_session), +): + """ + Dependency to validate user authentication and team membership. + Extracts X-Team header and verifies user belongs to that team. + """ + if not x_team: + raise HTTPException(status_code=400, detail="X-Team header missing") + + # Verify user is associated with the provided team id + stmt = select(UserTeam).where( + UserTeam.user_id == str(user.id), + UserTeam.team_id == x_team, + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if user_team is None: + raise HTTPException(status_code=403, detail="User is not a member of the specified team") + + return {"user": user, "team_id": x_team} + + @router.get("/test-users/authenticated-route") -async def authenticated_route(user: User = Depends(current_active_user)): - return {"message": f"Hello, {user.email}! You are authenticated."} +async def authenticated_route(user_and_team=Depends(get_user_and_team)): + user = user_and_team["user"] + team_id = user_and_team["team_id"] + return {"message": f"Hello, {user.email}! You are authenticated and acting as part of team {team_id}."} # To test this, register a new user via /auth/register From 357279bb1673a4462cbef24444637f3d4a943c0d Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 15:15:56 -0500 Subject: [PATCH 18/26] Renaming --- transformerlab/routers/auth2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index 89e3a86c3..ad373f651 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -41,15 +41,15 @@ async def get_user_and_team( user: User = Depends(current_active_user), - x_team: str | None = Header(None, alias="X-Team"), + x_team: str | None = Header(None, alias="X-Team-Id"), session: AsyncSession = Depends(get_async_session), ): """ Dependency to validate user authentication and team membership. - Extracts X-Team header and verifies user belongs to that team. + Extracts X-Team-Id header and verifies user belongs to that team. """ if not x_team: - raise HTTPException(status_code=400, detail="X-Team header missing") + raise HTTPException(status_code=400, detail="X-Team-Id header missing") # Verify user is associated with the provided team id stmt = select(UserTeam).where( From 4cf4ea54a093e969887e9b65604e85a593ce299b Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 16:17:10 -0500 Subject: [PATCH 19/26] Add role to the user_team --- transformerlab/shared/models/models.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformerlab/shared/models/models.py b/transformerlab/shared/models/models.py index bfe98dd66..78aa9ae84 100644 --- a/transformerlab/shared/models/models.py +++ b/transformerlab/shared/models/models.py @@ -1,7 +1,8 @@ from typing import Optional -from sqlalchemy import String, JSON, DateTime, func, Integer, Index, ForeignKey +from sqlalchemy import String, JSON, DateTime, func, Integer, Index, ForeignKey, Enum from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column import uuid +import enum class Base(DeclarativeBase): @@ -93,10 +94,17 @@ class Team(Base): name: Mapped[str] = mapped_column(String, nullable=False) +class TeamRole(str, enum.Enum): + """Enum for user roles within a team.""" + OWNER = "owner" + MEMBER = "member" + + class UserTeam(Base): """User-Team association model.""" __tablename__ = "users_teams" user_id: Mapped[str] = mapped_column(String, ForeignKey("user.id"), primary_key=True) - team_id: Mapped[str] = mapped_column(String, ForeignKey("teams.id"), primary_key=True) \ No newline at end of file + team_id: Mapped[str] = mapped_column(String, ForeignKey("teams.id"), primary_key=True) + role: Mapped[str] = mapped_column(String, nullable=False, default=TeamRole.MEMBER.value) \ No newline at end of file From b30e6becdd677c4a881fdaade1dea22782b27169 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 16:18:09 -0500 Subject: [PATCH 20/26] Add Added require_team_owner() --- transformerlab/routers/teams.py | 232 +++++++++++++++++++++++++++++--- 1 file changed, 214 insertions(+), 18 deletions(-) diff --git a/transformerlab/routers/teams.py b/transformerlab/routers/teams.py index 88cc1d3b1..796d44780 100644 --- a/transformerlab/routers/teams.py +++ b/transformerlab/routers/teams.py @@ -1,10 +1,11 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from transformerlab.shared.models.user_model import User, get_async_session -from transformerlab.shared.models.models import Team, UserTeam +from transformerlab.shared.models.models import Team, UserTeam, TeamRole from transformerlab.models.users import current_active_user +from transformerlab.routers.auth2 import require_team_owner, get_user_and_team from pydantic import BaseModel -from sqlalchemy import select, delete, update +from sqlalchemy import select, delete, update, func class TeamCreate(BaseModel): @@ -20,6 +21,21 @@ class TeamResponse(BaseModel): name: str +class InviteMemberRequest(BaseModel): + email: str + role: str = TeamRole.MEMBER.value + + +class UpdateMemberRoleRequest(BaseModel): + role: str + + +class MemberResponse(BaseModel): + user_id: str + email: str + role: str + + router = APIRouter(tags=["teams"]) @@ -35,8 +51,8 @@ async def create_team( await session.commit() await session.refresh(team) - # Add user to the team - user_team = UserTeam(user_id=str(user.id), team_id=team.id) + # Add user to the team as owner + user_team = UserTeam(user_id=str(user.id), team_id=team.id, role=TeamRole.OWNER.value) session.add(user_team) await session.commit() @@ -47,14 +63,13 @@ async def create_team( async def update_team( team_id: str, team_data: TeamUpdate, + owner_info=Depends(require_team_owner), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), ): - # Check if user is in the team - stmt = select(UserTeam).where(UserTeam.user_id == str(user.id), UserTeam.team_id == team_id) - result = await session.execute(stmt) - if not result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Not authorized to update this team") + """Update team name. Only team owners can update the team.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") # Update stmt = update(Team).where(Team.id == team_id).values(name=team_data.name) @@ -72,14 +87,15 @@ async def update_team( @router.delete("/teams/{team_id}") async def delete_team( team_id: str, + owner_info=Depends(require_team_owner), session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), ): - # Check if user is in the team - stmt = select(UserTeam).where(UserTeam.user_id == str(user.id), UserTeam.team_id == team_id) - result = await session.execute(stmt) - if not result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Not authorized to delete this team") + """Delete a team. Only team owners can delete the team.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + user = owner_info["user"] # Check if user has other teams stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) @@ -93,7 +109,7 @@ async def delete_team( result = await session.execute(stmt) team_users = result.scalars().all() if len(team_users) > 1: - raise HTTPException(status_code=400, detail="Cannot delete team with multiple users") + raise HTTPException(status_code=400, detail="Cannot delete team with multiple users. Remove other members first.") # Delete associations and team stmt = delete(UserTeam).where(UserTeam.team_id == team_id) @@ -102,4 +118,184 @@ async def delete_team( await session.execute(stmt) await session.commit() - return {"message": "Team deleted"} \ No newline at end of file + return {"message": "Team deleted"} + + +@router.get("/teams/{team_id}/members") +async def get_team_members( + team_id: str, + user_and_team=Depends(get_user_and_team), + session: AsyncSession = Depends(get_async_session), +): + """Get all members of a team. Any team member can view this.""" + # Verify team_id matches the one in header + if team_id != user_and_team["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Get all members of the team + stmt = select(UserTeam).where(UserTeam.team_id == team_id) + result = await session.execute(stmt) + user_teams = result.scalars().all() + + # Get user details + user_ids = [ut.user_id for ut in user_teams] + stmt = select(User).where(User.id.in_(user_ids)) + result = await session.execute(stmt) + users = result.scalars().all() + + # Create a mapping + users_dict = {str(user.id): user for user in users} + + members = [ + MemberResponse( + user_id=ut.user_id, + email=users_dict[ut.user_id].email if ut.user_id in users_dict else "unknown", + role=ut.role + ) + for ut in user_teams + ] + + return {"team_id": team_id, "members": members} + + +@router.post("/teams/{team_id}/members") +async def invite_member( + team_id: str, + invite_data: InviteMemberRequest, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Invite a user to the team. Only team owners can invite members.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Validate role + if invite_data.role not in [TeamRole.OWNER.value, TeamRole.MEMBER.value]: + raise HTTPException(status_code=400, detail="Invalid role. Must be 'owner' or 'member'") + + # Find user by email + stmt = select(User).where(User.email == invite_data.email) + result = await session.execute(stmt) + invited_user = result.scalar_one_or_none() + + if not invited_user: + raise HTTPException(status_code=404, detail="User not found") + + # Check if user is already in the team + stmt = select(UserTeam).where( + UserTeam.user_id == str(invited_user.id), + UserTeam.team_id == team_id + ) + result = await session.execute(stmt) + if result.scalar_one_or_none(): + raise HTTPException(status_code=400, detail="User is already a member of this team") + + # Add user to team + user_team = UserTeam(user_id=str(invited_user.id), team_id=team_id, role=invite_data.role) + session.add(user_team) + await session.commit() + + return { + "message": "User invited successfully", + "user_id": str(invited_user.id), + "email": invited_user.email, + "role": invite_data.role + } + + +@router.delete("/teams/{team_id}/members/{user_id}") +async def remove_member( + team_id: str, + user_id: str, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Remove a member from the team. Only team owners can remove members.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Check if the user to be removed exists in the team + stmt = select(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if not user_team: + raise HTTPException(status_code=404, detail="User is not a member of this team") + + # If removing an owner, check that there's at least one other owner + if user_team.role == TeamRole.OWNER.value: + stmt = select(func.count()).select_from(UserTeam).where( + UserTeam.team_id == team_id, + UserTeam.role == TeamRole.OWNER.value + ) + result = await session.execute(stmt) + owner_count = result.scalar() + + if owner_count <= 1: + raise HTTPException(status_code=400, detail="Cannot remove the last owner from the team") + + # Remove user from team + stmt = delete(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ) + await session.execute(stmt) + await session.commit() + + return {"message": "Member removed successfully"} + + +@router.put("/teams/{team_id}/members/{user_id}/role") +async def update_member_role( + team_id: str, + user_id: str, + role_data: UpdateMemberRoleRequest, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Update a member's role. Only team owners can change roles.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Validate role + if role_data.role not in [TeamRole.OWNER.value, TeamRole.MEMBER.value]: + raise HTTPException(status_code=400, detail="Invalid role. Must be 'owner' or 'member'") + + # Check if the user exists in the team + stmt = select(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if not user_team: + raise HTTPException(status_code=404, detail="User is not a member of this team") + + # If demoting from owner to member, check that there's at least one other owner + if user_team.role == TeamRole.OWNER.value and role_data.role == TeamRole.MEMBER.value: + stmt = select(func.count()).select_from(UserTeam).where( + UserTeam.team_id == team_id, + UserTeam.role == TeamRole.OWNER.value + ) + result = await session.execute(stmt) + owner_count = result.scalar() + + if owner_count <= 1: + raise HTTPException(status_code=400, detail="Cannot demote the last owner") + + # Update role + stmt = update(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ).values(role=role_data.role) + await session.execute(stmt) + await session.commit() + + return {"message": "Role updated successfully", "user_id": user_id, "new_role": role_data.role} \ No newline at end of file From f8c27a7a228a109845389a13fb3208ad03c34e7e Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 16:19:23 -0500 Subject: [PATCH 21/26] Add new member management, invite, list all memebers, remove members --- transformerlab/routers/auth2.py | 53 ++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py index ad373f651..91fd0f71a 100644 --- a/transformerlab/routers/auth2.py +++ b/transformerlab/routers/auth2.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Depends, HTTPException, Header from transformerlab.shared.models.user_model import User, get_async_session, create_default_team -from transformerlab.shared.models.models import Team, UserTeam +from transformerlab.shared.models.models import Team, UserTeam, TeamRole from transformerlab.models.users import ( fastapi_users, auth_backend, @@ -47,6 +47,7 @@ async def get_user_and_team( """ Dependency to validate user authentication and team membership. Extracts X-Team-Id header and verifies user belongs to that team. + Returns user, team_id, and role. """ if not x_team: raise HTTPException(status_code=400, detail="X-Team-Id header missing") @@ -62,7 +63,36 @@ async def get_user_and_team( if user_team is None: raise HTTPException(status_code=403, detail="User is not a member of the specified team") - return {"user": user, "team_id": x_team} + return {"user": user, "team_id": x_team, "role": user_team.role} + + +async def require_team_owner( + user: User = Depends(current_active_user), + x_team: str | None = Header(None, alias="X-Team-Id"), + session: AsyncSession = Depends(get_async_session), +): + """ + Dependency to validate user authentication and ensure user is an owner of the team. + Extracts X-Team-Id header and verifies user has owner role. + """ + if not x_team: + raise HTTPException(status_code=400, detail="X-Team-Id header missing") + + # Verify user is an owner of the team + stmt = select(UserTeam).where( + UserTeam.user_id == str(user.id), + UserTeam.team_id == x_team, + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if user_team is None: + raise HTTPException(status_code=403, detail="User is not a member of the specified team") + + if user_team.role != TeamRole.OWNER.value: + raise HTTPException(status_code=403, detail="Only team owners can perform this action") + + return {"user": user, "team_id": x_team, "role": user_team.role} @router.get("/test-users/authenticated-route") @@ -126,19 +156,28 @@ async def get_user_teams(user: User = Depends(current_active_user), session: Asy result = await session.execute(stmt) user_teams = result.scalars().all() - # If user has no team associations, assign them to default team + # If user has no team associations, assign them to default team as owner if not user_teams: default_team = await create_default_team(session) - user_team = UserTeam(user_id=str(user.id), team_id=default_team.id) + user_team = UserTeam(user_id=str(user.id), team_id=default_team.id, role=TeamRole.OWNER.value) session.add(user_team) await session.commit() await session.refresh(user_team) - return {"user_id": str(user.id), "teams": [{"id": default_team.id, "name": default_team.name}]} + return {"user_id": str(user.id), "teams": [{"id": default_team.id, "name": default_team.name, "role": TeamRole.OWNER.value}]} # User has team associations, get the actual team objects team_ids = [ut.team_id for ut in user_teams] stmt = select(Team).where(Team.id.in_(team_ids)) result = await session.execute(stmt) teams = result.scalars().all() - - return {"user_id": str(user.id), "teams": [{"id": team.id, "name": team.name} for team in teams]} + + # Create a mapping of team_id to team + teams_dict = {team.id: team for team in teams} + + # Return teams with role information + teams_with_roles = [ + {"id": ut.team_id, "name": teams_dict[ut.team_id].name, "role": ut.role} + for ut in user_teams if ut.team_id in teams_dict + ] + + return {"user_id": str(user.id), "teams": teams_with_roles} From cee92d269b00e5b2d6cd49b665cf5fb01c4fc85c Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 17:11:24 -0500 Subject: [PATCH 22/26] Ruff --- transformerlab/shared/models/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformerlab/shared/models/models.py b/transformerlab/shared/models/models.py index 78aa9ae84..254f9d928 100644 --- a/transformerlab/shared/models/models.py +++ b/transformerlab/shared/models/models.py @@ -1,5 +1,5 @@ from typing import Optional -from sqlalchemy import String, JSON, DateTime, func, Integer, Index, ForeignKey, Enum +from sqlalchemy import String, JSON, DateTime, func, Integer, Index, ForeignKey from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column import uuid import enum From 6512e0b03ddf1cc3133f63fbb586de857b698580 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Mon, 17 Nov 2025 17:11:40 -0500 Subject: [PATCH 23/26] Add test for team endpoints --- test/api/test_teams.py | 451 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 test/api/test_teams.py diff --git a/test/api/test_teams.py b/test/api/test_teams.py new file mode 100644 index 000000000..c229edf87 --- /dev/null +++ b/test/api/test_teams.py @@ -0,0 +1,451 @@ +import pytest + + +@pytest.fixture(scope="module") +def owner_user(client): + """Create and authenticate an owner user""" + # Register + user_data = { + "email": "owner@test.com", + "password": "password123", + } + resp = client.post("/auth/register", json=user_data) + assert resp.status_code in (200, 201, 400) # 400 if already exists + + # Login + login_data = { + "username": "owner@test.com", + "password": "password123", + } + resp = client.post("/auth/jwt/login", data=login_data) + assert resp.status_code == 200 + token = resp.json()["access_token"] + + return {"email": "owner@test.com", "token": token} + + +@pytest.fixture(scope="module") +def member_user(client): + """Create and authenticate a member user""" + # Register + user_data = { + "email": "member@test.com", + "password": "password123", + } + resp = client.post("/auth/register", json=user_data) + assert resp.status_code in (200, 201, 400) # 400 if already exists + + # Login + login_data = { + "username": "member@test.com", + "password": "password123", + } + resp = client.post("/auth/jwt/login", data=login_data) + assert resp.status_code == 200 + token = resp.json()["access_token"] + + return {"email": "member@test.com", "token": token} + + +@pytest.fixture(scope="module") +def test_team(client, owner_user): + """Create a test team""" + headers = {"Authorization": f"Bearer {owner_user['token']}"} + team_data = {"name": "Test Team"} + resp = client.post("/teams", json=team_data, headers=headers) + assert resp.status_code == 200 + team = resp.json() + assert "id" in team + assert team["name"] == "Test Team" + return team + + +def test_create_team(client, owner_user): + """Test creating a new team""" + headers = {"Authorization": f"Bearer {owner_user['token']}"} + team_data = {"name": "New Team"} + resp = client.post("/teams", json=team_data, headers=headers) + + assert resp.status_code == 200 + team = resp.json() + assert "id" in team + assert team["name"] == "New Team" + + +def test_get_user_teams(client, owner_user, test_team): + """Test getting user's teams""" + headers = {"Authorization": f"Bearer {owner_user['token']}"} + resp = client.get("/users/me/teams", headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert "teams" in data + assert len(data["teams"]) > 0 + + # Check that test_team is in the list + team_ids = [t["id"] for t in data["teams"]] + assert test_team["id"] in team_ids + + # Check that the user has owner role for the test team + test_team_data = next(t for t in data["teams"] if t["id"] == test_team["id"]) + assert test_team_data["role"] == "owner" + + +def test_list_team_members(client, owner_user, test_team): + """Test listing team members""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert "members" in data + assert len(data["members"]) >= 1 + + # Check owner is in the list + emails = [m["email"] for m in data["members"]] + assert owner_user["email"] in emails + + # Check owner has owner role + owner_data = next(m for m in data["members"] if m["email"] == owner_user["email"]) + assert owner_data["role"] == "owner" + + +def test_invite_member(client, owner_user, member_user, test_team): + """Test inviting a member to the team""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": member_user["email"], + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "User invited successfully" + assert data["email"] == member_user["email"] + assert data["role"] == "member" + + +def test_invite_duplicate_member(client, owner_user, member_user, test_team): + """Test inviting a member who is already in the team""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": member_user["email"], + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 400 + assert "already a member" in resp.json()["detail"] + + +def test_invite_nonexistent_user(client, owner_user, test_team): + """Test inviting a user that doesn't exist""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": "nonexistent@test.com", + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 404 + assert "not found" in resp.json()["detail"].lower() + + +def test_member_can_view_members(client, member_user, test_team): + """Test that a member can view team members""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert "members" in data + assert len(data["members"]) >= 2 # owner and member + + +def test_update_member_role_to_owner(client, owner_user, member_user, test_team): + """Test promoting a member to owner""" + # First get the member's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + member_data = next(m for m in members if m["email"] == member_user["email"]) + member_id = member_data["user_id"] + + # Promote to owner + role_data = {"role": "owner"} + resp = client.put( + f"/teams/{test_team['id']}/members/{member_id}/role", + json=role_data, + headers=headers + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Role updated successfully" + assert data["new_role"] == "owner" + + +def test_update_member_role_to_member(client, owner_user, member_user, test_team): + """Test demoting an owner to member""" + # First get the member's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + member_data = next(m for m in members if m["email"] == member_user["email"]) + member_id = member_data["user_id"] + + # Demote to member + role_data = {"role": "member"} + resp = client.put( + f"/teams/{test_team['id']}/members/{member_id}/role", + json=role_data, + headers=headers + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Role updated successfully" + assert data["new_role"] == "member" + + +def test_cannot_demote_last_owner(client, owner_user, test_team): + """Test that the last owner cannot be demoted""" + # Get owner's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to demote + role_data = {"role": "member"} + resp = client.put( + f"/teams/{test_team['id']}/members/{owner_id}/role", + json=role_data, + headers=headers + ) + + assert resp.status_code == 400 + assert "last owner" in resp.json()["detail"].lower() + + +def test_member_cannot_invite(client, member_user, test_team): + """Test that a member cannot invite other users""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": "another@test.com", + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_member_cannot_update_roles(client, owner_user, member_user, test_team): + """Test that a member cannot change roles""" + # Get owner's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to update role as member + headers_member = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + role_data = {"role": "member"} + resp = client.put( + f"/teams/{test_team['id']}/members/{owner_id}/role", + json=role_data, + headers=headers_member + ) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_remove_member(client, owner_user, member_user, test_team): + """Test removing a member from the team""" + # Get member's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + member_data = next(m for m in members if m["email"] == member_user["email"]) + member_id = member_data["user_id"] + + # Remove member + resp = client.delete( + f"/teams/{test_team['id']}/members/{member_id}", + headers=headers + ) + + assert resp.status_code == 200 + assert resp.json()["message"] == "Member removed successfully" + + # Verify member is removed + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + emails = [m["email"] for m in members] + assert member_user["email"] not in emails + + +def test_cannot_remove_last_owner(client, owner_user, test_team): + """Test that the last owner cannot be removed""" + # Get owner's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to remove + resp = client.delete( + f"/teams/{test_team['id']}/members/{owner_id}", + headers=headers + ) + + assert resp.status_code == 400 + assert "last owner" in resp.json()["detail"].lower() + + +def test_member_cannot_remove_members(client, owner_user, member_user, test_team): + """Test that a member cannot remove other members""" + # First re-add the member + headers_owner = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": member_user["email"], + "role": "member" + } + client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers_owner) + + # Get owner's user_id + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers_owner) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to remove as member + headers_member = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.delete( + f"/teams/{test_team['id']}/members/{owner_id}", + headers=headers_member + ) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_update_team_name(client, owner_user, test_team): + """Test updating team name as owner""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + update_data = {"name": "Updated Team Name"} + resp = client.put(f"/teams/{test_team['id']}", json=update_data, headers=headers) + + assert resp.status_code == 200 + team = resp.json() + assert team["name"] == "Updated Team Name" + + +def test_member_cannot_update_team_name(client, member_user, test_team): + """Test that a member cannot update team name""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + update_data = {"name": "Hacked Name"} + resp = client.put(f"/teams/{test_team['id']}", json=update_data, headers=headers) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_delete_team(client, owner_user): + """Test deleting a team""" + # Create a team to delete + headers = {"Authorization": f"Bearer {owner_user['token']}"} + team_data = {"name": "Team to Delete"} + resp = client.post("/teams", json=team_data, headers=headers) + team = resp.json() + + # Delete it + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": team["id"] + } + resp = client.delete(f"/teams/{team['id']}", headers=headers) + + assert resp.status_code == 200 + assert resp.json()["message"] == "Team deleted" + + +def test_member_cannot_delete_team(client, member_user, test_team): + """Test that a member cannot delete a team""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.delete(f"/teams/{test_team['id']}", headers=headers) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_cannot_delete_team_with_multiple_users(client, owner_user, member_user, test_team): + """Test that a team with multiple users cannot be deleted""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.delete(f"/teams/{test_team['id']}", headers=headers) + + assert resp.status_code == 400 + assert "multiple users" in resp.json()["detail"].lower() From c47ac3b470aa7b0965fbacdad0a4f85904ef9cb3 Mon Sep 17 00:00:00 2001 From: Mina Parham Date: Tue, 18 Nov 2025 10:10:46 -0500 Subject: [PATCH 24/26] Remove foreign key --- transformerlab/shared/models/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformerlab/shared/models/models.py b/transformerlab/shared/models/models.py index 254f9d928..2d3c41e2a 100644 --- a/transformerlab/shared/models/models.py +++ b/transformerlab/shared/models/models.py @@ -1,5 +1,5 @@ from typing import Optional -from sqlalchemy import String, JSON, DateTime, func, Integer, Index, ForeignKey +from sqlalchemy import String, JSON, DateTime, func, Integer, Index from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column import uuid import enum @@ -105,6 +105,6 @@ class UserTeam(Base): __tablename__ = "users_teams" - user_id: Mapped[str] = mapped_column(String, ForeignKey("user.id"), primary_key=True) - team_id: Mapped[str] = mapped_column(String, ForeignKey("teams.id"), primary_key=True) + user_id: Mapped[str] = mapped_column(String, primary_key=True) + team_id: Mapped[str] = mapped_column(String, primary_key=True) role: Mapped[str] = mapped_column(String, nullable=False, default=TeamRole.MEMBER.value) \ No newline at end of file From 51690ddece9540d6d2b89e6840c0fcbf321d0ddd Mon Sep 17 00:00:00 2001 From: ali asaria Date: Tue, 18 Nov 2025 10:10:27 -0500 Subject: [PATCH 25/26] add create user script --- scripts/create_user.py | 70 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 scripts/create_user.py diff --git a/scripts/create_user.py b/scripts/create_user.py new file mode 100644 index 000000000..a74bff22c --- /dev/null +++ b/scripts/create_user.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +Script to create a new user by calling the /auth/register endpoint. + +Usage: + python create_user.py --user test@example.com --password password123 +""" + +import argparse +import requests +import sys +import json + + +def create_user(email: str, password: str, base_url: str = "http://127.0.0.1:8338"): + """ + Register a new user via the /auth/register endpoint. + + Args: + email: User's email address + password: User's password + base_url: Base URL of the API (default: http://127.0.0.1:8338) + + Returns: + bool: True if successful, False otherwise + """ + url = f"{base_url}/auth/register" + + payload = {"email": email, "password": password} + + headers = {"Content-Type": "application/json"} + + try: + print(f"Registering user: {email}") + response = requests.post(url, json=payload, headers=headers) + + if response.status_code == 201 or response.status_code == 200: + print(f"✓ User created successfully!") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return True + else: + print(f"✗ Failed to create user. Status code: {response.status_code}") + print(f"Response: {response.text}") + return False + + except requests.exceptions.ConnectionError: + print(f"✗ Error: Could not connect to {base_url}") + print("Make sure the API server is running.") + return False + except Exception as e: + print(f"✗ Error: {str(e)}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Create a new user by calling the /auth/register endpoint") + parser.add_argument("--user", required=True, help="User's email address") + parser.add_argument("--password", required=True, help="User's password") + parser.add_argument( + "--url", default="http://127.0.0.1:8338", help="Base URL of the API (default: http://127.0.0.1:8338)" + ) + + args = parser.parse_args() + + success = create_user(args.user, args.password, args.url) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() From 249f0f209e2a19be8d211bf2ee9f3eeac98484c2 Mon Sep 17 00:00:00 2001 From: ali asaria Date: Tue, 18 Nov 2025 10:38:12 -0500 Subject: [PATCH 26/26] create user script --- scripts/create_user.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/create_user.py b/scripts/create_user.py index a74bff22c..62d4dd776 100644 --- a/scripts/create_user.py +++ b/scripts/create_user.py @@ -35,7 +35,7 @@ def create_user(email: str, password: str, base_url: str = "http://127.0.0.1:833 response = requests.post(url, json=payload, headers=headers) if response.status_code == 201 or response.status_code == 200: - print(f"✓ User created successfully!") + print("✓ User created successfully!") print(f"Response: {json.dumps(response.json(), indent=2)}") return True else: