Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions ccflow/_flow_model_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class _ParsedAnnotation:
base: Any
is_lazy: bool
is_from_context: bool
optional_context: bool = False


@dataclass(frozen=True)
Expand Down Expand Up @@ -390,6 +391,7 @@ def _resolved_flow_signature(
def _parse_annotation(annotation: Any) -> _ParsedAnnotation:
is_lazy = False
is_from_context = False
optional_context = False

while get_origin(annotation) is Annotated:
args = get_args(annotation)
Expand All @@ -400,12 +402,38 @@ def _parse_annotation(annotation: Any) -> _ParsedAnnotation:
elif isinstance(metadata, _FromContextMarker):
is_from_context = True

# Detect markers nested inside a top-level Optional/Union, e.g.
# ``Optional[FromContext[int]]`` == ``Union[Annotated[int, FromContext], None]``.
# Such a parameter is contextual but not required: when absent from the runtime
# context it is bound to ``None`` (an implicit default synthesized in
# ``_analyze_flow_function``). This is distinct from ``FromContext[Optional[int]]``,
# which is required-in-context but may carry a ``None`` value.
if not is_from_context and not is_lazy and get_origin(annotation) in _UNION_ORIGINS:
members = get_args(annotation)
non_none = [member for member in members if member is not type(None)]
has_none = len(non_none) != len(members)
marked = [(member, _parse_annotation(member)) for member in non_none]
marker_members = [(member, parsed) for member, parsed in marked if parsed.is_from_context or parsed.is_lazy]
if marker_members:
if any(parsed.is_lazy for _, parsed in marker_members):
raise TypeError("Lazy[...] cannot be nested inside Optional/Union; mark the whole parameter as Lazy[T].")
if len(marker_members) != 1 or not has_none or len(non_none) != 1:
raise TypeError(
"FromContext[...] inside a Union is only supported as Optional[FromContext[T]] "
"(exactly one contextual member plus None). Use FromContext[Optional[T]] for a "
"required-but-nullable contextual input."
)
inner = marker_members[0][1].base
annotation = Optional[inner]
is_from_context = True
optional_context = True

if annotation is FromContext:
raise TypeError("FromContext is an annotation marker; use FromContext[T] in @Flow.model signatures.")
if annotation is Lazy:
raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.")

return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context)
return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context, optional_context=optional_context)


def _strip_annotated(annotation: Any) -> Any:
Expand Down Expand Up @@ -520,14 +548,26 @@ def _analyze_flow_function(
if parsed.is_from_context and has_default and is_model_dependency(param.default):
raise TypeError(f"Parameter '{param.name}' is marked FromContext[...] and cannot default to a CallableModel.")

# ``Optional[FromContext[T]]`` is contextual with an implicit ``None`` default unless
# the signature already provides one. An explicit default always wins.
if has_default:
stored_has_default = True
stored_default = param.default
elif parsed.optional_context:
stored_has_default = True
stored_default = None
else:
stored_has_default = False
stored_default = _UNSET

analyzed_params.append(
_FlowModelParam(
name=param.name,
annotation=parsed.base,
is_contextual=parsed.is_from_context,
is_lazy=parsed.is_lazy,
has_function_default=has_default,
function_default=param.default if has_default else _UNSET,
has_function_default=stored_has_default,
function_default=stored_default,
)
)

Expand Down
100 changes: 100 additions & 0 deletions ccflow/tests/test_flow_model_optional_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Tests for ``Optional[FromContext[T]]`` vs ``FromContext[Optional[T]]`` semantics.

Two spellings are both contextual and share the same validation base (``Optional[T]``);
they differ only in required-ness:

- ``FromContext[Optional[int]]`` is required-in-context but its value may be ``None``.
- ``Optional[FromContext[int]]`` is optional: when absent from context it is bound to ``None``.
"""

from typing import Optional, Union

import pytest

from ccflow import Flow, FlowContext, FromContext, Lazy, ModelEvaluationContext
from ccflow.evaluators.common import cache_key


@Flow.model
def required_ctx(a: FromContext[Optional[int]]) -> int:
return -1 if a is None else a


@Flow.model
def optional_ctx(a: Optional[FromContext[int]]) -> int:
return -1 if a is None else a


@Flow.model
def explicit_none_default(a: FromContext[Optional[int]] = None) -> int:
return -1 if a is None else a


@Flow.model
def optional_with_default(a: Optional[FromContext[int]] = 3) -> int:
return -1 if a is None else a


class TestRequiredNullableContext:
def test_missing_raises(self):
with pytest.raises(TypeError, match="Missing contextual input"):
required_ctx().flow.compute()

def test_none_is_valid(self):
assert required_ctx().flow.compute(a=None).value == -1

def test_value(self):
assert required_ctx().flow.compute(a=5).value == 5

def test_inspect_required(self):
insp = required_ctx().flow.inspect()
assert "a" in insp.required_inputs


class TestOptionalContext:
def test_missing_binds_none(self):
assert optional_ctx().flow.compute().value == -1

def test_none_explicit(self):
assert optional_ctx().flow.compute(a=None).value == -1

def test_value(self):
assert optional_ctx().flow.compute(a=5).value == 5

def test_inspect_not_required(self):
insp = optional_ctx().flow.inspect()
assert "a" not in insp.required_inputs
assert "a" in insp.context_inputs


class TestConsistency:
def test_explicit_none_default_equiv_optional(self):
# FromContext[Optional[int]] = None is equivalent to Optional[FromContext[int]].
assert explicit_none_default().flow.compute().value == -1
assert optional_ctx().flow.compute().value == -1

def test_optional_with_explicit_default_wins(self):
# An explicit default overrides the implicit None of Optional[FromContext[int]].
assert optional_with_default().flow.compute().value == 3

def test_node_keys_distinguish_required_vs_optional(self):
req_key = cache_key(ModelEvaluationContext(model=required_ctx(), context=FlowContext(a=5)))
opt_key = cache_key(ModelEvaluationContext(model=optional_ctx(), context=FlowContext(a=5)))
# Different required-ness must give distinct logical identities.
assert req_key != opt_key


class TestRejections:
def test_nested_lazy_rejected(self):
with pytest.raises(TypeError, match="Lazy"):

@Flow.model
def bad(a: Optional[Lazy[int]]) -> int:
return 0

def test_non_optional_union_with_fromcontext_rejected(self):
with pytest.raises(TypeError, match="only supported as Optional"):

@Flow.model
def bad(a: Union[FromContext[int], str]) -> int:
return 0
Loading