diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index c8b0cf8..0cec78e 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -928,6 +928,20 @@ def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"] return None +def _declared_context_type_for_model(model: CallableModel) -> Optional[Type[ContextBase]]: + """Return a generated model's declared ``context_type``, if any. + + Generated ``@Flow.model`` instances expose ``FlowContext`` as their runtime + context type but may also declare a nominal ``context_type`` whose ordered + fields enable positional/string context shorthand. + """ + + generated = _generated_model_instance(model) + if generated is None: + return None + return type(generated).__flow_model_config__.declared_context_type + + def _model_context_contract( model: CallableModel, ) -> _ModelContextContract: @@ -1897,6 +1911,17 @@ def _compute_context_from_explicit(model: CallableModel, context: Any, contract: if _context_matches_type(context, model.context_type): return context return _runtime_context_for_model(model, _context_values(context)) + # Positional/string shorthand (e.g. `[start, end]` or "start,end", as used by + # Hydra `+context=[...]`) only carries field *order*, not names. A generated + # model's runtime context_type is the open FlowContext bag, which has no declared + # fields to zip against, so the shorthand would otherwise be silently dropped. + # When the model declares a context_type, validate the shorthand through it first + # (which applies the ordered `zip(model_fields, v)` mapping), then forward the + # named values into the FlowContext bag. Mappings already carry names and keep + # their existing path. + declared = _declared_context_type_for_model(model) + if declared is not None and not isinstance(context, Mapping): + return _runtime_context_for_model(model, _context_values(declared.model_validate(context))) return contract.runtime_context_type.model_validate(context) diff --git a/ccflow/tests/test_flow_model_context_shorthand.py b/ccflow/tests/test_flow_model_context_shorthand.py new file mode 100644 index 0000000..a9bd861 --- /dev/null +++ b/ccflow/tests/test_flow_model_context_shorthand.py @@ -0,0 +1,53 @@ +"""Tests for positional/string context shorthand on ``@Flow.model(context_type=...)``. + +Class-based ``CallableModel`` execution already accepts positional/string context +shorthand (the ordered ``zip(model_fields, v)`` mapping in ``ContextBase``). Generated +``@Flow.model`` instances expose ``FlowContext`` (the open bag) as their runtime context +type, so without a declared ``context_type`` there are no ordered fields to zip against. +When a ``context_type`` is declared, ``compute()`` routes the shorthand through it first. + +Scope: this covers the ``compute()`` entry point. The direct-call form (``model([...])``) +is intentionally not supported here because ``Flow.call`` validates against ``FlowContext`` +before the generated body runs; supporting it would require reverting the bag-of-types +design. +""" + +from datetime import date + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def span(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + +class TestComputeShorthand: + def test_list_shorthand(self): + assert span().flow.compute(["2025-01-02", "2026-01-01"]).value == 364 + + def test_tuple_shorthand(self): + assert span().flow.compute(("2025-01-02", "2026-01-01")).value == 364 + + def test_string_shorthand(self): + assert span().flow.compute("2025-01-02,2026-01-01").value == 364 + + def test_named_inputs_still_work(self): + assert span().flow.compute(start_date="2025-01-02", end_date="2026-01-01").value == 364 + + def test_context_object_still_works(self): + ctx = DateRangeContext(start_date=date(2025, 1, 2), end_date=date(2026, 1, 1)) + assert span().flow.compute(ctx).value == 364 + + def test_shorthand_matches_named(self): + assert span().flow.compute(["2025-01-02", "2026-01-01"]).value == span().flow.compute(start_date="2025-01-02", end_date="2026-01-01").value + + +class TestNoDeclaredContextTypeUnaffected: + def test_bag_model_still_uses_named(self): + @Flow.model + def bag(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return (end_date - start_date).days + + # Without a declared context_type there is no field order; named inputs are required. + assert bag().flow.compute(start_date="2025-01-02", end_date="2026-01-01").value == 364