Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 117 additions & 4 deletions ccflow/_flow_model_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def __repr__(self) -> str:
return "FromContext"


class _DepMarker:
def __repr__(self) -> str:
return "Dep"


class FromContext:
"""Marker used in ``@Flow.model`` signatures for runtime/contextual inputs."""

Expand All @@ -70,11 +75,22 @@ def __class_getitem__(cls, item):
return Annotated[item, _LazyMarker()]


class Dep:
"""Marker used in ``@Flow.model`` signatures for explicit dependency leaves."""

def __new__(cls, *args, **kwargs):
raise TypeError("Dep is an annotation marker; use Dep[T] in @Flow.model signatures.")

def __class_getitem__(cls, item):
return Annotated[item, _DepMarker()]


@dataclass(frozen=True)
class _ParsedAnnotation:
base: Any
is_lazy: bool
is_from_context: bool
is_dep: bool
optional_context: bool = False


Expand All @@ -87,6 +103,7 @@ class _FlowModelParam:
has_function_default: bool
function_default: Any = _UNSET
context_validation_annotation: Any = _UNSET
has_dep_slots: bool = False

@property
def validation_annotation(self) -> Any:
Expand Down Expand Up @@ -170,6 +187,7 @@ class _SerializedFlowModelParam(NamedTuple):
has_function_default: bool
function_default: Any
context_validation_annotation: _SerializedAnnotation
has_dep_slots: bool = False


class _SerializedFlowModelConfig(NamedTuple):
Expand Down Expand Up @@ -298,20 +316,24 @@ def _serialize_flow_model_param(param: _FlowModelParam) -> _SerializedFlowModelP
has_function_default=param.has_function_default,
function_default=param.function_default,
context_validation_annotation=_serialize_annotation(param.context_validation_annotation),
has_dep_slots=param.has_dep_slots,
)


def _restore_flow_model_param(payload: _SerializedFlowModelParam) -> _FlowModelParam:
if not isinstance(payload, _SerializedFlowModelParam):
raise TypeError(f"Unknown Flow.model parameter payload: {payload!r}")
annotation = _restore_annotation(payload.annotation)
context_validation_annotation = _restore_annotation(payload.context_validation_annotation)
return _FlowModelParam(
name=payload.name,
annotation=_restore_annotation(payload.annotation),
annotation=annotation,
is_contextual=payload.is_contextual,
is_lazy=payload.is_lazy,
has_function_default=payload.has_function_default,
function_default=payload.function_default,
context_validation_annotation=_restore_annotation(payload.context_validation_annotation),
context_validation_annotation=context_validation_annotation,
has_dep_slots=getattr(payload, "has_dep_slots", _annotation_contains_dep(annotation)),
)


Expand Down Expand Up @@ -391,6 +413,7 @@ def _resolved_flow_signature(
def _parse_annotation(annotation: Any) -> _ParsedAnnotation:
is_lazy = False
is_from_context = False
is_dep = False
optional_context = False

while get_origin(annotation) is Annotated:
Expand All @@ -401,6 +424,8 @@ def _parse_annotation(annotation: Any) -> _ParsedAnnotation:
is_lazy = True
elif isinstance(metadata, _FromContextMarker):
is_from_context = True
elif isinstance(metadata, _DepMarker):
is_dep = True

# Detect markers nested inside a top-level Optional/Union, e.g.
# ``Optional[FromContext[int]]`` == ``Union[Annotated[int, FromContext], None]``.
Expand Down Expand Up @@ -432,8 +457,16 @@ def _parse_annotation(annotation: Any) -> _ParsedAnnotation:
raise TypeError("FromContext is an annotation marker; use FromContext[T] in @Flow.model signatures.")
if annotation is Lazy:
raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.")

return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context, optional_context=optional_context)
if annotation is Dep:
raise TypeError("Dep is an annotation marker; use Dep[T] in @Flow.model signatures.")

return _ParsedAnnotation(
base=annotation,
is_lazy=is_lazy,
is_from_context=is_from_context,
is_dep=is_dep,
optional_context=optional_context,
)


def _strip_annotated(annotation: Any) -> Any:
Expand All @@ -442,6 +475,78 @@ def _strip_annotated(annotation: Any) -> Any:
return annotation


def _pop_dep_marker(annotation: Any) -> Tuple[Any, bool]:
"""Remove only the outer Dep marker while preserving other Annotated metadata."""

if get_origin(annotation) is not Annotated:
return annotation, False

