diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py index 78373e0..5057ba3 100644 --- a/ccflow/_flow_model_binding.py +++ b/ccflow/_flow_model_binding.py @@ -75,6 +75,7 @@ class _ParsedAnnotation: base: Any is_lazy: bool is_from_context: bool + optional_context: bool = False @dataclass(frozen=True) @@ -390,6 +391,7 @@ def _resolved_flow_signature( def _parse_annotation(annotation: Any) -> _ParsedAnnotation: is_lazy = False is_from_context = False + optional_context = False while get_origin(annotation) is Annotated: args = get_args(annotation) @@ -400,12 +402,38 @@ def _parse_annotation(annotation: Any) -> _ParsedAnnotation: elif isinstance(metadata, _FromContextMarker): is_from_context = True + # Detect markers nested inside a top-level Optional/Union, e.g. + # ``Optional[FromContext[int]]`` == ``Union[Annotated[int, FromContext], None]``. + # Such a parameter is contextual but not required: when absent from the runtime + # context it is bound to ``None`` (an implicit default synthesized in + # ``_analyze_flow_function``). This is distinct from ``FromContext[Optional[int]]``, + # which is required-in-context but may carry a ``None`` value. + if not is_from_context and not is_lazy and get_origin(annotation) in _UNION_ORIGINS: + members = get_args(annotation) + non_none = [member for member in members if member is not type(None)] + has_none = len(non_none) != len(members) + marked = [(member, _parse_annotation(member)) for member in non_none] + marker_members = [(member, parsed) for member, parsed in marked if parsed.is_from_context or parsed.is_lazy] + if marker_members: + if any(parsed.is_lazy for _, parsed in marker_members): + raise TypeError("Lazy[...] cannot be nested inside Optional/Union; mark the whole parameter as Lazy[T].") + if len(marker_members) != 1 or not has_none or len(non_none) != 1: + raise TypeError( + "FromContext[...] inside a Union is only supported as Optional[FromContext[T]] " + "(exactly one contextual member plus None). Use FromContext[Optional[T]] for a " + "required-but-nullable contextual input." + ) + inner = marker_members[0][1].base + annotation = Optional[inner] + is_from_context = True + optional_context = True + if annotation is FromContext: raise TypeError("FromContext is an annotation marker; use FromContext[T] in @Flow.model signatures.") if annotation is Lazy: raise TypeError("Lazy is an annotation marker; use Lazy[T] in @Flow.model signatures.") - return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context) + return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context, optional_context=optional_context) def _strip_annotated(annotation: Any) -> Any: @@ -520,14 +548,26 @@ def _analyze_flow_function( if parsed.is_from_context and has_default and is_model_dependency(param.default): raise TypeError(f"Parameter '{param.name}' is marked FromContext[...] and cannot default to a CallableModel.") + # ``Optional[FromContext[T]]`` is contextual with an implicit ``None`` default unless + # the signature already provides one. An explicit default always wins. + if has_default: + stored_has_default = True + stored_default = param.default + elif parsed.optional_context: + stored_has_default = True + stored_default = None + else: + stored_has_default = False + stored_default = _UNSET + analyzed_params.append( _FlowModelParam( name=param.name, annotation=parsed.base, is_contextual=parsed.is_from_context, is_lazy=parsed.is_lazy, - has_function_default=has_default, - function_default=param.default if has_default else _UNSET, + has_function_default=stored_has_default, + function_default=stored_default, ) ) diff --git a/ccflow/tests/test_flow_model_optional_context.py b/ccflow/tests/test_flow_model_optional_context.py new file mode 100644 index 0000000..3f067e4 --- /dev/null +++ b/ccflow/tests/test_flow_model_optional_context.py @@ -0,0 +1,100 @@ +"""Tests for ``Optional[FromContext[T]]`` vs ``FromContext[Optional[T]]`` semantics. + +Two spellings are both contextual and share the same validation base (``Optional[T]``); +they differ only in required-ness: + +- ``FromContext[Optional[int]]`` is required-in-context but its value may be ``None``. +- ``Optional[FromContext[int]]`` is optional: when absent from context it is bound to ``None``. +""" + +from typing import Optional, Union + +import pytest + +from ccflow import Flow, FlowContext, FromContext, Lazy, ModelEvaluationContext +from ccflow.evaluators.common import cache_key + + +@Flow.model +def required_ctx(a: FromContext[Optional[int]]) -> int: + return -1 if a is None else a + + +@Flow.model +def optional_ctx(a: Optional[FromContext[int]]) -> int: + return -1 if a is None else a + + +@Flow.model +def explicit_none_default(a: FromContext[Optional[int]] = None) -> int: + return -1 if a is None else a + + +@Flow.model +def optional_with_default(a: Optional[FromContext[int]] = 3) -> int: + return -1 if a is None else a + + +class TestRequiredNullableContext: + def test_missing_raises(self): + with pytest.raises(TypeError, match="Missing contextual input"): + required_ctx().flow.compute() + + def test_none_is_valid(self): + assert required_ctx().flow.compute(a=None).value == -1 + + def test_value(self): + assert required_ctx().flow.compute(a=5).value == 5 + + def test_inspect_required(self): + insp = required_ctx().flow.inspect() + assert "a" in insp.required_inputs + + +class TestOptionalContext: + def test_missing_binds_none(self): + assert optional_ctx().flow.compute().value == -1 + + def test_none_explicit(self): + assert optional_ctx().flow.compute(a=None).value == -1 + + def test_value(self): + assert optional_ctx().flow.compute(a=5).value == 5 + + def test_inspect_not_required(self): + insp = optional_ctx().flow.inspect() + assert "a" not in insp.required_inputs + assert "a" in insp.context_inputs + + +class TestConsistency: + def test_explicit_none_default_equiv_optional(self): + # FromContext[Optional[int]] = None is equivalent to Optional[FromContext[int]]. + assert explicit_none_default().flow.compute().value == -1 + assert optional_ctx().flow.compute().value == -1 + + def test_optional_with_explicit_default_wins(self): + # An explicit default overrides the implicit None of Optional[FromContext[int]]. + assert optional_with_default().flow.compute().value == 3 + + def test_node_keys_distinguish_required_vs_optional(self): + req_key = cache_key(ModelEvaluationContext(model=required_ctx(), context=FlowContext(a=5))) + opt_key = cache_key(ModelEvaluationContext(model=optional_ctx(), context=FlowContext(a=5))) + # Different required-ness must give distinct logical identities. + assert req_key != opt_key + + +class TestRejections: + def test_nested_lazy_rejected(self): + with pytest.raises(TypeError, match="Lazy"): + + @Flow.model + def bad(a: Optional[Lazy[int]]) -> int: + return 0 + + def test_non_optional_union_with_fromcontext_rejected(self): + with pytest.raises(TypeError, match="only supported as Optional"): + + @Flow.model + def bad(a: Union[FromContext[int], str]) -> int: + return 0