|
1 | | -from typing import Literal, Optional, List |
| 1 | +from typing import Literal, Optional, List, Generic, TypeVar |
2 | 2 | from pydantic import BaseModel, Field, ConfigDict |
3 | 3 | from .errors import ModelNotFoundError |
4 | 4 | from .types import FileInput, MotionTrajectoryInput |
|
17 | 17 | ImageModels = Literal["lucy-pro-t2i", "lucy-pro-i2i"] |
18 | 18 | Model = Literal[RealTimeModels, VideoModels, ImageModels] |
19 | 19 |
|
| 20 | +# Type variable for model name |
| 21 | +ModelT = TypeVar("ModelT", bound=str) |
| 22 | + |
20 | 23 |
|
21 | 24 | class DecartBaseModel(BaseModel): |
22 | 25 | model_config = ConfigDict(arbitrary_types_allowed=True) |
23 | 26 |
|
24 | 27 |
|
25 | | -class ModelDefinition(DecartBaseModel): |
26 | | - name: str |
| 28 | +class ModelDefinition(DecartBaseModel, Generic[ModelT]): |
| 29 | + name: ModelT |
27 | 30 | url_path: str |
28 | 31 | fps: int = Field(ge=1) |
29 | 32 | width: int = Field(ge=1) |
30 | 33 | height: int = Field(ge=1) |
31 | 34 | input_schema: type[BaseModel] |
32 | 35 |
|
33 | 36 |
|
| 37 | +# Type aliases for model definitions that support specific APIs |
| 38 | +ImageModelDefinition = ModelDefinition[ImageModels] |
| 39 | +"""Type alias for model definitions that support synchronous processing (process API).""" |
| 40 | + |
| 41 | +VideoModelDefinition = ModelDefinition[VideoModels] |
| 42 | +"""Type alias for model definitions that support queue processing (queue API).""" |
| 43 | + |
| 44 | +RealTimeModelDefinition = ModelDefinition[RealTimeModels] |
| 45 | +"""Type alias for model definitions that support realtime streaming.""" |
| 46 | + |
| 47 | + |
34 | 48 | class TextToVideoInput(BaseModel): |
35 | 49 | prompt: str = Field(..., min_length=1, max_length=1000) |
36 | 50 | seed: Optional[int] = None |
@@ -212,23 +226,45 @@ class ImageToImageInput(DecartBaseModel): |
212 | 226 |
|
213 | 227 | class Models: |
214 | 228 | @staticmethod |
215 | | - def realtime(model: RealTimeModels) -> ModelDefinition: |
| 229 | + def realtime(model: RealTimeModels) -> RealTimeModelDefinition: |
| 230 | + """Get a realtime model definition for WebRTC streaming.""" |
216 | 231 | try: |
217 | | - return _MODELS["realtime"][model] |
| 232 | + return _MODELS["realtime"][model] # type: ignore[return-value] |
218 | 233 | except KeyError: |
219 | 234 | raise ModelNotFoundError(model) |
220 | 235 |
|
221 | 236 | @staticmethod |
222 | | - def video(model: VideoModels) -> ModelDefinition: |
| 237 | + def video(model: VideoModels) -> VideoModelDefinition: |
| 238 | + """ |
| 239 | + Get a video model definition. |
| 240 | + Video models only support the queue API. |
| 241 | +
|
| 242 | + Available models: |
| 243 | + - "lucy-pro-t2v" - Text-to-video |
| 244 | + - "lucy-pro-i2v" - Image-to-video |
| 245 | + - "lucy-pro-v2v" - Video-to-video |
| 246 | + - "lucy-pro-flf2v" - First-last-frame-to-video |
| 247 | + - "lucy-dev-i2v" - Image-to-video (Dev quality) |
| 248 | + - "lucy-fast-v2v" - Video-to-video (Fast quality) |
| 249 | + - "lucy-motion" - Image-to-motion-video |
| 250 | + """ |
223 | 251 | try: |
224 | | - return _MODELS["video"][model] |
| 252 | + return _MODELS["video"][model] # type: ignore[return-value] |
225 | 253 | except KeyError: |
226 | 254 | raise ModelNotFoundError(model) |
227 | 255 |
|
228 | 256 | @staticmethod |
229 | | - def image(model: ImageModels) -> ModelDefinition: |
| 257 | + def image(model: ImageModels) -> ImageModelDefinition: |
| 258 | + """ |
| 259 | + Get an image model definition. |
| 260 | + Image models only support the process (sync) API. |
| 261 | +
|
| 262 | + Available models: |
| 263 | + - "lucy-pro-t2i" - Text-to-image |
| 264 | + - "lucy-pro-i2i" - Image-to-image |
| 265 | + """ |
230 | 266 | try: |
231 | | - return _MODELS["image"][model] |
| 267 | + return _MODELS["image"][model] # type: ignore[return-value] |
232 | 268 | except KeyError: |
233 | 269 | raise ModelNotFoundError(model) |
234 | 270 |
|
|
0 commit comments