args = get_args(annotation)
metadata = tuple(item for item in args[1:] if not isinstance(item, _DepMarker))
has_dep = len(metadata) != len(args[1:])
base = args[0]
if not metadata:
return base, has_dep
# Keep non-Dep metadata, such as pydantic Field constraints, on the annotation
# used to validate literals and resolved dependency results. Build the tuple
# first and subscript ``Annotated`` with it: ``Annotated[base, *metadata]`` is
# 3.11+-only syntax, and ``Annotated.__class_getitem__`` was removed in 3.14,
# but ``Annotated[(base, *metadata)]`` is portable across both.
return Annotated[(base, *metadata)], has_dep


def _annotation_contains_dep(annotation: Any) -> bool:
annotation, has_dep = _pop_dep_marker(annotation)
if has_dep:
return True
return any(_annotation_contains_dep(arg) for arg in get_args(annotation))


def _validate_dep_annotation(annotation: Any, *, in_dep: bool = False, dep_allowed: bool = False) -> None:
"""Validate the deliberately small Dep marker language.

Dep marks exact substitution slots. It is allowed inside container values,
but not inside another Dep, not in dict keys, and not mixed with Lazy or
FromContext markers inside a Dep slot.
"""

annotation, has_dep = _pop_dep_marker(annotation)
if has_dep:
if not dep_allowed:
raise TypeError("Dep[...] is only supported in regular parameter container values.")
if in_dep:
raise TypeError("Dep[...] cannot contain another Dep[...] marker.")
_validate_dep_annotation(annotation, in_dep=True, dep_allowed=False)
return

if in_dep and get_origin(annotation) is Annotated:
metadata = get_args(annotation)[1:]
if any(isinstance(item, (_LazyMarker, _FromContextMarker)) for item in metadata):
raise TypeError("Dep[...] cannot contain Lazy[...] or FromContext[...] markers.")

origin = get_origin(annotation)
args = get_args(annotation)
if origin in _UNION_ORIGINS and any(_annotation_contains_dep(arg) for arg in args):
raise TypeError("Dep[...] is not supported inside union annotations.")
if origin is list and len(args) == 1:
_validate_dep_annotation(args[0], in_dep=in_dep, dep_allowed=True)
return
if origin is tuple and args:
item_args = args[:1] if len(args) == 2 and args[1] is Ellipsis else args
for arg in item_args:
_validate_dep_annotation(arg, in_dep=in_dep, dep_allowed=True)
return
if origin is dict and len(args) == 2:
key_annotation, value_annotation = args
if _annotation_contains_dep(key_annotation):
raise TypeError("Dep[...] is not supported in dict keys.")
_validate_dep_annotation(value_annotation, in_dep=in_dep, dep_allowed=True)
return

for arg in args:
_validate_dep_annotation(arg, in_dep=in_dep, dep_allowed=False)


def _is_result_annotation(annotation: Any) -> bool:
annotation = _strip_annotated(annotation)
origin = get_origin(annotation) or annotation
Expand Down Expand Up @@ -542,6 +647,13 @@ def _analyze_flow_function(
parsed = _parse_annotation(param.annotation)
if parsed.is_lazy and parsed.is_from_context:
raise TypeError(f"Parameter '{param.name}' cannot combine Lazy[...] and FromContext[...].")
if (parsed.is_dep or _annotation_contains_dep(parsed.base)) and (parsed.is_lazy or parsed.is_from_context):
marker = "Lazy" if parsed.is_lazy else "FromContext"
raise TypeError(f"Parameter '{param.name}' cannot combine Dep[...] and {marker}[...].")
if parsed.is_dep:
raise TypeError("Dep[...] is only supported in regular parameter container values.")
_validate_dep_annotation(parsed.base)
has_dep_slots = _annotation_contains_dep(parsed.base)
has_default = param.default is not inspect.Parameter.empty
if parsed.is_lazy and has_default and not is_model_dependency(param.default):
raise TypeError(f"Parameter '{param.name}' is marked Lazy[...] and must default to a CallableModel dependency.")
Expand All @@ -568,6 +680,7 @@ def _analyze_flow_function(
is_lazy=parsed.is_lazy,
has_function_default=stored_has_default,
function_default=stored_default,
has_dep_slots=has_dep_slots,
)
)

Expand Down
30 changes: 19 additions & 11 deletions ccflow/examples/flow_model/flow_model_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
1. define stages as plain Python functions,
2. compose stages by passing upstream models as ordinary arguments,
3. rewrite contextual inputs on one dependency edge with `.flow.with_context(...)`,
4. execute the configured graph with `model.flow.compute(...)`.
4. use `Dep[...]` for model leaves inside regular container inputs,
5. execute the configured graph with `model.flow.compute(...)`.

