Skip to content

Commit 1e67651

Browse files
committed
Migrating shared access token logic to new grpc class
Signed-off-by: Ryan Lettieri <ryanLettieri@microsoft.com>
1 parent f8d79d3 commit 1e67651

File tree

7 files changed

+94
-126
lines changed

7 files changed

+94
-126
lines changed

durabletask/internal/shared.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import grpc
1111

1212
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
13+
from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
1314

1415
# Field name used to indicate that an object was automatically serialized
1516
# and should be deserialized as a SimpleNamespace
@@ -50,7 +51,12 @@ def get_grpc_channel(
5051
channel = grpc.insecure_channel(host_address)
5152

5253
if metadata is not None and len(metadata) > 0:
53-
interceptors = [DefaultClientInterceptorImpl(metadata)]
54+
for key, _ in metadata:
55+
# Check if we are using DTS as the backend and if so, construct the DTS specific interceptors
56+
if key == "dts":
57+
interceptors = [DTSDefaultClientInterceptorImpl(metadata)]
58+
else:
59+
interceptors = [DefaultClientInterceptorImpl(metadata)]
5460
channel = grpc.intercept_channel(channel, *interceptors)
5561
return channel
5662

examples/dts/dts_activity_sequence.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from durabletask import client, task
55
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
66
from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient
7-
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
87

98
def hello(ctx: task.ActivityContext, name: str) -> str:
109
"""Activity function that returns a greeting"""
@@ -48,15 +47,15 @@ def sequence(ctx: task.OrchestrationContext, _):
4847

4948

5049
# configure and start the worker
51-
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w:
50+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, use_managed_identity=False, client_id="", taskhub=taskhub_name) as w:
5251
w.add_orchestrator(sequence)
5352
w.add_activity(hello)
5453
w.start()
5554

5655
# Construct the client and run the orchestrations
5756
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name)
5857
instance_id = c.schedule_new_orchestration(sequence)
59-
state = c.wait_for_orchestration_completion(instance_id, timeout=45)
58+
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
6059
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:
6160
print(f'Orchestration completed! Result: {state.serialized_output}')
6261
elif state:

examples/dts/dts_fanout_fanin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from durabletask import client, task
99
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
1010
from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient
11-
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
1211

1312

1413
def get_work_items(ctx: task.ActivityContext, _) -> list[str]:
@@ -74,7 +73,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
7473
exit()
7574

7675
# configure and start the worker
77-
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w:
76+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w:
7877
w.add_orchestrator(orchestrator)
7978
w.add_activity(process_work_item)
8079
w.add_activity(get_work_items)
@@ -88,3 +87,4 @@ def orchestrator(ctx: task.OrchestrationContext, _):
8887
print(f'Orchestration completed! Result: {state.serialized_output}')
8988
elif state:
9089
print(f'Orchestration failed: {state.failure_details}')
90+
exit()
Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
4-
from datetime import datetime, timedelta
4+
from datetime import datetime, timedelta, timezone
55
from typing import Optional
66

7+
# By default, when there's 10minutes left before the token expires, refresh the token
78
class AccessTokenManager:
8-
def __init__(self, refresh_buffer: int = 60, use_managed_identity: bool = False, client_id: Optional[str] = None):
9+
def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None):
910
self.scope = "https://durabletask.io/.default"
1011
self.refresh_buffer = refresh_buffer
11-
12+
self._use_managed_identity = False
13+
self._metadata = metadata
14+
self._client_id = None
15+
16+
if metadata: # Ensure metadata is not None
17+
for key, value in metadata:
18+
if key == "use_managed_identity":
19+
self._use_managed_identity = value.lower() == "true" # Properly convert string to bool
20+
elif key == "client_id":
21+
self._client_id = value # Directly assign string
22+
1223
# Choose the appropriate credential based on use_managed_identity
13-
if use_managed_identity:
14-
if not client_id:
24+
if self._use_managed_identity:
25+
if not self._client_id:
1526
print("Using System Assigned Managed Identity for authentication.")
1627
self.credential = ManagedIdentityCredential()
1728
else:
1829
print("Using User Assigned Managed Identity for authentication.")
19-
self.credential = ManagedIdentityCredential(client_id)
30+
self.credential = ManagedIdentityCredential(client_id=self._client_id)
2031
else:
2132
self.credential = DefaultAzureCredential()
2233
print("Using Default Azure Credentials for authentication.")
@@ -29,13 +40,18 @@ def get_access_token(self) -> str:
2940
self.refresh_token()
3041
return self.token
3142

