diff --git a/ccflow/_flow_model_binding.py b/ccflow/_flow_model_binding.py index 5057ba3..3b98577 100644 --- a/ccflow/_flow_model_binding.py +++ b/ccflow/_flow_model_binding.py @@ -574,7 +574,12 @@ def _analyze_flow_function( return tuple(analyzed_params) -def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[_FlowModelParam, ...]) -> Type[ContextBase]: +def _validate_declared_context_type( + context_type: Any, + contextual_params: Tuple[_FlowModelParam, ...], + *, + strict: bool = False, +) -> Type[ContextBase]: if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") @@ -585,14 +590,20 @@ def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[ if missing: raise TypeError(f"context_type {context_type.__name__} must define fields for all FromContext parameters: {', '.join(missing)}") - required_extra_fields = sorted( - name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in contextual_names and info.is_required() - ) - if required_extra_fields: - raise TypeError( - f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: " - f"{', '.join(required_extra_fields)}" + if strict: + # Strict mode requires a full bijection: every required context_type field must be a + # FromContext parameter. The default (subset) mode lets the declared context act as an + # "omnibus" carrier whose extra fields this model simply does not consume. + required_extra_fields = sorted( + name + for name, info in context_fields.items() + if name not in ContextBase.model_fields and name not in contextual_names and info.is_required() ) + if required_extra_fields: + raise TypeError( + f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: " + f"{', '.join(required_extra_fields)}" + ) for param in contextual_params: ctx_field = context_fields[param.name] @@ -612,6 +623,7 @@ def _analyze_flow_model( context_type: Optional[Type[ContextBase]], auto_unwrap: bool, is_model_dependency: Callable[[Any], bool], + strict: bool = False, ) -> _FlowModelConfig: parameters = _analyze_flow_function(fn, sig, is_model_dependency=is_model_dependency) reserved = sorted(param.name for param in parameters if param.name in _RESERVED_FLOW_MODEL_PARAM_NAMES) @@ -624,7 +636,7 @@ def _analyze_flow_model( if context_type is not None and not contextual_params: raise TypeError("context_type=... requires FromContext[...] parameters.") if context_type is not None: - declared_context_type = _validate_declared_context_type(context_type, contextual_params) + declared_context_type = _validate_declared_context_type(context_type, contextual_params, strict=strict) if declared_context_type is not None: updated_params = [] diff --git a/ccflow/callable.py b/ccflow/callable.py index cff2a32..acdb2c8 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -507,6 +507,11 @@ def model(*args, **kwargs): applies the returned decorator. context_type: Optional ``ContextBase`` subclass used to validate ``FromContext[...]`` fields together. + strict: When ``context_type`` is given, require a full bijection + between the declared context's required fields and the + ``FromContext[...]`` parameters. Defaults to ``False``, which + allows the declared context to be an omnibus superset carrying + extra fields the model does not consume. auto_unwrap: When ``True`` and the function's plain return value is auto-wrapped in ``GenericResult[T]``, external ``model.flow.compute(...)`` calls return the raw ``T`` value diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index c8b0cf8..a274c7c 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -1086,6 +1086,17 @@ def _validate_declared_context_values(config: _FlowModelConfig, values: Dict[str if config.declared_context_type is None: return values + # Subset (omnibus) mode: when the declared context carries required fields this model + # does not consume, constructing the whole declared context from only the consumed + # fields would fail. Validate the consumed fields individually against their declared + # annotations instead. When the declared context is a full bijection (or only adds + # optional extras), construct the whole model so its cross-field validators still run. + declared_fields = config.declared_context_type.model_fields + consumed = set(config.contextual_param_names) + has_unconsumed_required = any(name not in consumed and info.is_required() for name, info in declared_fields.items()) + if has_unconsumed_required: + return {param.name: _coerce_contextual_value(config, param, values[param.name], "Context field") for param in config.contextual_params} + validated = config.declared_context_type.model_validate(values) return {param.name: getattr(validated, param.name) for param in config.contextual_params} @@ -3058,6 +3069,7 @@ def flow_model( func: Optional[_AnyCallable] = None, *, context_type: Optional[Type[ContextBase]] = None, + strict: bool = False, auto_unwrap: bool = False, model_base: Type[CallableModel] = CallableModel, cacheable: Any = _UNSET, @@ -3084,6 +3096,14 @@ def flow_model( context_type: Optional ``ContextBase`` subclass used to validate all contextual inputs together after individual ``FromContext[...]`` fields are resolved. + strict: When ``context_type`` is given, controls how strictly the + ``FromContext[...]`` parameters must match the declared context. The + default (``False``) allows the declared context to be an "omnibus" + superset: every ``FromContext`` field must exist on the context with a + compatible type, but the context may also carry extra fields this + model does not consume. ``strict=True`` additionally requires that + every required field of ``context_type`` is declared as a + ``FromContext`` parameter (a full bijection). auto_unwrap: When ``True`` and ccflow auto-wraps a plain return annotation in ``GenericResult[T]``, external ``model.flow.compute(...)`` calls return the raw ``T`` value instead @@ -3116,6 +3136,7 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: context_type=context_type, auto_unwrap=auto_unwrap, is_model_dependency=_is_model_dependency, + strict=strict, ) factory_kwargs = { "context_type": context_type, diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index a7adf9a..a395d59 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -2721,12 +2721,20 @@ class ExtraRequiredContext(ContextBase): end_date: date label: str + # Under the default (strict=False) an unconsumed required field is allowed: the declared + # context acts as an omnibus superset. strict=True restores the full-bijection requirement. with pytest.raises(TypeError, match="has required fields that are not declared as FromContext parameters"): - @Flow.model(context_type=ExtraRequiredContext) + @Flow.model(context_type=ExtraRequiredContext, strict=True) def bad_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int: return 0 + @Flow.model(context_type=ExtraRequiredContext) + def ok_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + assert ok_extra().flow.compute(ExtraRequiredContext(start_date=date(2025, 1, 1), end_date=date(2025, 1, 8), label="x")).value == 7 + class BadAnnotationContext(ContextBase): value: str diff --git a/ccflow/tests/test_flow_model_strict_context.py b/ccflow/tests/test_flow_model_strict_context.py new file mode 100644 index 0000000..4c3d4ef --- /dev/null +++ b/ccflow/tests/test_flow_model_strict_context.py @@ -0,0 +1,78 @@ +"""Tests for the ``Flow.model(context_type=, strict=...)`` subset/omnibus behavior.""" + +from datetime import date + +import pytest + +from ccflow import ContextBase, Flow, FromContext + + +class OmnibusContext(ContextBase): + start_date: date + end_date: date + region: str # extra required field a span model does not consume + + +class TestSubsetDefault: + def test_subset_model_builds_and_runs(self): + # Default strict=False: FromContext fields are a typed subset of the omnibus. + @Flow.model(context_type=OmnibusContext) + def span(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + model = span() + # The omnibus carries `region`, which this model ignores. + ctx = OmnibusContext(start_date=date(2025, 1, 1), end_date=date(2025, 1, 8), region="us") + assert model.flow.compute(ctx).value == 7 + + def test_subset_compute_with_named_inputs(self): + @Flow.model(context_type=OmnibusContext) + def span(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + assert span().flow.compute(start_date="2025-01-01", end_date="2025-01-08").value == 7 + + +class TestStrict: + def test_strict_rejects_unconsumed_required_field(self): + with pytest.raises(TypeError, match="has required fields that are not declared as FromContext"): + + @Flow.model(context_type=OmnibusContext, strict=True) + def span(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + def test_strict_accepts_full_bijection(self): + class ExactContext(ContextBase): + start_date: date + end_date: date + + @Flow.model(context_type=ExactContext, strict=True) + def span(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + assert span().flow.compute(start_date="2025-01-01", end_date="2025-01-08").value == 7 + + +class TestSharedChecks: + def test_missing_field_errors_in_both_modes(self): + class CtxNoFoo(ContextBase): + start_date: date + + for strict in (False, True): + with pytest.raises(TypeError, match="must define fields for all FromContext"): + + @Flow.model(context_type=CtxNoFoo, strict=strict) + def span(start_date: FromContext[date], foo: FromContext[int]) -> int: + return foo + + def test_type_mismatch_errors_in_both_modes(self): + class CtxBadType(ContextBase): + start_date: date + end_date: date + + for strict in (False, True): + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=CtxBadType, strict=strict) + def span(start_date: FromContext[int], end_date: FromContext[date]) -> int: + return 0