Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit a7ff4a1

Browse files
authored
Merge pull request #44 from CasperGN/fix(bug)--shutdown-on-iterator
Fix(bug): shutdown on iterator
2 parents 6b4a726 + b98630e commit a7ff4a1

5 files changed

Lines changed: 84 additions & 11 deletions

File tree

durabletask/aio/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# If `opentelemetry-instrumentation-grpc` is available, enable the gRPC client interceptor
2828
try:
2929
from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient
30+
3031
GrpcInstrumentorClient().instrument()
3132
except ImportError:
3233
pass

durabletask/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
# If `opentelemetry-instrumentation-grpc` is available, enable the gRPC client interceptor
2525
try:
2626
from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient
27+
2728
GrpcInstrumentorClient().instrument()
2829
except ImportError:
2930
pass
3031

32+
3133
class OrchestrationStatus(Enum):
3234
"""The status of an orchestration instance."""
3335

durabletask/worker.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from datetime import datetime, timedelta
1313
from threading import Event, Thread
1414
from types import GeneratorType
15-
from typing import Any, Generator, Optional, Sequence, TypeVar, Union
15+
from typing import Any, Generator, Iterator, Optional, Sequence, TypeVar, Union
1616

1717
import grpc
1818
from google.protobuf import empty_pb2
@@ -38,7 +38,6 @@
3838
otel_tracer = None
3939

4040

41-
4241
class VersionNotRegisteredException(Exception):
4342
pass
4443

@@ -283,7 +282,7 @@ class TaskHubGrpcWorker:
283282
activity function.
284283
"""
285284

286-
_response_stream: Optional[grpc.Future] = None
285+
_response_stream: Optional[Union[Iterator[grpc.Future], grpc.Future]] = None
287286
_interceptors: Optional[list[shared.ClientInterceptor]] = None
288287

289288
def __init__(
@@ -421,9 +420,12 @@ def invalidate_connection():
421420
# Cancel the response stream first to signal the reader thread to stop
422421
if self._response_stream is not None:
423422
try:
424-
self._response_stream.cancel()
425-
except Exception:
426-
pass
423+
if hasattr(self._response_stream, "call"):
424+
self._response_stream.call.cancel() # type: ignore
425+
else:
426+
self._response_stream.cancel() # type: ignore
427+
except Exception as e:
428+
self._logger.warning(f"Error cancelling response stream: {e}")
427429
self._response_stream = None
428430

429431
# Wait for the reader thread to finish
@@ -740,7 +742,13 @@ def stop(self):
740742

741743
self._logger.info("Stopping gRPC worker...")
742744
if self._response_stream is not None:
743-
self._response_stream.cancel()
745+
try:
746+
if hasattr(self._response_stream, "call"):
747+
self._response_stream.call.cancel() # type: ignore
748+
else:
749+
self._response_stream.cancel() # type: ignore
750+
except Exception as e:
751+
self._logger.warning(f"Error cancelling response stream: {e}")
744752
self._shutdown.set()
745753
# Explicitly close the gRPC channel to ensure OTel interceptors and other resources are cleaned up
746754
if self._current_channel is not None:
@@ -854,13 +862,15 @@ def _execute_activity(
854862

855863
if otel_tracer is not None:
856864
span_context = otel_tracer.start_as_current_span(
857-
name=f'activity: {req.name}',
858-
context=otel_propagator.extract(carrier={"traceparent": req.parentTraceContext.traceParent}),
865+
name=f"activity: {req.name}",
866+
context=otel_propagator.extract(
867+
carrier={"traceparent": req.parentTraceContext.traceParent}
868+
),
859869
attributes={
860870
"durabletask.task.instance_id": instance_id,
861871
"durabletask.task.id": req.taskId,
862872
"durabletask.activity.name": req.name,
863-
}
873+
},
864874
)
865875
else:
866876
span_context = contextlib.nullcontext()

tests/durabletask/test_registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def activity2(ctx, input):
165165
assert registry.get_activity(name1) is activity1
166166
assert registry.get_activity(name2) is activity2
167167

168+
168169
def test_registry_add_named_versioned_orchestrators():
169170
"""Test adding versioned orchestrators."""
170171
registry = worker._Registry()
@@ -179,7 +180,9 @@ def orchestrator3(ctx, input):
179180
return "two"
180181

181182
registry.add_named_orchestrator(name="orchestrator", fn=orchestrator1, version_name="v1")
182-
registry.add_named_orchestrator(name="orchestrator", fn=orchestrator2, version_name="v2", is_latest=True)
183+
registry.add_named_orchestrator(
184+
name="orchestrator", fn=orchestrator2, version_name="v2", is_latest=True
185+
)
183186
registry.add_named_orchestrator(name="orchestrator", fn=orchestrator3, version_name="v3")
184187

185188
orquestrator, version = registry.get_orchestrator(name="orchestrator")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import grpc
4+
5+
from durabletask.worker import TaskHubGrpcWorker
6+
7+
8+
# Helper to create a running worker with a mocked runLoop
9+
def _make_running_worker():
10+
worker = TaskHubGrpcWorker()
11+
worker._is_running = True
12+
worker._runLoop = MagicMock()
13+
worker._runLoop.is_alive.return_value = False
14+
return worker
15+
16+
17+
def test_stop_with_grpc_future():
18+
worker = _make_running_worker()
19+
mock_future = MagicMock(spec=grpc.Future)
20+
worker._response_stream = mock_future
21+
worker.stop()
22+
mock_future.cancel.assert_called_once()
23+
24+
25+
def test_stop_with_generator_call():
26+
worker = _make_running_worker()
27+
mock_call = MagicMock()
28+
mock_stream = MagicMock()
29+
mock_stream.call = mock_call
30+
worker._response_stream = mock_stream
31+
worker.stop()
32+
mock_call.cancel.assert_called_once()
33+
34+
35+
def test_stop_with_unknown_stream_type(caplog):
36+
worker = _make_running_worker()
37+
# Not a grpc.Future, no 'call' attribute
38+
worker._response_stream = object()
39+
with caplog.at_level("WARNING"):
40+
worker.stop()
41+
assert any("Error cancelling response stream: " in m for m in caplog.text.splitlines())
42+
43+
44+
def test_stop_with_none_stream():
45+
worker = _make_running_worker()
46+
worker._response_stream = None
47+
# Should not raise
48+
worker.stop()
49+
50+
51+
def test_stop_when_not_running():
52+
worker = TaskHubGrpcWorker()
53+
worker._is_running = False
54+
# Should return immediately, not set _shutdown
55+
with patch.object(worker._shutdown, "set") as shutdown_set:
56+
worker.stop()
57+
shutdown_set.assert_not_called()

0 commit comments

Comments
 (0)