Skip to content

Commit 41afe24

Browse files
committed
Add orchestration version to orchestration context
1 parent 0f0b30e commit 41afe24

File tree

3 files changed

+94
-15
lines changed

3 files changed

+94
-15
lines changed

durabletask/task.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def instance_id(self) -> str:
3535
"""
3636
pass
3737

38+
@property
39+
@abstractmethod
40+
def version(self) -> Optional[str]:
41+
"""Get the version of the orchestration instance.
42+
43+
This version is set when the orchestration is scheduled and can be used
44+
to determine which version of the orchestrator function is being executed.
45+
46+
Returns
47+
-------
48+
Optional[str]
49+
The version of the orchestration instance, or None if not set.
50+
"""
51+
pass
52+
3853
@property
3954
@abstractmethod
4055
def current_utc_datetime(self) -> datetime:

durabletask/worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ def __init__(self, instance_id: str, registry: _Registry):
651651
self._current_utc_datetime = datetime(1000, 1, 1)
652652
self._instance_id = instance_id
653653
self._registry = registry
654+
self._version: Optional[str] = None
654655
self._completion_status: Optional[pb.OrchestrationStatus] = None
655656
self._received_events: dict[str, list[Any]] = {}
656657
self._pending_events: dict[str, list[task.CompletableTask]] = {}
@@ -776,6 +777,10 @@ def next_sequence_number(self) -> int:
776777
def instance_id(self) -> str:
777778
return self._instance_id
778779

780+
@property
781+
def version(self) -> Optional[str]:
782+
return self._version
783+
779784
@property
780785
def current_utc_datetime(self) -> datetime:
781786
return self._current_utc_datetime
@@ -977,11 +982,12 @@ def execute(
977982

978983
# Process versioning if applicable
979984
execution_started_events = [e.executionStarted for e in old_events if e.HasField("executionStarted")]
985+
# We only check versioning if there are executionStarted events - otherwise, on the first replay when
986+
# ctx.version will be Null, we may invalidate orchestrations early depending on the versioning strategy.
980987
if self._registry.versioning and len(execution_started_events) > 0:
981-
execution_started_event = execution_started_events[-1]
982988
version_failure = self.evaluate_orchestration_versioning(
983989
self._registry.versioning,
984-
execution_started_event.version.value if execution_started_event.version else None,
990+
ctx.version
985991
)
986992
if version_failure:
987993
self._logger.warning(
@@ -1059,6 +1065,9 @@ def process_event(
10591065
f"A '{event.executionStarted.name}' orchestrator was not registered."
10601066
)
10611067

1068+
if event.executionStarted.version:
1069+
ctx._version = event.executionStarted.version.value
1070+
10621071
# deserialize the input, if any
10631072
input = None
10641073
if (

tests/durabletask-azuremanaged/test_dts_orchestration_versioning_e2e.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def plus_one(_: task.ActivityContext, input: int) -> int:
2323
return input + 1
2424

2525

26+
def plus_two(_: task.ActivityContext, input: int) -> int:
27+
return input + 2
28+
29+
2630
def single_activity(ctx: task.OrchestrationContext, start_val: int):
2731
yield ctx.call_activity(plus_one, input=start_val)
2832
return "Success"
@@ -153,6 +157,35 @@ def test_upper_version_worker_succeeds():
153157

154158

155159
def test_upper_version_worker_strict_fails():
160+
# Start a worker, which will connect to the sidecar in a background thread
161+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
162+
taskhub=taskhub_name, token_credential=None) as w:
163+
w.add_orchestrator(single_activity)
164+
w.add_activity(plus_one)
165+
w.use_versioning(worker.VersioningOptions(
166+
version="1.1.0",
167+
default_version="1.1.0",
168+
match_strategy=worker.VersionMatchStrategy.STRICT,
169+
failure_strategy=worker.VersionFailureStrategy.FAIL
170+
))
171+
w.start()
172+
173+
task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
174+
taskhub=taskhub_name, token_credential=None,
175+
default_version="1.1.0")
176+
id = task_hub_client.schedule_new_orchestration(single_activity, input=1, version="1.0.0")
177+
state = task_hub_client.wait_for_orchestration_completion(
178+
id, timeout=30)
179+
180+
assert state is not None
181+
assert state.name == task.get_name(single_activity)
182+
assert state.instance_id == id
183+
assert state.runtime_status == client.OrchestrationStatus.FAILED
184+
assert state.failure_details is not None
185+
assert state.failure_details.message.find("The orchestration version '1.0.0' does not match the worker version '1.1.0'.") >= 0
186+
187+
188+
def test_reject_abandons_and_reprocess():
156189
# Start a worker, which will connect to the sidecar in a background thread
157190
instance_id: str = ''
158191
thrown = False
@@ -206,36 +239,58 @@ def test_upper_version_worker_strict_fails():
206239
assert state.failure_details is None
207240

208241

209-
def test_reject_abandons_and_reprocess():
242+
def multiversion_sequence(ctx: task.OrchestrationContext, start_val: int):
243+
if ctx.version == "1.0.0":
244+
result = yield ctx.call_activity(plus_one, input=start_val)
245+
elif ctx.version == "1.1.0":
246+
result = yield ctx.call_activity(plus_two, input=start_val)
247+
else:
248+
raise ValueError(f"Unsupported version: {ctx.version}")
249+
return result
250+
251+
252+
def test_multiversion_orchestration_succeeds():
210253
# Start a worker, which will connect to the sidecar in a background thread
211254
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
212255
taskhub=taskhub_name, token_credential=None) as w:
213-
w.add_orchestrator(single_activity)
256+
w.add_orchestrator(multiversion_sequence)
214257
w.add_activity(plus_one)
258+
w.add_activity(plus_two)
215259
w.use_versioning(worker.VersioningOptions(
216260
version="1.1.0",
217261
default_version="1.1.0",
218-
match_strategy=worker.VersionMatchStrategy.STRICT,
262+
match_strategy=worker.VersionMatchStrategy.CURRENT_OR_OLDER,
219263
failure_strategy=worker.VersionFailureStrategy.FAIL
220264
))
221265
w.start()
222266

223267
task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
224268
taskhub=taskhub_name, token_credential=None,
225269
default_version="1.1.0")
226-
id = task_hub_client.schedule_new_orchestration(single_activity, input=1, version="1.0.0")
227-
state = task_hub_client.wait_for_orchestration_completion(
228-
id, timeout=30)
270+
id = task_hub_client.schedule_new_orchestration(multiversion_sequence, input=1, version="1.0.0")
271+
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
229272

230-
assert state is not None
231-
assert state.name == task.get_name(single_activity)
232-
assert state.instance_id == id
233-
assert state.runtime_status == client.OrchestrationStatus.FAILED
234-
assert state.failure_details is not None
235-
assert state.failure_details.message.find("The orchestration version '1.0.0' does not match the worker version '1.1.0'.") >= 0
273+
id_2 = task_hub_client.schedule_new_orchestration(multiversion_sequence, input=1, version="1.1.0")
274+
state_2 = task_hub_client.wait_for_orchestration_completion(id_2, timeout=30)
236275

276+
print(state.failure_details.message if state and state.failure_details else "State is None")
277+
print(state_2.failure_details.message if state_2 and state_2.failure_details else "State is None")
237278

238-
# Sub-orchestration tests
279+
assert state is not None
280+
assert state.name == task.get_name(multiversion_sequence)
281+
assert state.instance_id == id
282+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
283+
assert state.failure_details is None
284+
assert state.serialized_input == json.dumps(1)
285+
assert state.serialized_output == json.dumps(2)
286+
287+
assert state_2 is not None
288+
assert state_2.name == task.get_name(multiversion_sequence)
289+
assert state_2.instance_id == id_2
290+
assert state_2.runtime_status == client.OrchestrationStatus.COMPLETED
291+
assert state_2.failure_details is None
292+
assert state_2.serialized_input == json.dumps(1)
293+
assert state_2.serialized_output == json.dumps(3)
239294

240295

241296
def sequence_suborchestator(ctx: task.OrchestrationContext, start_val: int):

0 commit comments

Comments
 (0)