From da33daede3d0dbb9677c1fc15ae15783d0bee6d4 Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Sat, 16 May 2026 14:47:02 -0700 Subject: [PATCH 1/3] feat(parallelism): make sub-application ID an overridable hook (#761) Previously, the sub-application id for parallel sub-apps was computed inline inside MapActionsAndStates._create_task as a deterministic hash of (parent_app_id, key). That hash is what enables resume-on-rebuild for parallel sub-apps (retry-on-failure, crash recovery), but it also means that with a cascading state initializer the sub-apps will replay stale state on every parent invocation -- the case reported in #761. The previous attempt (#778, closed) tried to auto-fix this by salting sub-app ids with context.sequence_id. That broke test_end_to_end_parallel_collatz_many_unreliable_tasks, because sequence_id advances on parent rebuilds and the retry path could no longer find the previously-persisted sub-app checkpoints. This change keeps the default behavior unchanged (deterministic (parent_app_id, key) hash, so retry-on-failure still works) and instead exposes the id computation as a named, overridable method on TaskBasedParallelAction. Users who hit #761 override the hook to add whatever salt they need; everyone else is unaffected. Adds one test exercising the new hook with a cascading initializer. --- burr/core/parallelism.py | 36 +++++++++++- tests/core/test_parallelism.py | 102 ++++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 857fed333..3864d6ee7 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -316,6 +316,35 @@ def is_async(self) -> bool: """ return False + def sub_application_id(self, key: str, state: State, context: ApplicationContext) -> str: + """Compute the application_id for a sub-task. + + Default: deterministic hash of (parent_app_id, key) -- stable across parent + rebuilds, which is what enables sub-app checkpoint resume on crash recovery. + If the parent application is rebuilt (e.g. as part of retry-on-failure or + a resume-from-persistence flow), each sub-task gets the same id it had + before, so a cascading state initializer can find its prior checkpoint + and pick up where it left off. + + Override to customize cache/resume behavior: + + - Fresh execution per invocation: salt with something that advances + per-call (e.g. ``context.sequence_id``, a uuid, your own counter). + This is the workaround for `#761 `_, + where a cascading state initializer combined with deterministic sub-app + ids causes a parallel action to replay stale sub-app state on every + invocation instead of running fresh. Note that opting into per-invocation + ids gives up the resume-on-rebuild guarantee above. + - Pin to a business key: derive from ``state`` contents (e.g. a record + id) so re-runs against the same logical input reuse the same sub-app. + + :param key: Per-task key (unique within this invocation of the parent action). + :param state: State that will be passed to the sub-task. + :param context: Parent application context. + :return: Application id to use for the sub-task. + """ + return _stable_app_id_hash(context.app_id, key) + @property def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: """Inputs from this -- if you want to override you'll want to call super() @@ -509,7 +538,7 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask: graph=RunnableGraph.create(action), inputs=inputs, state=substate, - application_id=_stable_app_id_hash(context.app_id, key), + application_id=self.sub_application_id(key, substate, context), tracker=tracker, state_persister=state_persister, state_initializer=state_initializer, @@ -518,7 +547,10 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask: def _tasks() -> Generator[SubGraphTask, None, None]: for i, action in enumerate(self.actions(state, context, inputs)): for j, substate in enumerate(self.states(state, context, inputs)): - key = f"{i}-{j}" # this is a stable hash for now but will not handle caching + # Per-task key is stable across rebuilds. The actual sub-app id is + # computed via ``sub_application_id``; override that hook to opt + # into per-invocation ids (see issue #761). + key = f"{i}-{j}" yield _create_task(key, action, substate) async def _atasks() -> AsyncGenerator[SubGraphTask, None]: diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 25d37cc24..02674195c 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -45,7 +45,12 @@ _cascade_adapter, map_reduce_action, ) -from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData +from burr.core.persistence import ( + BaseStateLoader, + BaseStateSaver, + InMemoryPersister, + PersistedStateData, +) from burr.tracking.base import SyncTrackingClient from burr.visibility import ActionSpan @@ -1227,3 +1232,98 @@ def reads(self) -> list[str]: assert task.state_initializer is not None assert task.tracker is not None assert task.state_persister is task.state_initializer # This ensures they're the same + + +def test_sub_application_id_override_enables_fresh_execution_with_cascading_initializer(): + """Regression test for #761. + + With a cascading state initializer + the default deterministic sub-app id, + re-invoking the parallel action on the same parent reuses prior sub-app + state via the initializer, so per-invocation work does not actually re-run. + Overriding ``sub_application_id`` to salt with a per-invocation value + restores fresh execution while leaving the default (resume-on-rebuild) + behavior alone for everyone else. + """ + invocation_count = {"n": 0} + shared_persister = InMemoryPersister() + + @old_action(reads=["input_number"], writes=["output_number", "invocation"]) + def record_invocation(state: State) -> State: + invocation_count["n"] += 1 + return state.update( + output_number=state["input_number"], invocation=invocation_count["n"] + ) + + class SaltedMapStates(MapStates): + call_index = 0 + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for input_number in state["input_numbers_in_state"]: + yield state.update(input_number=input_number) + + def action( + self, state: State, inputs: Dict[str, Any] + ) -> Union[Action, Callable, RunnableGraph]: + return record_invocation + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + return state.update( + invocations=[output_state["invocation"] for output_state in states] + ) + + # Pin sub-app persistence to the shared persister. This mirrors the + # #761 setup where a cascading initializer makes sub-apps resume. + def state_initializer(self, **kwargs): + return shared_persister + + def state_persister(self, **kwargs): + return shared_persister + + def sub_application_id( + self, key: str, state: State, context: ApplicationContext + ) -> str: + # Per-invocation salt -- each top-level run gets fresh sub-app ids. + return f"{context.app_id}:{key}:call-{type(self).call_index}" + + @property + def writes(self) -> list[str]: + return ["invocations"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + def _build_and_run(): + app = ( + ApplicationBuilder() + .with_actions( + initial=Input("input_numbers_in_state"), + map_action=SaltedMapStates(), + final=Result("invocations"), + ) + .with_transitions(("initial", "map_action"), ("map_action", "final")) + .with_entrypoint("initial") + .with_identifiers(app_id="parent-app-761") + .build() + ) + _, _, state = app.run( + halt_after=["final"], inputs={"input_numbers_in_state": [1, 2, 3]} + ) + return state + + # Three independent parent invocations against the same parent app_id and + # the same sub-app persister. With the override, every sub-task should + # actually execute on every run (no stale-replay caching). + SaltedMapStates.call_index = 0 + _build_and_run() + SaltedMapStates.call_index = 1 + _build_and_run() + SaltedMapStates.call_index = 2 + final_state = _build_and_run() + + # 3 inputs * 3 invocations = 9 actual executions. + assert invocation_count["n"] == 9 + # The latest run's invocations all come from the most recent counter range. + assert all(inv > 6 for inv in final_state["invocations"]) From 500559ce9ec1402a93969eee6ea3e3f063f08a7b Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Sat, 16 May 2026 16:32:31 -0700 Subject: [PATCH 2/3] style: black formatting on new test (line-length=100) --- tests/core/test_parallelism.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 02674195c..3931eeac9 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -1250,9 +1250,7 @@ def test_sub_application_id_override_enables_fresh_execution_with_cascading_init @old_action(reads=["input_number"], writes=["output_number", "invocation"]) def record_invocation(state: State) -> State: invocation_count["n"] += 1 - return state.update( - output_number=state["input_number"], invocation=invocation_count["n"] - ) + return state.update(output_number=state["input_number"], invocation=invocation_count["n"]) class SaltedMapStates(MapStates): call_index = 0 @@ -1269,9 +1267,7 @@ def action( return record_invocation def reduce(self, state: State, states: Generator[State, None, None]) -> State: - return state.update( - invocations=[output_state["invocation"] for output_state in states] - ) + return state.update(invocations=[output_state["invocation"] for output_state in states]) # Pin sub-app persistence to the shared persister. This mirrors the # #761 setup where a cascading initializer makes sub-apps resume. @@ -1281,9 +1277,7 @@ def state_initializer(self, **kwargs): def state_persister(self, **kwargs): return shared_persister - def sub_application_id( - self, key: str, state: State, context: ApplicationContext - ) -> str: + def sub_application_id(self, key: str, state: State, context: ApplicationContext) -> str: # Per-invocation salt -- each top-level run gets fresh sub-app ids. return f"{context.app_id}:{key}:call-{type(self).call_index}" @@ -1308,9 +1302,7 @@ def _build_and_run(): .with_identifiers(app_id="parent-app-761") .build() ) - _, _, state = app.run( - halt_after=["final"], inputs={"input_numbers_in_state": [1, 2, 3]} - ) + _, _, state = app.run(halt_after=["final"], inputs={"input_numbers_in_state": [1, 2, 3]}) return state # Three independent parent invocations against the same parent app_id and From f3fee186a586e77613b9e126c13158f9918413cd Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Sat, 16 May 2026 21:36:21 -0700 Subject: [PATCH 3/3] docs: route example tasks() through sub_application_id hook --- burr/core/parallelism.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 3864d6ee7..8a5d046e2 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -220,12 +220,18 @@ def tasks(state: State, context: ApplicationContext) -> Generator[SubGraphTask, query_llm.bind(model="o1").with_name("o1_answer"), query_llm.bind(model="claude").with_name("claude_answer"), ] + # Route the application_id through self.sub_application_id so + # subclasses can customize the sub-app cache/resume behavior -- + # e.g. salt with a fresh value per invocation to bypass a + # cascading state_initializer's checkpoint hits. The default + # hook is a stable deterministic hash and is what most users + # want (it's load-bearing for retry-on-failure resume). + key = f"{prompt}:{action.name}" # any stable key you choose yield SubGraphTask( action=action, # can be a RunnableGraph as well state=state.update(prompt=prompt), inputs={}, - # stable hash -- up to you to ensure uniqueness - application_id=hashlib.sha256(context.application_id + action.name + prompt).hexdigest(), + application_id=self.sub_application_id(key, state, context), # a few other parameters we might add -- see advanced usage -- failure conditions, etc... )