43+
# Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds.
44+
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
45+
# We will grab a new token when there're 30minutes left on the lifespan of the token
3246
def is_token_expired(self) -> bool:
3347
if self.expiry_time is None:
3448
return True
35-
return datetime.utcnow() >= (self.expiry_time - timedelta(seconds=self.refresh_buffer))
49+
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer))
3650

3751
def refresh_token(self):
3852
new_token = self.credential.get_token(self.scope)
3953
self.token = f"Bearer {new_token.token}"
40-
self.expiry_time = datetime.utcnow() + timedelta(seconds=new_token.expires_on - int(datetime.utcnow().timestamp()))
54+
55+
# Convert UNIX timestamp to timezone-aware datetime
56+
self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc)
4157
print(f"Token refreshed. Expires at: {self.expiry_time}")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl
5+
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
6+
7+
import grpc
8+
9+
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
10+
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
11+
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
12+
interceptor to add additional headers to all calls as needed."""
13+
14+
def __init__(self, metadata: list[tuple[str, str]]):
15+
super().__init__(metadata)
16+
self._token_manager = AccessTokenManager(metadata=self._metadata)
17+
18+
def _intercept_call(
19+
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
20+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
21+
call details."""
22+
# Refresh the auth token if it is present and needed
23+
if self._metadata is not None:
24+
for i, (key, _) in enumerate(self._metadata):
25+
if key.lower() == "authorization": # Ensure case-insensitive comparison
26+
new_token = self._token_manager.get_access_token() # Get the new token
27+
self._metadata[i] = ("authorization", new_token) # Update the token
28+
29+
return super()._intercept_call(client_call_details)
Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
14
from typing import Optional
25
from durabletask.client import TaskHubGrpcClient
36
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
47

58
class DurableTaskSchedulerClient(TaskHubGrpcClient):
6-
def __init__(self, *args,
9+
def __init__(self,
10+
host_address: str,
11+
secure_channel: bool,
712
metadata: Optional[list[tuple[str, str]]] = None,
13+
use_managed_identity: Optional[bool] = False,
814
client_id: Optional[str] = None,
9-
taskhub: str,
15+
taskhub: str = None,
1016
**kwargs):
1117
if metadata is None:
1218
metadata = [] # Ensure metadata is initialized
1319
self._metadata = metadata
20+
self._use_managed_identity = use_managed_identity
1421
self._client_id = client_id
1522
self._metadata.append(("taskhub", taskhub))
16-
self._access_token_manager = AccessTokenManager(client_id=self._client_id)
23+
self._metadata.append(("dts", "True"))
24+
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
25+
self._metadata.append(("client_id", str(client_id)))
26+
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
1727
self.__update_metadata_with_token()
18-
super().__init__(*args, metadata=self._metadata, **kwargs)
28+
super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs)
1929

