Skip to content

Commit 69ac0fe

Browse files
feat: add model serialization with input types and constraint types (#103)
* feat: add model serialization with input types and constraint types - Add `type` computed_field to Constraint for JSON serialization - Add `InputType` enum (text, image, video, audio) to core - Add `supported_input_types` computed field to Model (per-capability) - Add `optional_input_types` computed field to Model (from constraints) - Add `get_required_input_types()` and `get_constraint_input_type()` helpers - Use `SerializeAsAny[Constraint]` for proper polymorphic serialization This enables the API to return complete model metadata including: - What input types each capability requires (e.g., text-generation: text) - What optional inputs are accepted via parameters (e.g., reference_images) - Constraint type identifiers for frontend rendering * chore: bump version to 0.3.8
1 parent 91c841c commit 69ac0fe

5 files changed

Lines changed: 123 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "celeste-ai"
3-
version = "0.3.7"
3+
version = "0.3.8"
44
description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface"
55
authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}]
66
readme = "README.md"

src/celeste/constraints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC, abstractmethod
66
from typing import Any, get_args, get_origin
77

8-
from pydantic import BaseModel, Field
8+
from pydantic import BaseModel, Field, computed_field
99

1010
from celeste.artifacts import ImageArtifact
1111
from celeste.exceptions import ConstraintViolationError
@@ -15,6 +15,12 @@
1515
class Constraint(BaseModel, ABC):
1616
"""Base constraint for parameter validation."""
1717

18+
@computed_field # type: ignore[prop-decorator]
19+
@property
20+
def type(self) -> str:
21+
"""Constraint type identifier for serialization."""
22+
return self.__class__.__name__
23+
1824
@abstractmethod
1925
def __call__(self, value: Any) -> Any: # noqa: ANN401
2026
"""Validate value against constraint and return validated value."""

src/celeste/core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ class Capability(StrEnum):
4949
SEARCH = "search"
5050

5151

52+
class InputType(StrEnum):
53+
"""Input types for capabilities."""
54+
55+
TEXT = "text"
56+
IMAGE = "image"
57+
VIDEO = "video"
58+
AUDIO = "audio"
59+
60+
5261
class Parameter(StrEnum):
5362
"""Universal parameters across most capabilities."""
5463

@@ -77,4 +86,4 @@ class UsageField(StrEnum):
7786
CACHE_READ_INPUT_TOKENS = "cache_read_input_tokens"
7887

7988

80-
__all__ = ["Capability", "Parameter", "Provider", "UsageField"]
89+
__all__ = ["Capability", "InputType", "Parameter", "Provider", "UsageField"]

