Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions ccflow/flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,20 @@ def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]
return None


def _declared_context_type_for_model(model: CallableModel) -> Optional[Type[ContextBase]]:
"""Return a generated model's declared ``context_type``, if any.

Generated ``@Flow.model`` instances expose ``FlowContext`` as their runtime
context type but may also declare a nominal ``context_type`` whose ordered
fields enable positional/string context shorthand.
"""

generated = _generated_model_instance(model)
if generated is None:
return None
return type(generated).__flow_model_config__.declared_context_type


def _model_context_contract(
model: CallableModel,
) -> _ModelContextContract:
Expand Down Expand Up @@ -1897,6 +1911,17 @@ def _compute_context_from_explicit(model: CallableModel, context: Any, contract:
if _context_matches_type(context, model.context_type):
return context
return _runtime_context_for_model(model, _context_values(context))
# Positional/string shorthand (e.g. `[start, end]` or "start,end", as used by
# Hydra `+context=[...]`) only carries field *order*, not names. A generated
# model's runtime context_type is the open FlowContext bag, which has no declared
# fields to zip against, so the shorthand would otherwise be silently dropped.
# When the model declares a context_type, validate the shorthand through it first
# (which applies the ordered `zip(model_fields, v)` mapping), then forward the
# named values into the FlowContext bag. Mappings already carry names and keep
# their existing path.
declared = _declared_context_type_for_model(model)
if declared is not None and not isinstance(context, Mapping):
return _runtime_context_for_model(model, _context_values(declared.model_validate(context)))
return contract.runtime_context_type.model_validate(context)


Expand Down
53 changes: 53 additions & 0 deletions ccflow/tests/test_flow_model_context_shorthand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for positional/string context shorthand on ``@Flow.model(context_type=...)``.

Class-based ``CallableModel`` execution already accepts positional/string context
shorthand (the ordered ``zip(model_fields, v)`` mapping in ``ContextBase``). Generated
``@Flow.model`` instances expose ``FlowContext`` (the open bag) as their runtime context
type, so without a declared ``context_type`` there are no ordered fields to zip against.
When a ``context_type`` is declared, ``compute()`` routes the shorthand through it first.

Scope: this covers the ``compute()`` entry point. The direct-call form (``model([...])``)
is intentionally not supported here because ``Flow.call`` validates against ``FlowContext``
before the generated body runs; supporting it would require reverting the bag-of-types
design.
"""

from datetime import date

from ccflow import DateRangeContext, Flow, FromContext


@Flow.model(context_type=DateRangeContext)
def span(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days


class TestComputeShorthand:
def test_list_shorthand(self):
assert span().flow.compute(["2025-01-02", "2026-01-01"]).value == 364

def test_tuple_shorthand(self):
assert span().flow.compute(("2025-01-02", "2026-01-01")).value == 364

def test_string_shorthand(self):
assert span().flow.compute("2025-01-02,2026-01-01").value == 364

def test_named_inputs_still_work(self):
assert span().flow.compute(start_date="2025-01-02", end_date="2026-01-01").value == 364

def test_context_object_still_works(self):
ctx = DateRangeContext(start_date=date(2025, 1, 2), end_date=date(2026, 1, 1))
assert span().flow.compute(ctx).value == 364

def test_shorthand_matches_named(self):
assert span().flow.compute(["2025-01-02", "2026-01-01"]).value == span().flow.compute(start_date="2025-01-02", end_date="2026-01-01").value


class TestNoDeclaredContextTypeUnaffected:
def test_bag_model_still_uses_named(self):
@Flow.model
def bag(start_date: FromContext[date], end_date: FromContext[date]) -> int:
return (end_date - start_date).days

# Without a declared context_type there is no field order; named inputs are required.
assert bag().flow.compute(start_date="2025-01-02", end_date="2026-01-01").value == 364
Loading