2030
def __update_metadata_with_token(self):
2131
"""
@@ -37,40 +47,4 @@ def __update_metadata_with_token(self):
3747

3848
# If not updated, add a new entry
3949
if not updated:
40-
self._metadata.append(("authorization", token))
41-
42-
def schedule_new_orchestration(self, *args, **kwargs) -> str:
43-
self.__update_metadata_with_token()
44-
return super().schedule_new_orchestration(*args, **kwargs)
45-
46-
def get_orchestration_state(self, *args, **kwargs):
47-
self.__update_metadata_with_token()
48-
super().get_orchestration_state(*args, **kwargs)
49-
50-
def wait_for_orchestration_start(self, *args, **kwargs):
51-
self.__update_metadata_with_token()
52-
super().wait_for_orchestration_start(*args, **kwargs)
53-
54-
def wait_for_orchestration_completion(self, *args, **kwargs):
55-
self.__update_metadata_with_token()
56-
super().wait_for_orchestration_completion(*args, **kwargs)
57-
58-
def raise_orchestration_event(self, *args, **kwargs):
59-
self.__update_metadata_with_token()
60-
super().raise_orchestration_event(*args, **kwargs)
61-
62-
def terminate_orchestration(self, *args, **kwargs):
63-
self.__update_metadata_with_token()
64-
super().terminate_orchestration(*args, **kwargs)
65-
66-
def suspend_orchestration(self, *args, **kwargs):
67-
self.__update_metadata_with_token()
68-
super().suspend_orchestration(*args, **kwargs)
69-
70-
def resume_orchestration(self, *args, **kwargs):
71-
self.__update_metadata_with_token()
72-
super().resume_orchestration(*args, **kwargs)
73-
74-
def purge_orchestration(self, *args, **kwargs):
75-
self.__update_metadata_with_token()
76-
super().purge_orchestration(*args, **kwargs)
50+
self._metadata.append(("authorization", token))
Lines changed: 15 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
1-
import concurrent.futures
2-
from threading import Thread
3-
from google.protobuf import empty_pb2
4-
import grpc
5-
import durabletask.internal.orchestrator_service_pb2 as pb
6-
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
7-
import durabletask.internal.shared as shared
8-
from typing import Optional
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
93

4+
from typing import Optional
105
from durabletask.worker import TaskHubGrpcWorker
116
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
127

8+
# Worker class used for Durable Task Scheduler (DTS)
139
class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
14-
def __init__(self, *args,
10+
def __init__(self,
11+
host_address: str,
12+
secure_channel: bool,
1513
metadata: Optional[list[tuple[str, str]]] = None,
14+
use_managed_identity: Optional[bool] = False,
1615
client_id: Optional[str] = None,
17-
taskhub: str,
16+
taskhub: str = None,
1817
**kwargs):
1918
if metadata is None:
2019
metadata = [] # Ensure metadata is initialized
2120
self._metadata = metadata
21+
self._use_managed_identity = use_managed_identity
2222
self._client_id = client_id
2323
self._metadata.append(("taskhub", taskhub))
24-
self._access_token_manager = AccessTokenManager(client_id=self._client_id)
24+
self._metadata.append(("dts", "True"))
25+
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
26+
self._metadata.append(("client_id", str(client_id)))
27+
self._access_token_manager = AccessTokenManager(metadata=self._metadata)
2528
self.__update_metadata_with_token()
26-
super().__init__(*args, metadata=self._metadata, **kwargs)
27-
29+
super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs)
2830

2931
def __update_metadata_with_token(self):
3032
"""
@@ -47,61 +49,3 @@ def __update_metadata_with_token(self):
4749
# If not updated, add a new entry
4850
if not updated:
4951
self._metadata.append(("authorization", token))
50-
51-
def start(self):
52-
"""Starts the worker on a background thread and begins listening for work items."""
53-
channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel)
54-
stub = stubs.TaskHubSidecarServiceStub(channel)
55-
56-
if self._is_running:
57-
raise RuntimeError('The worker is already running.')
58-
59-
def run_loop():
60-
# TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity
61-
# functions. We'd need to know ahead of time whether a function is async or not.
62-
# TODO: Max concurrency configuration settings
63-
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
64-
while not self._shutdown.is_set():
65-
try:
66-
self.__update_metadata_with_token()
67-
# send a "Hello" message to the sidecar to ensure that it's listening
68-
stub.Hello(empty_pb2.Empty())
69-
70-
# stream work items
71-
self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest())
72-
self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...')
73-
74-
# The stream blocks until either a work item is received or the stream is canceled
75-
# by another thread (see the stop() method).
76-
for work_item in self._response_stream: # type: ignore
77-
request_type = work_item.WhichOneof('request')
78-
self._logger.debug(f'Received "{request_type}" work item')
79-
if work_item.HasField('orchestratorRequest'):
80-
executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken)
81-
elif work_item.HasField('activityRequest'):
82-
executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken)
83-
elif work_item.HasField('healthPing'):
84-
pass # no-op
85-
else:
86-
self._logger.warning(f'Unexpected work item type: {request_type}')
87-
88-
except grpc.RpcError as rpc_error:
89-
if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
90-
self._logger.info(f'Disconnected from {self._host_address}')
91-
elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
92-
self._logger.warning(
93-
f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
94-
else:
95-
self._logger.warning(f'Unexpected error: {rpc_error}')
96-
except Exception as ex:
97-
self._logger.warning(f'Unexpected error: {ex}')
98-
99-
# CONSIDER: exponential backoff
100-
self._shutdown.wait(5)
101-
self._logger.info("No longer listening for work items")
102-
return
103-
104-
self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
105-
self._runLoop = Thread(target=run_loop)
106-
self._runLoop.start()
107-
self._is_running = True

0 commit comments

Comments
 (0)