Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions ccflow/_flow_model_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,12 @@ def _analyze_flow_function(
return tuple(analyzed_params)


def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[_FlowModelParam, ...]) -> Type[ContextBase]:
def _validate_declared_context_type(
context_type: Any,
contextual_params: Tuple[_FlowModelParam, ...],
*,
strict: bool = False,
) -> Type[ContextBase]:
if not isinstance(context_type, type) or not issubclass(context_type, ContextBase):
raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}")

Expand All @@ -585,14 +590,20 @@ def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[
if missing:
raise TypeError(f"context_type {context_type.__name__} must define fields for all FromContext parameters: {', '.join(missing)}")

required_extra_fields = sorted(
name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in contextual_names and info.is_required()
)
if required_extra_fields:
raise TypeError(
f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: "
f"{', '.join(required_extra_fields)}"
if strict:
# Strict mode requires a full bijection: every required context_type field must be a
# FromContext parameter. The default (subset) mode lets the declared context act as an
# "omnibus" carrier whose extra fields this model simply does not consume.
required_extra_fields = sorted(
name
for name, info in context_fields.items()
if name not in ContextBase.model_fields and name not in contextual_names and info.is_required()
)
if required_extra_fields:
raise TypeError(
f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: "
f"{', '.join(required_extra_fields)}"
)

for param in contextual_params:
ctx_field = context_fields[param.name]
Expand All @@ -612,6 +623,7 @@ def _analyze_flow_model(
context_type: Optional[Type[ContextBase]],
auto_unwrap: bool,
is_model_dependency: Callable[[Any], bool],
strict: bool = False,
) -> _FlowModelConfig:
parameters = _analyze_flow_function(fn, sig, is_model_dependency=is_model_dependency)
reserved = sorted(param.name for param in parameters if param.name in _RESERVED_FLOW_MODEL_PARAM_NAMES)
Expand All @@ -624,7 +636,7 @@ def _analyze_flow_model(
if context_type is not None and not contextual_params:
raise TypeError("context_type=... requires FromContext[...] parameters.")
if context_type is not None:
declared_context_type = _validate_declared_context_type(context_type, contextual_params)
declared_context_type = _validate_declared_context_type(context_type, contextual_params, strict=strict)

if declared_context_type is not None:
updated_params = []
Expand Down
5 changes: 5 additions & 0 deletions ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,11 @@ def model(*args, **kwargs):
applies the returned decorator.
context_type: Optional ``ContextBase`` subclass used to validate
``FromContext[...]`` fields together.
strict: When ``context_type`` is given, require a full bijection
between the declared context's required fields and the
``FromContext[...]`` parameters. Defaults to ``False``, which
allows the declared context to be an omnibus superset carrying
extra fields the model does not consume.
auto_unwrap: When ``True`` and the function's plain return value is
auto-wrapped in ``GenericResult[T]``, external
``model.flow.compute(...)`` calls return the raw ``T`` value
Expand Down
21 changes: 21 additions & 0 deletions ccflow/flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,17 @@ def _validate_declared_context_values(config: _FlowModelConfig, values: Dict[str
if config.declared_context_type is None:
return values

# Subset (omnibus) mode: when the declared context carries required fields this model
# does not consume, constructing the whole declared context from only the consumed
# fields would fail. Validate the consumed fields individually against their declared
# annotations instead. When the declared context is a full bijection (or only adds
# optional extras), construct the whole model so its cross-field validators still run.
declared_fields = config.declared_context_type.model_fields
consumed = set(config.contextual_param_names)
has_unconsumed_required = any(name not in consumed and info.is_required() for name, info in declared_fields.items())
if has_unconsumed_required:
return {param.name: _coerce_contextual_value(config, param, values[param.name], "Context field") for param in config.contextual_params}

validated = config.declared_context_type.model_validate(values)
return {param.name: getattr(validated, param.name) for param in config.contextual_params}

Expand Down Expand Up @@ -3058,6 +3069,7 @@ def flow_model(
func: Optional[_AnyCallable] = None,
*,
context_type: Optional[Type[ContextBase]] = None,
strict: bool = False,
auto_unwrap: bool = False,
model_base: Type[CallableModel] = CallableModel,
cacheable: Any = _UNSET,
Expand All @@ -3084,6 +3096,14 @@ def flow_model(
context_type: Optional ``ContextBase`` subclass used to validate all
contextual inputs together after individual ``FromContext[...]``
fields are resolved.
strict: When ``context_type`` is given, controls how strictly the
``FromContext[...]`` parameters must match the declared context. The
default (``False``) allows the declared context to be an "omnibus"
superset: every ``FromContext`` field must exist on the context with a
compatible type, but the context may also carry extra fields this
model does not consume. ``strict=True`` additionally requires that
every required field of ``context_type`` is declared as a
``FromContext`` parameter (a full bijection).
auto_unwrap: When ``True`` and ccflow auto-wraps a plain return
annotation in ``GenericResult[T]``, external
``model.flow.compute(...)`` calls return the raw ``T`` value instead
Expand Down Expand Up @@ -3116,6 +3136,7 @@ def decorator(fn: _AnyCallable) -> _AnyCallable:
context_type=context_type,
auto_unwrap=auto_unwrap,
is_model_dependency=_is_model_dependency,
strict=strict,
)
factory_kwargs = {
"context_type": context_type,
Expand Down
10 changes: 9 additions & 1 deletion ccflow/tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2721,12 +2721,20 @@ class ExtraRequiredContext(ContextBase):
end_date: date
label: str

# Under the default (strict=False) an unconsumed required field is allowed: the declared
# context acts as an omnibus superset. strict=True restores the full-bijection requirement.
with pytest.raises(TypeError, match="has required fields that are not declared as FromContext parameters"):

@Flow.model(context_type=ExtraRequiredContext)
@Flow.model(context_type=ExtraRequiredContext, strict=True)
def bad_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return 0

@Flow.model(context_type=ExtraRequiredContext)
def ok_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days

assert ok_extra().flow.compute(ExtraRequiredContext(start_date=date(2025, 1, 1), end_date=date(2025, 1, 8), label="x")).value == 7

class BadAnnotationContext(ContextBase):
value: str

Expand Down
78 changes: 78 additions & 0 deletions ccflow/tests/test_flow_model_strict_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Tests for the ``Flow.model(context_type=, strict=...)`` subset/omnibus behavior."""

from datetime import date

import pytest

from ccflow import ContextBase, Flow, FromContext


class OmnibusContext(ContextBase):
start_date: date
end_date: date
region: str # extra required field a span model does not consume


class TestSubsetDefault:
def test_subset_model_builds_and_runs(self):
# Default strict=False: FromContext fields are a typed subset of the omnibus.
@Flow.model(context_type=OmnibusContext)
def span(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days

model = span()
# The omnibus carries `region`, which this model ignores.
ctx = OmnibusContext(start_date=date(2025, 1, 1), end_date=date(2025, 1, 8), region="us")
assert model.flow.compute(ctx).value == 7

def test_subset_compute_with_named_inputs(self):
@Flow.model(context_type=OmnibusContext)
def span(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days

assert span().flow.compute(start_date="2025-01-01", end_date="2025-01-08").value == 7


class TestStrict:
def test_strict_rejects_unconsumed_required_field(self):
with pytest.raises(TypeError, match="has required fields that are not declared as FromContext"):

@Flow.model(context_type=OmnibusContext, strict=True)
def span(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days

def test_strict_accepts_full_bijection(self):
class ExactContext(ContextBase):
start_date: date
end_date: date

@Flow.model(context_type=ExactContext, strict=True)
def span(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days

assert span().flow.compute(start_date="2025-01-01", end_date="2025-01-08").value == 7


class TestSharedChecks:
def test_missing_field_errors_in_both_modes(self):
class CtxNoFoo(ContextBase):
start_date: date

for strict in (False, True):
with pytest.raises(TypeError, match="must define fields for all FromContext"):

@Flow.model(context_type=CtxNoFoo, strict=strict)
def span(start_date: FromContext[date], foo: FromContext[int]) -> int:
return foo

def test_type_mismatch_errors_in_both_modes(self):
class CtxBadType(ContextBase):
start_date: date
end_date: date

for strict in (False, True):
with pytest.raises(TypeError, match="annotates"):

@Flow.model(context_type=CtxBadType, strict=strict)
def span(start_date: FromContext[int], end_date: FromContext[date]) -> int:
return 0
Loading