Run with:
python ccflow/examples/flow_model/flow_model_example.py
python -m ccflow.examples.flow_model.flow_model_example
"""

from datetime import date, timedelta

from ccflow import DateRangeContext, Flow, FromContext
from ccflow import DateRangeContext, Dep, Flow, FromContext


def _format_input_names(inputs: dict[str, object]) -> str:
Expand All @@ -23,10 +24,18 @@ def _format_input_names(inputs: dict[str, object]) -> str:


def _format_bound_inputs(inputs: dict[str, object]) -> str:
def display_value(value: object) -> str:
if hasattr(value, "flow"):
return "model"
if isinstance(value, list):
return "[" + ", ".join(display_value(item) for item in value) + "]"
if isinstance(value, tuple):
return "(" + ", ".join(display_value(item) for item in value) + ")"
return repr(value)

parts = []
for name, value in inputs.items():
display = "model" if hasattr(value, "flow") else repr(value)
parts.append(f"{name}={display}")
parts.append(f"{name}={display_value(value)}")
return ", ".join(parts) or "(none)"


Expand All @@ -45,13 +54,13 @@ def count_visitors(

@Flow.model(context_type=DateRangeContext)
def visitor_delta(
current: int,
previous: int,
counts: list[Dep[int]],
label: str,
start_date: FromContext[date],
end_date: FromContext[date],
) -> dict[str, object]:
"""Return both visitor counts plus their difference."""
current, previous = counts
return {
"label": label,
"window": f"{start_date} -> {end_date}",
Expand All @@ -75,8 +84,7 @@ def build_visitor_pipeline(location: str):
current = count_visitors(location=location)
previous = current.flow.with_context(shift_window(days=7))
return visitor_delta(
current=current,
previous=previous,
counts=[current, previous],
label="previous_week",
)

Expand All @@ -101,8 +109,8 @@ def main() -> None:
print("\nPipeline:")
print(" model: visitor_delta")
pipeline_inspection = pipeline.flow.inspect()
current_inspection = pipeline.current.flow.inspect()
previous_inspection = pipeline.previous.flow.inspect()
current_inspection = pipeline.counts[0].flow.inspect()
previous_inspection = pipeline.counts[1].flow.inspect()
print(f" bound inputs: {_format_bound_inputs(pipeline_inspection.bound_inputs)}")
print(f" declared context inputs: {_format_input_names(pipeline_inspection.context_inputs)}")
print(f" runtime inputs: {_format_input_names(pipeline_inspection.runtime_inputs)}")
Expand Down
24 changes: 16 additions & 8 deletions ccflow/examples/flow_model/flow_model_hydra_builder_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@

- keep runtime context (`start_date`, `end_date`) as runtime inputs,
- use a plain Python builder function for graph construction,
- use `Dep[...]` when a regular container input holds upstream models,
- let Hydra instantiate that builder and register the returned model.

Run with:
python ccflow/examples/flow_model/flow_model_hydra_builder_demo.py
python -m ccflow.examples.flow_model.flow_model_hydra_builder_demo
"""

from datetime import date, timedelta
from pathlib import Path

from ccflow import CallableModel, DateRangeContext, Flow, FromContext, ModelRegistry
from ccflow import CallableModel, DateRangeContext, Dep, Flow, FromContext, ModelRegistry

CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml"

Expand All @@ -30,10 +31,18 @@ def _format_input_names(inputs: dict[str, object]) -> str:


def _format_bound_inputs(inputs: dict[str, object]) -> str:
def display_value(value: object) -> str:
if hasattr(value, "flow"):
return "model"
if isinstance(value, list):
return "[" + ", ".join(display_value(item) for item in value) + "]"
if isinstance(value, tuple):
return "(" + ", ".join(display_value(item) for item in value) + ")"
return repr(value)

parts = []
for name, value in inputs.items():
display = "model" if hasattr(value, "flow") else repr(value)
parts.append(f"{name}={display}")
parts.append(f"{name}={display_value(value)}")
return ", ".join(parts) or "(none)"


Expand All @@ -56,13 +65,13 @@ def count_visitors(location: str, start_date: FromContext[date], end_date: FromC

@Flow.model(context_type=DateRangeContext)
def visitor_delta(
current: int,
previous: int,
counts: list[Dep[int]],
label: str,
start_date: FromContext[date],
end_date: FromContext[date],
) -> dict[str, object]:
"""Return both visitor counts plus their difference."""
current, previous = counts
return {
"label": label,
"window": f"{start_date} -> {end_date}",
Expand All @@ -85,8 +94,7 @@ def build_visitor_delta(current: CallableModel, *, label: str, days_back: int):
"""Hydra-friendly builder that returns a configured visitor-count model."""
previous = current.flow.with_context(shift_window(days=days_back))
return visitor_delta(
current=current,
previous=previous,
counts=[current, previous],
label=label,
)

Expand Down
Loading
Loading