src/celeste/io.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Input and output types for generation operations."""
22

3-
from typing import Any
3+
import inspect
4+
import types
5+
from typing import Any, get_args, get_origin
46

57
from pydantic import BaseModel, Field
68

7-
from celeste.core import Capability
9+
from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact
10+
from celeste.constraints import Constraint
11+
from celeste.core import Capability, InputType
812

913

1014
class Input(BaseModel):
@@ -59,12 +63,88 @@ def get_input_class(capability: Capability) -> type[Input]:
5963
return _inputs[capability]
6064

6165

66+
# Centralized mapping: field type → InputType
67+
INPUT_TYPE_MAPPING: dict[type, InputType] = {
68+
str: InputType.TEXT,
69+
ImageArtifact: InputType.IMAGE,
70+
VideoArtifact: InputType.VIDEO,
71+
AudioArtifact: InputType.AUDIO,
72+
}
73+
74+
75+
def get_required_input_types(capability: Capability) -> set[InputType]:
76+
"""Derive required input types from Input class fields.
77+
78+
Introspects the Input class registered for a capability and returns
79+
the set of InputTypes based on field annotations.
80+
81+
Args:
82+
capability: The capability to get required input types for.
83+
84+
Returns:
85+
Set of InputType values required by the capability's Input class.
86+
"""
87+
input_class = get_input_class(capability)
88+
return {
89+
INPUT_TYPE_MAPPING[field.annotation]
90+
for field in input_class.model_fields.values()
91+
if field.annotation in INPUT_TYPE_MAPPING
92+
}
93+
94+
95+
def _extract_input_type(param_type: type) -> InputType | None:
96+
"""Extract InputType from a type, handling unions and generics.
97+
98+
Args:
99+
param_type: The type annotation to inspect.
100+
101+
Returns:
102+
InputType if found in the type or its nested types, None otherwise.
103+
"""
104+
# Direct match
105+
if param_type in INPUT_TYPE_MAPPING:
106+
return INPUT_TYPE_MAPPING[param_type]
107+
108+
# Handle union types (X | Y) and generics (list[X])
109+
origin = get_origin(param_type)
110+
if origin is types.UnionType or origin is not None:
111+
for arg in get_args(param_type):
112+
result = _extract_input_type(arg)
113+
if result is not None:
114+
return result
115+
116+
return None
117+
118+
119+
def get_constraint_input_type(constraint: Constraint) -> InputType | None:
120+
"""Get InputType from constraint's __call__ signature.
121+
122+
Introspects the constraint's __call__ method to find what artifact type
123+
it accepts, then maps to InputType using INPUT_TYPE_MAPPING.
124+
125+
Args:
126+
constraint: The constraint to inspect.
127+
128+
Returns:
129+
InputType if the constraint accepts a mapped artifact type, None otherwise.
130+
"""
131+
annotations = inspect.get_annotations(constraint.__call__, eval_str=True)
132+
for param_type in annotations.values():
133+
result = _extract_input_type(param_type)
134+
if result is not None:
135+
return result
136+
return None
137+
138+
62139
__all__ = [
140+
"INPUT_TYPE_MAPPING",
63141
"Chunk",
64142
"FinishReason",
65143
"Input",
66144
"Output",
67145
"Usage",
146+
"get_constraint_input_type",
68147
"get_input_class",
148+
"get_required_input_types",
69149
"register_input",
70150
]

src/celeste/models.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Models and model registry for Celeste."""
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, SerializeAsAny, computed_field
44

55
from celeste.constraints import Constraint
6-
from celeste.core import Capability, Provider
6+
from celeste.core import Capability, InputType, Provider
7+
from celeste.io import get_constraint_input_type, get_required_input_types
78

89

910
class Model(BaseModel):
@@ -13,14 +14,33 @@ class Model(BaseModel):
1314
provider: Provider
1415
display_name: str
1516
capabilities: set[Capability] = Field(default_factory=set)
16-
parameter_constraints: dict[str, Constraint] = Field(default_factory=dict)
17+
parameter_constraints: dict[str, SerializeAsAny[Constraint]] = Field(
18+
default_factory=dict
19+
)
1720
streaming: bool = Field(default=False)
1821

1922
@property
2023
def supported_parameters(self) -> set[str]:
2124
"""Compute supported parameter names from parameter_constraints."""
2225
return set(self.parameter_constraints.keys())
2326

27+
@computed_field # type: ignore[prop-decorator]
28+
@property
29+
def supported_input_types(self) -> dict[Capability, set[InputType]]:
30+
"""Input types supported per capability (derived from Input class fields)."""
31+
return {cap: get_required_input_types(cap) for cap in self.capabilities}
32+
33+
@computed_field # type: ignore[prop-decorator]
34+
@property
35+
def optional_input_types(self) -> set[InputType]:
36+
"""Optional input types accepted via parameter_constraints."""
37+
types: set[InputType] = set()
38+
for constraint in self.parameter_constraints.values():
39+
input_type = get_constraint_input_type(constraint)
40+
if input_type is not None:
41+
types.add(input_type)
42+
return types
43+
2444

2545
# Module-level registry mapping (model_id, provider) to model
2646
_models: dict[tuple[str, Provider], Model] = {}

0 commit comments

Comments
 (0)