diff --git a/api.py b/api.py index 800985714..08b78e0d5 100644 --- a/api.py +++ b/api.py @@ -50,6 +50,8 @@ batched_prompts, recipes, remote, + auth2, + teams, ) import torch @@ -82,6 +84,9 @@ from dotenv import load_dotenv + +from transformerlab.shared.models.user_model import create_db_and_tables + load_dotenv() # The following environment variable can be used by other scripts @@ -110,6 +115,7 @@ async def lifespan(app: FastAPI): galleries.update_gallery_cache() spawn_fastchat_controller_subprocess() await db.init() + await create_db_and_tables() print("✅ SEED DATA") # Initialize experiments and cancel any running jobs seed_default_experiments() @@ -230,6 +236,8 @@ async def validation_exception_handler(request, exc): app.include_router(batched_prompts.router) app.include_router(remote.router) app.include_router(fastchat_openai_api.router) +app.include_router(teams.router) +app.include_router(auth2.router) # Authentication and session management routes if os.getenv("TFL_MULTITENANT") == "true": diff --git a/requirements-no-gpu-uv.txt b/requirements-no-gpu-uv.txt index 9b82f0f9a..2c0de50d0 100644 --- a/requirements-no-gpu-uv.txt +++ b/requirements-no-gpu-uv.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements.in -o requirements-no-gpu-uv.txt --index-strategy unsafe-best-match +# uv pip compile requirements.in -o requirements-no-gpu-uv.txt absl-py==2.1.0 # via tensorboard accelerate==1.3.0 @@ -27,6 +27,10 @@ anyio==4.8.0 # sse-starlette # starlette # watchfiles +argon2-cffi==23.1.0 + # via pwdlib +argon2-cffi-bindings==25.1.0 + # via argon2-cffi attrs==25.1.0 # via aiohttp audioread==3.0.1 @@ -39,6 +43,8 @@ azure-core==1.33.0 # azure-identity azure-identity==1.21.0 # via markitdown +bcrypt==4.3.0 + # via pwdlib beautifulsoup4==4.13.3 # via # markdownify @@ -51,6 +57,7 @@ certifi==2022.12.7 # sentry-sdk cffi==1.17.1 # via + # argon2-cffi-bindings # cryptography # soundfile charset-normalizer==2.1.1 @@ -77,6 +84,7 @@ cryptography==44.0.2 # msal # pdfminer-six # pyjwt + # python-jose # workos datasets==3.6.0 # via @@ -93,12 +101,18 @@ dill==0.3.8 # datasets # evaluate # multiprocess +dnspython==2.8.0 + # via email-validator docker-pycreds==0.4.0 # via wandb +ecdsa==0.19.1 + # via python-jose einops==0.8.0 # via # -r requirements.in # controlnet-aux +email-validator==2.3.0 + # via fastapi-users et-xmlfile==2.0.0 # via openpyxl evaluate==0.4.3 @@ -106,7 +120,14 @@ evaluate==0.4.3 fastapi==0.115.7 # via # -r requirements.in + # fastapi-users # transformerlab-inference +fastapi-users==15.0.1 + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy +fastapi-users-db-sqlalchemy==7.0.0 + # via fastapi-users filelock==3.13.1 # via # controlnet-aux @@ -169,6 +190,7 @@ humanfriendly==10.0 idna==3.4 # via # anyio + # email-validator # httpx # requests # yarl @@ -205,6 +227,8 @@ macmon-python==0.1.2 # via -r requirements.in magika==0.6.1 # via markitdown +makefun==1.16.0 + # via fastapi-users mammoth==1.9.0 # via markitdown markdown==3.7 @@ -352,8 +376,14 @@ psutil==6.1.1 # peft # transformerlab-inference # wandb +pwdlib==0.2.1 + # via fastapi-users pyarrow==19.0.0 # via datasets +pyasn1==0.6.1 + # via + # python-jose + # rsa pycparser==2.22 # via cffi pydantic==2.11.7 @@ -381,6 +411,7 @@ pygments==2.19.1 # rich pyjwt==2.10.1 # via + # fastapi-users # msal # workos pytest==8.4.2 @@ -393,9 +424,12 @@ python-dotenv==1.0.1 # magika # mcp # pydantic-settings +python-jose==3.5.0 + # via -r requirements.in python-multipart==0.0.20 # via # -r requirements.in + # fastapi-users # mcp python-pptx==1.0.2 # via markitdown @@ -436,6 +470,8 @@ rich==13.9.4 # via # transformerlab-inference # typer +rsa==4.9.1 + # via python-jose safetensors==0.5.3 # via # accelerate @@ -477,6 +513,7 @@ six==1.17.0 # via # azure-core # docker-pycreds + # ecdsa # markdownify # python-dateutil # tensorboard @@ -496,7 +533,9 @@ soxr==0.5.0.post1 speechrecognition==3.14.2 # via markitdown sqlalchemy==2.0.38 - # via -r requirements.in + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp starlette==0.45.3 diff --git a/requirements-rocm-uv.txt b/requirements-rocm-uv.txt index 96193fc00..53d2b2792 100644 --- a/requirements-rocm-uv.txt +++ b/requirements-rocm-uv.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements-rocm.in -o requirements-rocm-uv.txt --index-strategy unsafe-best-match +# uv pip compile requirements-rocm.in -o requirements-rocm-uv.txt --index-strategy unsafe-best-match --python-platform linux absl-py==2.2.2 # via tensorboard accelerate==1.6.0 @@ -27,6 +27,10 @@ anyio==4.9.0 # sse-starlette # starlette # watchfiles +argon2-cffi==23.1.0 + # via pwdlib +argon2-cffi-bindings==25.1.0 + # via argon2-cffi attrs==25.3.0 # via aiohttp audioread==3.0.1 @@ -39,6 +43,8 @@ azure-core==1.33.0 # azure-identity azure-identity==1.21.0 # via markitdown +bcrypt==4.3.0 + # via pwdlib beautifulsoup4==4.13.4 # via # markdownify @@ -51,6 +57,7 @@ certifi==2022.12.7 # sentry-sdk cffi==1.17.1 # via + # argon2-cffi-bindings # cryptography # soundfile charset-normalizer==2.1.1 @@ -77,6 +84,7 @@ cryptography==44.0.2 # msal # pdfminer-six # pyjwt + # python-jose # workos datasets==3.6.0 # via @@ -93,12 +101,18 @@ dill==0.3.8 # datasets # evaluate # multiprocess +dnspython==2.8.0 + # via email-validator docker-pycreds==0.4.0 # via wandb +ecdsa==0.19.1 + # via python-jose einops==0.8.1 # via # -r requirements-rocm.in # controlnet-aux +email-validator==2.3.0 + # via fastapi-users et-xmlfile==2.0.0 # via openpyxl evaluate==0.4.3 @@ -106,7 +120,14 @@ evaluate==0.4.3 fastapi==0.115.12 # via # -r requirements-rocm.in + # fastapi-users # transformerlab-inference +fastapi-users==15.0.1 + # via + # -r requirements-rocm.in + # fastapi-users-db-sqlalchemy +fastapi-users-db-sqlalchemy==7.0.0 + # via fastapi-users filelock==3.13.1 # via # controlnet-aux @@ -169,6 +190,7 @@ humanfriendly==10.0 idna==3.4 # via # anyio + # email-validator # httpx # requests # yarl @@ -205,6 +227,8 @@ macmon-python==0.1.2 # via -r requirements-rocm.in magika==0.6.1 # via markitdown +makefun==1.16.0 + # via fastapi-users mammoth==1.9.0 # via markitdown markdown==3.8 @@ -350,8 +374,14 @@ psutil==7.0.0 # peft # transformerlab-inference # wandb +pwdlib==0.2.1 + # via fastapi-users pyarrow==20.0.0 # via datasets +pyasn1==0.6.1 + # via + # python-jose + # rsa pycparser==2.22 # via cffi pydantic==2.11.7 @@ -379,6 +409,7 @@ pygments==2.19.1 # rich pyjwt==2.10.1 # via + # fastapi-users # msal # workos pyrsmi==0.2.0 @@ -393,9 +424,12 @@ python-dotenv==1.1.0 # magika # mcp # pydantic-settings +python-jose==3.5.0 + # via -r requirements-rocm.in python-multipart==0.0.20 # via # -r requirements-rocm.in + # fastapi-users # mcp python-pptx==1.0.2 # via markitdown @@ -438,6 +472,8 @@ rich==14.0.0 # via # transformerlab-inference # typer +rsa==4.9.1 + # via python-jose safetensors==0.5.3 # via # accelerate @@ -480,6 +516,7 @@ six==1.17.0 # via # azure-core # docker-pycreds + # ecdsa # markdownify # python-dateutil # tensorboard @@ -499,7 +536,9 @@ soxr==0.5.0.post1 speechrecognition==3.14.2 # via markitdown sqlalchemy==2.0.40 - # via -r requirements-rocm.in + # via + # -r requirements-rocm.in + # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp starlette==0.46.2 diff --git a/requirements-rocm.in b/requirements-rocm.in index caa6c46a6..24aa5c262 100644 --- a/requirements-rocm.in +++ b/requirements-rocm.in @@ -5,10 +5,12 @@ datasets==3.6.0 einops evaluate fastapi +fastapi-users[sqlalchemy] packaging psutil python-multipart python-dotenv +python-jose[cryptography] pydantic>= 2.0 nltk==3.9.1 scipy diff --git a/requirements-uv.txt b/requirements-uv.txt index a1ea3f83a..8162f8c05 100644 --- a/requirements-uv.txt +++ b/requirements-uv.txt @@ -27,6 +27,10 @@ anyio==4.8.0 # sse-starlette # starlette # watchfiles +argon2-cffi==23.1.0 + # via pwdlib +argon2-cffi-bindings==25.1.0 + # via argon2-cffi attrs==25.1.0 # via aiohttp audioread==3.0.1 @@ -39,6 +43,8 @@ azure-core==1.33.0 # azure-identity azure-identity==1.21.0 # via markitdown +bcrypt==4.3.0 + # via pwdlib beautifulsoup4==4.13.3 # via # markdownify @@ -51,6 +57,7 @@ certifi==2022.12.7 # sentry-sdk cffi==1.17.1 # via + # argon2-cffi-bindings # cryptography # soundfile charset-normalizer==2.1.1 @@ -77,6 +84,7 @@ cryptography==44.0.2 # msal # pdfminer-six # pyjwt + # python-jose # workos datasets==3.6.0 # via @@ -93,12 +101,18 @@ dill==0.3.8 # datasets # evaluate # multiprocess +dnspython==2.8.0 + # via email-validator docker-pycreds==0.4.0 # via wandb +ecdsa==0.19.1 + # via python-jose einops==0.8.0 # via # -r requirements.in # controlnet-aux +email-validator==2.3.0 + # via fastapi-users et-xmlfile==2.0.0 # via openpyxl evaluate==0.4.3 @@ -106,7 +120,14 @@ evaluate==0.4.3 fastapi==0.115.7 # via # -r requirements.in + # fastapi-users # transformerlab-inference +fastapi-users==15.0.1 + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy +fastapi-users-db-sqlalchemy==7.0.0 + # via fastapi-users filelock==3.13.1 # via # controlnet-aux @@ -169,6 +190,7 @@ humanfriendly==10.0 idna==3.4 # via # anyio + # email-validator # httpx # requests # yarl @@ -205,6 +227,8 @@ macmon-python==0.1.2 # via -r requirements.in magika==0.6.1 # via markitdown +makefun==1.16.0 + # via fastapi-users mammoth==1.9.0 # via markitdown markdown==3.7 @@ -279,45 +303,8 @@ numpy==2.1.2 # torchvision # transformerlab-inference # transformers -nvidia-cublas-cu12==12.8.4.1 - # via - # nvidia-cudnn-cu12 - # nvidia-cusolver-cu12 - # torch -nvidia-cuda-cupti-cu12==12.8.90 - # via torch -nvidia-cuda-nvrtc-cu12==12.8.93 - # via torch -nvidia-cuda-runtime-cu12==12.8.90 - # via torch -nvidia-cudnn-cu12==9.10.2.21 - # via torch -nvidia-cufft-cu12==11.3.3.83 - # via torch -nvidia-cufile-cu12==1.13.1.3 - # via torch -nvidia-curand-cu12==10.3.9.90 - # via torch -nvidia-cusolver-cu12==11.7.3.90 - # via torch -nvidia-cusparse-cu12==12.5.8.93 - # via - # nvidia-cusolver-cu12 - # torch -nvidia-cusparselt-cu12==0.7.1 - # via torch nvidia-ml-py==12.575.51 # via -r requirements.in -nvidia-nccl-cu12==2.27.3 - # via torch -nvidia-nvjitlink-cu12==12.8.93 - # via - # nvidia-cufft-cu12 - # nvidia-cusolver-cu12 - # nvidia-cusparse-cu12 - # torch -nvidia-nvtx-cu12==12.8.90 - # via torch olefile==0.47 # via markitdown onnxruntime==1.21.0 @@ -389,8 +376,14 @@ psutil==6.1.1 # peft # transformerlab-inference # wandb +pwdlib==0.2.1 + # via fastapi-users pyarrow==19.0.0 # via datasets +pyasn1==0.6.1 + # via + # python-jose + # rsa pycparser==2.22 # via cffi pydantic==2.11.7 @@ -418,6 +411,7 @@ pygments==2.19.1 # rich pyjwt==2.10.1 # via + # fastapi-users # msal # workos pytest==8.4.2 @@ -430,9 +424,12 @@ python-dotenv==1.0.1 # magika # mcp # pydantic-settings +python-jose==3.5.0 + # via -r requirements.in python-multipart==0.0.20 # via # -r requirements.in + # fastapi-users # mcp python-pptx==1.0.2 # via markitdown @@ -473,6 +470,8 @@ rich==13.9.4 # via # transformerlab-inference # typer +rsa==4.9.1 + # via python-jose safetensors==0.5.3 # via # accelerate @@ -505,7 +504,6 @@ setproctitle==1.3.5 setuptools==70.2.0 # via # tensorboard - # triton # wandb shellingham==1.5.4 # via typer @@ -515,6 +513,7 @@ six==1.17.0 # via # azure-core # docker-pycreds + # ecdsa # markdownify # python-dateutil # tensorboard @@ -534,7 +533,9 @@ soxr==0.5.0.post1 speechrecognition==3.14.2 # via markitdown sqlalchemy==2.0.38 - # via -r requirements.in + # via + # -r requirements.in + # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp starlette==0.45.3 @@ -605,8 +606,6 @@ transformers==4.57.1 # -r requirements.in # peft # sentence-transformers -triton==3.4.0 - # via torch typer==0.15.4 # via mcp typing-extensions==4.12.2 diff --git a/requirements.in b/requirements.in index 5128b6c72..f1d41c95b 100644 --- a/requirements.in +++ b/requirements.in @@ -5,9 +5,11 @@ datasets==3.6.0 einops evaluate fastapi +fastapi-users[sqlalchemy] packaging psutil python-dotenv +python-jose[cryptography] python-multipart pydantic>= 2.0 nltk==3.9.1 diff --git a/scripts/create_user.py b/scripts/create_user.py new file mode 100644 index 000000000..62d4dd776 --- /dev/null +++ b/scripts/create_user.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +""" +Script to create a new user by calling the /auth/register endpoint. + +Usage: + python create_user.py --user test@example.com --password password123 +""" + +import argparse +import requests +import sys +import json + + +def create_user(email: str, password: str, base_url: str = "http://127.0.0.1:8338"): + """ + Register a new user via the /auth/register endpoint. + + Args: + email: User's email address + password: User's password + base_url: Base URL of the API (default: http://127.0.0.1:8338) + + Returns: + bool: True if successful, False otherwise + """ + url = f"{base_url}/auth/register" + + payload = {"email": email, "password": password} + + headers = {"Content-Type": "application/json"} + + try: + print(f"Registering user: {email}") + response = requests.post(url, json=payload, headers=headers) + + if response.status_code == 201 or response.status_code == 200: + print("✓ User created successfully!") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return True + else: + print(f"✗ Failed to create user. Status code: {response.status_code}") + print(f"Response: {response.text}") + return False + + except requests.exceptions.ConnectionError: + print(f"✗ Error: Could not connect to {base_url}") + print("Make sure the API server is running.") + return False + except Exception as e: + print(f"✗ Error: {str(e)}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Create a new user by calling the /auth/register endpoint") + parser.add_argument("--user", required=True, help="User's email address") + parser.add_argument("--password", required=True, help="User's password") + parser.add_argument( + "--url", default="http://127.0.0.1:8338", help="Base URL of the API (default: http://127.0.0.1:8338)" + ) + + args = parser.parse_args() + + success = create_user(args.user, args.password, args.url) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/test/api/test_teams.py b/test/api/test_teams.py new file mode 100644 index 000000000..c229edf87 --- /dev/null +++ b/test/api/test_teams.py @@ -0,0 +1,451 @@ +import pytest + + +@pytest.fixture(scope="module") +def owner_user(client): + """Create and authenticate an owner user""" + # Register + user_data = { + "email": "owner@test.com", + "password": "password123", + } + resp = client.post("/auth/register", json=user_data) + assert resp.status_code in (200, 201, 400) # 400 if already exists + + # Login + login_data = { + "username": "owner@test.com", + "password": "password123", + } + resp = client.post("/auth/jwt/login", data=login_data) + assert resp.status_code == 200 + token = resp.json()["access_token"] + + return {"email": "owner@test.com", "token": token} + + +@pytest.fixture(scope="module") +def member_user(client): + """Create and authenticate a member user""" + # Register + user_data = { + "email": "member@test.com", + "password": "password123", + } + resp = client.post("/auth/register", json=user_data) + assert resp.status_code in (200, 201, 400) # 400 if already exists + + # Login + login_data = { + "username": "member@test.com", + "password": "password123", + } + resp = client.post("/auth/jwt/login", data=login_data) + assert resp.status_code == 200 + token = resp.json()["access_token"] + + return {"email": "member@test.com", "token": token} + + +@pytest.fixture(scope="module") +def test_team(client, owner_user): + """Create a test team""" + headers = {"Authorization": f"Bearer {owner_user['token']}"} + team_data = {"name": "Test Team"} + resp = client.post("/teams", json=team_data, headers=headers) + assert resp.status_code == 200 + team = resp.json() + assert "id" in team + assert team["name"] == "Test Team" + return team + + +def test_create_team(client, owner_user): + """Test creating a new team""" + headers = {"Authorization": f"Bearer {owner_user['token']}"} + team_data = {"name": "New Team"} + resp = client.post("/teams", json=team_data, headers=headers) + + assert resp.status_code == 200 + team = resp.json() + assert "id" in team + assert team["name"] == "New Team" + + +def test_get_user_teams(client, owner_user, test_team): + """Test getting user's teams""" + headers = {"Authorization": f"Bearer {owner_user['token']}"} + resp = client.get("/users/me/teams", headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert "teams" in data + assert len(data["teams"]) > 0 + + # Check that test_team is in the list + team_ids = [t["id"] for t in data["teams"]] + assert test_team["id"] in team_ids + + # Check that the user has owner role for the test team + test_team_data = next(t for t in data["teams"] if t["id"] == test_team["id"]) + assert test_team_data["role"] == "owner" + + +def test_list_team_members(client, owner_user, test_team): + """Test listing team members""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert "members" in data + assert len(data["members"]) >= 1 + + # Check owner is in the list + emails = [m["email"] for m in data["members"]] + assert owner_user["email"] in emails + + # Check owner has owner role + owner_data = next(m for m in data["members"] if m["email"] == owner_user["email"]) + assert owner_data["role"] == "owner" + + +def test_invite_member(client, owner_user, member_user, test_team): + """Test inviting a member to the team""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": member_user["email"], + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "User invited successfully" + assert data["email"] == member_user["email"] + assert data["role"] == "member" + + +def test_invite_duplicate_member(client, owner_user, member_user, test_team): + """Test inviting a member who is already in the team""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": member_user["email"], + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 400 + assert "already a member" in resp.json()["detail"] + + +def test_invite_nonexistent_user(client, owner_user, test_team): + """Test inviting a user that doesn't exist""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": "nonexistent@test.com", + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 404 + assert "not found" in resp.json()["detail"].lower() + + +def test_member_can_view_members(client, member_user, test_team): + """Test that a member can view team members""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + + assert resp.status_code == 200 + data = resp.json() + assert "members" in data + assert len(data["members"]) >= 2 # owner and member + + +def test_update_member_role_to_owner(client, owner_user, member_user, test_team): + """Test promoting a member to owner""" + # First get the member's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + member_data = next(m for m in members if m["email"] == member_user["email"]) + member_id = member_data["user_id"] + + # Promote to owner + role_data = {"role": "owner"} + resp = client.put( + f"/teams/{test_team['id']}/members/{member_id}/role", + json=role_data, + headers=headers + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Role updated successfully" + assert data["new_role"] == "owner" + + +def test_update_member_role_to_member(client, owner_user, member_user, test_team): + """Test demoting an owner to member""" + # First get the member's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + member_data = next(m for m in members if m["email"] == member_user["email"]) + member_id = member_data["user_id"] + + # Demote to member + role_data = {"role": "member"} + resp = client.put( + f"/teams/{test_team['id']}/members/{member_id}/role", + json=role_data, + headers=headers + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Role updated successfully" + assert data["new_role"] == "member" + + +def test_cannot_demote_last_owner(client, owner_user, test_team): + """Test that the last owner cannot be demoted""" + # Get owner's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to demote + role_data = {"role": "member"} + resp = client.put( + f"/teams/{test_team['id']}/members/{owner_id}/role", + json=role_data, + headers=headers + ) + + assert resp.status_code == 400 + assert "last owner" in resp.json()["detail"].lower() + + +def test_member_cannot_invite(client, member_user, test_team): + """Test that a member cannot invite other users""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": "another@test.com", + "role": "member" + } + resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_member_cannot_update_roles(client, owner_user, member_user, test_team): + """Test that a member cannot change roles""" + # Get owner's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to update role as member + headers_member = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + role_data = {"role": "member"} + resp = client.put( + f"/teams/{test_team['id']}/members/{owner_id}/role", + json=role_data, + headers=headers_member + ) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_remove_member(client, owner_user, member_user, test_team): + """Test removing a member from the team""" + # Get member's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + member_data = next(m for m in members if m["email"] == member_user["email"]) + member_id = member_data["user_id"] + + # Remove member + resp = client.delete( + f"/teams/{test_team['id']}/members/{member_id}", + headers=headers + ) + + assert resp.status_code == 200 + assert resp.json()["message"] == "Member removed successfully" + + # Verify member is removed + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + emails = [m["email"] for m in members] + assert member_user["email"] not in emails + + +def test_cannot_remove_last_owner(client, owner_user, test_team): + """Test that the last owner cannot be removed""" + # Get owner's user_id + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to remove + resp = client.delete( + f"/teams/{test_team['id']}/members/{owner_id}", + headers=headers + ) + + assert resp.status_code == 400 + assert "last owner" in resp.json()["detail"].lower() + + +def test_member_cannot_remove_members(client, owner_user, member_user, test_team): + """Test that a member cannot remove other members""" + # First re-add the member + headers_owner = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + invite_data = { + "email": member_user["email"], + "role": "member" + } + client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers_owner) + + # Get owner's user_id + resp = client.get(f"/teams/{test_team['id']}/members", headers=headers_owner) + members = resp.json()["members"] + owner_data = next(m for m in members if m["email"] == owner_user["email"]) + owner_id = owner_data["user_id"] + + # Try to remove as member + headers_member = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.delete( + f"/teams/{test_team['id']}/members/{owner_id}", + headers=headers_member + ) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_update_team_name(client, owner_user, test_team): + """Test updating team name as owner""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + update_data = {"name": "Updated Team Name"} + resp = client.put(f"/teams/{test_team['id']}", json=update_data, headers=headers) + + assert resp.status_code == 200 + team = resp.json() + assert team["name"] == "Updated Team Name" + + +def test_member_cannot_update_team_name(client, member_user, test_team): + """Test that a member cannot update team name""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + update_data = {"name": "Hacked Name"} + resp = client.put(f"/teams/{test_team['id']}", json=update_data, headers=headers) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_delete_team(client, owner_user): + """Test deleting a team""" + # Create a team to delete + headers = {"Authorization": f"Bearer {owner_user['token']}"} + team_data = {"name": "Team to Delete"} + resp = client.post("/teams", json=team_data, headers=headers) + team = resp.json() + + # Delete it + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": team["id"] + } + resp = client.delete(f"/teams/{team['id']}", headers=headers) + + assert resp.status_code == 200 + assert resp.json()["message"] == "Team deleted" + + +def test_member_cannot_delete_team(client, member_user, test_team): + """Test that a member cannot delete a team""" + headers = { + "Authorization": f"Bearer {member_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.delete(f"/teams/{test_team['id']}", headers=headers) + + assert resp.status_code == 403 + assert "owner" in resp.json()["detail"].lower() + + +def test_cannot_delete_team_with_multiple_users(client, owner_user, member_user, test_team): + """Test that a team with multiple users cannot be deleted""" + headers = { + "Authorization": f"Bearer {owner_user['token']}", + "X-Team-Id": test_team["id"] + } + resp = client.delete(f"/teams/{test_team['id']}", headers=headers) + + assert resp.status_code == 400 + assert "multiple users" in resp.json()["detail"].lower() diff --git a/transformerlab/models/users.py b/transformerlab/models/users.py new file mode 100644 index 000000000..4170e8dcd --- /dev/null +++ b/transformerlab/models/users.py @@ -0,0 +1,151 @@ +# users.py +import uuid +from typing import Optional +from fastapi import Depends, Request +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas +from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy +from fastapi_users.db import SQLAlchemyUserDatabase +from transformerlab.shared.models.user_model import User, get_async_session, create_default_team +from transformerlab.shared.models.models import UserTeam +from sqlalchemy.ext.asyncio import AsyncSession +from jose import jwt as _jose_jwt +from datetime import datetime, timedelta + + +# --- Pydantic Schemas for API interactions --- +class UserRead(schemas.BaseUser[uuid.UUID]): + # Add your custom fields here if you added them to the User model + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass + + +# --- User Manager (Handles registration, password reset, etc.) --- +SECRET = "YOUR_STRONG_SECRET" # !! CHANGE THIS IN PRODUCTION !! +REFRESH_SECRET = "YOUR_REFRESH_TOKEN_SECRET" # !! USE A DIFFERENT SECRET !! +REFRESH_LIFETIME = 60 * 60 * 24 * 7 # 7 days + + +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + reset_password_token_secret = SECRET + verification_token_secret = SECRET + + # Optional: Define custom logic after registration + async def on_after_register(self, user: User, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + # Add to default team + async with self.user_db.session as session: + team = await create_default_team(session) + user_team = UserTeam(user_id=str(user.id), team_id=team.id) + session.add(user_team) + await session.commit() + + async def on_after_forgot_password(self, user: User, token: str, request: Request | None = None): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify(self, user: User, token: str, request: Request | None = None): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + +async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(session, User) + + +async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): + yield UserManager(user_db) + + +# --- Authentication Backend (JWT/Bearer Token) --- +bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") + + +def get_jwt_strategy() -> JWTStrategy: + # Token lasts for 3600 seconds (1 hour) + return JWTStrategy(secret=SECRET, lifetime_seconds=3600) + + +auth_backend = AuthenticationBackend( + name="jwt", + transport=bearer_transport, + get_strategy=get_jwt_strategy, +) + +# --- FastAPIUsers Instance (The main utility) --- +fastapi_users = FastAPIUsers[User, uuid.UUID]( + get_user_manager, + [auth_backend], # Add more backends (like Google OAuth) here +) + +# --- Dependency for Protected Routes --- +# This is what you'll use in your route decorators +current_active_user = fastapi_users.current_user(active=True) + + +def get_refresh_strategy() -> JWTStrategy: + return JWTStrategy(secret=REFRESH_SECRET, lifetime_seconds=REFRESH_LIFETIME) + + +# --- Small helper to create access + refresh tokens for manual flows (e.g. refresh endpoint) --- + + +class _JWTAuthenticationHelper: + """Minimal helper that mirrors a login response (access + refresh token). + + We keep this small and explicit so callers (like the `refresh` endpoint in + `routers/auth.py`) can create new access tokens when given a valid + refresh token. + """ + + def __init__( + self, + access_secret: str, + refresh_secret: str, + access_lifetime: int = 3600, + refresh_lifetime: int = REFRESH_LIFETIME, + ): + self.access_secret = access_secret + self.refresh_secret = refresh_secret + self.access_lifetime = access_lifetime + self.refresh_lifetime = refresh_lifetime + + def _create_token(self, user, secret: str, lifetime_seconds: int) -> str: + now = datetime.utcnow() + exp = now + timedelta(seconds=lifetime_seconds) + payload = { + "sub": str(user.id), + "email": getattr(user, "email", None), + "exp": int(exp.timestamp()), + } + return _jose_jwt.encode(payload, secret, algorithm="HS256") + + def get_login_response(self, user) -> dict: + """Return a dict similar to what FastAPI-Users returns on login. + + Keys: + - access_token: short-lived JWT + - refresh_token: long-lived JWT (can be validated with refresh strategy) + - token_type: 'bearer' + - expires_in: seconds until access token expiry + """ + access = self._create_token(user, self.access_secret, self.access_lifetime) + refresh = self._create_token(user, self.refresh_secret, self.refresh_lifetime) + return { + "access_token": access, + "refresh_token": refresh, + "token_type": "bearer", + "expires_in": self.access_lifetime, + } + + +# Module-level helpers for imports elsewhere +jwt_authentication = _JWTAuthenticationHelper( + SECRET, REFRESH_SECRET, access_lifetime=3600, refresh_lifetime=REFRESH_LIFETIME +) +access_strategy = get_jwt_strategy() +refresh_strategy = get_refresh_strategy() diff --git a/transformerlab/routers/auth2.py b/transformerlab/routers/auth2.py new file mode 100644 index 000000000..91fd0f71a --- /dev/null +++ b/transformerlab/routers/auth2.py @@ -0,0 +1,183 @@ +from fastapi import APIRouter, Depends, HTTPException, Header +from transformerlab.shared.models.user_model import User, get_async_session, create_default_team +from transformerlab.shared.models.models import Team, UserTeam, TeamRole +from transformerlab.models.users import ( + fastapi_users, + auth_backend, + current_active_user, + UserRead, + UserCreate, + UserUpdate, + get_user_manager, + get_refresh_strategy, + jwt_authentication, +) +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from jose import jwt, JWTError + +router = APIRouter(tags=["users"]) + + +# Include Auth and Registration Routers +router.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth/jwt", + tags=["auth"], +) +router.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +# Include User Management Router (allows authenticated users to view/update their profile) +router.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) + + +async def get_user_and_team( + user: User = Depends(current_active_user), + x_team: str | None = Header(None, alias="X-Team-Id"), + session: AsyncSession = Depends(get_async_session), +): + """ + Dependency to validate user authentication and team membership. + Extracts X-Team-Id header and verifies user belongs to that team. + Returns user, team_id, and role. + """ + if not x_team: + raise HTTPException(status_code=400, detail="X-Team-Id header missing") + + # Verify user is associated with the provided team id + stmt = select(UserTeam).where( + UserTeam.user_id == str(user.id), + UserTeam.team_id == x_team, + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if user_team is None: + raise HTTPException(status_code=403, detail="User is not a member of the specified team") + + return {"user": user, "team_id": x_team, "role": user_team.role} + + +async def require_team_owner( + user: User = Depends(current_active_user), + x_team: str | None = Header(None, alias="X-Team-Id"), + session: AsyncSession = Depends(get_async_session), +): + """ + Dependency to validate user authentication and ensure user is an owner of the team. + Extracts X-Team-Id header and verifies user has owner role. + """ + if not x_team: + raise HTTPException(status_code=400, detail="X-Team-Id header missing") + + # Verify user is an owner of the team + stmt = select(UserTeam).where( + UserTeam.user_id == str(user.id), + UserTeam.team_id == x_team, + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if user_team is None: + raise HTTPException(status_code=403, detail="User is not a member of the specified team") + + if user_team.role != TeamRole.OWNER.value: + raise HTTPException(status_code=403, detail="Only team owners can perform this action") + + return {"user": user, "team_id": x_team, "role": user_team.role} + + +@router.get("/test-users/authenticated-route") +async def authenticated_route(user_and_team=Depends(get_user_and_team)): + user = user_and_team["user"] + team_id = user_and_team["team_id"] + return {"message": f"Hello, {user.email}! You are authenticated and acting as part of team {team_id}."} + + +# To test this, register a new user via /auth/register +# curl -X POST 'http://127.0.0.1:8338/auth/register' \ +# -H 'Content-Type: application/json' \ +# -d '{ +# "email": "test@example.com", +# "password": "password123" +# }' + +# Then +# curl -X POST 'http://127.0.0.1:8338/auth/jwt/login' \ +# -H 'Content-Type: application/x-www-form-urlencoded' \ +# -d 'username=test@example.com&password=password123' + +# This will return a token, which you can use to access the authenticated route: +# curl -X GET 'http://127.0.0.1:8338/authenticated-route' \ +# -H 'Authorization: Bearer ' + + +@router.post("/auth/refresh") +async def refresh_access_token( + refresh_token: str, # Sent by the client in the request body + user_manager=Depends(get_user_manager), +): + try: + # 1. Decode and Validate the Refresh Token + # Get a fresh refresh strategy instance and use its secret to decode + refresh_strategy = get_refresh_strategy() + payload = jwt.decode(refresh_token, str(refresh_strategy.secret), algorithms=["HS256"]) + user_id = payload.get("sub") + + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid refresh token payload") + + # 2. Get the user object from the database + user = await user_manager.get(user_id) + if user is None or not user.is_active: + raise HTTPException(status_code=401, detail="User inactive or not found") + + # 3. Create a NEW Access Token (using the short-lived strategy from the main JWT) + new_access_token = jwt_authentication.get_login_response(user) # Needs custom helper + + return {"access_token": new_access_token["access_token"], "token_type": "bearer"} + + except JWTError: + raise HTTPException(status_code=401, detail="Expired or invalid refresh token") + + +@router.get("/users/me/teams") +async def get_user_teams(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)): + # Check if user has any team associations + stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) + result = await session.execute(stmt) + user_teams = result.scalars().all() + + # If user has no team associations, assign them to default team as owner + if not user_teams: + default_team = await create_default_team(session) + user_team = UserTeam(user_id=str(user.id), team_id=default_team.id, role=TeamRole.OWNER.value) + session.add(user_team) + await session.commit() + await session.refresh(user_team) + return {"user_id": str(user.id), "teams": [{"id": default_team.id, "name": default_team.name, "role": TeamRole.OWNER.value}]} + + # User has team associations, get the actual team objects + team_ids = [ut.team_id for ut in user_teams] + stmt = select(Team).where(Team.id.in_(team_ids)) + result = await session.execute(stmt) + teams = result.scalars().all() + + # Create a mapping of team_id to team + teams_dict = {team.id: team for team in teams} + + # Return teams with role information + teams_with_roles = [ + {"id": ut.team_id, "name": teams_dict[ut.team_id].name, "role": ut.role} + for ut in user_teams if ut.team_id in teams_dict + ] + + return {"user_id": str(user.id), "teams": teams_with_roles} diff --git a/transformerlab/routers/teams.py b/transformerlab/routers/teams.py new file mode 100644 index 000000000..796d44780 --- /dev/null +++ b/transformerlab/routers/teams.py @@ -0,0 +1,301 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from transformerlab.shared.models.user_model import User, get_async_session +from transformerlab.shared.models.models import Team, UserTeam, TeamRole +from transformerlab.models.users import current_active_user +from transformerlab.routers.auth2 import require_team_owner, get_user_and_team +from pydantic import BaseModel +from sqlalchemy import select, delete, update, func + + +class TeamCreate(BaseModel): + name: str + + +class TeamUpdate(BaseModel): + name: str + + +class TeamResponse(BaseModel): + id: str + name: str + + +class InviteMemberRequest(BaseModel): + email: str + role: str = TeamRole.MEMBER.value + + +class UpdateMemberRoleRequest(BaseModel): + role: str + + +class MemberResponse(BaseModel): + user_id: str + email: str + role: str + + +router = APIRouter(tags=["teams"]) + + +@router.post("/teams", response_model=TeamResponse) +async def create_team( + team_data: TeamCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + # Create team + team = Team(name=team_data.name) + session.add(team) + await session.commit() + await session.refresh(team) + + # Add user to the team as owner + user_team = UserTeam(user_id=str(user.id), team_id=team.id, role=TeamRole.OWNER.value) + session.add(user_team) + await session.commit() + + return TeamResponse(id=team.id, name=team.name) + + +@router.put("/teams/{team_id}", response_model=TeamResponse) +async def update_team( + team_id: str, + team_data: TeamUpdate, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Update team name. Only team owners can update the team.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Update + stmt = update(Team).where(Team.id == team_id).values(name=team_data.name) + await session.execute(stmt) + await session.commit() + + # Fetch updated + stmt = select(Team).where(Team.id == team_id) + result = await session.execute(stmt) + team = result.scalar_one() + + return TeamResponse(id=team.id, name=team.name) + + +@router.delete("/teams/{team_id}") +async def delete_team( + team_id: str, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Delete a team. Only team owners can delete the team.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + user = owner_info["user"] + + # Check if user has other teams + stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) + result = await session.execute(stmt) + user_teams = result.scalars().all() + if len(user_teams) <= 1: + raise HTTPException(status_code=400, detail="Cannot delete the last team") + + # Check if team has only this user + stmt = select(UserTeam).where(UserTeam.team_id == team_id) + result = await session.execute(stmt) + team_users = result.scalars().all() + if len(team_users) > 1: + raise HTTPException(status_code=400, detail="Cannot delete team with multiple users. Remove other members first.") + + # Delete associations and team + stmt = delete(UserTeam).where(UserTeam.team_id == team_id) + await session.execute(stmt) + stmt = delete(Team).where(Team.id == team_id) + await session.execute(stmt) + await session.commit() + + return {"message": "Team deleted"} + + +@router.get("/teams/{team_id}/members") +async def get_team_members( + team_id: str, + user_and_team=Depends(get_user_and_team), + session: AsyncSession = Depends(get_async_session), +): + """Get all members of a team. Any team member can view this.""" + # Verify team_id matches the one in header + if team_id != user_and_team["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Get all members of the team + stmt = select(UserTeam).where(UserTeam.team_id == team_id) + result = await session.execute(stmt) + user_teams = result.scalars().all() + + # Get user details + user_ids = [ut.user_id for ut in user_teams] + stmt = select(User).where(User.id.in_(user_ids)) + result = await session.execute(stmt) + users = result.scalars().all() + + # Create a mapping + users_dict = {str(user.id): user for user in users} + + members = [ + MemberResponse( + user_id=ut.user_id, + email=users_dict[ut.user_id].email if ut.user_id in users_dict else "unknown", + role=ut.role + ) + for ut in user_teams + ] + + return {"team_id": team_id, "members": members} + + +@router.post("/teams/{team_id}/members") +async def invite_member( + team_id: str, + invite_data: InviteMemberRequest, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Invite a user to the team. Only team owners can invite members.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Validate role + if invite_data.role not in [TeamRole.OWNER.value, TeamRole.MEMBER.value]: + raise HTTPException(status_code=400, detail="Invalid role. Must be 'owner' or 'member'") + + # Find user by email + stmt = select(User).where(User.email == invite_data.email) + result = await session.execute(stmt) + invited_user = result.scalar_one_or_none() + + if not invited_user: + raise HTTPException(status_code=404, detail="User not found") + + # Check if user is already in the team + stmt = select(UserTeam).where( + UserTeam.user_id == str(invited_user.id), + UserTeam.team_id == team_id + ) + result = await session.execute(stmt) + if result.scalar_one_or_none(): + raise HTTPException(status_code=400, detail="User is already a member of this team") + + # Add user to team + user_team = UserTeam(user_id=str(invited_user.id), team_id=team_id, role=invite_data.role) + session.add(user_team) + await session.commit() + + return { + "message": "User invited successfully", + "user_id": str(invited_user.id), + "email": invited_user.email, + "role": invite_data.role + } + + +@router.delete("/teams/{team_id}/members/{user_id}") +async def remove_member( + team_id: str, + user_id: str, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Remove a member from the team. Only team owners can remove members.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Check if the user to be removed exists in the team + stmt = select(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if not user_team: + raise HTTPException(status_code=404, detail="User is not a member of this team") + + # If removing an owner, check that there's at least one other owner + if user_team.role == TeamRole.OWNER.value: + stmt = select(func.count()).select_from(UserTeam).where( + UserTeam.team_id == team_id, + UserTeam.role == TeamRole.OWNER.value + ) + result = await session.execute(stmt) + owner_count = result.scalar() + + if owner_count <= 1: + raise HTTPException(status_code=400, detail="Cannot remove the last owner from the team") + + # Remove user from team + stmt = delete(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ) + await session.execute(stmt) + await session.commit() + + return {"message": "Member removed successfully"} + + +@router.put("/teams/{team_id}/members/{user_id}/role") +async def update_member_role( + team_id: str, + user_id: str, + role_data: UpdateMemberRoleRequest, + owner_info=Depends(require_team_owner), + session: AsyncSession = Depends(get_async_session), +): + """Update a member's role. Only team owners can change roles.""" + # Verify team_id matches the one in header + if team_id != owner_info["team_id"]: + raise HTTPException(status_code=400, detail="Team ID mismatch") + + # Validate role + if role_data.role not in [TeamRole.OWNER.value, TeamRole.MEMBER.value]: + raise HTTPException(status_code=400, detail="Invalid role. Must be 'owner' or 'member'") + + # Check if the user exists in the team + stmt = select(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ) + result = await session.execute(stmt) + user_team = result.scalar_one_or_none() + + if not user_team: + raise HTTPException(status_code=404, detail="User is not a member of this team") + + # If demoting from owner to member, check that there's at least one other owner + if user_team.role == TeamRole.OWNER.value and role_data.role == TeamRole.MEMBER.value: + stmt = select(func.count()).select_from(UserTeam).where( + UserTeam.team_id == team_id, + UserTeam.role == TeamRole.OWNER.value + ) + result = await session.execute(stmt) + owner_count = result.scalar() + + if owner_count <= 1: + raise HTTPException(status_code=400, detail="Cannot demote the last owner") + + # Update role + stmt = update(UserTeam).where( + UserTeam.user_id == user_id, + UserTeam.team_id == team_id + ).values(role=role_data.role) + await session.execute(stmt) + await session.commit() + + return {"message": "Role updated successfully", "user_id": user_id, "new_role": role_data.role} \ No newline at end of file diff --git a/transformerlab/shared/models/models.py b/transformerlab/shared/models/models.py index 70584c775..2d3c41e2a 100644 --- a/transformerlab/shared/models/models.py +++ b/transformerlab/shared/models/models.py @@ -1,6 +1,8 @@ from typing import Optional from sqlalchemy import String, JSON, DateTime, func, Integer, Index from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +import uuid +import enum class Base(DeclarativeBase): @@ -82,3 +84,27 @@ class WorkflowRun(Base): updated_at: Mapped[DateTime] = mapped_column( DateTime, server_default=func.now(), onupdate=func.now(), nullable=False ) + +class Team(Base): + """Team model.""" + + __tablename__ = "teams" + + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + name: Mapped[str] = mapped_column(String, nullable=False) + + +class TeamRole(str, enum.Enum): + """Enum for user roles within a team.""" + OWNER = "owner" + MEMBER = "member" + + +class UserTeam(Base): + """User-Team association model.""" + + __tablename__ = "users_teams" + + user_id: Mapped[str] = mapped_column(String, primary_key=True) + team_id: Mapped[str] = mapped_column(String, primary_key=True) + role: Mapped[str] = mapped_column(String, nullable=False, default=TeamRole.MEMBER.value) \ No newline at end of file diff --git a/transformerlab/shared/models/user_model.py b/transformerlab/shared/models/user_model.py new file mode 100644 index 000000000..8ff244a94 --- /dev/null +++ b/transformerlab/shared/models/user_model.py @@ -0,0 +1,45 @@ +# database.py +from typing import AsyncGenerator +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from fastapi_users.db import SQLAlchemyBaseUserTableUUID +from sqlalchemy import select + +# Replace with your actual database URL (e.g., PostgreSQL, SQLite) +from transformerlab.db.constants import DATABASE_URL +from .models import Base, Team + + +# 1. Define your User Model (inherits from a FastAPI Users base class) +class User(SQLAlchemyBaseUserTableUUID, Base): + pass # You can add custom fields here later, like 'first_name: str' + + +# 2. Setup the Async Engine and Session +engine = create_async_engine(DATABASE_URL) +AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +# 3. Utility to create tables (run this on app startup) +async def create_db_and_tables(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +# 4. Database session dependency +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSessionLocal() as session: + yield session + + +# 5. Create default team if not exists +async def create_default_team(session: AsyncSession) -> Team: + stmt = select(Team).where(Team.name == "Default Team") + result = await session.execute(stmt) + team = result.scalar_one_or_none() + if not team: + team = Team(name="Default Team") + session.add(team) + await session.commit() + await session.refresh(team) + return team