|
1 | 1 | """Input and output types for generation operations.""" |
2 | 2 |
|
3 | | -from typing import Any |
| 3 | +import inspect |
| 4 | +import types |
| 5 | +from typing import Any, get_args, get_origin |
4 | 6 |
|
5 | 7 | from pydantic import BaseModel, Field |
6 | 8 |
|
7 | | -from celeste.core import Capability |
| 9 | +from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact |
| 10 | +from celeste.constraints import Constraint |
| 11 | +from celeste.core import Capability, InputType |
8 | 12 |
|
9 | 13 |
|
10 | 14 | class Input(BaseModel): |
@@ -59,12 +63,88 @@ def get_input_class(capability: Capability) -> type[Input]: |
59 | 63 | return _inputs[capability] |
60 | 64 |
|
61 | 65 |
|
| 66 | +# Centralized mapping: field type → InputType |
| 67 | +INPUT_TYPE_MAPPING: dict[type, InputType] = { |
| 68 | + str: InputType.TEXT, |
| 69 | + ImageArtifact: InputType.IMAGE, |
| 70 | + VideoArtifact: InputType.VIDEO, |
| 71 | + AudioArtifact: InputType.AUDIO, |
| 72 | +} |
| 73 | + |
| 74 | + |
| 75 | +def get_required_input_types(capability: Capability) -> set[InputType]: |
| 76 | + """Derive required input types from Input class fields. |
| 77 | +
|
| 78 | + Introspects the Input class registered for a capability and returns |
| 79 | + the set of InputTypes based on field annotations. |
| 80 | +
|
| 81 | + Args: |
| 82 | + capability: The capability to get required input types for. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + Set of InputType values required by the capability's Input class. |
| 86 | + """ |
| 87 | + input_class = get_input_class(capability) |
| 88 | + return { |
| 89 | + INPUT_TYPE_MAPPING[field.annotation] |
| 90 | + for field in input_class.model_fields.values() |
| 91 | + if field.annotation in INPUT_TYPE_MAPPING |
| 92 | + } |
| 93 | + |
| 94 | + |
| 95 | +def _extract_input_type(param_type: type) -> InputType | None: |
| 96 | + """Extract InputType from a type, handling unions and generics. |
| 97 | +
|
| 98 | + Args: |
| 99 | + param_type: The type annotation to inspect. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + InputType if found in the type or its nested types, None otherwise. |
| 103 | + """ |
| 104 | + # Direct match |
| 105 | + if param_type in INPUT_TYPE_MAPPING: |
| 106 | + return INPUT_TYPE_MAPPING[param_type] |
| 107 | + |
| 108 | + # Handle union types (X | Y) and generics (list[X]) |
| 109 | + origin = get_origin(param_type) |
| 110 | + if origin is types.UnionType or origin is not None: |
| 111 | + for arg in get_args(param_type): |
| 112 | + result = _extract_input_type(arg) |
| 113 | + if result is not None: |
| 114 | + return result |
| 115 | + |
| 116 | + return None |
| 117 | + |
| 118 | + |
| 119 | +def get_constraint_input_type(constraint: Constraint) -> InputType | None: |
| 120 | + """Get InputType from constraint's __call__ signature. |
| 121 | +
|
| 122 | + Introspects the constraint's __call__ method to find what artifact type |
| 123 | + it accepts, then maps to InputType using INPUT_TYPE_MAPPING. |
| 124 | +
|
| 125 | + Args: |
| 126 | + constraint: The constraint to inspect. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + InputType if the constraint accepts a mapped artifact type, None otherwise. |
| 130 | + """ |
| 131 | + annotations = inspect.get_annotations(constraint.__call__, eval_str=True) |
| 132 | + for param_type in annotations.values(): |
| 133 | + result = _extract_input_type(param_type) |
| 134 | + if result is not None: |
| 135 | + return result |
| 136 | + return None |
| 137 | + |
| 138 | + |
62 | 139 | __all__ = [ |
| 140 | + "INPUT_TYPE_MAPPING", |
63 | 141 | "Chunk", |
64 | 142 | "FinishReason", |
65 | 143 | "Input", |
66 | 144 | "Output", |
67 | 145 | "Usage", |
| 146 | + "get_constraint_input_type", |
68 | 147 | "get_input_class", |
| 148 | + "get_required_input_types", |
69 | 149 | "register_input", |
70 | 150 | ] |
0 commit comments