Skip to content

Commit 9c65176

Browse files
committed
More review feedback for token passing
Signed-off-by: Ryan Lettieri <ryanLettieri@microsoft.com>
1 parent 2c251ea commit 9c65176

File tree

6 files changed

+28
-49
lines changed

6 files changed

+28
-49
lines changed

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,13 @@ class DurableTaskSchedulerClient(TaskHubGrpcClient):
1212
def __init__(self, *,
1313
host_address: str,
1414
taskhub: str,
15-
secure_channel: Optional[bool] = True,
16-
metadata: Optional[list[tuple[str, str]]] = None,
17-
token_credential: Optional[TokenCredential] = None):
15+
token_credential: TokenCredential = None,
16+
secure_channel: Optional[bool] = True):
1817

1918
if taskhub == None:
2019
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
2120

22-
# Ensure metadata is a list
23-
metadata = metadata or []
24-
self._metadata = metadata.copy() # Use a copy to avoid modifying original
25-
26-
# Append DurableTask-specific metadata
27-
self._metadata.append(("taskhub", taskhub))
28-
self._metadata.append(("dts", "True"))
29-
self._metadata.append(("token_credential", token_credential))
30-
self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
21+
self._interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)]
3122

3223
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
3324
# Since the parent class doesn't use anything metadata for anything else, we can set it as None

durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,23 @@
33

44
from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl
55
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
6-
6+
from azure.core.credentials import TokenCredential
77
import grpc
88

99
class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
1010
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
1111
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
1212
interceptor to add additional headers to all calls as needed."""
1313

14-
def __init__(self, metadata: list[tuple[str, str]]):
14+
def __init__(self, token_credential: TokenCredential, taskhub_name: str):
15+
metadata = [("taskhub", taskhub_name)]
1516
super().__init__(metadata)
1617

17-
self._token_credential = None
18-
19-
# Check what authentication we are using
20-
if metadata:
21-
for key, value in metadata:
22-
if key.lower() == "token_credential":
23-
self._token_credential = value
24-
25-
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
26-
token = self._token_manager.get_access_token()
27-
self._metadata.append(("authorization", token))
18+
if token_credential is not None:
19+
self._token_credential = token_credential
20+
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
21+
token = self._token_manager.get_access_token()
22+
self._metadata.append(("authorization", token))
2823

2924
def _intercept_call(
3025
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:

durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,7 @@ def __init__(self, refresh_interval_seconds: int = 600, token_credential: TokenC
1313
self._refresh_interval_seconds = refresh_interval_seconds
1414
self._logger = shared.get_logger("token_manager")
1515

16-
# Choose the appropriate credential.
17-
# Both TokenCredential and DefaultAzureCredential get_token methods return an AccessToken
18-
if token_credential:
19-
self._logger.debug("Using user provided token credentials.")
20-
self._credential = token_credential
21-
else:
22-
self._credential = DefaultAzureCredential()
23-
self._logger.debug("Using Default Azure Credentials for authentication.")
16+
self._credential = token_credential
2417

2518
self._token = self._credential.get_token(self._scope)
2619
self.expiry_time = None

durabletask-azuremanaged/durabletask/azuremanaged/worker.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,13 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
1212
def __init__(self, *,
1313
host_address: str,
1414
taskhub: str,
15-
secure_channel: Optional[bool] = True,
16-
metadata: Optional[list[tuple[str, str]]] = None,
17-
token_credential: Optional[TokenCredential] = None):
15+
token_credential: TokenCredential = None,
16+
secure_channel: Optional[bool] = True):
1817

1918
if taskhub == None:
2019
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
2120

22-
# Ensure metadata is a list
23-
metadata = metadata or []
24-
self._metadata = metadata.copy() # Copy to prevent modifying input
25-
26-
# Append DurableTask-specific metadata
27-
self._metadata.append(("taskhub", taskhub))
28-
self._metadata.append(("dts", "True"))
29-
self._metadata.append(("token_credential", token_credential))
30-
interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
21+
interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)]
3122

3223
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
3324
# Since the parent class doesn't use anything metadata for anything else, we can set it as None

examples/dts/dts_activity_sequence.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from durabletask import task
55
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
66
from durabletask.azuremanaged.client import DurableTaskSchedulerClient, OrchestrationStatus
7+
from azure.identity import DefaultAzureCredential
78

89
def hello(ctx: task.ActivityContext, name: str) -> str:
910
"""Activity function that returns a greeting"""
@@ -45,15 +46,18 @@ def sequence(ctx: task.OrchestrationContext, _):
4546
print("If you are using bash, run the following: export ENDPOINT=\"<schedulerEndpoint>\"")
4647
exit()
4748

49+
credential = DefaultAzureCredential()
4850

4951
# configure and start the worker
50-
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w:
52+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
53+
taskhub=taskhub_name, token_credential=credential) as w:
5154
w.add_orchestrator(sequence)
5255
w.add_activity(hello)
5356
w.start()
5457

5558
# Construct the client and run the orchestrations
56-
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name)
59+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
60+
taskhub=taskhub_name, token_credential=credential)
5761
instance_id = c.schedule_new_orchestration(sequence)
5862
state = c.wait_for_orchestration_completion(instance_id, timeout=60)
5963
if state and state.runtime_status == OrchestrationStatus.COMPLETED:

examples/dts/dts_fanout_fanin.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from durabletask import client, task
99
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
1010
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
11+
from azure.identity import DefaultAzureCredential
1112

1213
def get_work_items(ctx: task.ActivityContext, _) -> list[str]:
1314
"""Activity function that returns a list of work items"""
@@ -71,15 +72,19 @@ def orchestrator(ctx: task.OrchestrationContext, _):
7172
print("If you are using bash, run the following: export ENDPOINT=\"<schedulerEndpoint>\"")
7273
exit()
7374

75+
credential = DefaultAzureCredential()
76+
7477
# configure and start the worker
75-
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w:
78+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
79+
taskhub=taskhub_name, token_credential=credential) as w:
7680
w.add_orchestrator(orchestrator)
7781
w.add_activity(process_work_item)
7882
w.add_activity(get_work_items)
7983
w.start()
8084

8185
# create a client, start an orchestration, and wait for it to finish
82-
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name)
86+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
87+
taskhub=taskhub_name, token_credential=credential)
8388
instance_id = c.schedule_new_orchestration(orchestrator)
8489
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
8590
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:

0 commit comments

Comments
 (0)