From c3af93a55a13586342cadee2fb38d2bde07bc460 Mon Sep 17 00:00:00 2001 From: Sayre Blades Date: Thu, 18 Dec 2025 16:58:05 -0500 Subject: [PATCH 1/3] plan: adding the plan.md file --- PLAN.md | 528 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 PLAN.md diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..770e694 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,528 @@ +# TRELLIS.2 Modal Integration Plan + +This document outlines the implementation plan for adding Modal serverless deployment to TRELLIS.2, following the patterns established in TRELLIS v1. + +## Summary of Decisions + +| Decision | Choice | Rationale | +|---------------------|---------------------------------|------------------------------------------------| +| State Serialization | Store Final MeshWithVoxel | Simpler extraction, matches v1 pattern | +| Attention Backend | flash_attn | Best A100 performance, v2 default | +| Preview Rendering | Single PBR video + bundled HDRI | Quality/simplicity balance | +| CUDA Version | 12.4 + PyTorch 2.6.0 | Matches upstream requirements | +| API Parameters | Tiered (basic + advanced) | Simple for basic use, flexible for power users | +| GLB Extraction | Expose all parameters | Maximum flexibility | +| GPU | A100-80GB | Sufficient for all pipeline types | + +## Architecture Overview + +``` ++-------------------------------------------------------------+ +| LOCAL GRADIO CLIENT | +| - Image upload, parameter controls, result display | +| - Stores LZ4-compressed MeshWithVoxel state | +| - Communicates via HTTPS | ++-----------------------------+-------------------------------+ + | + | POST /generate, POST /extract_glb + | Headers: X-API-Key: sk_xxx + v ++-------------------------------------------------------------+ +| MODAL WEB ENDPOINTS (A100-80GB) | +| | +| Authentication Layer (API Key validation from Volume) | +| | | +| v | +| +-------------------------------------------------------+ | +| | TRELLIS2Service Class | | +| | gpu="A100-80GB" | | +| | enable_memory_snapshot=True | | +| | | | +| | @enter(snap=True): Load pipeline to CPU | | +| | @enter(snap=False): Move pipeline to GPU | | +| | generate(): Image -> MeshWithVoxel state + video | | +| | extract_glb(): State -> GLB mesh | | +| +-------------------------------------------------------+ | +| | +| Volumes: | +| - /cache/huggingface: Model weights cache | +| - /data/keys.json: API key storage | ++-------------------------------------------------------------+ +``` + +## Implementation Phases + +### Phase 1: Modal Image Definition (image.py) + +**Goal**: Build a Modal container image with all CUDA extensions pre-compiled. + +**Test First (TDD)**: +```python +# tests/test_image_build.py +def test_image_has_required_cuda_extensions(): + """Verify all CUDA extensions are importable.""" + # This test runs inside the Modal container + +def test_pytorch_cuda_available(): + """Verify PyTorch can access CUDA.""" + +def test_flash_attn_available(): + """Verify flash_attn is installed and working.""" + +def test_trellis2_pipeline_importable(): + """Verify trellis2 package is on PYTHONPATH.""" +``` + +**Implementation**: + +1. Base image: `nvidia/cuda:12.4.0-devel-ubuntu22.04` + Python 3.10 +2. System packages: git, ninja-build, cmake, clang, OpenGL libs +3. Python packages (no GPU needed): + - pillow, imageio, imageio-ffmpeg, tqdm, easydict + - opencv-python-headless, trimesh, transformers, huggingface-hub + - lz4, fastapi, safetensors, kornia, timm +4. GPU build block (requires `gpu="T4"` for compilation): + - PyTorch 2.6.0 + CUDA 12.4 + - flash-attn 2.7.3 + - nvdiffrast (v0.4.0 tag) + - nvdiffrec (renderutils branch from JeffreyXiang fork) + - CuMesh (from JeffreyXiang/CuMesh) + - FlexGEMM (from JeffreyXiang/FlexGEMM) + - o-voxel (bundled in repo, copy and build) +5. Clone TRELLIS.2 repo to /opt/TRELLIS.2 +6. Pre-download models: + - DINOv2 (via torch.hub) + - BiRefNet for background removal +7. Bundle HDRI file (forest.exr) for rendering +8. Environment variables: + ``` + ATTN_BACKEND=flash_attn + PYTHONPATH=/opt/TRELLIS.2 + HF_HOME=/cache/huggingface + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + OPENCV_IO_ENABLE_OPENEXR=1 + ``` + +**Verification**: `modal run trellis2_modal/service/image.py::verify_image` + +--- + +### Phase 2: Generator Wrapper (generator.py) + +**Goal**: Wrap Trellis2ImageTo3DPipeline with Modal's two-phase loading pattern. + +**Test First (TDD)**: +```python +# tests/test_generator.py +def test_generator_load_model_cpu(): + """Test CPU loading with mock pipeline factory.""" + +def test_generator_move_model_gpu(): + """Test GPU transfer after CPU load.""" + +def test_generator_generate_returns_mesh_with_voxel(): + """Test generation returns correct state structure.""" + +def test_generator_all_pipeline_types(): + """Test 512, 1024, 1024_cascade, 1536_cascade.""" + +def test_generator_render_preview_video(): + """Test PBR video rendering with HDRI.""" +``` + +**Implementation**: + +```python +class TRELLIS2Generator: + def __init__(self, pipeline_factory=None): + self._pipeline_factory = pipeline_factory or _default_pipeline_factory + self.pipeline = None + self.envmap = None + + def load_model_cpu(self): + """@modal.enter(snap=True) - Load to CPU for snapshot.""" + self.pipeline = self._pipeline_factory(MODEL_NAME) + # Don't load envmap here - needs GPU + + def move_model_gpu(self): + """@modal.enter(snap=False) - Move to GPU after restore.""" + self.pipeline.cuda() + self._load_envmap() + + def _load_envmap(self): + """Load bundled HDRI for rendering.""" + import cv2 + from trellis2.renderers import EnvMap + hdri = cv2.imread('/opt/TRELLIS.2/assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED) + hdri = cv2.cvtColor(hdri, cv2.COLOR_BGR2RGB) + self.envmap = EnvMap(torch.tensor(hdri, dtype=torch.float32, device='cuda')) + + def generate(self, image, seed, pipeline_type, + ss_params=None, shape_params=None, tex_params=None): + """Generate MeshWithVoxel from image.""" + mesh = self.pipeline.run( + image, + seed=seed, + pipeline_type=pipeline_type, + sparse_structure_sampler_params=ss_params or {}, + shape_slat_sampler_params=shape_params or {}, + tex_slat_sampler_params=tex_params or {}, + )[0] + mesh.simplify(16777216) # nvdiffrast limit + return mesh + + def render_preview_video(self, mesh, num_frames=120, fps=15): + """Render PBR video with environment lighting.""" + from trellis2.utils import render_utils + frames = render_utils.render_video(mesh, envmap=self.envmap, + num_frames=num_frames) + # Encode to MP4 bytes + ... +``` + +--- + +### Phase 3: State Serialization (state.py) + +**Goal**: Pack/unpack MeshWithVoxel for client-side storage. + +**Test First (TDD)**: +```python +# tests/test_state.py +def test_pack_state_contains_all_fields(): + """Verify packed state has vertices, faces, attrs, coords, etc.""" + +def test_unpack_state_reconstructs_mesh(): + """Verify unpacked mesh matches original.""" + +def test_state_roundtrip_preserves_data(): + """Pack -> unpack -> pack produces identical state.""" + +def test_state_is_json_serializable_after_numpy(): + """Verify numpy arrays are used (not tensors).""" +``` + +**Implementation**: + +```python +def pack_state(mesh: MeshWithVoxel) -> dict: + """Pack MeshWithVoxel into serializable dict.""" + return { + "vertices": mesh.vertices.cpu().numpy(), + "faces": mesh.faces.cpu().numpy(), + "attrs": mesh.attrs.cpu().numpy(), + "coords": mesh.coords.cpu().numpy(), + "voxel_size": float(mesh.voxel_size), + "voxel_shape": list(mesh.voxel_shape), + "origin": mesh.origin.cpu().tolist(), + "layout": {k: [v.start, v.stop] for k, v in mesh.layout.items()}, + } + +def unpack_state(state: dict) -> MeshWithVoxel: + """Reconstruct MeshWithVoxel from packed state.""" + layout = {k: slice(v[0], v[1]) for k, v in state["layout"].items()} + return MeshWithVoxel( + vertices=torch.tensor(state["vertices"], device="cuda"), + faces=torch.tensor(state["faces"], device="cuda"), + origin=state["origin"], + voxel_size=state["voxel_size"], + coords=torch.tensor(state["coords"], device="cuda"), + attrs=torch.tensor(state["attrs"], device="cuda"), + voxel_shape=torch.Size(state["voxel_shape"]), + layout=layout, + ) +``` + +--- + +### Phase 4: Service Endpoints (service.py) + +**Goal**: Modal web endpoints for generate and extract_glb. + +**Test First (TDD)**: +```python +# tests/test_service.py +def test_generate_endpoint_validates_api_key(): + """401 on missing/invalid key.""" + +def test_generate_endpoint_validates_image(): + """400 on invalid image payload.""" + +def test_generate_endpoint_returns_state_and_video(): + """Successful generation returns expected structure.""" + +def test_extract_glb_endpoint_returns_glb(): + """GLB extraction returns valid GLB bytes.""" + +def test_health_endpoint_no_auth_required(): + """Health check works without API key.""" +``` + +**Implementation**: + +```python +@app.cls( + image=trellis2_image, + gpu="A100-80GB", + volumes={ + HF_CACHE_PATH: hf_cache_volume, + API_KEYS_PATH: api_keys_volume, + }, + enable_memory_snapshot=True, + container_idle_timeout=300, +) +class TRELLIS2Service: + @modal.enter(snap=True) + def load_model_cpu(self): + self.generator = TRELLIS2Generator() + self.generator.load_model_cpu() + + @modal.enter(snap=False) + def load_model_gpu(self): + self.generator.move_model_gpu() + + @modal.web_endpoint(method="POST") + def generate(self, request: Request) -> dict: + """Generate 3D from image.""" + # Auth check + # Validate request + # Generate mesh + # Render video + # Pack state + # Return base64 state + video + + @modal.web_endpoint(method="POST") + def extract_glb(self, request: Request) -> dict: + """Extract GLB from state.""" + # Auth check + # Unpack state + # Call o_voxel.postprocess.to_glb() + # Return base64 GLB + + @modal.web_endpoint(method="GET") + def health(self) -> dict: + return {"status": "ok", "service": "trellis2-api"} +``` + +**API Request/Response**: + +``` +POST /generate +Request: +{ + "image": "", + "seed": 42, + "pipeline_type": "1024_cascade", // 512, 1024, 1024_cascade, 1536_cascade + + // Optional advanced params + "ss_sampling_steps": 12, + "ss_guidance_strength": 7.5, + "ss_guidance_rescale": 0.7, + "ss_rescale_t": 5.0, + "shape_slat_sampling_steps": 12, + "shape_slat_guidance_strength": 7.5, + "shape_slat_guidance_rescale": 0.5, + "shape_slat_rescale_t": 3.0, + "tex_slat_sampling_steps": 12, + "tex_slat_guidance_strength": 1.0, + "tex_slat_guidance_rescale": 0.0, + "tex_slat_rescale_t": 3.0 +} + +Response: +{ + "state": "", + "video": "" +} +``` + +``` +POST /extract_glb +Request: +{ + "state": "", + "decimation_target": 500000, + "texture_size": 2048, + "remesh": true, + "remesh_band": 1.0, + "remesh_project": 0.0 +} + +Response: +{ + "glb": "" +} +``` + +--- + +### Phase 5: Client (api.py, app.py) + +**Goal**: Gradio client that talks to Modal service. + +**Test First (TDD)**: +```python +# tests/test_api_client.py +def test_client_generate_sends_correct_request(): + """Verify request format matches API spec.""" + +def test_client_handles_cold_start_timeout(): + """Retry logic for slow cold starts.""" + +def test_client_decompresses_state(): + """LZ4 decompression works correctly.""" +``` + +**Implementation**: + +Adapt v1's client with updated: +- Parameter names for TRELLIS.2 +- Pipeline type selector (512/1024/1536) +- Advanced parameter accordions for each stage +- Updated preview display for PBR video + +--- + +### Phase 6: Auth & Compression (reuse from v1) + +**Goal**: Copy and adapt auth.py and compression.py from v1. + +These modules are largely unchanged: +- `auth.py`: API key validation, rate limiting, quota management +- `compression.py`: LZ4 compression for state transfer + +Minor updates: +- Update app name references +- Update volume names + +--- + +### Phase 7: Documentation & CI + +**Goal**: Docs and GitHub Actions for testing. + +Files to create: +- `trellis2_modal/docs/MODAL_INTEGRATION.md` +- `trellis2_modal/docs/OPERATIONS_RUNBOOK.md` +- `.github/workflows/modal-ci.yml` + +--- + +## File Structure + +``` +TRELLIS.2/ +├── trellis2/ # Existing - unchanged +├── o-voxel/ # Existing - unchanged +├── trellis2_modal/ # NEW +│ ├── __init__.py +│ ├── service/ +│ │ ├── __init__.py +│ │ ├── config.py # GPU type, model name, defaults +│ │ ├── image.py # Modal image definition +│ │ ├── generator.py # TRELLIS2Generator class +│ │ ├── service.py # Modal endpoints +│ │ ├── state.py # Pack/unpack MeshWithVoxel +│ │ ├── auth.py # API key management (from v1) +│ │ └── streaming.py # SSE utilities (from v1) +│ ├── client/ +│ │ ├── __init__.py +│ │ ├── app.py # Gradio UI +│ │ ├── api.py # HTTP client +│ │ ├── compression.py # LZ4 (from v1) +│ │ └── requirements.txt +│ ├── tests/ +│ │ ├── __init__.py +│ │ ├── conftest.py # Mock fixtures +│ │ ├── test_generator.py +│ │ ├── test_state.py +│ │ ├── test_api_client.py +│ │ ├── test_auth.py +│ │ ├── test_compression.py +│ │ ├── test_service.py +│ │ └── test_integration.py +│ └── docs/ +│ ├── MODAL_INTEGRATION.md +│ └── OPERATIONS_RUNBOOK.md +├── .github/ +│ └── workflows/ +│ └── modal-ci.yml +└── PLAN.md # This file +``` + +--- + +## Dependencies Comparison + +### TRELLIS v1 (CUDA 11.8) +``` +torch==2.4.0+cu118 +xformers==0.0.27.post2 +spconv-cu118 +nvdiffrast (pinned commit) +diffoctreerast (pinned commit) +diff-gaussian-rasterization (mip-splatting) +``` + +### TRELLIS.2 (CUDA 12.4) +``` +torch==2.6.0+cu124 +flash-attn==2.7.3 +nvdiffrast==0.4.0 +nvdiffrec (renderutils branch) +cumesh (JeffreyXiang/CuMesh) +flex_gemm (JeffreyXiang/FlexGEMM) +o_voxel (bundled) +``` + +--- + +## Estimated Effort + +| Phase | Description | Effort | +|-------|-------------|--------| +| 1 | Modal Image | 4-6 hours (CUDA builds are finicky) | +| 2 | Generator | 2-3 hours | +| 3 | State Serialization | 1-2 hours | +| 4 | Service Endpoints | 3-4 hours | +| 5 | Client | 2-3 hours | +| 6 | Auth & Compression | 1 hour (mostly copy) | +| 7 | Docs & CI | 2 hours | + +**Total**: ~16-21 hours (2-3 days) + +--- + +## Risk Mitigation + +| Risk | Mitigation | +|------|------------| +| CUDA extension build failures | Pin exact commits, test builds early | +| Memory issues on A100-80GB | low_vram mode is built-in to TRELLIS.2 | +| Cold start too slow | GPU memory snapshots (proven in v1) | +| HDRI file too large for image | Compress or use smaller resolution | +| State too large after compression | Already using LZ4, can tune if needed | + +--- + +## Verification Checklist + +After each phase, verify: + +- [ ] All tests pass locally (mocked, no GPU) +- [ ] `modal run` works for smoke tests +- [ ] `modal deploy` succeeds +- [ ] Health endpoint responds +- [ ] Generate endpoint produces valid output +- [ ] Extract GLB produces downloadable file +- [ ] Cold start time acceptable (<30s with snapshot) +- [ ] Warm request latency acceptable (<60s for 1024) + +--- + +## Next Steps + +1. Review and approve this plan +2. Start Phase 1: Modal Image Definition +3. Test image builds before proceeding +4. Continue through phases sequentially From 315f15a6d8db9d3a4c334723586229bf919c39aa Mon Sep 17 00:00:00 2001 From: Daniel Nouri Date: Sat, 20 Dec 2025 17:35:27 +0100 Subject: [PATCH 2/3] feat: Add Modal serverless GPU deployment for TRELLIS.2 Deploy TRELLIS.2 to Modal's cloud infrastructure without needing a local GPU. Key features: - Modal service with /generate and /extract_glb endpoints on A100-80GB - Gradio web UI and Python API client for remote access - Modal Proxy Auth for secure authentication - Automatic endpoint URL routing for Modal subdomains - Comprehensive documentation and operations runbook Technical implementation: - Container image with pre-compiled CUDA extensions - MeshWithVoxel state serialization with LZ4 compression - Cold start optimization research (GPU snapshots not effective) - 141 unit tests covering service, client, and compression --- .github/workflows/modal-ci.yml | 121 ++++ .gitignore | 8 + PLAN.md | 528 -------------- README.md | 19 + trellis2_modal/README.md | 133 ++++ trellis2_modal/__init__.py | 1 + trellis2_modal/client/__init__.py | 22 + trellis2_modal/client/api.py | 332 +++++++++ trellis2_modal/client/app.py | 379 ++++++++++ trellis2_modal/client/compression.py | 73 ++ trellis2_modal/client/requirements.txt | 7 + trellis2_modal/docs/MODAL_INTEGRATION.md | 286 ++++++++ trellis2_modal/docs/OPERATIONS_RUNBOOK.md | 282 +++++++ trellis2_modal/requirements-deploy.txt | 11 + trellis2_modal/service/__init__.py | 1 + trellis2_modal/service/auth.py | 374 ++++++++++ trellis2_modal/service/config.py | 54 ++ trellis2_modal/service/generator.py | 270 +++++++ trellis2_modal/service/image.py | 289 ++++++++ trellis2_modal/service/service.py | 685 ++++++++++++++++++ trellis2_modal/service/state.py | 76 ++ trellis2_modal/tests/__init__.py | 1 + trellis2_modal/tests/conftest.py | 11 + trellis2_modal/tests/test_api_client.py | 551 ++++++++++++++ trellis2_modal/tests/test_auth.py | 411 +++++++++++ trellis2_modal/tests/test_compression.py | 153 ++++ .../tests/test_config_consistency.py | 42 ++ trellis2_modal/tests/test_generator.py | 158 ++++ trellis2_modal/tests/test_gpu_snapshots.py | 301 ++++++++ trellis2_modal/tests/test_service.py | 332 +++++++++ trellis2_modal/tests/test_state.py | 143 ++++ 31 files changed, 5526 insertions(+), 528 deletions(-) create mode 100644 .github/workflows/modal-ci.yml delete mode 100644 PLAN.md create mode 100644 trellis2_modal/README.md create mode 100644 trellis2_modal/__init__.py create mode 100644 trellis2_modal/client/__init__.py create mode 100644 trellis2_modal/client/api.py create mode 100644 trellis2_modal/client/app.py create mode 100644 trellis2_modal/client/compression.py create mode 100644 trellis2_modal/client/requirements.txt create mode 100644 trellis2_modal/docs/MODAL_INTEGRATION.md create mode 100644 trellis2_modal/docs/OPERATIONS_RUNBOOK.md create mode 100644 trellis2_modal/requirements-deploy.txt create mode 100644 trellis2_modal/service/__init__.py create mode 100644 trellis2_modal/service/auth.py create mode 100644 trellis2_modal/service/config.py create mode 100644 trellis2_modal/service/generator.py create mode 100644 trellis2_modal/service/image.py create mode 100644 trellis2_modal/service/service.py create mode 100644 trellis2_modal/service/state.py create mode 100644 trellis2_modal/tests/__init__.py create mode 100644 trellis2_modal/tests/conftest.py create mode 100644 trellis2_modal/tests/test_api_client.py create mode 100644 trellis2_modal/tests/test_auth.py create mode 100644 trellis2_modal/tests/test_compression.py create mode 100644 trellis2_modal/tests/test_config_consistency.py create mode 100644 trellis2_modal/tests/test_generator.py create mode 100644 trellis2_modal/tests/test_gpu_snapshots.py create mode 100644 trellis2_modal/tests/test_service.py create mode 100644 trellis2_modal/tests/test_state.py diff --git a/.github/workflows/modal-ci.yml b/.github/workflows/modal-ci.yml new file mode 100644 index 0000000..87e0d17 --- /dev/null +++ b/.github/workflows/modal-ci.yml @@ -0,0 +1,121 @@ +name: Modal CI + +on: + push: + branches: [main] + paths: + - 'trellis2_modal/**' + - '.github/workflows/modal-ci.yml' + pull_request: + branches: [main] + paths: + - 'trellis2_modal/**' + - '.github/workflows/modal-ci.yml' + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Run unit tests + run: | + uv run --with pytest --with lz4 --with numpy --with pillow --with modal --with fastapi --with requests \ + pytest trellis2_modal/tests/ -v --tb=short + + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Check formatting with ruff + run: | + uv run --with ruff ruff format --check trellis2_modal/ + + - name: Lint with ruff + run: | + uv run --with ruff ruff check trellis2_modal/ + + verify-image: + name: Verify Modal Image + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + needs: [test, lint] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install deploy dependencies + run: pip install -r trellis2_modal/requirements-deploy.txt + + - name: Set up Modal credentials + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + run: | + modal token set --token-id "$MODAL_TOKEN_ID" --token-secret "$MODAL_TOKEN_SECRET" + + - name: Verify image builds + run: modal run trellis2_modal/service/image.py + timeout-minutes: 30 + + deploy: + name: Deploy to Modal + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + needs: [verify-image] + environment: production + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install deploy dependencies + run: pip install -r trellis2_modal/requirements-deploy.txt + + - name: Set up Modal credentials + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + run: | + modal token set --token-id "$MODAL_TOKEN_ID" --token-secret "$MODAL_TOKEN_SECRET" + + - name: Deploy service + run: modal deploy -m trellis2_modal.service.service + timeout-minutes: 10 + + # Note: Health check removed because endpoint URL is dynamic + # Monitor deployment status in Modal dashboard instead diff --git a/.gitignore b/.gitignore index b7faf40..33e181e 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,11 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# Temporary test files +tmp/ + +# Modal credentials - NEVER commit these! +.trellis2_modal_secrets.json +*_secrets.json +*.secrets.json diff --git a/PLAN.md b/PLAN.md deleted file mode 100644 index 770e694..0000000 --- a/PLAN.md +++ /dev/null @@ -1,528 +0,0 @@ -# TRELLIS.2 Modal Integration Plan - -This document outlines the implementation plan for adding Modal serverless deployment to TRELLIS.2, following the patterns established in TRELLIS v1. - -## Summary of Decisions - -| Decision | Choice | Rationale | -|---------------------|---------------------------------|------------------------------------------------| -| State Serialization | Store Final MeshWithVoxel | Simpler extraction, matches v1 pattern | -| Attention Backend | flash_attn | Best A100 performance, v2 default | -| Preview Rendering | Single PBR video + bundled HDRI | Quality/simplicity balance | -| CUDA Version | 12.4 + PyTorch 2.6.0 | Matches upstream requirements | -| API Parameters | Tiered (basic + advanced) | Simple for basic use, flexible for power users | -| GLB Extraction | Expose all parameters | Maximum flexibility | -| GPU | A100-80GB | Sufficient for all pipeline types | - -## Architecture Overview - -``` -+-------------------------------------------------------------+ -| LOCAL GRADIO CLIENT | -| - Image upload, parameter controls, result display | -| - Stores LZ4-compressed MeshWithVoxel state | -| - Communicates via HTTPS | -+-----------------------------+-------------------------------+ - | - | POST /generate, POST /extract_glb - | Headers: X-API-Key: sk_xxx - v -+-------------------------------------------------------------+ -| MODAL WEB ENDPOINTS (A100-80GB) | -| | -| Authentication Layer (API Key validation from Volume) | -| | | -| v | -| +-------------------------------------------------------+ | -| | TRELLIS2Service Class | | -| | gpu="A100-80GB" | | -| | enable_memory_snapshot=True | | -| | | | -| | @enter(snap=True): Load pipeline to CPU | | -| | @enter(snap=False): Move pipeline to GPU | | -| | generate(): Image -> MeshWithVoxel state + video | | -| | extract_glb(): State -> GLB mesh | | -| +-------------------------------------------------------+ | -| | -| Volumes: | -| - /cache/huggingface: Model weights cache | -| - /data/keys.json: API key storage | -+-------------------------------------------------------------+ -``` - -## Implementation Phases - -### Phase 1: Modal Image Definition (image.py) - -**Goal**: Build a Modal container image with all CUDA extensions pre-compiled. - -**Test First (TDD)**: -```python -# tests/test_image_build.py -def test_image_has_required_cuda_extensions(): - """Verify all CUDA extensions are importable.""" - # This test runs inside the Modal container - -def test_pytorch_cuda_available(): - """Verify PyTorch can access CUDA.""" - -def test_flash_attn_available(): - """Verify flash_attn is installed and working.""" - -def test_trellis2_pipeline_importable(): - """Verify trellis2 package is on PYTHONPATH.""" -``` - -**Implementation**: - -1. Base image: `nvidia/cuda:12.4.0-devel-ubuntu22.04` + Python 3.10 -2. System packages: git, ninja-build, cmake, clang, OpenGL libs -3. Python packages (no GPU needed): - - pillow, imageio, imageio-ffmpeg, tqdm, easydict - - opencv-python-headless, trimesh, transformers, huggingface-hub - - lz4, fastapi, safetensors, kornia, timm -4. GPU build block (requires `gpu="T4"` for compilation): - - PyTorch 2.6.0 + CUDA 12.4 - - flash-attn 2.7.3 - - nvdiffrast (v0.4.0 tag) - - nvdiffrec (renderutils branch from JeffreyXiang fork) - - CuMesh (from JeffreyXiang/CuMesh) - - FlexGEMM (from JeffreyXiang/FlexGEMM) - - o-voxel (bundled in repo, copy and build) -5. Clone TRELLIS.2 repo to /opt/TRELLIS.2 -6. Pre-download models: - - DINOv2 (via torch.hub) - - BiRefNet for background removal -7. Bundle HDRI file (forest.exr) for rendering -8. Environment variables: - ``` - ATTN_BACKEND=flash_attn - PYTHONPATH=/opt/TRELLIS.2 - HF_HOME=/cache/huggingface - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - OPENCV_IO_ENABLE_OPENEXR=1 - ``` - -**Verification**: `modal run trellis2_modal/service/image.py::verify_image` - ---- - -### Phase 2: Generator Wrapper (generator.py) - -**Goal**: Wrap Trellis2ImageTo3DPipeline with Modal's two-phase loading pattern. - -**Test First (TDD)**: -```python -# tests/test_generator.py -def test_generator_load_model_cpu(): - """Test CPU loading with mock pipeline factory.""" - -def test_generator_move_model_gpu(): - """Test GPU transfer after CPU load.""" - -def test_generator_generate_returns_mesh_with_voxel(): - """Test generation returns correct state structure.""" - -def test_generator_all_pipeline_types(): - """Test 512, 1024, 1024_cascade, 1536_cascade.""" - -def test_generator_render_preview_video(): - """Test PBR video rendering with HDRI.""" -``` - -**Implementation**: - -```python -class TRELLIS2Generator: - def __init__(self, pipeline_factory=None): - self._pipeline_factory = pipeline_factory or _default_pipeline_factory - self.pipeline = None - self.envmap = None - - def load_model_cpu(self): - """@modal.enter(snap=True) - Load to CPU for snapshot.""" - self.pipeline = self._pipeline_factory(MODEL_NAME) - # Don't load envmap here - needs GPU - - def move_model_gpu(self): - """@modal.enter(snap=False) - Move to GPU after restore.""" - self.pipeline.cuda() - self._load_envmap() - - def _load_envmap(self): - """Load bundled HDRI for rendering.""" - import cv2 - from trellis2.renderers import EnvMap - hdri = cv2.imread('/opt/TRELLIS.2/assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED) - hdri = cv2.cvtColor(hdri, cv2.COLOR_BGR2RGB) - self.envmap = EnvMap(torch.tensor(hdri, dtype=torch.float32, device='cuda')) - - def generate(self, image, seed, pipeline_type, - ss_params=None, shape_params=None, tex_params=None): - """Generate MeshWithVoxel from image.""" - mesh = self.pipeline.run( - image, - seed=seed, - pipeline_type=pipeline_type, - sparse_structure_sampler_params=ss_params or {}, - shape_slat_sampler_params=shape_params or {}, - tex_slat_sampler_params=tex_params or {}, - )[0] - mesh.simplify(16777216) # nvdiffrast limit - return mesh - - def render_preview_video(self, mesh, num_frames=120, fps=15): - """Render PBR video with environment lighting.""" - from trellis2.utils import render_utils - frames = render_utils.render_video(mesh, envmap=self.envmap, - num_frames=num_frames) - # Encode to MP4 bytes - ... -``` - ---- - -### Phase 3: State Serialization (state.py) - -**Goal**: Pack/unpack MeshWithVoxel for client-side storage. - -**Test First (TDD)**: -```python -# tests/test_state.py -def test_pack_state_contains_all_fields(): - """Verify packed state has vertices, faces, attrs, coords, etc.""" - -def test_unpack_state_reconstructs_mesh(): - """Verify unpacked mesh matches original.""" - -def test_state_roundtrip_preserves_data(): - """Pack -> unpack -> pack produces identical state.""" - -def test_state_is_json_serializable_after_numpy(): - """Verify numpy arrays are used (not tensors).""" -``` - -**Implementation**: - -```python -def pack_state(mesh: MeshWithVoxel) -> dict: - """Pack MeshWithVoxel into serializable dict.""" - return { - "vertices": mesh.vertices.cpu().numpy(), - "faces": mesh.faces.cpu().numpy(), - "attrs": mesh.attrs.cpu().numpy(), - "coords": mesh.coords.cpu().numpy(), - "voxel_size": float(mesh.voxel_size), - "voxel_shape": list(mesh.voxel_shape), - "origin": mesh.origin.cpu().tolist(), - "layout": {k: [v.start, v.stop] for k, v in mesh.layout.items()}, - } - -def unpack_state(state: dict) -> MeshWithVoxel: - """Reconstruct MeshWithVoxel from packed state.""" - layout = {k: slice(v[0], v[1]) for k, v in state["layout"].items()} - return MeshWithVoxel( - vertices=torch.tensor(state["vertices"], device="cuda"), - faces=torch.tensor(state["faces"], device="cuda"), - origin=state["origin"], - voxel_size=state["voxel_size"], - coords=torch.tensor(state["coords"], device="cuda"), - attrs=torch.tensor(state["attrs"], device="cuda"), - voxel_shape=torch.Size(state["voxel_shape"]), - layout=layout, - ) -``` - ---- - -### Phase 4: Service Endpoints (service.py) - -**Goal**: Modal web endpoints for generate and extract_glb. - -**Test First (TDD)**: -```python -# tests/test_service.py -def test_generate_endpoint_validates_api_key(): - """401 on missing/invalid key.""" - -def test_generate_endpoint_validates_image(): - """400 on invalid image payload.""" - -def test_generate_endpoint_returns_state_and_video(): - """Successful generation returns expected structure.""" - -def test_extract_glb_endpoint_returns_glb(): - """GLB extraction returns valid GLB bytes.""" - -def test_health_endpoint_no_auth_required(): - """Health check works without API key.""" -``` - -**Implementation**: - -```python -@app.cls( - image=trellis2_image, - gpu="A100-80GB", - volumes={ - HF_CACHE_PATH: hf_cache_volume, - API_KEYS_PATH: api_keys_volume, - }, - enable_memory_snapshot=True, - container_idle_timeout=300, -) -class TRELLIS2Service: - @modal.enter(snap=True) - def load_model_cpu(self): - self.generator = TRELLIS2Generator() - self.generator.load_model_cpu() - - @modal.enter(snap=False) - def load_model_gpu(self): - self.generator.move_model_gpu() - - @modal.web_endpoint(method="POST") - def generate(self, request: Request) -> dict: - """Generate 3D from image.""" - # Auth check - # Validate request - # Generate mesh - # Render video - # Pack state - # Return base64 state + video - - @modal.web_endpoint(method="POST") - def extract_glb(self, request: Request) -> dict: - """Extract GLB from state.""" - # Auth check - # Unpack state - # Call o_voxel.postprocess.to_glb() - # Return base64 GLB - - @modal.web_endpoint(method="GET") - def health(self) -> dict: - return {"status": "ok", "service": "trellis2-api"} -``` - -**API Request/Response**: - -``` -POST /generate -Request: -{ - "image": "", - "seed": 42, - "pipeline_type": "1024_cascade", // 512, 1024, 1024_cascade, 1536_cascade - - // Optional advanced params - "ss_sampling_steps": 12, - "ss_guidance_strength": 7.5, - "ss_guidance_rescale": 0.7, - "ss_rescale_t": 5.0, - "shape_slat_sampling_steps": 12, - "shape_slat_guidance_strength": 7.5, - "shape_slat_guidance_rescale": 0.5, - "shape_slat_rescale_t": 3.0, - "tex_slat_sampling_steps": 12, - "tex_slat_guidance_strength": 1.0, - "tex_slat_guidance_rescale": 0.0, - "tex_slat_rescale_t": 3.0 -} - -Response: -{ - "state": "", - "video": "" -} -``` - -``` -POST /extract_glb -Request: -{ - "state": "", - "decimation_target": 500000, - "texture_size": 2048, - "remesh": true, - "remesh_band": 1.0, - "remesh_project": 0.0 -} - -Response: -{ - "glb": "" -} -``` - ---- - -### Phase 5: Client (api.py, app.py) - -**Goal**: Gradio client that talks to Modal service. - -**Test First (TDD)**: -```python -# tests/test_api_client.py -def test_client_generate_sends_correct_request(): - """Verify request format matches API spec.""" - -def test_client_handles_cold_start_timeout(): - """Retry logic for slow cold starts.""" - -def test_client_decompresses_state(): - """LZ4 decompression works correctly.""" -``` - -**Implementation**: - -Adapt v1's client with updated: -- Parameter names for TRELLIS.2 -- Pipeline type selector (512/1024/1536) -- Advanced parameter accordions for each stage -- Updated preview display for PBR video - ---- - -### Phase 6: Auth & Compression (reuse from v1) - -**Goal**: Copy and adapt auth.py and compression.py from v1. - -These modules are largely unchanged: -- `auth.py`: API key validation, rate limiting, quota management -- `compression.py`: LZ4 compression for state transfer - -Minor updates: -- Update app name references -- Update volume names - ---- - -### Phase 7: Documentation & CI - -**Goal**: Docs and GitHub Actions for testing. - -Files to create: -- `trellis2_modal/docs/MODAL_INTEGRATION.md` -- `trellis2_modal/docs/OPERATIONS_RUNBOOK.md` -- `.github/workflows/modal-ci.yml` - ---- - -## File Structure - -``` -TRELLIS.2/ -├── trellis2/ # Existing - unchanged -├── o-voxel/ # Existing - unchanged -├── trellis2_modal/ # NEW -│ ├── __init__.py -│ ├── service/ -│ │ ├── __init__.py -│ │ ├── config.py # GPU type, model name, defaults -│ │ ├── image.py # Modal image definition -│ │ ├── generator.py # TRELLIS2Generator class -│ │ ├── service.py # Modal endpoints -│ │ ├── state.py # Pack/unpack MeshWithVoxel -│ │ ├── auth.py # API key management (from v1) -│ │ └── streaming.py # SSE utilities (from v1) -│ ├── client/ -│ │ ├── __init__.py -│ │ ├── app.py # Gradio UI -│ │ ├── api.py # HTTP client -│ │ ├── compression.py # LZ4 (from v1) -│ │ └── requirements.txt -│ ├── tests/ -│ │ ├── __init__.py -│ │ ├── conftest.py # Mock fixtures -│ │ ├── test_generator.py -│ │ ├── test_state.py -│ │ ├── test_api_client.py -│ │ ├── test_auth.py -│ │ ├── test_compression.py -│ │ ├── test_service.py -│ │ └── test_integration.py -│ └── docs/ -│ ├── MODAL_INTEGRATION.md -│ └── OPERATIONS_RUNBOOK.md -├── .github/ -│ └── workflows/ -│ └── modal-ci.yml -└── PLAN.md # This file -``` - ---- - -## Dependencies Comparison - -### TRELLIS v1 (CUDA 11.8) -``` -torch==2.4.0+cu118 -xformers==0.0.27.post2 -spconv-cu118 -nvdiffrast (pinned commit) -diffoctreerast (pinned commit) -diff-gaussian-rasterization (mip-splatting) -``` - -### TRELLIS.2 (CUDA 12.4) -``` -torch==2.6.0+cu124 -flash-attn==2.7.3 -nvdiffrast==0.4.0 -nvdiffrec (renderutils branch) -cumesh (JeffreyXiang/CuMesh) -flex_gemm (JeffreyXiang/FlexGEMM) -o_voxel (bundled) -``` - ---- - -## Estimated Effort - -| Phase | Description | Effort | -|-------|-------------|--------| -| 1 | Modal Image | 4-6 hours (CUDA builds are finicky) | -| 2 | Generator | 2-3 hours | -| 3 | State Serialization | 1-2 hours | -| 4 | Service Endpoints | 3-4 hours | -| 5 | Client | 2-3 hours | -| 6 | Auth & Compression | 1 hour (mostly copy) | -| 7 | Docs & CI | 2 hours | - -**Total**: ~16-21 hours (2-3 days) - ---- - -## Risk Mitigation - -| Risk | Mitigation | -|------|------------| -| CUDA extension build failures | Pin exact commits, test builds early | -| Memory issues on A100-80GB | low_vram mode is built-in to TRELLIS.2 | -| Cold start too slow | GPU memory snapshots (proven in v1) | -| HDRI file too large for image | Compress or use smaller resolution | -| State too large after compression | Already using LZ4, can tune if needed | - ---- - -## Verification Checklist - -After each phase, verify: - -- [ ] All tests pass locally (mocked, no GPU) -- [ ] `modal run` works for smoke tests -- [ ] `modal deploy` succeeds -- [ ] Health endpoint responds -- [ ] Generate endpoint produces valid output -- [ ] Extract GLB produces downloadable file -- [ ] Cold start time acceptable (<30s with snapshot) -- [ ] Warm request latency acceptable (<60s for 1024) - ---- - -## Next Steps - -1. Review and approve this plan -2. Start Phase 1: Modal Image Definition -3. Test image builds before proceeding -4. Continue through phases sequentially diff --git a/README.md b/README.md index f55e315..f24b79a 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,25 @@ Then, you can access the demo at the address shown in the terminal. Will be released soon. Please stay tuned! +## ☁️ Cloud Deployment (Modal) + +This fork adds support for deploying TRELLIS.2 to [Modal](https://modal.com/)'s serverless GPU infrastructure. Run image-to-3D generation without a local GPU: + +- **No local GPU needed** - runs on cloud A100s +- **Pay-per-use** - ~$0.05 per generation +- **Web UI or API** - access from any device + +```bash +# Deploy to Modal +modal deploy -m trellis2_modal.service.service + +# Run the local client +python -m trellis2_modal.client.app +``` + +See **[trellis2_modal/README.md](trellis2_modal/README.md)** for setup instructions. + + ## 🧩 Related Packages TRELLIS.2 is built upon several specialized high-performance packages developed by our team: diff --git a/trellis2_modal/README.md b/trellis2_modal/README.md new file mode 100644 index 0000000..1ed8d9d --- /dev/null +++ b/trellis2_modal/README.md @@ -0,0 +1,133 @@ +# TRELLIS.2 on Modal + +Deploy TRELLIS.2 to Modal's serverless GPU infrastructure. Generate 3D assets from images without needing a local GPU. + +**Why use this?** +- No local GPU required - runs on cloud A100s +- Pay only for what you use (~$0.05 per generation) +- Access via web UI or API from any device + +## Prerequisites + +Before starting, you need accounts and tokens from three services: + +| Service | What you need | Where to get it | +|---------|--------------|-----------------| +| **Modal** | Account + A100 GPU quota | [modal.com](https://modal.com/) | +| **HuggingFace** | Account + accept 3 model licenses | See below | +| **Modal Proxy Auth** | Token ID + Secret | Modal dashboard | + +### HuggingFace Model Access + +TRELLIS.2 requires three gated models. Click each link and accept the license: + +1. [microsoft/TRELLIS.2-4B](https://huggingface.co/microsoft/TRELLIS.2-4B) - Main model +2. [facebook/dinov3-vitl16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vitl16-pretrain-lvd1689m) - Vision encoder +3. [briaai/RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) - Background removal *(non-commercial)* + +Then create a token at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) (Read access). + +## Quick Start + +### 1. Install Modal CLI + +```bash +pip install modal +modal token new +``` + +### 2. Add HuggingFace Token to Modal + +```bash +modal secret create huggingface HF_TOKEN=hf_your_token_here +``` + +### 3. Deploy the Service + +```bash +cd TRELLIS.2 +pip install -r trellis2_modal/requirements-deploy.txt +modal deploy -m trellis2_modal.service.service +``` + +Note the endpoint URLs printed (e.g., `https://yourname--trellis2-3d-...generate.modal.run`). + +### 4. Create API Credentials + +1. Go to [modal.com/settings/proxy-auth-tokens](https://modal.com/settings/proxy-auth-tokens) +2. Click "New Token" +3. Save both the Token ID and Secret (secret shown only once!) + +Store them locally: +```bash +cat > ~/.trellis2_modal_secrets.json << 'EOF' +{ + "modal_key": "wk-xxxxx", + "modal_secret": "ws-xxxxx" +} +EOF +``` + +### 5. Run the Web UI + +```bash +pip install -r trellis2_modal/client/requirements.txt +export TRELLIS2_API_URL=https://yourname--trellis2-3d-trellis2service-generate.modal.run +python -m trellis2_modal.client.app +``` + +Open http://localhost:7860, upload an image, and click Generate. + +> **First request takes 2-3 minutes** (cold start). Subsequent requests take 10-90 seconds depending on resolution. + +## API Usage + +```python +from trellis2_modal.client import TRELLIS2APIClient + +client = TRELLIS2APIClient( + base_url="https://yourname--trellis2-3d-trellis2service-generate.modal.run" +) + +# Generate 3D from image +result = client.generate(image_path="input.png", pipeline_type="1024_cascade") + +# Extract GLB mesh +client.extract_glb(state=result["state"], output_path="output.glb") +``` + +Credentials are loaded automatically from `~/.trellis2_modal_secrets.json` or environment variables. + +## Pipeline Options + +| Pipeline | Resolution | Time | Use Case | +|----------|------------|------|----------| +| `512` | 512³ | ~10s | Quick preview | +| `1024_cascade` | 1024³ | ~30s | Recommended | +| `1536_cascade` | 1536³ | ~90s | Maximum quality | + +## Cost + +| What | GPU Time | Cost | +|------|----------|------| +| Cold start | 2-3 min | ~$0.10 | +| Generation (1024) | 30s | ~$0.02 | +| GLB extraction | 30-60s | ~$0.03 | + +Containers stay warm for 5 minutes between requests (no cold start cost for sequential generations). + +## Troubleshooting + +**403 from HuggingFace**: Accept all three model licenses (links above). + +**Cold start every time**: Containers scale to zero after 5 minutes idle. Use a cron job to ping the health endpoint every 4 minutes to keep warm. + +**CUDA OOM**: Use a lower resolution pipeline or reduce GLB texture_size. + +## Full Documentation + +See [docs/MODAL_INTEGRATION.md](docs/MODAL_INTEGRATION.md) for: +- Detailed configuration options +- Operations runbook +- Cold start optimization strategies +- Complete API reference diff --git a/trellis2_modal/__init__.py b/trellis2_modal/__init__.py new file mode 100644 index 0000000..bb3b02d --- /dev/null +++ b/trellis2_modal/__init__.py @@ -0,0 +1 @@ +"""TRELLIS.2 Modal deployment package.""" diff --git a/trellis2_modal/client/__init__.py b/trellis2_modal/client/__init__.py new file mode 100644 index 0000000..251783d --- /dev/null +++ b/trellis2_modal/client/__init__.py @@ -0,0 +1,22 @@ +""" +TRELLIS.2 Modal Client Package. + +Provides API client and Gradio UI for interacting with the Modal service. + +Usage: + # As library + from trellis2_modal.client import TRELLIS2APIClient, APIError + + # As application + python -m trellis2_modal.client.app +""" + +from .api import APIError, TRELLIS2APIClient +from .compression import compress_state, decompress_state + +__all__ = [ + "TRELLIS2APIClient", + "APIError", + "compress_state", + "decompress_state", +] diff --git a/trellis2_modal/client/api.py b/trellis2_modal/client/api.py new file mode 100644 index 0000000..9d8cf28 --- /dev/null +++ b/trellis2_modal/client/api.py @@ -0,0 +1,332 @@ +""" +API client for the Modal TRELLIS.2 service. + +Credentials can be provided via: +1. Constructor parameters: modal_key and modal_secret +2. Environment variables: TRELLIS2_MODAL_KEY and TRELLIS2_MODAL_SECRET +3. Secrets file: ~/.trellis2_modal_secrets.json +""" + +from __future__ import annotations + +import base64 +import json +import os +import time +from pathlib import Path +from typing import Any + +import requests + + +# Default path for local secrets file +SECRETS_FILE_PATH = Path.home() / ".trellis2_modal_secrets.json" + + +def load_credentials( + modal_key: str | None = None, + modal_secret: str | None = None, + secrets_file: Path | None = None, +) -> tuple[str, str]: + """ + Load Modal Proxy Auth credentials from various sources. + + Priority order: + 1. Explicit parameters (modal_key, modal_secret) + 2. Environment variables (TRELLIS2_MODAL_KEY, TRELLIS2_MODAL_SECRET) + 3. Local secrets file (~/.trellis2_modal_secrets.json) + + Args: + modal_key: Optional explicit Modal-Key value + modal_secret: Optional explicit Modal-Secret value + secrets_file: Optional path to secrets file (default: ~/.trellis2_modal_secrets.json) + + Returns: + Tuple of (modal_key, modal_secret) + + Raises: + ValueError: If credentials cannot be found from any source + """ + # 1. Check explicit parameters + if modal_key and modal_secret: + return modal_key, modal_secret + + # 2. Check environment variables + env_key = os.environ.get("TRELLIS2_MODAL_KEY") + env_secret = os.environ.get("TRELLIS2_MODAL_SECRET") + if env_key and env_secret: + return env_key, env_secret + + # 3. Check local secrets file + secrets_path = secrets_file or SECRETS_FILE_PATH + if secrets_path.exists(): + try: + data = json.loads(secrets_path.read_text()) + file_key = data.get("modal_key") + file_secret = data.get("modal_secret") + if file_key and file_secret: + return file_key, file_secret + except (json.JSONDecodeError, KeyError) as e: + raise ValueError(f"Invalid secrets file format: {e}") from e + + raise ValueError( + "Modal credentials not found. Provide credentials via:\n" + " 1. Constructor parameters: modal_key and modal_secret\n" + " 2. Environment variables: TRELLIS2_MODAL_KEY and TRELLIS2_MODAL_SECRET\n" + " 3. Secrets file: ~/.trellis2_modal_secrets.json with keys 'modal_key' and 'modal_secret'\n" + "\n" + "Create Proxy Auth Tokens in the Modal dashboard at /settings/proxy-auth-tokens" + ) + + +class TRELLIS2APIClient: + """Client for the Modal-deployed TRELLIS.2 service.""" + + # Default timeout for requests (10 minutes for cascade pipelines) + DEFAULT_TIMEOUT = 600 + + # Default cold start threshold in seconds + DEFAULT_COLD_START_THRESHOLD = 30.0 + + # Retry configuration + MAX_RETRIES = 2 + INITIAL_BACKOFF = 1.0 + + def __init__( + self, + base_url: str, + modal_key: str | None = None, + modal_secret: str | None = None, + ) -> None: + """ + Initialize the API client. + + Args: + base_url: Modal endpoint URL. For Modal subdomain routing, provide + the generate endpoint URL (e.g., https://...generate.modal.run). + The extract_glb URL will be derived automatically. + modal_key: Modal Proxy Auth key (or use env/secrets file) + modal_secret: Modal Proxy Auth secret (or use env/secrets file) + + Raises: + ValueError: If credentials cannot be loaded from any source + """ + base_url = base_url.rstrip("/") + + # Detect Modal subdomain routing pattern and derive endpoint URLs + if "-generate.modal.run" in base_url: + # Modal subdomain pattern: derive extract_glb URL + self.generate_url = base_url + self.extract_glb_url = base_url.replace("-generate.modal.run", "-extract-glb.modal.run") + elif "-extract-glb.modal.run" in base_url: + # User provided extract_glb URL, derive generate URL + self.extract_glb_url = base_url + self.generate_url = base_url.replace("-extract-glb.modal.run", "-generate.modal.run") + else: + # Path-based routing: append /generate and /extract_glb + self.generate_url = f"{base_url}/generate" + self.extract_glb_url = f"{base_url}/extract_glb" + + self.modal_key, self.modal_secret = load_credentials(modal_key, modal_secret) + self._last_request_elapsed: float | None = None + + @property + def last_request_elapsed(self) -> float | None: + """Return elapsed time of last request in seconds, or None if no request made.""" + return self._last_request_elapsed + + def was_cold_start(self, threshold: float | None = None) -> bool: + """Check if the last request was likely a cold start (>threshold seconds).""" + if self._last_request_elapsed is None: + return False + if threshold is None: + threshold = self.DEFAULT_COLD_START_THRESHOLD + return self._last_request_elapsed > threshold + + def _headers(self) -> dict[str, str]: + """Return headers for API requests with Modal Proxy Auth.""" + return { + "Modal-Key": self.modal_key, + "Modal-Secret": self.modal_secret, + "Content-Type": "application/json", + } + + def _request_with_retry( + self, + method: str, + url: str, + **kwargs: Any, + ) -> requests.Response: + """ + Make HTTP request with retry logic for transient failures. + + Retries on ConnectionError and Timeout with exponential backoff. + """ + last_exception = None + for attempt in range(self.MAX_RETRIES + 1): + try: + return requests.request(method, url, **kwargs) + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + ) as e: + last_exception = e + if attempt < self.MAX_RETRIES: + delay = self.INITIAL_BACKOFF * (2**attempt) + time.sleep(delay) + + raise last_exception # type: ignore[misc] + + def _check_error(self, response_data: dict[str, Any]) -> None: + """Check response for errors and raise APIError if found.""" + if "error" in response_data: + error = response_data["error"] + raise APIError(error["code"], error["message"]) + + def generate( + self, + image_path: str, + seed: int = 42, + pipeline_type: str = "1024_cascade", + ss_sampling_steps: int = 12, + ss_guidance_strength: float = 7.5, + shape_slat_sampling_steps: int = 12, + shape_slat_guidance_strength: float = 7.5, + tex_slat_sampling_steps: int = 12, + tex_slat_guidance_strength: float = 1.0, + ) -> dict[str, Any]: + """ + Generate 3D from image via the Modal service. + + Args: + image_path: Path to input image + seed: Random seed for reproducibility + pipeline_type: One of "512", "1024", "1024_cascade", "1536_cascade" + ss_sampling_steps: Sparse structure sampling steps + ss_guidance_strength: Sparse structure guidance strength + shape_slat_sampling_steps: Shape SLAT sampling steps + shape_slat_guidance_strength: Shape SLAT guidance strength + tex_slat_sampling_steps: Texture SLAT sampling steps + tex_slat_guidance_strength: Texture SLAT guidance strength + + Returns: + Dict with 'state' (base64 compressed) and 'video' (base64) + + Raises: + APIError: If the request fails + FileNotFoundError: If image_path doesn't exist + """ + path = Path(image_path) + if not path.exists(): + raise FileNotFoundError(f"Image not found: {image_path}") + + image_b64 = base64.b64encode(path.read_bytes()).decode("utf-8") + + payload = { + "image": image_b64, + "seed": seed, + "pipeline_type": pipeline_type, + "ss_sampling_steps": ss_sampling_steps, + "ss_guidance_strength": ss_guidance_strength, + "shape_slat_sampling_steps": shape_slat_sampling_steps, + "shape_slat_guidance_strength": shape_slat_guidance_strength, + "tex_slat_sampling_steps": tex_slat_sampling_steps, + "tex_slat_guidance_strength": tex_slat_guidance_strength, + } + + start = time.perf_counter() + response = self._request_with_retry( + "POST", + self.generate_url, + headers=self._headers(), + json=payload, + timeout=self.DEFAULT_TIMEOUT, + ) + self._last_request_elapsed = time.perf_counter() - start + + result = response.json() + self._check_error(result) + + return result + + def extract_glb( + self, + state: str, + output_path: str, + decimation_target: int = 1000000, + texture_size: int = 4096, + remesh: bool = True, + remesh_band: float = 1.0, + remesh_project: float = 0.0, + ) -> str: + """ + Extract GLB mesh from generation state. + + Args: + state: Base64 compressed state string from generate() + output_path: Path to write GLB file + decimation_target: Target vertex count (default 1M) + texture_size: Texture resolution (512, 1024, 2048, 4096) + remesh: Whether to remesh for cleaner topology + remesh_band: Remesh band size + remesh_project: Remesh projection factor + + Returns: + Path to written GLB file + + Raises: + APIError: If the request fails + """ + payload = { + "state": state, + "decimation_target": decimation_target, + "texture_size": texture_size, + "remesh": remesh, + "remesh_band": remesh_band, + "remesh_project": remesh_project, + } + + start = time.perf_counter() + response = self._request_with_retry( + "POST", + self.extract_glb_url, + headers=self._headers(), + json=payload, + timeout=self.DEFAULT_TIMEOUT, + ) + self._last_request_elapsed = time.perf_counter() - start + + result = response.json() + self._check_error(result) + + glb_bytes = base64.b64decode(result["glb"]) + Path(output_path).write_bytes(glb_bytes) + + return output_path + + def health_check(self) -> bool: + """ + Check if the generate endpoint is reachable. + + Returns: + True if endpoint is reachable, False otherwise + """ + try: + requests.request( + "HEAD", + self.generate_url, + headers=self._headers(), + timeout=10, + ) + return True + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + return False + + +class APIError(Exception): + """Exception raised for API errors.""" + + def __init__(self, code: str, message: str) -> None: + self.code = code + self.message = message + super().__init__(f"{code}: {message}") diff --git a/trellis2_modal/client/app.py b/trellis2_modal/client/app.py new file mode 100644 index 0000000..0bbe6a0 --- /dev/null +++ b/trellis2_modal/client/app.py @@ -0,0 +1,379 @@ +""" +Gradio-based local client for TRELLIS.2 3D generation via Modal. + +Provides a user interface for: +- Image upload and preprocessing +- Resolution and generation parameter controls +- Video preview of generated 3D model +- GLB extraction and download +""" + +from __future__ import annotations + +import base64 +import os +import tempfile + +import gradio as gr +import numpy as np + +from .api import APIError, TRELLIS2APIClient + + +MAX_SEED = np.iinfo(np.int32).max + + +def get_client() -> TRELLIS2APIClient | None: + """ + Create API client from environment variables or secrets file. + + Requires TRELLIS2_API_URL to be set. + Credentials are loaded from (in order of priority): + 1. Environment variables: TRELLIS2_MODAL_KEY and TRELLIS2_MODAL_SECRET + 2. Secrets file: ~/.trellis2_modal_secrets.json + + Returns: + TRELLIS2APIClient if configured, None otherwise + """ + api_url = os.environ.get("TRELLIS2_API_URL") + if not api_url: + return None + + try: + # Credentials loaded automatically from env vars or secrets file + return TRELLIS2APIClient(base_url=api_url) + except ValueError: + # No credentials found + return None + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """Get the random seed.""" + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +def generate_3d( + image: str | None, + seed: int, + resolution: str, + ss_guidance_strength: float, + ss_sampling_steps: int, + shape_slat_guidance_strength: float, + shape_slat_sampling_steps: int, + tex_slat_guidance_strength: float, + tex_slat_sampling_steps: int, +) -> tuple[str | None, str | None, str]: + """ + Generate 3D from uploaded image via Modal API. + + Args: + image: Path to uploaded image + seed: Random seed + resolution: "512", "1024", or "1536" + ss_*: Sparse structure params + shape_slat_*: Shape SLAT params + tex_slat_*: Texture SLAT params + + Returns: + Tuple of (video_path, state_b64, status_message) + """ + if image is None: + return None, None, "⚠️ Please upload an image first." + + client = get_client() + if client is None: + return None, None, "❌ Error: TRELLIS2_API_URL not set or credentials missing." + + # Map resolution to pipeline_type + pipeline_type = { + "512": "512", + "1024": "1024_cascade", + "1536": "1536_cascade", + }.get(resolution, "1024_cascade") + + try: + result = client.generate( + image_path=image, + seed=seed, + pipeline_type=pipeline_type, + ss_sampling_steps=ss_sampling_steps, + ss_guidance_strength=ss_guidance_strength, + shape_slat_sampling_steps=shape_slat_sampling_steps, + shape_slat_guidance_strength=shape_slat_guidance_strength, + tex_slat_sampling_steps=tex_slat_sampling_steps, + tex_slat_guidance_strength=tex_slat_guidance_strength, + ) + + # Decode video and save to temp file + video_bytes = base64.b64decode(result["video"]) + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(video_bytes) + video_path = f.name + + # Build status message + elapsed = client.last_request_elapsed or 0 + cold_start_msg = " (cold start)" if client.was_cold_start() else "" + status = f"✅ Generated in {elapsed:.1f}s{cold_start_msg}" + + return video_path, result["state"], status + + except APIError as e: + return None, None, f"❌ API Error: {e.code} - {e.message}" + except FileNotFoundError as e: + return None, None, f"❌ {e}" + except Exception as e: + return None, None, f"❌ Error: {e}" + + +def extract_glb( + state: str | None, + decimation_target: int, + texture_size: int, +) -> tuple[str | None, str | None, str]: + """ + Extract GLB mesh from generation state via Modal API. + + Args: + state: Base64 compressed state from generate() + decimation_target: Target vertex count + texture_size: Texture resolution + + Returns: + Tuple of (glb_path, download_path, status_message) + """ + if state is None: + return None, None, "⚠️ Please generate a 3D model first." + + client = get_client() + if client is None: + return None, None, "❌ Error: TRELLIS2_API_URL not set or credentials missing." + + try: + # Create temp file for output + with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as f: + output_path = f.name + + client.extract_glb( + state=state, + output_path=output_path, + decimation_target=decimation_target, + texture_size=texture_size, + remesh=True, + remesh_band=1.0, + remesh_project=0.0, + ) + + elapsed = client.last_request_elapsed or 0 + return output_path, output_path, f"✅ Extracted in {elapsed:.1f}s" + + except APIError as e: + return None, None, f"❌ API Error: {e.code} - {e.message}" + except Exception as e: + return None, None, f"❌ Error: {e}" + + +def create_interface() -> gr.Blocks: + """Create and configure the Gradio interface.""" + with gr.Blocks(title="TRELLIS.2 3D Generator") as demo: + gr.Markdown(""" + ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) via Modal + + * Upload an image and click **Generate** to create a 3D asset. + * Click **Extract GLB** to export and download the GLB file. + * Resolution affects quality and generation time: + - **512**: ~10s, fast preview + - **1024**: ~30s, good quality (recommended) + - **1536**: ~90s, maximum quality + """) + + # State for storing generation result + generation_state = gr.State(value=None) + + with gr.Row(): + # Left column: Input and controls + with gr.Column(scale=1, min_width=360): + image_input = gr.Image( + label="Image Prompt", + type="filepath", + image_mode="RGBA", + height=400, + ) + + resolution = gr.Radio( + ["512", "1024", "1536"], + label="Resolution", + value="1024", + ) + + with gr.Row(): + seed = gr.Slider( + 0, + MAX_SEED, + label="Seed", + value=0, + step=1, + ) + randomize_seed = gr.Checkbox( + label="Randomize", + value=True, + ) + + generate_btn = gr.Button("Generate", variant="primary") + gen_status = gr.Textbox( + label="Status", + interactive=False, + show_label=False, + ) + + with gr.Accordion(label="Advanced Settings", open=False): + gr.Markdown("**Stage 1: Sparse Structure Generation**") + with gr.Row(): + ss_guidance_strength = gr.Slider( + 1.0, + 10.0, + label="Guidance Strength", + value=7.5, + step=0.1, + ) + ss_sampling_steps = gr.Slider( + 1, + 50, + label="Sampling Steps", + value=12, + step=1, + ) + + gr.Markdown("**Stage 2: Shape Generation**") + with gr.Row(): + shape_slat_guidance_strength = gr.Slider( + 1.0, + 10.0, + label="Guidance Strength", + value=7.5, + step=0.1, + ) + shape_slat_sampling_steps = gr.Slider( + 1, + 50, + label="Sampling Steps", + value=12, + step=1, + ) + + gr.Markdown("**Stage 3: Material Generation**") + with gr.Row(): + tex_slat_guidance_strength = gr.Slider( + 1.0, + 10.0, + label="Guidance Strength", + value=1.0, + step=0.1, + ) + tex_slat_sampling_steps = gr.Slider( + 1, + 50, + label="Sampling Steps", + value=12, + step=1, + ) + + # Right column: Output + with gr.Column(scale=2): + with gr.Tabs(): + with gr.TabItem("Preview"): + video_output = gr.Video( + label="3D Preview", + height=500, + autoplay=True, + loop=True, + ) + + with gr.TabItem("Extract GLB"): + with gr.Row(): + decimation_target = gr.Slider( + 100000, + 1000000, + label="Decimation Target (vertices)", + value=500000, + step=10000, + ) + texture_size = gr.Dropdown( + [1024, 2048, 4096], + label="Texture Size", + value=2048, + ) + + extract_btn = gr.Button("Extract GLB") + extract_status = gr.Textbox( + label="Status", + interactive=False, + show_label=False, + ) + + glb_output = gr.Model3D( + label="3D Model", + height=400, + clear_color=(0.25, 0.25, 0.25, 1.0), + ) + download_btn = gr.DownloadButton( + label="Download GLB", + visible=True, + ) + + # Wire up events + generate_btn.click( + fn=get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + fn=generate_3d, + inputs=[ + image_input, + seed, + resolution, + ss_guidance_strength, + ss_sampling_steps, + shape_slat_guidance_strength, + shape_slat_sampling_steps, + tex_slat_guidance_strength, + tex_slat_sampling_steps, + ], + outputs=[video_output, generation_state, gen_status], + ) + + extract_btn.click( + fn=extract_glb, + inputs=[generation_state, decimation_target, texture_size], + outputs=[glb_output, download_btn, extract_status], + ) + + return demo + + +def main() -> None: + """Entry point for the client application.""" + api_url = os.environ.get("TRELLIS2_API_URL") + + if not api_url: + print("=" * 60) + print("TRELLIS.2 Modal Client") + print("=" * 60) + print() + print("⚠️ Warning: TRELLIS2_API_URL not set.") + print() + print("Set the endpoint URL:") + print(" export TRELLIS2_API_URL=https://your-app--generate.modal.run") + print() + print("Credentials are loaded from (in order):") + print(" 1. TRELLIS2_MODAL_KEY and TRELLIS2_MODAL_SECRET env vars") + print(" 2. ~/.trellis2_modal_secrets.json") + print() + print("Starting anyway (will show error when generating)...") + print() + + demo = create_interface() + demo.launch() + + +if __name__ == "__main__": + main() diff --git a/trellis2_modal/client/compression.py b/trellis2_modal/client/compression.py new file mode 100644 index 0000000..0d1844c --- /dev/null +++ b/trellis2_modal/client/compression.py @@ -0,0 +1,73 @@ +""" +State compression utilities for TRELLIS.2 client-server communication. + +Handles LZ4 compression of MeshWithVoxel state for efficient transfer +between server and client. The compressed state is stored client-side +in gr.State to maintain stateless server design. + +Security note: Uses pickle for serialization. Only decompress data from +trusted sources (our own server). Data flow is: server creates state -> +compress -> client stores -> client sends back -> server decompresses. +""" + +from __future__ import annotations + +import pickle +from typing import Any + +import lz4.frame + + +def compress_state(state: dict[str, Any]) -> bytes: + """ + Compress generation state for storage/transfer. + + Uses pickle serialization + LZ4 frame compression. + LZ4 frame format is streaming-friendly and self-delimiting. + + Args: + state: Dictionary containing numpy arrays and metadata + + Returns: + LZ4-compressed bytes + """ + serialized = pickle.dumps(state, protocol=pickle.HIGHEST_PROTOCOL) + return lz4.frame.compress(serialized) + + +def decompress_state(data: bytes) -> dict[str, Any]: + """ + Decompress generation state. + + Args: + data: LZ4-compressed bytes from compress_state + + Returns: + Original state dictionary with numpy arrays + """ + decompressed = lz4.frame.decompress(data) + return pickle.loads(decompressed) + + +def encode_image(image_path: str) -> str: + """ + Encode an image file to base64 for API transmission. + + Args: + image_path: Path to image file + + Returns: + Base64-encoded image string + """ + raise NotImplementedError("Image encoding not yet implemented") + + +def decode_file(data: str, output_path: str) -> None: + """ + Decode base64 file data and write to disk. + + Args: + data: Base64-encoded file contents + output_path: Path to write decoded file + """ + raise NotImplementedError("File decoding not yet implemented") diff --git a/trellis2_modal/client/requirements.txt b/trellis2_modal/client/requirements.txt new file mode 100644 index 0000000..3ef172c --- /dev/null +++ b/trellis2_modal/client/requirements.txt @@ -0,0 +1,7 @@ +# TRELLIS.2 Modal Client Dependencies +# Install with: pip install -r requirements.txt + +gradio>=4.0 +requests +numpy +lz4 diff --git a/trellis2_modal/docs/MODAL_INTEGRATION.md b/trellis2_modal/docs/MODAL_INTEGRATION.md new file mode 100644 index 0000000..42bec9e --- /dev/null +++ b/trellis2_modal/docs/MODAL_INTEGRATION.md @@ -0,0 +1,286 @@ +# TRELLIS.2 Modal Integration + +This guide explains how to deploy and use TRELLIS.2 on Modal's serverless GPU infrastructure. + +## Prerequisites + +- Python 3.11+ +- [Modal account](https://modal.com/) with GPU quota (A100-80GB) +- Modal CLI installed and authenticated +- [HuggingFace account](https://huggingface.co/) with access to gated models + +```bash +pip install -r trellis2_modal/requirements-deploy.txt +modal token new +``` + +### HuggingFace Setup + +TRELLIS.2 uses gated models that require HuggingFace authentication: + +1. **Accept model licenses** on HuggingFace (click "Agree and access repository"): + - [microsoft/TRELLIS.2-4B](https://huggingface.co/microsoft/TRELLIS.2-4B) - Main TRELLIS.2 model + - [facebook/dinov3-vitl16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vitl16-pretrain-lvd1689m) - DINOv3 vision encoder (used for image conditioning) + - [briaai/RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) - Background removal model (CC BY-NC 4.0, **non-commercial use only**) + + **Important**: All three models must be accepted. The DINOv3 and RMBG-2.0 models are + loaded automatically by the TRELLIS.2 pipeline. + +2. **Create a HuggingFace token** at https://huggingface.co/settings/tokens + - Use "Read" access (write not needed) + +3. **Verify access** (should return model info, not 403): + ```bash + # Check all three gated models + curl -H "Authorization: Bearer hf_your_token" \ + https://huggingface.co/api/models/microsoft/TRELLIS.2-4B + curl -H "Authorization: Bearer hf_your_token" \ + https://huggingface.co/api/models/facebook/dinov3-vitl16-pretrain-lvd1689m + curl -H "Authorization: Bearer hf_your_token" \ + https://huggingface.co/api/models/briaai/RMBG-2.0 + ``` + +4. **Create Modal secret**: + ```bash + modal secret create huggingface HF_TOKEN=hf_your_token_here + ``` + +## Project Structure + +``` +trellis2_modal/ +├── requirements-deploy.txt # Dependencies for modal deploy +├── service/ # Modal service (runs on cloud) +│ ├── config.py # Configuration constants +│ ├── image.py # Container image definition +│ ├── generator.py # TRELLIS2Generator wrapper +│ ├── service.py # Web endpoints +│ ├── state.py # MeshWithVoxel serialization +│ └── auth.py # API key authentication +├── client/ # Local client (runs on your machine) +│ ├── api.py # HTTP client +│ ├── app.py # Gradio UI +│ ├── compression.py # State compression +│ └── requirements.txt # Client dependencies +├── tests/ # Unit tests +└── docs/ # This documentation +``` + +## Deployment + +### 1. Verify the Image Builds + +First, verify all CUDA extensions compile correctly: + +```bash +modal run trellis2_modal/service/image.py +``` + +This runs a verification that checks: +- CUDA availability and version +- flash_attn, nvdiffrast, cumesh, flex_gemm extensions +- TRELLIS.2 pipeline imports +- HDRI files exist + +Expected output: +``` +✓ Image verification PASSED +``` + +### 2. Deploy the Service + +Deploy to Modal with memory snapshots enabled: + +```bash +modal deploy -m trellis2_modal.service.service +``` + +Note the endpoint URLs printed: +``` +├── Created web endpoint for TRELLIS2Service.health => https://your-app--health.modal.run +├── Created web endpoint for TRELLIS2Service.generate => https://your-app--generate.modal.run +└── Created web endpoint for TRELLIS2Service.extract_glb => https://your-app--extract-glb.modal.run +``` + +### 3. Create Proxy Auth Tokens + +Authentication uses Modal Proxy Auth Tokens. Create tokens in the Modal dashboard: + +1. Go to https://modal.com/settings/proxy-auth-tokens +2. Click "New Token" +3. Copy both the **Token ID** (Modal-Key) and **Token Secret** (Modal-Secret) + - ⚠️ The secret is only shown once - save it immediately! + +Store credentials locally (choose one method): + +**Option A: Environment variables** +```bash +export TRELLIS2_MODAL_KEY="wk-xxxxx" +export TRELLIS2_MODAL_SECRET="ws-xxxxx" +``` + +**Option B: Secrets file** (`~/.trellis2_modal_secrets.json`) +```json +{ + "modal_key": "wk-xxxxx", + "modal_secret": "ws-xxxxx" +} +``` + +The file is in your home directory and is gitignored. + +### 4. Test the Deployment + +Test the health endpoint (no auth required): + +```bash +curl https://your-app--health.modal.run +# {"status": "ok", "service": "trellis2-api"} +``` + +Test generation (auth required): + +```bash +curl -X POST https://your-app--generate.modal.run \ + -H "Modal-Key: wk-xxxxx" \ + -H "Modal-Secret: ws-xxxxx" \ + -H "Content-Type: application/json" \ + -d '{"image": "", "seed": 42}' +``` + +## Running the Client + +### Environment Setup + +The client loads credentials from (in priority order): +1. Environment variables: `TRELLIS2_MODAL_KEY` and `TRELLIS2_MODAL_SECRET` +2. Secrets file: `~/.trellis2_modal_secrets.json` + +```bash +export TRELLIS2_API_URL=https://your-app--generate.modal.run +export TRELLIS2_MODAL_KEY=wk-xxxxx +export TRELLIS2_MODAL_SECRET=ws-xxxxx +``` + +The client auto-detects Modal subdomain routing from the URL pattern and derives the `extract_glb` endpoint automatically. + +### Launch Gradio UI + +```bash +pip install -r trellis2_modal/client/requirements.txt +python -m trellis2_modal.client.app +``` + +Open http://localhost:7860 in your browser. + +### Programmatic Usage + +```python +from trellis2_modal.client import TRELLIS2APIClient + +# Credentials loaded automatically from env vars or ~/.trellis2_modal_secrets.json +client = TRELLIS2APIClient(base_url="https://your-app.modal.run") + +# Or provide credentials explicitly: +# client = TRELLIS2APIClient( +# base_url="https://your-app.modal.run", +# modal_key="wk-xxxxx", +# modal_secret="ws-xxxxx", +# ) + +# Generate 3D from image +result = client.generate( + image_path="input.png", + seed=42, + pipeline_type="1024_cascade", # Options: 512, 1024, 1024_cascade, 1536_cascade +) + +# result contains: +# - state: compressed state for GLB extraction +# - video: base64-encoded MP4 preview + +# Extract GLB +client.extract_glb( + state=result["state"], + output_path="output.glb", + decimation_target=500000, + texture_size=2048, +) +``` + +## Configuration + +### GPU Type + +Default: `A100-80GB` (configured in `config.py`) + +The A100-80GB provides: +- Sufficient VRAM for all pipeline types +- Good balance of cost and performance + +### Pipeline Types + +| Type | Resolution | Time (A100) | Use Case | +|------|------------|-------------|----------| +| `512` | 512³ | ~10s | Fast preview | +| `1024` | 1024³ | ~25s | Good quality | +| `1024_cascade` | 512→1024 | ~30s | High quality (recommended) | +| `1536_cascade` | 512→1536 | ~90s | Maximum quality | + +### GLB Extraction Presets + +| Preset | decimation_target | texture_size | File Size | +|--------|-------------------|--------------|-----------| +| Quality | 1,000,000 | 4096 | ~30MB | +| Balanced | 500,000 | 2048 | ~15MB | +| Fast | 100,000 | 1024 | ~5MB | + +## Cold Starts + +TRELLIS.2 has a cold start time of approximately **2-2.5 minutes** on A100-80GB: +- Container startup: ~20s +- Model loading: ~2 minutes (loading from HuggingFace cache + GPU initialization) + +### Memory Snapshots: Not Recommended + +Both CPU and GPU Memory Snapshots were tested and found **not effective** for TRELLIS.2: + +| Configuration | Cold Start | Model Load | Improvement | +|---------------|------------|------------|-------------| +| No snapshots (baseline) | ~143s | ~124s | — | +| GPU Memory Snapshots | ~146s | ~131s | <5% | + +**Root cause**: TRELLIS.2's dependencies (flex_gemm, Triton, flash_attn) require +re-initialization after snapshot restoration, negating the benefits. The model +weights may be restored, but GPU state for these libraries is not preserved. + +### Recommended Cold Start Strategies + +1. **Keep containers warm** (default, recommended): + - `scaledown_window=300` keeps containers alive for 5 minutes after last request + - Cost: ~$0.10/hour for idle A100-80GB container + +2. **Periodic warm-up ping** (for consistent response times): + - Ping the service every 4 minutes to prevent scale-down + - Use a simple health check: `curl https://your-app--health.modal.run` + +3. **min_containers=1** (for always-on availability): + - Keeps one container always running + - Cost: ~$2.40/hour (continuous A100-80GB) + +## Troubleshooting + +See [OPERATIONS_RUNBOOK.md](./OPERATIONS_RUNBOOK.md) for common issues and solutions. + +## Cost Estimation + +| Operation | GPU Time | Approx Cost | +|-----------|----------|-------------| +| Cold start | 2-3 min | ~$0.10-0.15 | +| Generate (1024_cascade) | 30s | ~$0.02 | +| Extract GLB | 30-60s | ~$0.02-0.04 | + +Tips to reduce costs: +- Use `container_idle_timeout=300` to keep containers warm between requests +- Use lower resolution pipelines for previews +- Use smaller GLB presets when file size matters diff --git a/trellis2_modal/docs/OPERATIONS_RUNBOOK.md b/trellis2_modal/docs/OPERATIONS_RUNBOOK.md new file mode 100644 index 0000000..ac4884f --- /dev/null +++ b/trellis2_modal/docs/OPERATIONS_RUNBOOK.md @@ -0,0 +1,282 @@ +# TRELLIS.2 Modal Operations Runbook + +This runbook covers monitoring, troubleshooting, and operational procedures for the TRELLIS.2 Modal deployment. + +## Monitoring + +### Health Check + +The `/health` endpoint provides basic liveness status: + +```bash +curl https://your-app--health.modal.run +``` + +Expected response: +```json +{"status": "ok", "service": "trellis2-api"} +``` + +### Detailed Health (via Modal method) + +For detailed diagnostics, use the `health_check` method: + +```python +# In Modal shell or via remote call +service = TRELLIS2Service() +result = service.health_check.remote() +print(result) +``` + +Response includes: +- `status`: "healthy" or "unhealthy" +- `gpu`: GPU device name +- `vram_allocated_gb`: Current VRAM usage +- `vram_total_gb`: Total VRAM +- `load_time_seconds`: Model load time + +### Modal Dashboard + +Monitor at: https://modal.com/apps + +Key metrics: +- Container count and utilization +- Request latency (P50, P95, P99) +- Error rate +- Cold start frequency + +## Common Issues + +### 1. CUDA Out of Memory (OOM) + +**Symptoms:** +```json +{"error": {"code": "cuda_oom", "message": "GPU out of memory..."}} +``` + +**Causes:** +- High-resolution pipeline (1536_cascade) with complex image +- Large texture size during GLB extraction +- Memory fragmentation from repeated requests + +**Solutions:** +1. Use lower resolution pipeline (`1024` instead of `1536_cascade`) +2. Reduce texture size (2048 instead of 4096) +3. The service automatically calls `torch.cuda.empty_cache()` after operations +4. If persistent, restart the container (Modal will auto-restart) + +### 2. Slow Cold Starts + +**Symptoms:** +- First request takes 2-2.5 minutes + +**Causes:** +- Container starting from cold (no warm instances) +- Model loading (~124s): TRELLIS.2 model + DINOv3 + RMBG-2.0 +- GPU initialization for flex_gemm, Triton, flash_attn + +**Solutions:** +1. Always use `modal deploy` for production +2. Use `scaledown_window=300` to keep containers warm (5 min idle timeout) +3. Implement a warm-up cron job that pings every 4 minutes: + ```bash + # Example cron: */4 * * * * curl -s https://your-app--health.modal.run + ``` + +**Note on Memory Snapshots:** + +GPU Memory Snapshots were tested (2025-12-21) and found **not effective**: +- Baseline cold start: ~143s +- With GPU snapshots: ~146s (no improvement) + +Root cause: flex_gemm, Triton, and flash_attn require re-initialization after +snapshot restoration, negating any benefits from preserved model weights. + +### 3. Authentication Failed + +**Symptoms:** +- HTTP 401 Unauthorized response +- Request rejected before reaching endpoint + +**Causes:** +- Missing `Modal-Key` or `Modal-Secret` headers +- Invalid or expired Proxy Auth Token +- Token doesn't have access to this app + +**Solutions:** +1. Verify headers are set correctly: + ```bash + curl -X POST https://your-app--generate.modal.run \ + -H "Modal-Key: wk-xxxxx" \ + -H "Modal-Secret: ws-xxxxx" \ + -H "Content-Type: application/json" \ + -d '{"image": "...", "seed": 42}' + ``` +2. Create new token at https://modal.com/settings/proxy-auth-tokens +3. Update client credentials: + - Environment variables: `TRELLIS2_MODAL_KEY` and `TRELLIS2_MODAL_SECRET` + - Or secrets file: `~/.trellis2_modal_secrets.json` + +### 4. Image Too Large + +**Symptoms:** +```json +{"error": {"code": "validation_error", "message": "Image size exceeds limit..."}} +``` + +**Causes:** +- Input image exceeds 10MB (base64 decoded size) +- Image dimensions exceed 4096x4096 + +**Solutions:** +1. Resize image before sending +2. Use JPEG instead of PNG for smaller size +3. The pipeline will preprocess/resize automatically, but the upload must be within limits + +### 5. Connection Timeouts + +**Symptoms:** +- Client throws `requests.exceptions.Timeout` +- Requests take >10 minutes + +**Causes:** +- Cold start + generation time exceeds timeout +- Modal infrastructure issues + +**Solutions:** +1. Client has 10-minute timeout by default +2. For 1536_cascade, total time can exceed 2 minutes +3. Check Modal status page for outages + +### 6. GLB Extraction Fails + +**Symptoms:** +```json +{"error": {"code": "extraction_error", "message": "GLB extraction failed..."}} +``` + +**Causes:** +- Invalid state data +- State from incompatible version +- VRAM exhaustion during remeshing + +**Solutions:** +1. Regenerate the 3D model +2. Try smaller decimation_target +3. Disable remesh: `remesh=false` + +## Operational Procedures + +### Deploying Updates + +```bash +# 1. Run tests locally +pytest trellis2_modal/tests/ + +# 2. Verify image builds +modal run trellis2_modal/service/image.py + +# 3. Deploy +modal deploy -m trellis2_modal.service.service + +# 4. Verify health +curl https://your-app--health.modal.run +``` + +### Rolling Back + +Modal keeps previous deployments. To rollback: + +1. Go to Modal dashboard → Apps → Your app +2. Find previous deployment in history +3. Click "Redeploy" on that version + +### Scaling + +Modal automatically scales based on request load. To adjust: + +```python +# In service.py, modify @app.cls parameters: +@app.cls( + ... + concurrency_limit=5, # Max concurrent requests per container + container_idle_timeout=300, # Seconds before scaling down + ... +) +``` + +### Managing Authentication + +Authentication uses Modal Proxy Auth Tokens. Manage tokens in the Modal dashboard: + +1. **Create new token**: https://modal.com/settings/proxy-auth-tokens → "New Token" +2. **Revoke token**: Click the token in the dashboard → "Delete" + +Client credentials are stored locally (not in the repository): +- Environment variables: `TRELLIS2_MODAL_KEY` and `TRELLIS2_MODAL_SECRET` +- Or secrets file: `~/.trellis2_modal_secrets.json` + +```json +{ + "modal_key": "wk-xxxxx", + "modal_secret": "ws-xxxxx" +} +``` + +### Viewing Logs + +```bash +# View recent logs +modal app logs trellis2-3d + +# Stream logs +modal app logs trellis2-3d --follow +``` + +### Checking Volume Contents + +```bash +# List files in HuggingFace cache volume +modal volume ls trellis2-hf-cache /cache/huggingface/ +``` + +## Performance Benchmarks + +Expected times on A100-80GB (warm container): + +| Operation | Time | +|-----------|------| +| Generate (512) | 8-12s | +| Generate (1024) | 20-30s | +| Generate (1024_cascade) | 25-35s | +| Generate (1536_cascade) | 80-100s | +| Extract GLB (quality) | 30-60s | +| Extract GLB (fast) | 10-20s | +| Video render | 5-10s | + +Cold start adds 90-120 seconds. + +## Alerts to Set Up + +Recommended Modal/external monitoring alerts: + +1. **Error rate > 5%** - Indicates systematic issue +2. **P95 latency > 5 minutes** - Cold starts or overload +3. **OOM errors > 10/hour** - Memory pressure +4. **Health check failures** - Service down + +## Emergency Procedures + +### Service Completely Down + +1. Check Modal status: https://status.modal.com/ +2. Try redeploying: `modal deploy -m trellis2_modal.service.service` +3. Check logs for errors: `modal app logs trellis2-3d` +4. If HuggingFace is down, cached models may still work + +### Data Recovery + +Model weights are cached in the `trellis2-hf-cache` volume. + +Authentication uses Modal Proxy Auth Tokens (managed in Modal dashboard), so there's no +authentication data to backup from volumes. diff --git a/trellis2_modal/requirements-deploy.txt b/trellis2_modal/requirements-deploy.txt new file mode 100644 index 0000000..8ce36ad --- /dev/null +++ b/trellis2_modal/requirements-deploy.txt @@ -0,0 +1,11 @@ +# Dependencies for deploying TRELLIS.2 to Modal +# Install with: pip install -r requirements-deploy.txt +# +# These are minimal dependencies needed locally to run: +# modal deploy -m trellis2_modal.service.service +# +# Note: The Modal container image has its own dependencies +# defined in service/image.py + +modal +fastapi diff --git a/trellis2_modal/service/__init__.py b/trellis2_modal/service/__init__.py new file mode 100644 index 0000000..aeb75f0 --- /dev/null +++ b/trellis2_modal/service/__init__.py @@ -0,0 +1 @@ +"""Service modules for Modal deployment.""" diff --git a/trellis2_modal/service/auth.py b/trellis2_modal/service/auth.py new file mode 100644 index 0000000..55bfc8d --- /dev/null +++ b/trellis2_modal/service/auth.py @@ -0,0 +1,374 @@ +""" +API key authentication for the Modal TRELLIS.2 service. + +Handles validation of API keys stored in a Modal Volume. Keys are +stored in JSON format with metadata including creation date, quotas, +and usage tracking. +""" + +from __future__ import annotations + +from typing import Any + + +def validate_api_key( + key: str | None, + keys_data: dict[str, Any], +) -> tuple[bool, dict[str, Any] | None]: + """ + Validate an API key against the provided keys data. + + Pure function that checks format and looks up the key in the provided + dictionary. Does not perform I/O - caller is responsible for loading + keys_data. + + Args: + key: API key string (format: sk_xxx_...) or None + keys_data: Dictionary with 'keys' field containing key->info mappings + + Returns: + Tuple of (is_valid, key_info) where: + - is_valid: True if key is valid and active + - key_info: Dict with key metadata if valid, None otherwise + """ + if not key: + return False, None + + if not key.startswith("sk_"): + return False, None + + keys = keys_data.get("keys", {}) + if key not in keys: + return False, None + + key_info = keys[key] + if not key_info.get("active", True): + return False, None + + return True, key_info + + +def load_keys(path: str) -> dict[str, Any]: + """ + Load API keys from a JSON file. + + If the file doesn't exist, returns an empty keys structure. + This simplifies initial deployment - no need to pre-create keys.json. + + Args: + path: Path to the keys.json file + + Returns: + Dictionary with 'version' and 'keys' fields + + Raises: + ValueError: If file exists but contains invalid JSON + """ + import json + from pathlib import Path + + file_path = Path(path) + if not file_path.exists(): + return {"version": 1, "keys": {}} + + try: + return json.loads(file_path.read_text()) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in {path}: {e}") from e + + +def save_keys(path: str, data: dict[str, Any]) -> None: + """ + Save API keys to a JSON file. + + Creates parent directories if they don't exist. + + Args: + path: Path to the keys.json file + data: Dictionary to save (should have 'version' and 'keys' fields) + """ + import json + from pathlib import Path + + file_path = Path(path) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(json.dumps(data, indent=2)) + + +def check_auth( + api_key: str | None, + keys_path: str = "/data/keys.json", +) -> tuple[bool, dict[str, Any] | None]: + """ + Check API key authentication. + + Loads keys from file and validates the provided key. + Convenience function that combines load_keys and validate_api_key. + + Args: + api_key: API key from request header (may be None) + keys_path: Path to keys.json file + + Returns: + Tuple of (is_valid, key_info) where: + - is_valid: True if key is valid and active + - key_info: Dict with key metadata if valid, None otherwise + """ + keys_data = load_keys(keys_path) + return validate_api_key(api_key, keys_data) + + +def mask_api_key(key: str | None) -> str: + """ + Mask an API key for safe logging. + + Shows first 7 chars (sk_xxx_) and last 4 chars, masks the middle. + + Args: + key: API key to mask + + Returns: + Masked key like 'sk_dev_...abcd' or '' if None + """ + if not key: + return "" + if len(key) <= 11: + return key[:7] + "..." + return key[:7] + "..." + key[-4:] + + +def increment_usage(key: str, keys_data: dict[str, Any]) -> None: + """ + Increment the usage counter for an API key. + + Pure function that mutates keys_data in place. Caller is responsible + for persisting changes with save_keys if needed. + + Silently ignores unknown keys or missing 'keys' field. + + Args: + key: API key that was used + keys_data: Dictionary with 'keys' field containing key->info mappings + """ + from datetime import datetime, timezone + + keys = keys_data.get("keys", {}) + if key not in keys: + return + + key_info = keys[key] + key_info["usage_count"] = key_info.get("usage_count", 0) + 1 + key_info["last_used"] = datetime.now(timezone.utc).isoformat() + + +def check_rate_limit(key_info: dict[str, Any]) -> bool: + """ + Check if a key is within its rate limit quota. + + Args: + key_info: Dictionary with key metadata (usage_count, quota) + + Returns: + True if within quota (or no quota set), False if quota exceeded. + """ + quota = key_info.get("quota") + if quota is None: + return True # No quota = unlimited + + usage_count = key_info.get("usage_count", 0) + return usage_count < quota + + +def generate_api_key(prefix: str = "dev") -> str: + """ + Generate a new API key. + + Keys are in format: sk_{prefix}_{random_hex} + Example: sk_dev_a1b2c3d4e5f6g7h8 + + Args: + prefix: Key prefix, typically "dev" or "live" + + Returns: + New API key string + """ + import secrets + + random_part = secrets.token_hex(16) # 32 hex chars + return f"sk_{prefix}_{random_part}" + + +def add_key( + keys_data: dict[str, Any], + name: str, + prefix: str = "dev", + quota: int | None = None, +) -> str: + """ + Add a new API key to keys_data. + + Args: + keys_data: Keys data dict to modify + name: Human-readable name for the key + prefix: Key prefix, typically "dev" or "live" + quota: Optional request quota limit + + Returns: + The generated API key + """ + from datetime import datetime, timezone + + key = generate_api_key(prefix) + + if "keys" not in keys_data: + keys_data["keys"] = {} + + key_info: dict[str, Any] = { + "name": name, + "active": True, + "created": datetime.now(timezone.utc).isoformat(), + "usage_count": 0, + } + if quota is not None: + key_info["quota"] = quota + + keys_data["keys"][key] = key_info + return key + + +def revoke_key(keys_data: dict[str, Any], key: str) -> bool: + """ + Revoke an API key. + + Args: + keys_data: Keys data dict to modify + key: API key to revoke + + Returns: + True if key was found and revoked, False otherwise + """ + keys = keys_data.get("keys", {}) + if key not in keys: + return False + + keys[key]["active"] = False + return True + + +# CLI interface +def _cli_add_key(args: Any, keys_path: str) -> None: + """CLI handler for add-key command.""" + keys_data = load_keys(keys_path) + key = add_key( + keys_data, + name=args.name, + prefix=args.prefix, + quota=args.quota, + ) + save_keys(keys_path, keys_data) + print(f"Created API key: {key}") + print(f"Name: {args.name}") + if args.quota: + print(f"Quota: {args.quota}") + + +def _cli_list_keys(args: Any, keys_path: str) -> None: + """CLI handler for list-keys command.""" + keys_data = load_keys(keys_path) + keys = keys_data.get("keys", {}) + + if not keys: + print("No API keys found.") + return + + for key, info in keys.items(): + status = "active" if info.get("active", True) else "revoked" + usage = info.get("usage_count", 0) + print( + f"{mask_api_key(key):20} {info.get('name', 'unnamed'):15} {status:8} usage: {usage}" + ) + + +def _cli_revoke_key(args: Any, keys_path: str) -> None: + """CLI handler for revoke-key command.""" + keys_data = load_keys(keys_path) + if revoke_key(keys_data, args.key): + save_keys(keys_path, keys_data) + print(f"Revoked key: {mask_api_key(args.key)}") + else: + print(f"Key not found: {mask_api_key(args.key)}") + + +def _cli_usage(args: Any, keys_path: str) -> None: + """CLI handler for usage command.""" + keys_data = load_keys(keys_path) + keys = keys_data.get("keys", {}) + + if args.key not in keys: + print(f"Key not found: {mask_api_key(args.key)}") + return + + info = keys[args.key] + print(f"Key: {mask_api_key(args.key)}") + print(f"Name: {info.get('name', 'unnamed')}") + print(f"Status: {'active' if info.get('active', True) else 'revoked'}") + print(f"Created: {info.get('created', 'unknown')}") + print(f"Usage count: {info.get('usage_count', 0)}") + print(f"Last used: {info.get('last_used', 'never')}") + if "quota" in info: + print(f"Quota: {info['quota']}") + + +def main() -> None: + """CLI entrypoint for API key management.""" + import argparse + + from .config import API_KEYS_PATH + + parser = argparse.ArgumentParser( + description="Manage API keys for TRELLIS.2 Modal service" + ) + parser.add_argument( + "--keys-file", + default=API_KEYS_PATH, + help=f"Path to keys.json file (default: {API_KEYS_PATH})", + ) + + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # add-key command + add_parser = subparsers.add_parser("add-key", help="Add a new API key") + add_parser.add_argument( + "--name", required=True, help="Human-readable name for the key" + ) + add_parser.add_argument("--prefix", default="dev", help="Key prefix (dev or live)") + add_parser.add_argument("--quota", type=int, help="Optional request quota limit") + + # list-keys command + subparsers.add_parser("list-keys", help="List all API keys") + + # revoke-key command + revoke_parser = subparsers.add_parser("revoke-key", help="Revoke an API key") + revoke_parser.add_argument("key", help="API key to revoke") + + # usage command + usage_parser = subparsers.add_parser("usage", help="Show usage for an API key") + usage_parser.add_argument("key", help="API key to show usage for") + + args = parser.parse_args() + keys_path = args.keys_file + + if args.command == "add-key": + _cli_add_key(args, keys_path) + elif args.command == "list-keys": + _cli_list_keys(args, keys_path) + elif args.command == "revoke-key": + _cli_revoke_key(args, keys_path) + elif args.command == "usage": + _cli_usage(args, keys_path) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/trellis2_modal/service/config.py b/trellis2_modal/service/config.py new file mode 100644 index 0000000..04e7359 --- /dev/null +++ b/trellis2_modal/service/config.py @@ -0,0 +1,54 @@ +""" +Configuration constants for the Modal TRELLIS.2 service. + +Contains deployment settings, resource limits, and path configurations +that are shared across the service modules. + +NOTE: Some constants are duplicated in image.py because Modal copies +that file in isolation during image builds (imports fail). The test +test_config_consistency.py verifies these stay in sync. +""" + +# Modal resource configuration +GPU_TYPE = "A100-80GB" +CONTAINER_IDLE_TIMEOUT = 300 # seconds + +# GPU snapshots don't help: flex_gemm/Triton reinit negates benefits (~143s vs ~146s) +GPU_MEMORY_SNAPSHOT = False + +# Volume mount paths +HF_CACHE_PATH = "/cache/huggingface" +API_KEYS_PATH = "/data/keys.json" + +# Build/runtime paths +TRELLIS2_PATH = "/opt/TRELLIS.2" + +# Model configuration +MODEL_NAME = "microsoft/TRELLIS.2-4B" + +# Input validation limits +MAX_IMAGE_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB (decoded binary size) +MAX_IMAGE_DIMENSION = 4096 # Max width or height in pixels + +# Generation defaults (from pipeline.json, can be overridden per-request) +DEFAULT_SEED = 42 +DEFAULT_PIPELINE_TYPE = "1024_cascade" + +# Sparse structure sampler defaults +DEFAULT_SS_SAMPLING_STEPS = 12 +DEFAULT_SS_GUIDANCE_STRENGTH = 7.5 + +# Shape SLAT sampler defaults +DEFAULT_SHAPE_SLAT_SAMPLING_STEPS = 12 +DEFAULT_SHAPE_SLAT_GUIDANCE_STRENGTH = 7.5 + +# Texture SLAT sampler defaults +DEFAULT_TEX_SLAT_SAMPLING_STEPS = 12 +DEFAULT_TEX_SLAT_GUIDANCE_STRENGTH = 1.0 + +# GLB extraction defaults (from official example.py) +DEFAULT_DECIMATION_TARGET = 1_000_000 +DEFAULT_TEXTURE_SIZE = 4096 +DEFAULT_REMESH = True +DEFAULT_REMESH_BAND = 1.0 +DEFAULT_REMESH_PROJECT = 0.0 diff --git a/trellis2_modal/service/generator.py b/trellis2_modal/service/generator.py new file mode 100644 index 0000000..7b415cd --- /dev/null +++ b/trellis2_modal/service/generator.py @@ -0,0 +1,270 @@ +""" +TRELLIS.2 Generator class for Modal deployment. + +Handles model loading, image-to-3D generation, video rendering, and GLB extraction. +Imports are deferred because trellis2.* requires GPU and PYTHONPATH setup. +""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Callable, Protocol + +if TYPE_CHECKING: + from PIL import Image + +from .config import MODEL_NAME + +# nvdiffrast has a maximum face count for rasterization +NVDIFFRAST_MAX_FACES = 16_777_216 # 2^24 + + +class PipelineProtocol(Protocol): + """Protocol defining the pipeline interface for type checking and mocking.""" + + models: dict[str, Any] + + def to(self, device: Any) -> None: ... + + def cuda(self) -> None: ... + + def preprocess_image(self, image: "Image") -> "Image": ... + + def run( + self, + image: "Image", + seed: int, + sparse_structure_sampler_params: dict[str, Any], + shape_slat_sampler_params: dict[str, Any], + tex_slat_sampler_params: dict[str, Any], + pipeline_type: str, + max_num_tokens: int, + **kwargs: Any, + ) -> list[Any]: ... + + +def _default_pipeline_factory(model_name: str) -> PipelineProtocol: + """Default factory that loads the real TRELLIS.2 pipeline.""" + from trellis2.pipelines import Trellis2ImageTo3DPipeline + + return Trellis2ImageTo3DPipeline.from_pretrained(model_name) + + +class TRELLIS2Generator: + """GPU-accelerated TRELLIS.2 generator for Modal deployment.""" + + def __init__( + self, + pipeline_factory: Callable[[str], PipelineProtocol] | None = None, + ) -> None: + """Initialize generator state. Model loaded via load_model().""" + self._pipeline_factory = pipeline_factory or _default_pipeline_factory + self.pipeline: PipelineProtocol | None = None + self.envmap: Any = None + self.load_time: float = 0.0 + + @property + def is_ready(self) -> bool: + """Check if generator is fully ready for inference.""" + return self.pipeline is not None and self.envmap is not None + + def load_model(self) -> None: + """Load model and HDRI. Called with @modal.enter().""" + import cv2 + import torch + from trellis2.renderers import EnvMap + + from .config import TRELLIS2_PATH + + start = time.perf_counter() + + # Load pipeline from HuggingFace + try: + self.pipeline = self._pipeline_factory(MODEL_NAME) + except Exception as e: + raise RuntimeError( + f"Failed to load TRELLIS.2 model '{MODEL_NAME}'. " + f"Check HuggingFace cache and network connectivity." + ) from e + + # Move pipeline to GPU + self.pipeline.cuda() + + # Load HDRI for PBR rendering + hdri_path = f"{TRELLIS2_PATH}/assets/hdri/forest.exr" + hdri = cv2.imread(hdri_path, cv2.IMREAD_UNCHANGED) + if hdri is None: + raise RuntimeError(f"Failed to load HDRI from {hdri_path}") + hdri = cv2.cvtColor(hdri, cv2.COLOR_BGR2RGB) + self.envmap = EnvMap(torch.tensor(hdri, dtype=torch.float32, device="cuda")) + + self.load_time = time.perf_counter() - start + + def _cleanup_gpu_memory(self) -> None: + """Clean up GPU memory after operations.""" + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass # Running in test environment without torch + + def generate( + self, + image: "Image", + seed: int = 42, + pipeline_type: str = "1024_cascade", + ss_params: dict[str, Any] | None = None, + shape_params: dict[str, Any] | None = None, + tex_params: dict[str, Any] | None = None, + max_num_tokens: int = 49152, + ) -> dict[str, Any]: + """ + Generate MeshWithVoxel from image. + + Args: + image: PIL Image (will be preprocessed) + seed: Random seed for reproducibility + pipeline_type: One of "512", "1024", "1024_cascade", "1536_cascade" + ss_params: Sparse structure sampler params + shape_params: Shape SLAT sampler params + tex_params: Texture SLAT sampler params + max_num_tokens: Maximum tokens for cascade (controls resolution cap) + + Returns: + Packed state dict ready for serialization + """ + from .state import pack_state + + if self.pipeline is None: + raise RuntimeError("Pipeline not loaded. Ensure load_model() was called.") + + # Preprocess and generate + processed = self.pipeline.preprocess_image(image) + meshes = self.pipeline.run( + processed, + seed=seed, + pipeline_type=pipeline_type, + sparse_structure_sampler_params=ss_params or {}, + shape_slat_sampler_params=shape_params or {}, + tex_slat_sampler_params=tex_params or {}, + max_num_tokens=max_num_tokens, + ) + + mesh = meshes[0] + + # Simplify to nvdiffrast limit + mesh.simplify(NVDIFFRAST_MAX_FACES) + + state = pack_state(mesh) + self._cleanup_gpu_memory() + return state + + def render_preview_video( + self, + state: dict[str, Any], + num_frames: int = 120, + fps: int = 15, + ) -> bytes: + """ + Render PBR preview video from packed state. + + Returns MP4 video as bytes. + """ + if not self.is_ready: + raise RuntimeError( + "Generator not ready. Ensure load_model() was called successfully." + ) + + import imageio + from trellis2.utils import render_utils + + from .state import unpack_state + + mesh = unpack_state(state) + + # Render with environment lighting + raw_result = render_utils.render_video( + mesh, + resolution=512, + num_frames=num_frames, + envmap=self.envmap, + ) + + # Create PBR visualization frames + frames = render_utils.make_pbr_vis_frames(raw_result, resolution=512) + + # Encode to MP4 + video_bytes = imageio.mimwrite( + "", + frames, + format="mp4", + fps=fps, + ) + + self._cleanup_gpu_memory() + return video_bytes + + def extract_glb( + self, + state: dict[str, Any], + decimation_target: int = 1000000, + texture_size: int = 4096, + remesh: bool = True, + remesh_band: float = 1.0, + remesh_project: float = 0.0, + ) -> bytes: + """ + Extract GLB mesh from packed state. + + Args: + state: Packed state from generate() + decimation_target: Target vertex count (default 1M for quality) + texture_size: Texture resolution (default 4096 for quality) + remesh: Whether to remesh for cleaner topology + remesh_band: Remesh band size + remesh_project: Projection factor for remesh + + Returns: + GLB file as bytes + """ + if not self.is_ready: + raise RuntimeError( + "Generator not ready. Ensure load_model() was called successfully." + ) + + import io + + import torch + from o_voxel.postprocess import to_glb + + from .state import unpack_state + + mesh = unpack_state(state) + + # Extract GLB using o_voxel + glb_mesh = to_glb( + vertices=mesh.vertices, + faces=mesh.faces, + attr_volume=mesh.attrs, + coords=mesh.coords, + attr_layout=mesh.layout, + aabb=torch.tensor([[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], device="cuda"), + voxel_size=mesh.voxel_size, + decimation_target=decimation_target, + texture_size=texture_size, + remesh=remesh, + remesh_band=remesh_band, + remesh_project=remesh_project, + verbose=False, + ) + + # Export to bytes + buffer = io.BytesIO() + glb_mesh.export(buffer, file_type="glb") + buffer.seek(0) + glb_bytes = buffer.read() + + self._cleanup_gpu_memory() + return glb_bytes diff --git a/trellis2_modal/service/image.py b/trellis2_modal/service/image.py new file mode 100644 index 0000000..88d7e40 --- /dev/null +++ b/trellis2_modal/service/image.py @@ -0,0 +1,289 @@ +""" +Modal image definition for TRELLIS.2. + +This module defines the container image with all dependencies pre-installed, +including CUDA extensions that are compiled at image build time rather than +runtime. This ensures reproducible builds and eliminates cold start compilation. + +Build notes: +- gpu="T4" in run_commands creates a FRESH build context +- ALL dependencies (torch, wheel, setuptools) must be in the GPU block +- --no-build-isolation needed for packages requiring torch at build time +- clang required for CUDA extension builds +- CUDA 12.4 required for TRELLIS.2 +""" + +import modal + +# Constants duplicated from config.py - Modal copies this file to /root/image.py +# without the trellis2_modal package, so imports fail. Keep in sync manually. +# Verified by tests/test_config_consistency.py +GPU_TYPE = "A100-80GB" +HF_CACHE_PATH = "/cache/huggingface" +MODEL_NAME = "microsoft/TRELLIS.2-4B" + +# Pinned git commits for reproducible builds +# These commits are validated to work with PyTorch 2.6.0 / CUDA 12.4 +PINNED_COMMITS = { + "nvdiffrast": "253ac4fcea7de5f396371124af597e6cc957bfae", # v0.4.0 tag + "nvdiffrec": "b296927cc7fd01c2ac1087c8065c4d7248f72da4", # renderutils branch + "utils3d": "9a4eb15e4021b67b12c460c7057d642626897ec8", # TRELLIS integration + "cumesh": "d8d28794721a3f4984b1b12c24403f546f41d28c", # HEAD 2025-12-20 + "flex_gemm": "8b9afa2d56f667b709ccd761d0bd7aab48bdd7cf", # HEAD 2025-12-20 + "trellis2": "1762f493fe7731a3b7cc6b79ad5da7b015b516c1", # HEAD 2025-12-20 +} + +# Build paths (avoid magic strings) +BUILD_TMP = "/tmp" +TRELLIS2_PATH = "/opt/TRELLIS.2" + +# Modal app for TRELLIS.2 +app = modal.App("trellis2-3d") + +# Modal volumes for persistent storage +hf_cache_volume = modal.Volume.from_name("trellis2-hf-cache", create_if_missing=True) +api_keys_volume = modal.Volume.from_name("trellis2-api-keys", create_if_missing=True) + +# HuggingFace secret for accessing gated models (dinov3, etc.) +# Create with: modal secret create huggingface HF_TOKEN=hf_xxxxx +hf_secret = modal.Secret.from_name("huggingface") + +# System packages required for CUDA extensions and TRELLIS.2 dependencies +SYSTEM_PACKAGES = [ + "git", + "ninja-build", + "cmake", + "build-essential", + "clang", # Required for CUDA extension builds + "libgl1-mesa-glx", # OpenGL for nvdiffrast + "libglib2.0-0", + "libjpeg-dev", + "libpng-dev", + "libgomp1", # OpenMP for parallel processing + "libopenexr-dev", # For HDRI loading with OpenCV +] + +# Core Python dependencies (no GPU needed for install) +CORE_PYTHON_PACKAGES = [ + "pillow", + "imageio", + "imageio-ffmpeg", + "tqdm", + "easydict", + "opencv-python-headless", + "scipy", + "ninja", + "trimesh", + "transformers", + "huggingface-hub", + "safetensors", + "lz4", # State compression + "kornia", # Image processing + "timm", # Vision models +] + +# TRELLIS.2 image with all CUDA extensions pre-compiled +trellis2_image = ( + modal.Image.from_registry( + "nvidia/cuda:12.4.0-devel-ubuntu22.04", + add_python="3.10", + ) + .apt_install(*SYSTEM_PACKAGES) + .pip_install(*CORE_PYTHON_PACKAGES) + # GPU build block - ALL CUDA-related builds in ONE block + # because gpu="T4" creates a fresh build context + .run_commands( + # Install PyTorch with CUDA 12.4 + "pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124", + # Build tools for --no-build-isolation + "pip install wheel setuptools", + # flash-attn build dependencies + "pip install psutil packaging", + # flash-attn (needs torch at build time, use --no-build-isolation) + "pip install flash-attn==2.7.3 --no-build-isolation", + # utils3d (pinned version for TRELLIS.2) + f"pip install git+https://github.com/EasternJournalist/utils3d.git@{PINNED_COMMITS['utils3d']}", + # nvdiffrast (builds from source, pinned) + f"git clone https://github.com/NVlabs/nvdiffrast.git {BUILD_TMP}/nvdiffrast && cd {BUILD_TMP}/nvdiffrast && git checkout {PINNED_COMMITS['nvdiffrast']}", + f"pip install {BUILD_TMP}/nvdiffrast --no-build-isolation", + # nvdiffrec renderutils (needs torch at build time, pinned) + f"git clone https://github.com/JeffreyXiang/nvdiffrec.git {BUILD_TMP}/nvdiffrec && cd {BUILD_TMP}/nvdiffrec && git checkout {PINNED_COMMITS['nvdiffrec']}", + f"pip install {BUILD_TMP}/nvdiffrec --no-build-isolation", + # CuMesh (needs torch at build time, pinned) + f"git clone --recursive https://github.com/JeffreyXiang/CuMesh.git {BUILD_TMP}/CuMesh && cd {BUILD_TMP}/CuMesh && git checkout {PINNED_COMMITS['cumesh']}", + f"pip install {BUILD_TMP}/CuMesh --no-build-isolation", + # FlexGEMM (needs torch/triton at build time, pinned) + f"git clone --recursive https://github.com/JeffreyXiang/FlexGEMM.git {BUILD_TMP}/FlexGEMM && cd {BUILD_TMP}/FlexGEMM && git checkout {PINNED_COMMITS['flex_gemm']}", + f"pip install {BUILD_TMP}/FlexGEMM --no-build-isolation", + gpu="T4", + ) + # Clone TRELLIS.2 repository (with submodules for o-voxel) + .run_commands( + f"git clone --recursive https://github.com/microsoft/TRELLIS.2.git {TRELLIS2_PATH} && cd {TRELLIS2_PATH} && git checkout {PINNED_COMMITS['trellis2']}", + ) + # Build o-voxel from the repo's submodule + .run_commands( + f"pip install {TRELLIS2_PATH}/o-voxel --no-build-isolation", + gpu="T4", + ) + # Pre-download DINOv2 model to avoid runtime download + .run_commands( + "python -c \"import torch; torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg', pretrained=True)\"", + gpu="T4", + ) + # Set environment variables + .env( + { + "ATTN_BACKEND": "flash_attn", + "PYTHONPATH": TRELLIS2_PATH, + "HF_HOME": HF_CACHE_PATH, + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + "OPENCV_IO_ENABLE_OPENEXR": "1", + "TORCH_CUDA_ARCH_LIST": "8.0;8.6;8.9;9.0", + } + ) +) + + +@app.function(image=trellis2_image, gpu="T4", timeout=600) +def verify_image(): + """ + Verify the image is correctly built with all dependencies. + + Run with: modal run trellis2_modal/service/image.py::verify_image + """ + import json + + results = {} + + # Test PyTorch + CUDA + import torch + + results["pytorch_version"] = str(torch.__version__) + results["cuda_version"] = str(torch.version.cuda) + results["cuda_available"] = torch.cuda.is_available() + results["device_name"] = ( + torch.cuda.get_device_name(0) if torch.cuda.is_available() else None + ) + + # Test CUDA compute + if torch.cuda.is_available(): + x = torch.randn(100, 100, device="cuda") + _ = torch.matmul(x, x) + results["cuda_compute_works"] = True + + # Test flash-attn + try: + from flash_attn import flash_attn_func + + results["flash_attn"] = callable(flash_attn_func) + except ImportError as e: + results["flash_attn"] = str(e) + + # Test CUDA extensions + try: + import nvdiffrast.torch as dr + + results["nvdiffrast"] = hasattr(dr, "rasterize") + except ImportError as e: + results["nvdiffrast"] = str(e) + + try: + import cumesh + + results["cumesh"] = hasattr(cumesh, "CuMesh") + except ImportError as e: + results["cumesh"] = str(e) + + try: + import flex_gemm # noqa: F401 + + results["flex_gemm"] = True + except ImportError as e: + results["flex_gemm"] = str(e) + + # Test o_voxel + try: + from o_voxel.postprocess import to_glb + + results["o_voxel"] = callable(to_glb) + except ImportError as e: + results["o_voxel"] = str(e) + + # Test utils3d + try: + import utils3d # noqa: F401 + + results["utils3d"] = True + except ImportError as e: + results["utils3d"] = str(e) + + # Test DINOv2 is cached + import os + + dinov2_hub = os.path.expanduser("~/.cache/torch/hub/facebookresearch_dinov2_main") + results["dinov2_cached"] = os.path.exists(dinov2_hub) + + # Test TRELLIS.2 imports + try: + from trellis2.pipelines import Trellis2ImageTo3DPipeline # noqa: F401 + + results["trellis2_pipeline"] = True + except ImportError as e: + results["trellis2_pipeline"] = str(e) + + try: + from trellis2.utils import render_utils + + results["trellis2_render_utils"] = hasattr(render_utils, "render_video") + except ImportError as e: + results["trellis2_render_utils"] = str(e) + + # Test HDRI file exists (TRELLIS2_PATH is /opt/TRELLIS.2) + hdri_path = f"{TRELLIS2_PATH}/assets/hdri/forest.exr" + results["hdri_exists"] = os.path.exists(hdri_path) + + print("=== Image Verification Results ===") + for k, v in sorted(results.items()): + status = "✓" if v is True else "✗" if v is False else "?" + print(f" {status} {k}: {v}") + + return json.dumps(results) + + +@app.local_entrypoint() +def main(): + """Run verification.""" + import json + + print("\n" + "=" * 60) + print("TRELLIS.2 Modal Image Verification") + print("=" * 60 + "\n") + + result = verify_image.remote() + data = json.loads(result) + + # Check critical components + critical = [ + "cuda_available", + "flash_attn", + "nvdiffrast", + "cumesh", + "flex_gemm", + "o_voxel", + "trellis2_pipeline", + "trellis2_render_utils", + "dinov2_cached", + "hdri_exists", + ] + + all_passed = all(data.get(k, False) is True for k in critical) + + if all_passed: + print("\n✓ Image verification PASSED") + else: + print("\n✗ Image verification FAILED") + for k in critical: + v = data.get(k, "MISSING") + status = "✓" if v is True else "✗" + print(f" {status} {k}: {v}") diff --git a/trellis2_modal/service/service.py b/trellis2_modal/service/service.py new file mode 100644 index 0000000..c052ada --- /dev/null +++ b/trellis2_modal/service/service.py @@ -0,0 +1,685 @@ +""" +Modal service for TRELLIS.2 3D generation. + +This module defines the Modal classes for 3D generation: +- HealthCheckService: CPU-only health endpoint (no GPU needed) +- TRELLIS2Service: GPU-accelerated generation endpoints + +Usage: + modal run -m trellis2_modal.service.service # Test health check + modal deploy -m trellis2_modal.service.service # Deploy with snapshots + +Endpoints: + GET /health - Health check (no auth, CPU-only, no GPU allocation) + POST /generate - Image → 3D state + video (Modal Proxy Auth required) + POST /extract_glb - State → GLB mesh (Modal Proxy Auth required) + +Authentication: + POST endpoints use Modal Proxy Auth Tokens. Clients must provide: + - Modal-Key: + - Modal-Secret: + Create tokens in the Modal dashboard. + +Scaling: + max_containers=1 limits GPU container count to prevent quota exhaustion. + Retry policy handles transient OOM failures with exponential backoff. +""" + +from __future__ import annotations + +import logging +import secrets +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import modal +from fastapi import Request + +from .auth import mask_api_key # Only needed for logging +from .config import ( + API_KEYS_PATH, + DEFAULT_DECIMATION_TARGET, + DEFAULT_PIPELINE_TYPE, + DEFAULT_REMESH, + DEFAULT_REMESH_BAND, + DEFAULT_REMESH_PROJECT, + DEFAULT_SEED, + DEFAULT_SHAPE_SLAT_GUIDANCE_STRENGTH, + DEFAULT_SHAPE_SLAT_SAMPLING_STEPS, + DEFAULT_SS_GUIDANCE_STRENGTH, + DEFAULT_SS_SAMPLING_STEPS, + DEFAULT_TEX_SLAT_GUIDANCE_STRENGTH, + DEFAULT_TEX_SLAT_SAMPLING_STEPS, + DEFAULT_TEXTURE_SIZE, + GPU_MEMORY_SNAPSHOT, + GPU_TYPE, + HF_CACHE_PATH, + MAX_IMAGE_DIMENSION, + MAX_IMAGE_PAYLOAD_SIZE, +) +from .generator import TRELLIS2Generator +from .image import api_keys_volume, app, hf_cache_volume, hf_secret, trellis2_image + +logger = logging.getLogger(__name__) + +# Derive volume path from config (parent directory of API_KEYS_PATH) +# Note: os is imported at top of file via __future__ +import os as _os # noqa: E402 + +API_KEYS_VOLUME_PATH = _os.path.dirname(API_KEYS_PATH) + +# Valid pipeline types +VALID_PIPELINE_TYPES = frozenset({"512", "1024", "1024_cascade", "1536_cascade"}) + +# Valid texture sizes +VALID_TEXTURE_SIZES = frozenset({512, 1024, 2048, 4096}) + +if TYPE_CHECKING: + from PIL import Image + + +def generate_request_id() -> str: + """Generate a unique request ID for tracing.""" + return f"req_{secrets.token_hex(8)}" + + +@dataclass +class GenerateParams: + """Validated parameters for the generate endpoint.""" + + image: Image.Image + seed: int + pipeline_type: str + ss_sampling_steps: int + ss_guidance_strength: float + shape_slat_sampling_steps: int + shape_slat_guidance_strength: float + tex_slat_sampling_steps: int + tex_slat_guidance_strength: float + + +@dataclass +class ExtractGLBParams: + """Validated parameters for the extract_glb endpoint.""" + + state: dict[str, Any] + decimation_target: int + texture_size: int + remesh: bool + remesh_band: float + remesh_project: float + + +def _error_response(code: str, message: str) -> dict: + """Create a standardized error response.""" + return {"error": {"code": code, "message": message}} + + +def _log_request( + endpoint: str, + api_key: str | None, + duration_ms: float, + status: str, + error_code: str | None = None, + request_id: str | None = None, + extra_metrics: dict[str, Any] | None = None, +) -> None: + """Log structured request completion.""" + log_data: dict[str, Any] = { + "endpoint": endpoint, + "api_key": mask_api_key(api_key), + "duration_ms": round(duration_ms, 2), + "status": status, + } + if request_id: + log_data["request_id"] = request_id + if error_code: + log_data["error_code"] = error_code + if extra_metrics: + log_data.update(extra_metrics) + + if status == "success": + logger.info("Request completed: %s", log_data) + else: + logger.warning("Request failed: %s", log_data) + + +def _parse_generate_request(request: dict) -> GenerateParams | dict: + """ + Parse and validate a generate request. + + Returns: + GenerateParams if valid, or error dict if validation fails. + """ + import base64 + from io import BytesIO + + from PIL import Image + + if not isinstance(request, dict): + return _error_response("validation_error", "Request must be JSON object") + + # Required field + image_b64 = request.get("image") + if not image_b64: + return _error_response("validation_error", "Missing 'image' field") + + # Check payload size + estimated_size = len(image_b64) * 3 // 4 + if estimated_size > MAX_IMAGE_PAYLOAD_SIZE: + return _error_response( + "validation_error", + f"Image size exceeds limit ({estimated_size // (1024 * 1024)}MB > " + f"{MAX_IMAGE_PAYLOAD_SIZE // (1024 * 1024)}MB)", + ) + + # Parse parameters with defaults + try: + seed = int(request.get("seed", DEFAULT_SEED)) + pipeline_type = str(request.get("pipeline_type", DEFAULT_PIPELINE_TYPE)) + ss_sampling_steps = int( + request.get("ss_sampling_steps", DEFAULT_SS_SAMPLING_STEPS) + ) + ss_guidance_strength = float( + request.get("ss_guidance_strength", DEFAULT_SS_GUIDANCE_STRENGTH) + ) + shape_slat_sampling_steps = int( + request.get("shape_slat_sampling_steps", DEFAULT_SHAPE_SLAT_SAMPLING_STEPS) + ) + shape_slat_guidance_strength = float( + request.get( + "shape_slat_guidance_strength", DEFAULT_SHAPE_SLAT_GUIDANCE_STRENGTH + ) + ) + tex_slat_sampling_steps = int( + request.get("tex_slat_sampling_steps", DEFAULT_TEX_SLAT_SAMPLING_STEPS) + ) + tex_slat_guidance_strength = float( + request.get( + "tex_slat_guidance_strength", DEFAULT_TEX_SLAT_GUIDANCE_STRENGTH + ) + ) + except (TypeError, ValueError) as e: + return _error_response("validation_error", f"Invalid parameter: {e}") + + # Validate pipeline_type + if pipeline_type not in VALID_PIPELINE_TYPES: + return _error_response( + "validation_error", + f"Invalid pipeline_type '{pipeline_type}'. " + f"Must be one of: {', '.join(sorted(VALID_PIPELINE_TYPES))}", + ) + + # Decode image + try: + image_bytes = base64.b64decode(image_b64) + image = Image.open(BytesIO(image_bytes)) + except Exception as e: + return _error_response("validation_error", f"Invalid image: {e}") + + # Validate image dimensions + width, height = image.size + if width > MAX_IMAGE_DIMENSION or height > MAX_IMAGE_DIMENSION: + return _error_response( + "validation_error", + f"Image dimensions exceed limit " + f"({width}x{height} > {MAX_IMAGE_DIMENSION}x{MAX_IMAGE_DIMENSION})", + ) + + return GenerateParams( + image=image, + seed=seed, + pipeline_type=pipeline_type, + ss_sampling_steps=ss_sampling_steps, + ss_guidance_strength=ss_guidance_strength, + shape_slat_sampling_steps=shape_slat_sampling_steps, + shape_slat_guidance_strength=shape_slat_guidance_strength, + tex_slat_sampling_steps=tex_slat_sampling_steps, + tex_slat_guidance_strength=tex_slat_guidance_strength, + ) + + +def _parse_extract_glb_request(request: dict) -> ExtractGLBParams | dict: + """ + Parse and validate an extract_glb request. + + Returns: + ExtractGLBParams if valid, or error dict if validation fails. + """ + import base64 + + from trellis2_modal.client.compression import decompress_state + + if not isinstance(request, dict): + return _error_response("validation_error", "Request must be JSON object") + + state_b64 = request.get("state") + if not state_b64: + return _error_response("validation_error", "Missing 'state' field") + + # Parse parameters with defaults + try: + decimation_target = int( + request.get("decimation_target", DEFAULT_DECIMATION_TARGET) + ) + texture_size = int(request.get("texture_size", DEFAULT_TEXTURE_SIZE)) + remesh = bool(request.get("remesh", DEFAULT_REMESH)) + remesh_band = float(request.get("remesh_band", DEFAULT_REMESH_BAND)) + remesh_project = float(request.get("remesh_project", DEFAULT_REMESH_PROJECT)) + except (TypeError, ValueError) as e: + return _error_response("validation_error", f"Invalid parameter: {e}") + + # Validate ranges + if decimation_target < 1000: + return _error_response( + "validation_error", + "decimation_target must be at least 1000", + ) + if texture_size not in VALID_TEXTURE_SIZES: + return _error_response( + "validation_error", + f"texture_size must be one of: {', '.join(map(str, sorted(VALID_TEXTURE_SIZES)))}", + ) + + # Decode and decompress state + try: + compressed_state = base64.b64decode(state_b64) + state = decompress_state(compressed_state) + except Exception as e: + return _error_response("validation_error", f"Invalid state: {e}") + + return ExtractGLBParams( + state=state, + decimation_target=decimation_target, + texture_size=texture_size, + remesh=remesh, + remesh_band=remesh_band, + remesh_project=remesh_project, + ) + + +# CPU-only health check class - prevents health checks from spinning up GPU containers +# This uses a minimal image and no GPU, so health probes are fast and cheap +@app.cls( + image=modal.Image.debian_slim().pip_install("fastapi[standard]"), + cpu=1.0, + min_containers=1, # Keep one warm for instant health responses +) +class HealthCheckService: + """ + Lightweight CPU-only health check service. + + Separated from GPU service to prevent load balancer health probes + from consuming expensive GPU resources. This class can handle + thousands of health checks per hour at minimal cost. + """ + + @modal.fastapi_endpoint(method="GET") + def health(self) -> dict: + """GET /health - Health check for load balancers. No auth required.""" + return {"status": "ok", "service": "trellis2-api"} + + +@app.cls( + image=trellis2_image, + gpu=GPU_TYPE, + secrets=[hf_secret], + volumes={ + HF_CACHE_PATH: hf_cache_volume, + API_KEYS_VOLUME_PATH: api_keys_volume, + }, + timeout=600, + scaledown_window=300, + enable_memory_snapshot=GPU_MEMORY_SNAPSHOT, + max_containers=1, # Limit GPU container count to prevent quota exhaustion + retries=modal.Retries( + max_retries=2, + initial_delay=5.0, + backoff_coefficient=2.0, + ), +) +class TRELLIS2Service: + """ + Modal class wrapping TRELLIS2Generator. + + Uses composition to separate Modal infrastructure from domain logic. + + Note: CPU memory snapshots are not supported because TRELLIS.2 dependencies + (flex_gemm, Triton) require GPU access during import. Model loading happens + entirely in the GPU phase. + """ + + @modal.enter() + def load_model(self) -> None: + """Load model to GPU. Called once when container starts.""" + import torch + + self._load_start = time.time() + self.generator = TRELLIS2Generator() + + logger.info("Loading model to GPU...") + logger.info("GPU: %s", torch.cuda.get_device_name(0)) + + self.generator.load_model() + + logger.info("Model loaded: %.2fs", self.generator.load_time) + logger.info("VRAM: %.2f GB", torch.cuda.memory_allocated() / 1e9) + + @modal.method() + def health_check(self) -> dict: + """Return health status and diagnostic information.""" + import torch + + return { + "status": "healthy" if self.generator.is_ready else "unhealthy", + "gpu": torch.cuda.get_device_name(0), + "vram_allocated_gb": round(torch.cuda.memory_allocated() / 1e9, 2), + "vram_total_gb": round( + torch.cuda.get_device_properties(0).total_memory / 1e9, 2 + ), + "load_time_seconds": round(self.generator.load_time, 2), + } + + @modal.fastapi_endpoint(method="POST", requires_proxy_auth=True) + def generate(self, http_request: Request, request: dict) -> dict: + """ + POST /generate - Generate 3D model from image. + + Requires Modal Proxy Auth (Modal-Key/Modal-Secret headers). + Create tokens in the Modal dashboard. + + Args: + http_request: FastAPI Request for header access + request: Dict with image (base64), seed, pipeline_type, sampler params + + Returns: + Dict with state (compressed, base64), video (base64), request_id + """ + import base64 + + import torch + + from trellis2_modal.client.compression import compress_state + + request_id = generate_request_id() + start_time = time.perf_counter() + # Note: Modal Proxy Auth handles authentication at the proxy layer + # Requests that reach this code are already authenticated + api_key = http_request.headers.get("Modal-Key", "proxy-auth") + + # Parse request + params = _parse_generate_request(request) + if isinstance(params, dict): + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", + api_key, + duration_ms, + "error", + "validation_error", + request_id, + ) + return params + + # Build sampler param dicts + # Note: TRELLIS.2 samplers use "guidance_strength" (not "cfg_strength") + ss_params = { + "steps": params.ss_sampling_steps, + "guidance_strength": params.ss_guidance_strength, + } + shape_params = { + "steps": params.shape_slat_sampling_steps, + "guidance_strength": params.shape_slat_guidance_strength, + } + tex_params = { + "steps": params.tex_slat_sampling_steps, + "guidance_strength": params.tex_slat_guidance_strength, + } + + # Generate 3D + # OOM handling pattern: catch outside exception scope to properly free memory + # (Python exception frames hold tensor references, preventing empty_cache from working) + gen_oom = False + gen_error = None + state = None + try: + state = self.generator.generate( + image=params.image, + seed=params.seed, + pipeline_type=params.pipeline_type, + ss_params=ss_params, + shape_params=shape_params, + tex_params=tex_params, + ) + except torch.cuda.OutOfMemoryError: + gen_oom = True + except Exception as e: + gen_error = e + + if gen_oom: + torch.cuda.empty_cache() + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", api_key, duration_ms, "error", "cuda_oom", request_id + ) + return _error_response( + "cuda_oom", + "GPU out of memory. Try a smaller image or lower resolution pipeline.", + ) + + if gen_error is not None: + logger.exception("Generation failed: %s", gen_error) + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", + api_key, + duration_ms, + "error", + "generation_error", + request_id, + ) + return _error_response("generation_error", f"Generation failed: {gen_error}") + + # Render preview video + # OOM handling pattern: catch outside exception scope + video_oom = False + video_error = None + video_bytes = None + try: + video_bytes = self.generator.render_preview_video(state) + video_b64 = base64.b64encode(video_bytes).decode("utf-8") + except torch.cuda.OutOfMemoryError: + video_oom = True + except Exception as e: + video_error = e + + if video_oom: + torch.cuda.empty_cache() + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", api_key, duration_ms, "error", "cuda_oom", request_id + ) + return _error_response( + "cuda_oom", + "GPU out of memory during video rendering.", + ) + + if video_error is not None: + logger.exception("Video rendering failed: %s", video_error) + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", + api_key, + duration_ms, + "error", + "rendering_error", + request_id, + ) + return _error_response("rendering_error", f"Video rendering failed: {video_error}") + + # Compress state + try: + compressed_state = compress_state(state) + state_b64 = base64.b64encode(compressed_state).decode("utf-8") + except Exception as e: + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", + api_key, + duration_ms, + "error", + "compression_error", + request_id, + ) + return _error_response( + "compression_error", f"State compression failed: {e}" + ) + + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "generate", + api_key, + duration_ms, + "success", + request_id=request_id, + extra_metrics={ + "state_size_bytes": len(compressed_state), + "video_size_bytes": len(video_bytes), + }, + ) + return { + "state": state_b64, + "video": video_b64, + "request_id": request_id, + } + + @modal.fastapi_endpoint(method="POST", requires_proxy_auth=True) + def extract_glb(self, http_request: Request, request: dict) -> dict: + """ + POST /extract_glb - Extract GLB mesh from generation state. + + Requires Modal Proxy Auth (Modal-Key/Modal-Secret headers). + Create tokens in the Modal dashboard. + + Args: + http_request: FastAPI Request for header access + request: Dict with state (base64), decimation params + + Returns: + Dict with glb (base64), request_id + """ + import base64 + + import torch + + request_id = generate_request_id() + start_time = time.perf_counter() + # Note: Modal Proxy Auth handles authentication at the proxy layer + # Requests that reach this code are already authenticated + api_key = http_request.headers.get("Modal-Key", "proxy-auth") + + # Parse request + params = _parse_extract_glb_request(request) + if isinstance(params, dict): + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "extract_glb", + api_key, + duration_ms, + "error", + "validation_error", + request_id, + ) + return params + + # Extract GLB + # OOM handling pattern: catch outside exception scope + glb_oom = False + glb_error = None + glb_bytes = None + try: + glb_bytes = self.generator.extract_glb( + state=params.state, + decimation_target=params.decimation_target, + texture_size=params.texture_size, + remesh=params.remesh, + remesh_band=params.remesh_band, + remesh_project=params.remesh_project, + ) + except torch.cuda.OutOfMemoryError: + glb_oom = True + except Exception as e: + glb_error = e + + if glb_oom: + torch.cuda.empty_cache() + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "extract_glb", api_key, duration_ms, "error", "cuda_oom", request_id + ) + return _error_response( + "cuda_oom", + "GPU out of memory during GLB extraction. Try reducing texture size.", + ) + + if glb_error is not None: + logger.exception("GLB extraction failed: %s", glb_error) + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "extract_glb", + api_key, + duration_ms, + "error", + "extraction_error", + request_id, + ) + return _error_response("extraction_error", f"GLB extraction failed: {glb_error}") + + glb_b64 = base64.b64encode(glb_bytes).decode("utf-8") + + duration_ms = (time.perf_counter() - start_time) * 1000 + _log_request( + "extract_glb", + api_key, + duration_ms, + "success", + request_id=request_id, + extra_metrics={"glb_size_bytes": len(glb_bytes)}, + ) + return {"glb": glb_b64, "request_id": request_id} + + +@app.local_entrypoint() +def test_service(): + """Test the service with a health check.""" + import json + + print("\n" + "=" * 60) + print("TRELLIS.2 Modal Service - Health Check") + print("=" * 60 + "\n") + + service = TRELLIS2Service() + + print("Starting first call (cold start)...") + start = time.time() + result = service.health_check.remote() + first_call_time = time.time() - start + + print(f"\nFirst call completed in {first_call_time:.2f}s") + print(f"Result: {json.dumps(result, indent=2)}") + + print("\nStarting second call (warm container)...") + start = time.time() + result = service.health_check.remote() + second_call_time = time.time() - start + + print(f"\nSecond call completed in {second_call_time:.2f}s") + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(f"Cold start: {first_call_time:.2f}s") + print(f"Warm container: {second_call_time:.2f}s") + print(f"Speedup: {first_call_time / second_call_time:.1f}x") + print(f"Model load: {result.get('load_time_seconds', 'N/A')}s") diff --git a/trellis2_modal/service/state.py b/trellis2_modal/service/state.py new file mode 100644 index 0000000..247b3ff --- /dev/null +++ b/trellis2_modal/service/state.py @@ -0,0 +1,76 @@ +""" +State serialization for MeshWithVoxel. + +Converts between GPU tensors and numpy arrays for network transfer. +Uses LZ4 compression on the client side (see client/compression.py). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from trellis2.representations import MeshWithVoxel + + +def pack_state(mesh: MeshWithVoxel) -> dict[str, Any]: + """ + Pack MeshWithVoxel into a serializable dictionary. + + All GPU tensors are moved to CPU and converted to numpy arrays. + Slice objects in layout are converted to [start, stop] lists. + + Args: + mesh: MeshWithVoxel from generation pipeline + + Returns: + Dictionary with numpy arrays, ready for JSON serialization + """ + return { + "vertices": mesh.vertices.cpu().numpy(), + "faces": mesh.faces.cpu().numpy(), + "attrs": mesh.attrs.cpu().numpy(), + "coords": mesh.coords.cpu().numpy(), + "voxel_size": float(mesh.voxel_size), + "voxel_shape": list(mesh.voxel_shape), + "origin": mesh.origin.cpu().tolist(), + "layout": {k: [v.start, v.stop] for k, v in mesh.layout.items()}, + } + + +def unpack_state(state: dict[str, Any]) -> MeshWithVoxel: + """ + Reconstruct MeshWithVoxel from a packed state dictionary. + + All arrays are converted to CUDA tensors. This function assumes + CUDA is available (it runs on the Modal GPU service). + + Args: + state: Dictionary from pack_state() + + Returns: + MeshWithVoxel with tensors on CUDA device + + Raises: + RuntimeError: If CUDA is not available + """ + import torch + from trellis2.representations import MeshWithVoxel + + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA not available. unpack_state() must run on GPU service." + ) + + layout = {k: slice(v[0], v[1]) for k, v in state["layout"].items()} + + return MeshWithVoxel( + vertices=torch.tensor(state["vertices"], device="cuda", dtype=torch.float32), + faces=torch.tensor(state["faces"], device="cuda", dtype=torch.int32), + origin=state["origin"], + voxel_size=state["voxel_size"], + coords=torch.tensor(state["coords"], device="cuda"), + attrs=torch.tensor(state["attrs"], device="cuda", dtype=torch.float32), + voxel_shape=torch.Size(state["voxel_shape"]), + layout=layout, + ) diff --git a/trellis2_modal/tests/__init__.py b/trellis2_modal/tests/__init__.py new file mode 100644 index 0000000..cbdacd5 --- /dev/null +++ b/trellis2_modal/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for trellis2_modal package.""" diff --git a/trellis2_modal/tests/conftest.py b/trellis2_modal/tests/conftest.py new file mode 100644 index 0000000..9d6b5dd --- /dev/null +++ b/trellis2_modal/tests/conftest.py @@ -0,0 +1,11 @@ +""" +Pytest configuration and shared fixtures for trellis2_modal tests. +""" + +import sys +from pathlib import Path + +# Add the project root to Python path for imports +project_root = Path(__file__).parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) diff --git a/trellis2_modal/tests/test_api_client.py b/trellis2_modal/tests/test_api_client.py new file mode 100644 index 0000000..d8d382e --- /dev/null +++ b/trellis2_modal/tests/test_api_client.py @@ -0,0 +1,551 @@ +""" +Tests for the TRELLIS.2 API client. + +Tests validate TRELLIS2APIClient HTTP calls, error handling, and response parsing. +""" + +import base64 +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from trellis2_modal.client.api import APIError, TRELLIS2APIClient + + +class TestTRELLIS2APIClientInit: + """Tests for TRELLIS2APIClient initialization.""" + + def test_init_stores_credentials(self) -> None: + """Client should store modal_key and modal_secret.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test123", + modal_secret="ws-secret456", + ) + assert client.modal_key == "wk-test123" + assert client.modal_secret == "ws-secret456" + + def test_init_derives_urls_for_path_routing(self) -> None: + """Path-based URLs should derive generate and extract_glb URLs.""" + client = TRELLIS2APIClient( + base_url="https://example.com/", + modal_key="wk-test", + modal_secret="ws-test", + ) + assert client.generate_url == "https://example.com/generate" + assert client.extract_glb_url == "https://example.com/extract_glb" + + def test_init_derives_urls_for_modal_subdomain(self) -> None: + """Modal subdomain URLs should derive both endpoint URLs.""" + client = TRELLIS2APIClient( + base_url="https://user--app-generate.modal.run", + modal_key="wk-test", + modal_secret="ws-test", + ) + assert client.generate_url == "https://user--app-generate.modal.run" + assert client.extract_glb_url == "https://user--app-extract-glb.modal.run" + + def test_last_request_elapsed_is_none_initially(self) -> None: + """last_request_elapsed should be None before any request.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + assert client.last_request_elapsed is None + + +class TestTRELLIS2APIClientGenerate: + """Tests for TRELLIS2APIClient.generate().""" + + def test_generate_sends_correct_request(self, tmp_path: Path) -> None: + """generate() should send POST with correct headers and payload.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test123", + modal_secret="ws-secret456", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + + client.generate( + image_path=str(image_path), + seed=42, + pipeline_type="1024_cascade", + ss_sampling_steps=12, + ss_guidance_strength=7.5, + shape_slat_sampling_steps=10, + shape_slat_guidance_strength=6.0, + tex_slat_sampling_steps=8, + tex_slat_guidance_strength=1.0, + ) + + mock_request.assert_called_once() + call_args = mock_request.call_args + + assert call_args[0][0] == "POST" + assert call_args[0][1] == "https://example.com/generate" + + assert call_args[1]["headers"]["Modal-Key"] == "wk-test123" + assert call_args[1]["headers"]["Modal-Secret"] == "ws-secret456" + assert call_args[1]["headers"]["Content-Type"] == "application/json" + + payload = call_args[1]["json"] + assert payload["seed"] == 42 + assert payload["pipeline_type"] == "1024_cascade" + assert payload["ss_sampling_steps"] == 12 + assert payload["ss_guidance_strength"] == 7.5 + assert payload["shape_slat_sampling_steps"] == 10 + assert payload["shape_slat_guidance_strength"] == 6.0 + assert payload["tex_slat_sampling_steps"] == 8 + assert payload["tex_slat_guidance_strength"] == 1.0 + assert payload["image"] == base64.b64encode(b"fake png data").decode() + + def test_generate_returns_result_dict(self, tmp_path: Path) -> None: + """generate() should return dict with state and video.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + + result = client.generate(image_path=str(image_path)) + + assert "state" in result + assert "video" in result + assert result["state"] == "c3RhdGU=" + assert result["video"] == "dmlkZW8=" + + def test_generate_uses_default_params(self, tmp_path: Path) -> None: + """generate() should use defaults when params not specified.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + + client.generate(image_path=str(image_path)) + + payload = mock_request.call_args[1]["json"] + assert payload["seed"] == 42 + assert payload["pipeline_type"] == "1024_cascade" + assert payload["ss_sampling_steps"] == 12 + assert payload["ss_guidance_strength"] == 7.5 + assert payload["shape_slat_sampling_steps"] == 12 + assert payload["shape_slat_guidance_strength"] == 7.5 + assert payload["tex_slat_sampling_steps"] == 12 + assert payload["tex_slat_guidance_strength"] == 1.0 + + def test_generate_raises_file_not_found(self) -> None: + """generate() should raise FileNotFoundError for missing image.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with pytest.raises(FileNotFoundError, match="Image not found"): + client.generate(image_path="/nonexistent/image.png") + + def test_generate_raises_api_error(self, tmp_path: Path) -> None: + """generate() should raise APIError on server error response.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=400, + json=lambda: { + "error": {"code": "validation_error", "message": "Invalid image"} + }, + ) + + with pytest.raises(APIError) as exc_info: + client.generate(image_path=str(image_path)) + + assert exc_info.value.code == "validation_error" + assert exc_info.value.message == "Invalid image" + + @pytest.mark.parametrize( + "pipeline_type", + ["512", "1024", "1024_cascade", "1536_cascade"], + ) + def test_generate_accepts_all_pipeline_types( + self, tmp_path: Path, pipeline_type: str + ) -> None: + """generate() should accept all valid pipeline types.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + + client.generate(image_path=str(image_path), pipeline_type=pipeline_type) + + payload = mock_request.call_args[1]["json"] + assert payload["pipeline_type"] == pipeline_type + + +class TestTRELLIS2APIClientExtractGLB: + """Tests for TRELLIS2APIClient.extract_glb().""" + + def test_extract_glb_sends_correct_request(self, tmp_path: Path) -> None: + """extract_glb() should send POST with correct payload.""" + output_path = tmp_path / "output.glb" + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"glb": base64.b64encode(b"glb data").decode()}, + ) + + client.extract_glb( + state="c3RhdGU=", + output_path=str(output_path), + decimation_target=500000, + texture_size=2048, + remesh=False, + remesh_band=0.5, + remesh_project=0.1, + ) + + mock_request.assert_called_once() + call_args = mock_request.call_args + + assert call_args[0][0] == "POST" + assert call_args[0][1] == "https://example.com/extract_glb" + + payload = call_args[1]["json"] + assert payload["state"] == "c3RhdGU=" + assert payload["decimation_target"] == 500000 + assert payload["texture_size"] == 2048 + assert payload["remesh"] is False + assert payload["remesh_band"] == 0.5 + assert payload["remesh_project"] == 0.1 + + def test_extract_glb_writes_file(self, tmp_path: Path) -> None: + """extract_glb() should write GLB data to output path.""" + output_path = tmp_path / "output.glb" + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + glb_data = b"fake glb binary data" + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"glb": base64.b64encode(glb_data).decode()}, + ) + + result = client.extract_glb( + state="c3RhdGU=", + output_path=str(output_path), + ) + + assert result == str(output_path) + assert output_path.exists() + assert output_path.read_bytes() == glb_data + + def test_extract_glb_uses_default_params(self, tmp_path: Path) -> None: + """extract_glb() should use defaults when params not specified.""" + output_path = tmp_path / "output.glb" + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"glb": base64.b64encode(b"data").decode()}, + ) + + client.extract_glb(state="c3RhdGU=", output_path=str(output_path)) + + payload = mock_request.call_args[1]["json"] + assert payload["decimation_target"] == 1000000 + assert payload["texture_size"] == 4096 + assert payload["remesh"] is True + assert payload["remesh_band"] == 1.0 + assert payload["remesh_project"] == 0.0 + + def test_extract_glb_raises_api_error(self, tmp_path: Path) -> None: + """extract_glb() should raise APIError on server error.""" + output_path = tmp_path / "output.glb" + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=400, + json=lambda: { + "error": {"code": "cuda_oom", "message": "Out of memory"} + }, + ) + + with pytest.raises(APIError) as exc_info: + client.extract_glb(state="c3RhdGU=", output_path=str(output_path)) + + assert exc_info.value.code == "cuda_oom" + + +class TestColdStartDetection: + """Tests for cold start detection.""" + + def test_was_cold_start_true_for_slow_request(self, tmp_path: Path) -> None: + """was_cold_start() returns True if request exceeded threshold.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + with patch("time.perf_counter", side_effect=[0.0, 35.0]): + client.generate(image_path=str(image_path)) + + assert client.was_cold_start() is True + + def test_was_cold_start_false_for_fast_request(self, tmp_path: Path) -> None: + """was_cold_start() returns False if request was fast.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + with patch("time.perf_counter", side_effect=[0.0, 5.0]): + client.generate(image_path=str(image_path)) + + assert client.was_cold_start() is False + + def test_was_cold_start_false_when_no_request_made(self) -> None: + """was_cold_start() returns False if no request made yet.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + assert client.was_cold_start() is False + + def test_was_cold_start_custom_threshold(self, tmp_path: Path) -> None: + """was_cold_start() respects custom threshold.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ) + with patch("time.perf_counter", side_effect=[0.0, 10.0]): + client.generate(image_path=str(image_path)) + + assert client.was_cold_start(threshold=5.0) is True + assert client.was_cold_start(threshold=15.0) is False + + +class TestRetryLogic: + """Tests for request retry logic.""" + + def test_retries_on_connection_error(self, tmp_path: Path) -> None: + """Should retry on ConnectionError.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.side_effect = [ + requests.exceptions.ConnectionError(), + MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ), + ] + with patch("time.sleep"): + result = client.generate(image_path=str(image_path)) + + assert result["state"] == "c3RhdGU=" + assert mock_request.call_count == 2 + + def test_retries_on_timeout(self, tmp_path: Path) -> None: + """Should retry on Timeout.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.side_effect = [ + requests.exceptions.Timeout(), + MagicMock( + status_code=200, + json=lambda: {"state": "c3RhdGU=", "video": "dmlkZW8="}, + ), + ] + with patch("time.sleep"): + result = client.generate(image_path=str(image_path)) + + assert result["state"] == "c3RhdGU=" + assert mock_request.call_count == 2 + + def test_raises_after_max_retries(self, tmp_path: Path) -> None: + """Should raise after exhausting retries.""" + image_path = tmp_path / "test.png" + image_path.write_bytes(b"fake png data") + + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.side_effect = requests.exceptions.ConnectionError() + with patch("time.sleep"): + with pytest.raises(requests.exceptions.ConnectionError): + client.generate(image_path=str(image_path)) + + assert mock_request.call_count == 3 # Initial + 2 retries + + +class TestHealthCheck: + """Tests for health check functionality.""" + + def test_health_check_returns_true_on_success(self) -> None: + """health_check() returns True when service responds.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.return_value = MagicMock(status_code=200) + assert client.health_check() is True + + def test_health_check_returns_false_on_connection_error(self) -> None: + """health_check() returns False on connection error.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.side_effect = requests.exceptions.ConnectionError() + assert client.health_check() is False + + def test_health_check_returns_false_on_timeout(self) -> None: + """health_check() returns False on timeout.""" + client = TRELLIS2APIClient( + base_url="https://example.com", + modal_key="wk-test", + modal_secret="ws-test", + ) + + with patch("requests.request") as mock_request: + mock_request.side_effect = requests.exceptions.Timeout() + assert client.health_check() is False + + +class TestAPIError: + """Tests for APIError exception.""" + + def test_api_error_stores_code_and_message(self) -> None: + """APIError should store code and message.""" + error = APIError("test_code", "test message") + assert error.code == "test_code" + assert error.message == "test message" + + def test_api_error_str_includes_code_and_message(self) -> None: + """APIError str should include code and message.""" + error = APIError("test_code", "test message") + assert str(error) == "test_code: test message" diff --git a/trellis2_modal/tests/test_auth.py b/trellis2_modal/tests/test_auth.py new file mode 100644 index 0000000..e984a88 --- /dev/null +++ b/trellis2_modal/tests/test_auth.py @@ -0,0 +1,411 @@ +""" +Tests for API key authentication. + +Tests validate_api_key function with various valid and invalid key scenarios. +Tests load_keys and save_keys for file I/O. +""" + +import json +from pathlib import Path + +import pytest + +from trellis2_modal.service.auth import ( + add_key, + check_auth, + generate_api_key, + increment_usage, + load_keys, + mask_api_key, + revoke_key, + save_keys, + validate_api_key, +) + + +class TestValidateApiKeyFormat: + """Test API key format validation.""" + + def test_none_key_returns_invalid(self) -> None: + """None key should return invalid.""" + keys_data = {"keys": {}} + is_valid, key_info = validate_api_key(None, keys_data) + assert is_valid is False + assert key_info is None + + def test_empty_key_returns_invalid(self) -> None: + """Empty string key should return invalid.""" + keys_data = {"keys": {}} + is_valid, key_info = validate_api_key("", keys_data) + assert is_valid is False + assert key_info is None + + def test_key_without_prefix_returns_invalid(self) -> None: + """Key without 'sk_' prefix should return invalid.""" + keys_data = {"keys": {"not_a_valid_key": {"active": True}}} + is_valid, key_info = validate_api_key("not_a_valid_key", keys_data) + assert is_valid is False + assert key_info is None + + +class TestValidateApiKeyLookup: + """Test API key lookup in keys data.""" + + def test_valid_key_returns_valid_and_info(self) -> None: + """Valid active key should return valid with key info.""" + keys_data = { + "keys": { + "sk_dev_test123": { + "name": "test", + "active": True, + "created": "2025-01-01", + } + } + } + is_valid, key_info = validate_api_key("sk_dev_test123", keys_data) + assert is_valid is True + assert key_info is not None + assert key_info["name"] == "test" + + def test_unknown_key_returns_invalid(self) -> None: + """Key not in keys_data should return invalid.""" + keys_data = {"keys": {"sk_dev_other": {"active": True}}} + is_valid, key_info = validate_api_key("sk_dev_unknown", keys_data) + assert is_valid is False + assert key_info is None + + def test_inactive_key_returns_invalid(self) -> None: + """Inactive key should return invalid.""" + keys_data = { + "keys": { + "sk_dev_revoked": { + "name": "revoked", + "active": False, + } + } + } + is_valid, key_info = validate_api_key("sk_dev_revoked", keys_data) + assert is_valid is False + assert key_info is None + + def test_key_without_active_field_defaults_to_active(self) -> None: + """Key without 'active' field should default to active.""" + keys_data = {"keys": {"sk_dev_noactive": {"name": "no_active_field"}}} + is_valid, key_info = validate_api_key("sk_dev_noactive", keys_data) + assert is_valid is True + assert key_info is not None + + +class TestValidateApiKeyEdgeCases: + """Test edge cases in API key validation.""" + + def test_empty_keys_data_returns_invalid(self) -> None: + """Empty keys dict should return invalid for any key.""" + keys_data = {"keys": {}} + is_valid, key_info = validate_api_key("sk_dev_test", keys_data) + assert is_valid is False + assert key_info is None + + def test_missing_keys_field_returns_invalid(self) -> None: + """Missing 'keys' field should return invalid.""" + keys_data = {} + is_valid, key_info = validate_api_key("sk_dev_test", keys_data) + assert is_valid is False + assert key_info is None + + def test_live_key_prefix_accepted(self) -> None: + """Key with 'sk_live_' prefix should be accepted.""" + keys_data = {"keys": {"sk_live_prod123": {"active": True}}} + is_valid, key_info = validate_api_key("sk_live_prod123", keys_data) + assert is_valid is True + + +class TestLoadKeys: + """Tests for load_keys function.""" + + def test_loads_valid_json_file(self, tmp_path: Path) -> None: + """Valid keys.json should load correctly.""" + keys_file = tmp_path / "keys.json" + expected = { + "version": 1, + "keys": {"sk_dev_test": {"name": "test", "active": True}}, + } + keys_file.write_text(json.dumps(expected)) + + result = load_keys(str(keys_file)) + assert result == expected + + def test_missing_file_returns_empty_structure(self, tmp_path: Path) -> None: + """Missing file should return empty keys structure.""" + missing_file = tmp_path / "nonexistent.json" + + result = load_keys(str(missing_file)) + assert result == {"version": 1, "keys": {}} + + def test_invalid_json_raises_error(self, tmp_path: Path) -> None: + """Invalid JSON should raise ValueError.""" + bad_file = tmp_path / "bad.json" + bad_file.write_text("not valid json {{{") + + with pytest.raises(ValueError, match="Invalid JSON"): + load_keys(str(bad_file)) + + +class TestSaveKeys: + """Tests for save_keys function.""" + + def test_saves_keys_to_file(self, tmp_path: Path) -> None: + """save_keys should write valid JSON to file.""" + keys_file = tmp_path / "keys.json" + data = {"version": 1, "keys": {"sk_dev_test": {"active": True}}} + + save_keys(str(keys_file), data) + + assert keys_file.exists() + loaded = json.loads(keys_file.read_text()) + assert loaded == data + + def test_creates_parent_directory(self, tmp_path: Path) -> None: + """save_keys should create parent directories if needed.""" + nested_file = tmp_path / "subdir" / "keys.json" + data = {"version": 1, "keys": {}} + + save_keys(str(nested_file), data) + + assert nested_file.exists() + + def test_overwrites_existing_file(self, tmp_path: Path) -> None: + """save_keys should overwrite existing file.""" + keys_file = tmp_path / "keys.json" + keys_file.write_text('{"old": "data"}') + + new_data = {"version": 2, "keys": {"sk_new": {}}} + save_keys(str(keys_file), new_data) + + loaded = json.loads(keys_file.read_text()) + assert loaded == new_data + + +class TestCheckAuth: + """Tests for check_auth convenience function.""" + + def test_valid_key_with_file(self, tmp_path: Path) -> None: + """check_auth should validate key from file.""" + keys_file = tmp_path / "keys.json" + keys_file.write_text( + json.dumps( + { + "version": 1, + "keys": {"sk_dev_test123": {"name": "test", "active": True}}, + } + ) + ) + + is_valid, key_info = check_auth("sk_dev_test123", str(keys_file)) + assert is_valid is True + assert key_info["name"] == "test" + + def test_invalid_key_with_file(self, tmp_path: Path) -> None: + """check_auth should reject unknown key.""" + keys_file = tmp_path / "keys.json" + keys_file.write_text(json.dumps({"version": 1, "keys": {}})) + + is_valid, key_info = check_auth("sk_dev_unknown", str(keys_file)) + assert is_valid is False + assert key_info is None + + def test_missing_file_rejects_all_keys(self, tmp_path: Path) -> None: + """check_auth with missing file should reject any key.""" + missing_file = tmp_path / "nonexistent.json" + + is_valid, key_info = check_auth("sk_dev_test", str(missing_file)) + assert is_valid is False + assert key_info is None + + +class TestMaskApiKey: + """Tests for mask_api_key function.""" + + def test_masks_long_key(self) -> None: + """Long key should be masked with ... in middle.""" + result = mask_api_key("sk_dev_abc123def456xyz") + assert result == "sk_dev_...6xyz" + + def test_masks_short_key(self) -> None: + """Short key should show first 7 chars + ...""" + result = mask_api_key("sk_dev_a") + assert result == "sk_dev_..." + + def test_none_key_returns_placeholder(self) -> None: + """None key should return ''.""" + result = mask_api_key(None) + assert result == "" + + def test_empty_key_returns_placeholder(self) -> None: + """Empty key should return ''.""" + result = mask_api_key("") + assert result == "" + + +class TestIncrementUsage: + """Tests for increment_usage function.""" + + def test_increments_usage_count(self) -> None: + """increment_usage should increase usage_count by 1.""" + keys_data = { + "keys": { + "sk_dev_test123": {"name": "test", "active": True, "usage_count": 5} + } + } + increment_usage("sk_dev_test123", keys_data) + assert keys_data["keys"]["sk_dev_test123"]["usage_count"] == 6 + + def test_initializes_usage_count_if_missing(self) -> None: + """increment_usage should initialize usage_count to 1 if not present.""" + keys_data = {"keys": {"sk_dev_test123": {"name": "test", "active": True}}} + increment_usage("sk_dev_test123", keys_data) + assert keys_data["keys"]["sk_dev_test123"]["usage_count"] == 1 + + def test_updates_last_used_timestamp(self) -> None: + """increment_usage should update last_used to ISO timestamp.""" + keys_data = {"keys": {"sk_dev_test123": {"name": "test", "active": True}}} + increment_usage("sk_dev_test123", keys_data) + assert "last_used" in keys_data["keys"]["sk_dev_test123"] + # Verify it's an ISO format timestamp + last_used = keys_data["keys"]["sk_dev_test123"]["last_used"] + assert last_used.startswith("20") # Year starts with 20xx + assert "T" in last_used # ISO format has T separator + + def test_does_nothing_for_unknown_key(self) -> None: + """increment_usage should silently ignore unknown keys.""" + keys_data = {"keys": {"sk_dev_other": {"active": True}}} + increment_usage("sk_dev_unknown", keys_data) + # Should not raise, just silently return + assert "sk_dev_unknown" not in keys_data["keys"] + + def test_does_nothing_for_empty_keys_data(self) -> None: + """increment_usage should handle missing 'keys' field gracefully.""" + keys_data = {} + increment_usage("sk_dev_test", keys_data) + # Should not raise or modify keys_data + assert "keys" not in keys_data + + +class TestGenerateApiKey: + """Tests for generate_api_key function.""" + + def test_generates_key_with_dev_prefix(self) -> None: + """Generated key should start with sk_dev_.""" + key = generate_api_key("dev") + assert key.startswith("sk_dev_") + + def test_generates_key_with_live_prefix(self) -> None: + """Generated key should start with sk_live_.""" + key = generate_api_key("live") + assert key.startswith("sk_live_") + + def test_generates_unique_keys(self) -> None: + """Each call should generate a unique key.""" + keys = [generate_api_key() for _ in range(100)] + assert len(set(keys)) == 100 + + def test_key_has_correct_length(self) -> None: + """Generated key should have sk_dev_ + 32 hex chars.""" + key = generate_api_key("dev") + # sk_dev_ is 7 chars, random part is 32 hex chars + assert len(key) == 7 + 32 + + +class TestAddKey: + """Tests for add_key function.""" + + def test_adds_key_to_keys_data(self) -> None: + """add_key should add new key to keys_data.""" + keys_data = {"keys": {}} + key = add_key(keys_data, name="test") + assert key in keys_data["keys"] + + def test_sets_key_metadata(self) -> None: + """add_key should set name, active, created, usage_count.""" + keys_data = {"keys": {}} + key = add_key(keys_data, name="mykey") + info = keys_data["keys"][key] + assert info["name"] == "mykey" + assert info["active"] is True + assert "created" in info + assert info["usage_count"] == 0 + + def test_sets_quota_when_provided(self) -> None: + """add_key should set quota when provided.""" + keys_data = {"keys": {}} + key = add_key(keys_data, name="test", quota=1000) + assert keys_data["keys"][key]["quota"] == 1000 + + def test_creates_keys_dict_if_missing(self) -> None: + """add_key should create 'keys' dict if missing.""" + keys_data = {} + key = add_key(keys_data, name="test") + assert "keys" in keys_data + assert key in keys_data["keys"] + + +class TestRevokeKey: + """Tests for revoke_key function.""" + + def test_revokes_existing_key(self) -> None: + """revoke_key should set active=False on existing key.""" + keys_data = {"keys": {"sk_dev_test": {"active": True}}} + result = revoke_key(keys_data, "sk_dev_test") + assert result is True + assert keys_data["keys"]["sk_dev_test"]["active"] is False + + def test_returns_false_for_unknown_key(self) -> None: + """revoke_key should return False for unknown key.""" + keys_data = {"keys": {}} + result = revoke_key(keys_data, "sk_dev_unknown") + assert result is False + + def test_handles_missing_keys_field(self) -> None: + """revoke_key should return False if 'keys' field missing.""" + keys_data = {} + result = revoke_key(keys_data, "sk_dev_test") + assert result is False + + +class TestCheckRateLimit: + """Tests for check_rate_limit function.""" + + def test_within_quota_returns_true(self) -> None: + """Should return True when usage_count < quota.""" + from trellis2_modal.service.auth import check_rate_limit + + key_info = {"usage_count": 5, "quota": 100} + assert check_rate_limit(key_info) is True + + def test_at_quota_returns_false(self) -> None: + """Should return False when usage_count >= quota.""" + from trellis2_modal.service.auth import check_rate_limit + + key_info = {"usage_count": 100, "quota": 100} + assert check_rate_limit(key_info) is False + + def test_over_quota_returns_false(self) -> None: + """Should return False when usage_count > quota.""" + from trellis2_modal.service.auth import check_rate_limit + + key_info = {"usage_count": 150, "quota": 100} + assert check_rate_limit(key_info) is False + + def test_no_quota_returns_true(self) -> None: + """Should return True when no quota is set (unlimited).""" + from trellis2_modal.service.auth import check_rate_limit + + key_info = {"usage_count": 999999} + assert check_rate_limit(key_info) is True + + def test_zero_quota_returns_false(self) -> None: + """Should return False when quota is 0 (disabled key).""" + from trellis2_modal.service.auth import check_rate_limit + + key_info = {"usage_count": 0, "quota": 0} + assert check_rate_limit(key_info) is False diff --git a/trellis2_modal/tests/test_compression.py b/trellis2_modal/tests/test_compression.py new file mode 100644 index 0000000..9406522 --- /dev/null +++ b/trellis2_modal/tests/test_compression.py @@ -0,0 +1,153 @@ +""" +Tests for LZ4 compression utilities. + +Tests the compress_state/decompress_state functions that handle +serialization and compression of TRELLIS generation state. +""" + +import pickle + +import numpy as np + +from trellis2_modal.client.compression import compress_state, decompress_state + + +class TestCompressState: + """Tests for compress_state function.""" + + def test_returns_bytes(self) -> None: + """Compress should return bytes.""" + state = {"key": "value"} + result = compress_state(state) + assert isinstance(result, bytes) + + def test_returns_non_empty_bytes(self) -> None: + """Compressed result should not be empty.""" + state = {"key": "value"} + result = compress_state(state) + assert len(result) > 0 + + def test_handles_nested_dicts(self) -> None: + """Should compress nested dictionary structures.""" + state = {"outer": {"inner": {"deep": 42}}} + result = compress_state(state) + assert isinstance(result, bytes) + + +class TestDecompressState: + """Tests for decompress_state function.""" + + def test_returns_dict(self) -> None: + """Decompress should return a dictionary.""" + state = {"key": "value"} + compressed = compress_state(state) + result = decompress_state(compressed) + assert isinstance(result, dict) + + +class TestCompressionRoundtrip: + """Tests for compress/decompress round-trip.""" + + def test_roundtrip_simple_dict(self) -> None: + """Simple dict should survive round-trip.""" + state = {"key": "value", "number": 42} + compressed = compress_state(state) + decompressed = decompress_state(compressed) + assert decompressed == state + + def test_roundtrip_with_list(self) -> None: + """List values should survive round-trip.""" + state = {"aabb": [0, 0, 0, 1, 1, 1]} + compressed = compress_state(state) + decompressed = decompress_state(compressed) + assert decompressed == state + + def test_roundtrip_numpy_array(self) -> None: + """Numpy array should survive round-trip.""" + original = np.array([1.0, 2.0, 3.0], dtype=np.float32) + state = {"data": original} + compressed = compress_state(state) + decompressed = decompress_state(compressed) + np.testing.assert_array_equal(decompressed["data"], original) + + def test_roundtrip_preserves_dtype(self) -> None: + """Numpy array dtype should be preserved.""" + original = np.array([1, 2, 3], dtype=np.int64) + state = {"data": original} + compressed = compress_state(state) + decompressed = decompress_state(compressed) + assert decompressed["data"].dtype == np.int64 + + def test_roundtrip_multidimensional_array(self) -> None: + """Multi-dimensional arrays should survive round-trip.""" + original = np.random.randn(100, 3).astype(np.float32) + state = {"positions": original} + compressed = compress_state(state) + decompressed = decompress_state(compressed) + np.testing.assert_array_equal(decompressed["positions"], original) + + def test_roundtrip_realistic_state(self) -> None: + """Realistic state structure (from pack_state) should survive.""" + state = { + "gaussian": { + "aabb": [0, 0, 0, 1, 1, 1], + "sh_degree": 0, + "scaling_activation": "exp", + "_xyz": np.random.randn(1000, 3).astype(np.float32), + "_features_dc": np.random.randn(1000, 3).astype(np.float32), + "_scaling": np.random.randn(1000, 3).astype(np.float32), + "_rotation": np.random.randn(1000, 4).astype(np.float32), + "_opacity": np.random.randn(1000, 1).astype(np.float32), + }, + "mesh": { + "vertices": np.random.randn(500, 3).astype(np.float32), + "faces": np.arange(1500).reshape(500, 3).astype(np.int64), + }, + } + compressed = compress_state(state) + decompressed = decompress_state(compressed) + + # Check scalar values + assert decompressed["gaussian"]["aabb"] == state["gaussian"]["aabb"] + assert decompressed["gaussian"]["sh_degree"] == state["gaussian"]["sh_degree"] + + # Check numpy arrays + np.testing.assert_array_equal( + decompressed["gaussian"]["_xyz"], state["gaussian"]["_xyz"] + ) + np.testing.assert_array_equal( + decompressed["mesh"]["faces"], state["mesh"]["faces"] + ) + + def test_roundtrip_empty_dict(self) -> None: + """Empty dict should survive round-trip.""" + state: dict = {} + compressed = compress_state(state) + decompressed = decompress_state(compressed) + assert decompressed == state + + +class TestCompressionEfficiency: + """Tests for compression efficiency.""" + + def test_compresses_repetitive_data(self) -> None: + """Repetitive data should compress well.""" + # Zeros compress very well + state = {"zeros": np.zeros((1000, 3), dtype=np.float32)} + serialized = pickle.dumps(state) + compressed = compress_state(state) + assert len(compressed) < len(serialized) + + def test_compresses_realistic_state(self) -> None: + """Realistic state should have meaningful compression.""" + state = { + "gaussian": { + "_xyz": np.random.randn(10000, 3).astype(np.float32), + "_features_dc": np.random.randn(10000, 3).astype(np.float32), + } + } + serialized = pickle.dumps(state) + compressed = compress_state(state) + # LZ4 on random floats won't compress much, but should not expand significantly + # The overhead should be minimal (LZ4 frame header is small) + assert len(compressed) < len(serialized) * 1.1 diff --git a/trellis2_modal/tests/test_config_consistency.py b/trellis2_modal/tests/test_config_consistency.py new file mode 100644 index 0000000..088c8b2 --- /dev/null +++ b/trellis2_modal/tests/test_config_consistency.py @@ -0,0 +1,42 @@ +""" +Test that ensures constants are consistent across modules. + +Modal's execution model copies image.py to /root/image.py without the +trellis2_modal package context, so imports from config.py fail. We must +duplicate constants but verify they stay in sync with this test. +""" + +from trellis2_modal.service import config +from trellis2_modal.service import image + + +def test_gpu_type_matches(): + """GPU_TYPE must match between config and image modules.""" + assert image.GPU_TYPE == config.GPU_TYPE, ( + f"GPU_TYPE mismatch: image.py has '{image.GPU_TYPE}' " + f"but config.py has '{config.GPU_TYPE}'" + ) + + +def test_hf_cache_path_matches(): + """HF_CACHE_PATH must match between config and image modules.""" + assert image.HF_CACHE_PATH == config.HF_CACHE_PATH, ( + f"HF_CACHE_PATH mismatch: image.py has '{image.HF_CACHE_PATH}' " + f"but config.py has '{config.HF_CACHE_PATH}'" + ) + + +def test_model_name_matches(): + """MODEL_NAME must match between config and image modules.""" + assert image.MODEL_NAME == config.MODEL_NAME, ( + f"MODEL_NAME mismatch: image.py has '{image.MODEL_NAME}' " + f"but config.py has '{config.MODEL_NAME}'" + ) + + +def test_trellis2_path_matches(): + """TRELLIS2_PATH must match between config and image modules.""" + assert image.TRELLIS2_PATH == config.TRELLIS2_PATH, ( + f"TRELLIS2_PATH mismatch: image.py has '{image.TRELLIS2_PATH}' " + f"but config.py has '{config.TRELLIS2_PATH}'" + ) diff --git a/trellis2_modal/tests/test_generator.py b/trellis2_modal/tests/test_generator.py new file mode 100644 index 0000000..f9bc0f1 --- /dev/null +++ b/trellis2_modal/tests/test_generator.py @@ -0,0 +1,158 @@ +""" +Tests for TRELLIS2Generator. + +These tests use mocks to verify generator behavior without requiring +GPU or the full TRELLIS.2 pipeline. +""" + +from unittest.mock import MagicMock, patch + +import pytest + + +class MockPipeline: + """Mock pipeline for testing.""" + + def __init__(self): + self.models = {"model1": MagicMock()} + self._device = "cpu" + # Use MagicMock for methods we want to verify + self.preprocess_image = MagicMock(side_effect=lambda x: x) + mock_mesh = MagicMock() + mock_mesh.simplify = MagicMock() + self.run = MagicMock(return_value=[mock_mesh]) + + def to(self, device): + self._device = str(device) + + def cuda(self): + self._device = "cuda" + + +@pytest.fixture +def mock_pipeline_factory(): + """Factory that returns a mock pipeline.""" + return lambda model_name: MockPipeline() + + +@pytest.fixture +def generator(mock_pipeline_factory): + """Create a generator with mock pipeline.""" + from trellis2_modal.service.generator import TRELLIS2Generator + + return TRELLIS2Generator(pipeline_factory=mock_pipeline_factory) + + +def test_generator_init_state(): + """Test initial state of generator.""" + from trellis2_modal.service.generator import TRELLIS2Generator + + gen = TRELLIS2Generator() + assert gen.pipeline is None + assert gen.envmap is None + assert gen.load_time == 0.0 + assert gen.is_ready is False + + +def test_generator_is_ready_property(generator): + """Test is_ready property requires both pipeline and envmap.""" + assert generator.is_ready is False + + # Manually set pipeline (simulating partial load) + generator.pipeline = MockPipeline() + assert generator.is_ready is False # envmap not set + + generator.envmap = "mock_envmap" + assert generator.is_ready is True + + +def test_generator_generate_requires_loaded_pipeline(generator): + """Test that generate fails without loaded pipeline.""" + mock_image = MagicMock() + with pytest.raises(RuntimeError, match="Pipeline not loaded"): + generator.generate(mock_image) + + +def test_generator_generate_calls_pipeline(generator): + """Test that generate calls the pipeline correctly.""" + # Manually set up pipeline for testing + generator.pipeline = MockPipeline() + + with patch("trellis2_modal.service.state.pack_state") as mock_pack: + mock_pack.return_value = {"test": "state"} + mock_image = MagicMock() + + generator.generate(mock_image, seed=123) + + # Verify pipeline methods were called + generator.pipeline.preprocess_image.assert_called_once_with(mock_image) + generator.pipeline.run.assert_called_once() + + # Verify pack_state was called + mock_pack.assert_called_once() + + +def test_generator_generate_passes_correct_params(generator): + """Test that generate passes parameters correctly to pipeline.""" + generator.pipeline = MockPipeline() + + with patch("trellis2_modal.service.state.pack_state") as mock_pack: + mock_pack.return_value = {} + mock_image = MagicMock() + + generator.generate( + mock_image, + seed=42, + pipeline_type="512", + ss_params={"steps": 8}, + shape_params={"steps": 10}, + tex_params={"steps": 12}, + max_num_tokens=10000, + ) + + call_kwargs = generator.pipeline.run.call_args.kwargs + assert call_kwargs["seed"] == 42 + assert call_kwargs["pipeline_type"] == "512" + assert call_kwargs["sparse_structure_sampler_params"] == {"steps": 8} + assert call_kwargs["shape_slat_sampler_params"] == {"steps": 10} + assert call_kwargs["tex_slat_sampler_params"] == {"steps": 12} + assert call_kwargs["max_num_tokens"] == 10000 + + +@pytest.mark.parametrize( + "pipeline_type", + ["512", "1024", "1024_cascade", "1536_cascade"], +) +def test_generator_accepts_all_pipeline_types(generator, pipeline_type): + """Test all 4 pipeline types are accepted.""" + generator.pipeline = MockPipeline() + + with patch("trellis2_modal.service.state.pack_state") as mock_pack: + mock_pack.return_value = {} + mock_image = MagicMock() + + # Should not raise + generator.generate(mock_image, pipeline_type=pipeline_type) + + call_kwargs = generator.pipeline.run.call_args.kwargs + assert call_kwargs["pipeline_type"] == pipeline_type + + +def test_generator_render_preview_requires_ready(generator): + """Test that render_preview_video fails without full initialization.""" + with pytest.raises(RuntimeError, match="Generator not ready"): + generator.render_preview_video({}) + + +def test_generator_extract_glb_requires_ready(generator): + """Test that extract_glb fails without full initialization.""" + with pytest.raises(RuntimeError, match="Generator not ready"): + generator.extract_glb({}) + + +def test_generator_load_time_initialized(): + """Test load_time starts at 0.""" + from trellis2_modal.service.generator import TRELLIS2Generator + + gen = TRELLIS2Generator() + assert gen.load_time == 0.0 diff --git a/trellis2_modal/tests/test_gpu_snapshots.py b/trellis2_modal/tests/test_gpu_snapshots.py new file mode 100644 index 0000000..d6545d4 --- /dev/null +++ b/trellis2_modal/tests/test_gpu_snapshots.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +""" +GPU Memory Snapshots Cold Start Test. + +Result: GPU snapshots don't help (~143s vs ~146s) due to flex_gemm/Triton reinit. +Kept for future retesting if Modal improves support for Triton-based models. + +Usage: + pip install -r trellis2_modal/requirements-deploy.txt + python -m trellis2_modal.tests.test_gpu_snapshots full +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +import time +from pathlib import Path + + +RESULTS_FILE = Path.home() / ".trellis2_gpu_snapshot_results.json" + + +def run_command(cmd: list[str], env: dict | None = None) -> tuple[int, str, str]: + """Run a command and return exit code, stdout, stderr.""" + full_env = os.environ.copy() + if env: + full_env.update(env) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + env=full_env, + ) + return result.returncode, result.stdout, result.stderr + + +def deploy_service(with_snapshot: bool = False) -> bool: + """Deploy the TRELLIS.2 service with or without GPU snapshots.""" + env = {"GPU_MEMORY_SNAPSHOT": "true" if with_snapshot else "false"} + snapshot_str = "enabled" if with_snapshot else "disabled" + + print(f"\n{'='*60}") + print(f"Deploying TRELLIS.2 with GPU Memory Snapshots {snapshot_str}") + print(f"{'='*60}\n") + + # Force rebuild by adding a timestamp to ensure fresh snapshot + code, stdout, stderr = run_command( + ["modal", "deploy", "-m", "trellis2_modal.service.service"], + env=env, + ) + + if code != 0: + print(f"Deploy failed (exit {code}):") + print(stderr) + return False + + print("Deploy successful") + return True + + +def stop_containers() -> bool: + """Stop all running containers to force cold start.""" + print("\nStopping all containers...") + + code, stdout, stderr = run_command( + ["modal", "app", "stop", "trellis2-3d", "--force"], + ) + + # It's okay if there are no containers to stop + if code != 0 and "not found" not in stderr.lower(): + print(f"Warning: stop command returned {code}") + + # Wait a bit for containers to fully stop + time.sleep(5) + print("Containers stopped") + return True + + +def measure_cold_start() -> tuple[float, dict]: + """ + Measure cold start time by calling health_check on the GPU service. + + Returns: + Tuple of (cold_start_seconds, health_check_result) + """ + print("\nMeasuring cold start time...") + print("(This may take 2-3 minutes for model loading)") + + start = time.perf_counter() + + # Use modal run to call the health_check method + code, stdout, stderr = run_command( + ["modal", "run", "-m", "trellis2_modal.service.service"], + env=None, + ) + + elapsed = time.perf_counter() - start + + if code != 0: + print(f"Cold start test failed (exit {code}):") + print(stderr) + return elapsed, {"error": stderr} + + # Parse the output to get health check result + try: + # The test_service entrypoint prints JSON result + lines = stdout.strip().split("\n") + for line in lines: + if "Result:" in line: + # Next line should be JSON + idx = lines.index(line) + if idx + 1 < len(lines): + result = json.loads(lines[idx + 1].strip()) + return elapsed, result + + # Fallback: try to extract timing from output + result = {"output": stdout} + return elapsed, result + except Exception as e: + return elapsed, {"error": str(e), "output": stdout} + + +def run_baseline_test() -> dict: + """Run baseline test without GPU snapshots.""" + print("\n" + "="*60) + print("BASELINE TEST (No GPU Snapshots)") + print("="*60) + + result = { + "test": "baseline", + "gpu_snapshot": False, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + } + + # Deploy without snapshots + if not deploy_service(with_snapshot=False): + result["error"] = "Deploy failed" + return result + + # Stop containers to force cold start + stop_containers() + + # Measure cold start + cold_start_time, health_result = measure_cold_start() + + result["cold_start_seconds"] = cold_start_time + result["health_result"] = health_result + + print(f"\n{'='*60}") + print(f"BASELINE RESULT: {cold_start_time:.1f}s cold start") + print(f"{'='*60}") + + # Save result + save_result("baseline", result) + + return result + + +def run_snapshot_test() -> dict: + """Run test with GPU Memory Snapshots enabled.""" + print("\n" + "="*60) + print("GPU SNAPSHOT TEST (Enabled)") + print("="*60) + + result = { + "test": "snapshot", + "gpu_snapshot": True, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + } + + # Deploy with snapshots enabled + if not deploy_service(with_snapshot=True): + result["error"] = "Deploy failed - check if GPU snapshots cause issues" + return result + + print("\nFirst run creates the snapshot (may take longer)...") + print("Waiting 60s for snapshot to be created...") + time.sleep(60) + + # Stop containers to force cold start from snapshot + stop_containers() + + # Measure cold start from snapshot + cold_start_time, health_result = measure_cold_start() + + result["cold_start_seconds"] = cold_start_time + result["health_result"] = health_result + + print(f"\n{'='*60}") + print(f"SNAPSHOT RESULT: {cold_start_time:.1f}s cold start") + print(f"{'='*60}") + + # Save result + save_result("snapshot", result) + + return result + + +def save_result(test_name: str, result: dict) -> None: + """Save test result to file.""" + results = {} + if RESULTS_FILE.exists(): + try: + results = json.loads(RESULTS_FILE.read_text()) + except json.JSONDecodeError: + pass + + results[test_name] = result + RESULTS_FILE.write_text(json.dumps(results, indent=2)) + print(f"Result saved to {RESULTS_FILE}") + + +def compare_results() -> None: + """Compare baseline and snapshot test results.""" + print("\n" + "="*60) + print("COMPARISON") + print("="*60) + + if not RESULTS_FILE.exists(): + print("No results found. Run baseline and snapshot tests first.") + return + + try: + results = json.loads(RESULTS_FILE.read_text()) + except json.JSONDecodeError: + print("Invalid results file.") + return + + baseline = results.get("baseline", {}) + snapshot = results.get("snapshot", {}) + + print("\nResults:") + print("-" * 40) + + if baseline: + baseline_time = baseline.get("cold_start_seconds", "N/A") + print(f"Baseline (no snapshot): {baseline_time:.1f}s" if isinstance(baseline_time, (int, float)) else f"Baseline: {baseline_time}") + else: + print("Baseline: Not run yet") + + if snapshot: + snapshot_time = snapshot.get("cold_start_seconds", "N/A") + print(f"With GPU Snapshot: {snapshot_time:.1f}s" if isinstance(snapshot_time, (int, float)) else f"Snapshot: {snapshot_time}") + else: + print("Snapshot: Not run yet") + + # Calculate improvement + if baseline and snapshot: + baseline_time = baseline.get("cold_start_seconds") + snapshot_time = snapshot.get("cold_start_seconds") + + if isinstance(baseline_time, (int, float)) and isinstance(snapshot_time, (int, float)): + improvement = baseline_time - snapshot_time + percent = (improvement / baseline_time) * 100 if baseline_time > 0 else 0 + + print("-" * 40) + print(f"Improvement: {improvement:.1f}s ({percent:.1f}%)") + + if percent > 30: + print("\n✅ GPU Memory Snapshots provide significant benefit!") + elif percent > 10: + print("\n⚡ GPU Memory Snapshots provide moderate benefit.") + elif percent > 0: + print("\n➡️ GPU Memory Snapshots provide minimal benefit.") + else: + print("\n❌ GPU Memory Snapshots don't help (or failed).") + + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Test GPU Memory Snapshots for TRELLIS.2 Modal service" + ) + parser.add_argument( + "command", + choices=["baseline", "snapshot", "compare", "full"], + help="Test command: baseline (no snapshots), snapshot (with), compare, full (all)", + ) + + args = parser.parse_args() + + if args.command == "baseline": + run_baseline_test() + elif args.command == "snapshot": + run_snapshot_test() + elif args.command == "compare": + compare_results() + elif args.command == "full": + run_baseline_test() + run_snapshot_test() + compare_results() + + +if __name__ == "__main__": + main() diff --git a/trellis2_modal/tests/test_service.py b/trellis2_modal/tests/test_service.py new file mode 100644 index 0000000..08ec563 --- /dev/null +++ b/trellis2_modal/tests/test_service.py @@ -0,0 +1,332 @@ +""" +Tests for TRELLIS2Service request parsing and validation. + +These tests verify the request parsing and validation logic without +requiring GPU or the full TRELLIS.2 pipeline. They use mocks for +external dependencies. +""" + +import base64 +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest +from PIL import Image + + +class TestParseGenerateRequest: + """Tests for _parse_generate_request function.""" + + @pytest.fixture + def valid_image_b64(self): + """Create a valid base64-encoded PNG image.""" + img = Image.new("RGB", (100, 100), color="red") + buffer = BytesIO() + img.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + def test_missing_image_returns_error(self): + """Missing 'image' field returns validation error.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request({"seed": 42}) + assert "error" in result + assert result["error"]["code"] == "validation_error" + assert "image" in result["error"]["message"].lower() + + def test_invalid_json_type_returns_error(self): + """Non-dict request returns validation error.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request("not a dict") + assert "error" in result + assert result["error"]["code"] == "validation_error" + + def test_invalid_base64_returns_error(self): + """Invalid base64 returns validation error.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request({"image": "not-valid-base64!!!"}) + assert "error" in result + assert result["error"]["code"] == "validation_error" + + def test_corrupt_image_returns_error(self): + """Valid base64 but corrupt image data returns error.""" + from trellis2_modal.service.service import _parse_generate_request + + garbage = base64.b64encode(b"not an image").decode("utf-8") + result = _parse_generate_request({"image": garbage}) + assert "error" in result + assert result["error"]["code"] == "validation_error" + + def test_valid_image_returns_params(self, valid_image_b64): + """Valid image returns GenerateParams dataclass.""" + from trellis2_modal.service.service import ( + GenerateParams, + _parse_generate_request, + ) + + result = _parse_generate_request({"image": valid_image_b64}) + assert isinstance(result, GenerateParams) + assert result.image.size == (100, 100) + + def test_default_seed(self, valid_image_b64): + """Default seed is 42.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request({"image": valid_image_b64}) + assert result.seed == 42 + + def test_custom_seed(self, valid_image_b64): + """Custom seed is preserved.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request({"image": valid_image_b64, "seed": 123}) + assert result.seed == 123 + + def test_default_pipeline_type(self, valid_image_b64): + """Default pipeline_type is 1024_cascade.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request({"image": valid_image_b64}) + assert result.pipeline_type == "1024_cascade" + + @pytest.mark.parametrize( + "pipeline_type", ["512", "1024", "1024_cascade", "1536_cascade"] + ) + def test_valid_pipeline_types(self, valid_image_b64, pipeline_type): + """All 4 pipeline types are accepted.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request( + {"image": valid_image_b64, "pipeline_type": pipeline_type} + ) + assert result.pipeline_type == pipeline_type + + def test_invalid_pipeline_type_returns_error(self, valid_image_b64): + """Invalid pipeline_type returns validation error.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request( + {"image": valid_image_b64, "pipeline_type": "invalid"} + ) + assert "error" in result + assert result["error"]["code"] == "validation_error" + assert "pipeline_type" in result["error"]["message"] + + def test_default_sampler_params(self, valid_image_b64): + """Sampler params have correct defaults.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request({"image": valid_image_b64}) + assert result.ss_sampling_steps == 12 + assert result.ss_guidance_strength == 7.5 + assert result.shape_slat_sampling_steps == 12 + assert result.shape_slat_guidance_strength == 7.5 + assert result.tex_slat_sampling_steps == 12 + assert result.tex_slat_guidance_strength == 1.0 + + def test_custom_sampler_params(self, valid_image_b64): + """Custom sampler params are preserved.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request( + { + "image": valid_image_b64, + "ss_sampling_steps": 8, + "ss_guidance_strength": 5.0, + "shape_slat_sampling_steps": 10, + "shape_slat_guidance_strength": 6.0, + "tex_slat_sampling_steps": 6, + "tex_slat_guidance_strength": 2.0, + } + ) + assert result.ss_sampling_steps == 8 + assert result.ss_guidance_strength == 5.0 + assert result.shape_slat_sampling_steps == 10 + assert result.shape_slat_guidance_strength == 6.0 + assert result.tex_slat_sampling_steps == 6 + assert result.tex_slat_guidance_strength == 2.0 + + def test_invalid_numeric_param_returns_error(self, valid_image_b64): + """Invalid numeric param returns validation error.""" + from trellis2_modal.service.service import _parse_generate_request + + result = _parse_generate_request( + { + "image": valid_image_b64, + "seed": "not a number", + } + ) + assert "error" in result + assert result["error"]["code"] == "validation_error" + + +class TestParseExtractGLBRequest: + """Tests for _parse_extract_glb_request function.""" + + @pytest.fixture + def valid_state_b64(self): + """Create a valid base64-encoded compressed state.""" + from trellis2_modal.client.compression import compress_state + import numpy as np + + state = { + "vertices": np.array([[0, 0, 0]], dtype=np.float32), + "faces": np.array([[0, 0, 0]], dtype=np.int32), + } + compressed = compress_state(state) + return base64.b64encode(compressed).decode("utf-8") + + def test_missing_state_returns_error(self): + """Missing 'state' field returns validation error.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request({}) + assert "error" in result + assert result["error"]["code"] == "validation_error" + assert "state" in result["error"]["message"].lower() + + def test_invalid_json_type_returns_error(self): + """Non-dict request returns validation error.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request("not a dict") + assert "error" in result + assert result["error"]["code"] == "validation_error" + + def test_invalid_state_returns_error(self): + """Invalid state data returns validation error.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + garbage = base64.b64encode(b"not valid state").decode("utf-8") + result = _parse_extract_glb_request({"state": garbage}) + assert "error" in result + assert result["error"]["code"] == "validation_error" + + def test_valid_state_returns_params(self, valid_state_b64): + """Valid state returns ExtractGLBParams dataclass.""" + from trellis2_modal.service.service import ( + ExtractGLBParams, + _parse_extract_glb_request, + ) + + result = _parse_extract_glb_request({"state": valid_state_b64}) + assert isinstance(result, ExtractGLBParams) + assert "vertices" in result.state + + def test_default_decimation_target(self, valid_state_b64): + """Default decimation_target is 1000000.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request({"state": valid_state_b64}) + assert result.decimation_target == 1000000 + + def test_default_texture_size(self, valid_state_b64): + """Default texture_size is 4096.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request({"state": valid_state_b64}) + assert result.texture_size == 4096 + + def test_default_remesh(self, valid_state_b64): + """Default remesh is True.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request({"state": valid_state_b64}) + assert result.remesh is True + + def test_custom_decimation_target(self, valid_state_b64): + """Custom decimation_target is preserved.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request( + { + "state": valid_state_b64, + "decimation_target": 500000, + } + ) + assert result.decimation_target == 500000 + + def test_decimation_target_too_small_returns_error(self, valid_state_b64): + """decimation_target < 1000 returns error.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request( + { + "state": valid_state_b64, + "decimation_target": 500, + } + ) + assert "error" in result + assert result["error"]["code"] == "validation_error" + + @pytest.mark.parametrize("texture_size", [512, 1024, 2048, 4096]) + def test_valid_texture_sizes(self, valid_state_b64, texture_size): + """Valid texture sizes are accepted.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request( + { + "state": valid_state_b64, + "texture_size": texture_size, + } + ) + assert result.texture_size == texture_size + + def test_invalid_texture_size_returns_error(self, valid_state_b64): + """Invalid texture_size returns error.""" + from trellis2_modal.service.service import _parse_extract_glb_request + + result = _parse_extract_glb_request( + { + "state": valid_state_b64, + "texture_size": 999, + } + ) + assert "error" in result + assert result["error"]["code"] == "validation_error" + +# Note: Authentication is now handled by Modal Proxy Auth at the proxy level. +# No _check_auth_or_error tests needed since requests that reach endpoints are pre-authenticated. + + +class TestGenerateRequestId: + """Tests for generate_request_id function.""" + + def test_returns_string(self): + """Returns a string.""" + from trellis2_modal.service.service import generate_request_id + + result = generate_request_id() + assert isinstance(result, str) + + def test_starts_with_req(self): + """Starts with 'req_' prefix.""" + from trellis2_modal.service.service import generate_request_id + + result = generate_request_id() + assert result.startswith("req_") + + def test_unique_ids(self): + """Generated IDs are unique.""" + from trellis2_modal.service.service import generate_request_id + + ids = {generate_request_id() for _ in range(100)} + assert len(ids) == 100 + + +class TestErrorResponse: + """Tests for _error_response function.""" + + def test_returns_correct_structure(self): + """Returns dict with error.code and error.message.""" + from trellis2_modal.service.service import _error_response + + result = _error_response("test_code", "test message") + assert result == { + "error": { + "code": "test_code", + "message": "test message", + } + } diff --git a/trellis2_modal/tests/test_state.py b/trellis2_modal/tests/test_state.py new file mode 100644 index 0000000..bad0364 --- /dev/null +++ b/trellis2_modal/tests/test_state.py @@ -0,0 +1,143 @@ +""" +Tests for state serialization. + +These tests verify pack_state and unpack_state work correctly +without requiring GPU or the full TRELLIS.2 pipeline. +""" + +import json + +import numpy as np +import pytest + + +class MockMeshWithVoxel: + """Mock MeshWithVoxel for testing without GPU dependencies.""" + + def __init__( + self, + vertices, + faces, + origin, + voxel_size, + coords, + attrs, + voxel_shape, + layout, + ): + self.vertices = vertices + self.faces = faces + self.origin = origin + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + +class MockTensor: + """Mock tensor that mimics torch.Tensor for CPU operations.""" + + def __init__(self, data, device="cpu"): + self._data = np.array(data) + self._device = device + + def cpu(self): + return MockTensor(self._data, device="cpu") + + def numpy(self): + return self._data + + def tolist(self): + return self._data.tolist() + + +@pytest.fixture +def mock_mesh(): + """Create a mock MeshWithVoxel for testing.""" + return MockMeshWithVoxel( + vertices=MockTensor(np.random.randn(100, 3).astype(np.float32)), + faces=MockTensor(np.random.randint(0, 100, (50, 3)).astype(np.int32)), + origin=MockTensor([-0.5, -0.5, -0.5]), + voxel_size=1 / 512, + coords=MockTensor(np.random.randint(0, 512, (200, 3))), + attrs=MockTensor(np.random.randn(200, 6).astype(np.float32)), + voxel_shape=(1, 6, 512, 512, 512), + layout={ + "base_color": slice(0, 3), + "metallic": slice(3, 4), + "roughness": slice(4, 5), + "alpha": slice(5, 6), + }, + ) + + +def test_pack_state_contains_all_required_fields(mock_mesh): + """Verify packed state has all MeshWithVoxel fields.""" + from trellis2_modal.service.state import pack_state + + state = pack_state(mock_mesh) + + required_fields = [ + "vertices", + "faces", + "attrs", + "coords", + "voxel_size", + "voxel_shape", + "origin", + "layout", + ] + for field in required_fields: + assert field in state, f"Missing field: {field}" + + +def test_pack_state_converts_tensors_to_numpy(mock_mesh): + """Verify tensors become numpy arrays.""" + from trellis2_modal.service.state import pack_state + + state = pack_state(mock_mesh) + + assert isinstance(state["vertices"], np.ndarray) + assert isinstance(state["faces"], np.ndarray) + assert isinstance(state["attrs"], np.ndarray) + assert isinstance(state["coords"], np.ndarray) + + +def test_pack_state_layout_is_json_serializable(mock_mesh): + """Verify layout uses lists not slice objects.""" + from trellis2_modal.service.state import pack_state + + state = pack_state(mock_mesh) + + # Should not raise - layout must be JSON serializable + json.dumps(state["layout"]) + + # Verify structure + assert state["layout"]["base_color"] == [0, 3] + assert state["layout"]["metallic"] == [3, 4] + assert state["layout"]["roughness"] == [4, 5] + assert state["layout"]["alpha"] == [5, 6] + + +def test_pack_state_preserves_shapes(mock_mesh): + """Verify array shapes are preserved.""" + from trellis2_modal.service.state import pack_state + + state = pack_state(mock_mesh) + + assert state["vertices"].shape == (100, 3) + assert state["faces"].shape == (50, 3) + assert state["attrs"].shape == (200, 6) + assert state["coords"].shape == (200, 3) + + +def test_pack_state_preserves_scalar_values(mock_mesh): + """Verify scalar values are preserved.""" + from trellis2_modal.service.state import pack_state + + state = pack_state(mock_mesh) + + assert state["voxel_size"] == pytest.approx(1 / 512) + assert state["voxel_shape"] == [1, 6, 512, 512, 512] + assert state["origin"] == [-0.5, -0.5, -0.5] From cabdb9dcbd7c8938ed502e33bed1f2451070065e Mon Sep 17 00:00:00 2001 From: Daniel Nouri Date: Fri, 26 Dec 2025 20:00:30 +0100 Subject: [PATCH 3/3] docs: improve trellis2_modal documentation - Add input image guidelines and preparation tips - Add object-type specific parameter tuning (hard-surface vs organic) - Document all sampler parameters with meanings and ranges - Improve troubleshooting with symptom-to-cause correlations - Generalize use-case presets (remove platform-specific focus) - Add known limitations and PBR material descriptions --- trellis2_modal/README.md | 174 +++++++++++++++++++++-- trellis2_modal/docs/MODAL_INTEGRATION.md | 136 +++++++++++++++--- 2 files changed, 275 insertions(+), 35 deletions(-) diff --git a/trellis2_modal/README.md b/trellis2_modal/README.md index 1ed8d9d..2801a60 100644 --- a/trellis2_modal/README.md +++ b/trellis2_modal/README.md @@ -7,6 +7,16 @@ Deploy TRELLIS.2 to Modal's serverless GPU infrastructure. Generate 3D assets fr - Pay only for what you use (~$0.05 per generation) - Access via web UI or API from any device +## What You Get + +TRELLIS.2 generates 3D meshes with full PBR (Physically-Based Rendering) materials: +- **Base Color** - diffuse texture +- **Metallic** - metal vs non-metal surfaces +- **Roughness** - shiny vs matte appearance +- **Opacity** - transparency and translucency support + +The model uses DINOv2 vision features to infer geometry even for occluded areas not visible in the input image. It handles complex topology including open surfaces (clothing, leaves), non-manifold geometry, and internal structures. + ## Prerequisites Before starting, you need accounts and tokens from three services: @@ -80,8 +90,30 @@ Open http://localhost:7860, upload an image, and click Generate. > **First request takes 2-3 minutes** (cold start). Subsequent requests take 10-90 seconds depending on resolution. +## Input Image Guidelines + +Quality of results depends heavily on input image preparation: + +**Optimal inputs:** +- Clean, well-lit subject against neutral/solid background +- Object centered in frame with minimal perspective distortion +- Resolution at least 512×512 (1024×1024 recommended) +- Clear separation between subject and background +- Consistent lighting without harsh shadows + +**Problematic inputs:** +- Cluttered backgrounds (confuses edge detection) +- Extreme perspective angles or fish-eye distortion +- Low resolution or heavily compressed images +- Transparent or highly reflective surfaces +- Multiple overlapping objects + +For product photography, a simple white or gray background significantly improves results. + ## API Usage +### Basic Example + ```python from trellis2_modal.client import TRELLIS2APIClient @@ -92,19 +124,140 @@ client = TRELLIS2APIClient( # Generate 3D from image result = client.generate(image_path="input.png", pipeline_type="1024_cascade") -# Extract GLB mesh +# Extract GLB mesh (high quality) client.extract_glb(state=result["state"], output_path="output.glb") ``` Credentials are loaded automatically from `~/.trellis2_modal_secrets.json` or environment variables. -## Pipeline Options +### Export for Different Platforms + +Adjust `decimation_target` and `texture_size` based on your target platform: + +```python +# High-quality for rendering, web viewers, Sketchfab +client.extract_glb( + state=result["state"], + output_path="high_quality.glb", + decimation_target=100000, + texture_size=2048, +) + +# Game engines (Unity, Unreal, Godot) +client.extract_glb( + state=result["state"], + output_path="game_asset.glb", + decimation_target=10000, + texture_size=1024, +) + +# Mobile games, AR apps +client.extract_glb( + state=result["state"], + output_path="mobile_asset.glb", + decimation_target=5000, + texture_size=512, +) +``` + +## Generation Parameters + +### Pipeline Types + +Controls the output resolution. Higher resolution = more detail but slower. | Pipeline | Resolution | Time | Use Case | |----------|------------|------|----------| -| `512` | 512³ | ~10s | Quick preview | -| `1024_cascade` | 1024³ | ~30s | Recommended | -| `1536_cascade` | 1536³ | ~90s | Maximum quality | +| `512` | 512³ voxels | ~10s | Quick preview, iteration | +| `1024_cascade` | 1024³ voxels | ~30s | **Recommended** for most uses | +| `1536_cascade` | 1536³ voxels | ~90s | Maximum quality, hero assets | + +### Quality Tuning + +TRELLIS.2 generates 3D in three stages, each with tunable parameters: + +| Stage | What it does | Parameters | +|-------|--------------|------------| +| **Sparse Structure** | Creates initial 3D shape from image | `ss_sampling_steps`, `ss_guidance_strength` | +| **Shape Refinement** | Adds geometric detail | `shape_slat_sampling_steps`, `shape_slat_guidance_strength` | +| **Texture Generation** | Applies PBR materials | `tex_slat_sampling_steps`, `tex_slat_guidance_strength` | + +**Parameter meanings:** + +- **`*_sampling_steps`** (default: 12): Number of denoising iterations. Higher = better quality but slower. Range: 8-20 is reasonable; beyond 20 has diminishing returns. Each additional step adds ~1-2 seconds. + +- **`*_guidance_strength`**: How closely to follow the input image. + - Shape stages default to **7.5** — produces faithful reconstructions + - Texture stage defaults to **1.0** — allows natural material variation + - Higher values (10-15) = more literal interpretation, risk of artifacts + - Lower values (1-3) = more creative freedom, may drift from input + +### Tuning by Object Type + +| Object Type | SS Guidance | Shape Guidance | Notes | +|-------------|-------------|----------------|-------| +| **Hard-surface** (furniture, vehicles, architecture) | 8-9 | 7.5 | Stricter geometric adherence | +| **Organic** (characters, plants, fabric) | 7.5 | 7.5 | Default values work well | +| **Ambiguous shapes** | 5-7 | 5-7 | Lower guidance for coherence | + +**Example with custom parameters:** + +```python +result = client.generate( + image_path="input.png", + pipeline_type="1024_cascade", + seed=42, # For reproducibility + # Increase steps for higher quality (slower) + ss_sampling_steps=16, + shape_slat_sampling_steps=16, + tex_slat_sampling_steps=16, + # Adjust guidance based on object type + ss_guidance_strength=8.0, # Higher for hard-surface + shape_slat_guidance_strength=7.5, + tex_slat_guidance_strength=1.0, +) +``` + +**Tip**: Use the `seed` parameter for reproducibility. When you find settings that work well, record the seed to generate consistent results. + +## GLB Extraction Options + +After generation, extract a mesh with these options: + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `decimation_target` | 1,000,000 | Target triangle count. Lower = smaller file, less detail. | +| `texture_size` | 4096 | Texture resolution: 512, 1024, 2048, or 4096 | +| `remesh` | True | Clean up mesh topology (recommended) | + +### Recommended Presets by Use Case + +| Use Case | decimation_target | texture_size | File Size | +|----------|-------------------|--------------|-----------| +| **Maximum Quality** | 1,000,000 | 4096 | ~30MB | +| **Web Viewers / Sketchfab** | 100,000 | 2048 | ~5MB | +| **Game Engines** | 10,000 | 1024 | ~1MB | +| **Mobile / AR** | 5,000 | 512 | ~300KB | +| **Low-poly Stylized** | 2,000 | 512 | ~150KB | + +**Note**: Aggressive decimation may collapse fine details like fingers, facial features, or thin structural elements. Test with your specific asset types. + +## Troubleshooting + +| Symptom | Likely Cause | Solution | +|---------|--------------|----------| +| Missing geometry / holes | Poor background separation in input | Use cleaner background, better lighting | +| Distorted mesh | Perspective distortion in input | Re-photograph with less distortion | +| Low-quality textures | Low input resolution or small texture_size | Use higher resolution input (1024×1024+), increase texture_size | +| 403 from HuggingFace | Model licenses not accepted | Accept all three model licenses (links above) | +| Cold start every time | Container scaling to zero | Ping health endpoint every 4 minutes to keep warm | +| CUDA OOM | Insufficient GPU memory | Use lower resolution pipeline or reduce texture_size | + +## Known Limitations + +- **Small holes in geometry**: Generated meshes may occasionally have minor topological artifacts. For watertight meshes (e.g., 3D printing), post-processing in Blender may be needed. +- **Style variation**: This is a base model without aesthetic fine-tuning. Results reflect training data distribution. +- **Opacity in GLB**: Transparency is preserved in texture alpha channel but may need manual material setup in some applications. ## Cost @@ -116,18 +269,9 @@ Credentials are loaded automatically from `~/.trellis2_modal_secrets.json` or en Containers stay warm for 5 minutes between requests (no cold start cost for sequential generations). -## Troubleshooting - -**403 from HuggingFace**: Accept all three model licenses (links above). - -**Cold start every time**: Containers scale to zero after 5 minutes idle. Use a cron job to ping the health endpoint every 4 minutes to keep warm. - -**CUDA OOM**: Use a lower resolution pipeline or reduce GLB texture_size. - ## Full Documentation See [docs/MODAL_INTEGRATION.md](docs/MODAL_INTEGRATION.md) for: -- Detailed configuration options +- Complete API reference - Operations runbook - Cold start optimization strategies -- Complete API reference diff --git a/trellis2_modal/docs/MODAL_INTEGRATION.md b/trellis2_modal/docs/MODAL_INTEGRATION.md index 42bec9e..de61310 100644 --- a/trellis2_modal/docs/MODAL_INTEGRATION.md +++ b/trellis2_modal/docs/MODAL_INTEGRATION.md @@ -89,7 +89,7 @@ Expected output: ### 2. Deploy the Service -Deploy to Modal with memory snapshots enabled: +Deploy to Modal: ```bash modal deploy -m trellis2_modal.service.service @@ -179,35 +179,139 @@ Open http://localhost:7860 in your browser. from trellis2_modal.client import TRELLIS2APIClient # Credentials loaded automatically from env vars or ~/.trellis2_modal_secrets.json -client = TRELLIS2APIClient(base_url="https://your-app.modal.run") +client = TRELLIS2APIClient(base_url="https://your-app--generate.modal.run") # Or provide credentials explicitly: # client = TRELLIS2APIClient( -# base_url="https://your-app.modal.run", +# base_url="https://your-app--generate.modal.run", # modal_key="wk-xxxxx", # modal_secret="ws-xxxxx", # ) -# Generate 3D from image -result = client.generate( - image_path="input.png", - seed=42, - pipeline_type="1024_cascade", # Options: 512, 1024, 1024_cascade, 1536_cascade -) +# Generate 3D from image (uses defaults) +result = client.generate(image_path="input.png") # result contains: # - state: compressed state for GLB extraction # - video: base64-encoded MP4 preview -# Extract GLB +# Extract GLB with default settings (high quality) client.extract_glb( state=result["state"], output_path="output.glb", - decimation_target=500000, - texture_size=2048, +) + +# Or for game engines with polygon limits: +client.extract_glb( + state=result["state"], + output_path="game_asset.glb", + decimation_target=10000, # ~10K triangles + texture_size=1024, ) ``` +## Complete API Reference + +### `client.generate()` Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `image_path` | str | *required* | Path to input image (PNG, JPEG) | +| `seed` | int | 42 | Random seed for reproducibility | +| `pipeline_type` | str | "1024_cascade" | Resolution pipeline (see below) | +| `ss_sampling_steps` | int | 12 | Sparse structure denoising steps | +| `ss_guidance_strength` | float | 7.5 | Sparse structure guidance | +| `shape_slat_sampling_steps` | int | 12 | Shape refinement denoising steps | +| `shape_slat_guidance_strength` | float | 7.5 | Shape refinement guidance | +| `tex_slat_sampling_steps` | int | 12 | Texture generation denoising steps | +| `tex_slat_guidance_strength` | float | 1.0 | Texture generation guidance | + +#### Understanding the Three-Stage Pipeline + +TRELLIS.2 generates 3D models in three stages: + +1. **Sparse Structure (SS)**: Creates the initial coarse 3D voxel structure from the input image. This determines the overall shape and proportions. + +2. **Shape SLAT (Structured Latent)**: Refines the geometry to the target resolution, adding fine geometric details like edges, corners, and surface features. + +3. **Texture SLAT**: Generates PBR (Physically-Based Rendering) materials including base color, metallic, roughness, and opacity. + +#### Input Image Guidelines + +Quality of results depends heavily on input image preparation: + +**Optimal inputs:** +- Clean, well-lit subject against neutral/solid background +- Object centered in frame with minimal perspective distortion +- Resolution at least 512×512 (1024×1024 recommended) +- Clear separation between subject and background +- Consistent lighting without harsh shadows or specular highlights + +**Problematic inputs:** +- Cluttered backgrounds (confuses edge detection) +- Extreme perspective angles or fish-eye distortion +- Low resolution or heavily compressed images +- Transparent or highly reflective surfaces +- Multiple overlapping objects + +#### Parameter Tuning Guide + +**Sampling Steps** (`*_sampling_steps`): +- Controls quality vs speed tradeoff +- Default of 12 is well-balanced +- 8 steps: Faster but may have artifacts +- 16-20 steps: Higher quality, diminishing returns beyond +- Each additional step adds ~1-2 seconds + +**Guidance Strength** (`*_guidance_strength`): +- Controls how closely output follows the input image +- **Shape stages (7.5 default)**: Higher values produce more faithful reconstructions + - 5.0-7.5: Good balance of accuracy and natural appearance + - 10.0-15.0: Very literal interpretation, may cause rigidity + - Below 3.0: More creative but may drift from input +- **Texture stage (1.0 default)**: Lower values allow natural material variation + - 1.0: Natural-looking materials + - 3.0-5.0: More constrained to input colors + - Higher: May produce flat, unnatural textures + +#### Tuning by Object Type + +| Object Type | SS Guidance | Shape Guidance | Notes | +|-------------|-------------|----------------|-------| +| **Hard-surface** (furniture, vehicles, architecture) | 8-9 | 7.5 | Stricter geometric adherence | +| **Organic** (characters, plants, fabric) | 7.5 | 7.5 | Default values work well | +| **Ambiguous shapes** | 5-7 | 5-7 | Lower guidance for coherence | + +**Tip**: Use the `seed` parameter for reproducibility. When you find settings that work well for a particular object type, record the seed to generate consistent results across similar inputs. + +### `client.extract_glb()` Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `state` | str | *required* | Base64 state from `generate()` | +| `output_path` | str | *required* | Where to save the GLB file | +| `decimation_target` | int | 1,000,000 | Target triangle count | +| `texture_size` | int | 4096 | Texture resolution (512/1024/2048/4096) | +| `remesh` | bool | True | Clean up mesh topology | +| `remesh_band` | float | 1.0 | Remesh band size | +| `remesh_project` | float | 0.0 | Remesh projection factor | + +### GLB Extraction Presets + +| Use Case | decimation_target | texture_size | Approx Size | +|----------|-------------------|--------------|-------------| +| **Maximum Quality** | 1,000,000 | 4096 | ~30MB | +| **Web Viewers / Sketchfab** | 100,000 | 2048 | ~5MB | +| **Game Engines (Unity/Unreal/Godot)** | 10,000 | 1024 | ~1MB | +| **Mobile / AR** | 5,000 | 512 | ~300KB | +| **Low-poly / Stylized** | 2,000 | 512 | ~150KB | + +**Platform-specific notes**: +- **Unity/Unreal**: GLB imports directly; may need material adjustment for opacity +- **Godot**: Native GLB support with PBR materials +- **Web (Three.js, Babylon.js)**: Use Web/Viewer preset for good balance +- **Some platforms require FBX**: Convert via Blender if GLB not supported + ## Configuration ### GPU Type @@ -227,14 +331,6 @@ The A100-80GB provides: | `1024_cascade` | 512→1024 | ~30s | High quality (recommended) | | `1536_cascade` | 512→1536 | ~90s | Maximum quality | -### GLB Extraction Presets - -| Preset | decimation_target | texture_size | File Size | -|--------|-------------------|--------------|-----------| -| Quality | 1,000,000 | 4096 | ~30MB | -| Balanced | 500,000 | 2048 | ~15MB | -| Fast | 100,000 | 1024 | ~5MB | - ## Cold Starts TRELLIS.2 has a cold start time of approximately **2-2.5 minutes** on A100-80GB: