Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 6ac8d14

Browse files
committed
add options to client argument and e2e test
Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com>
1 parent 9bba479 commit 6ac8d14

7 files changed

Lines changed: 275 additions & 43 deletions

File tree

durabletask/aio/client.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020

2121

2222
class AsyncTaskHubGrpcClient:
23-
24-
def __init__(self, *,
25-
host_address: Optional[str] = None,
26-
metadata: Optional[list[tuple[str, str]]] = None,
27-
log_handler: Optional[logging.Handler] = None,
28-
log_formatter: Optional[logging.Formatter] = None,
29-
secure_channel: bool = False,
30-
interceptors: Optional[Sequence[ClientInterceptor]] = None):
31-
23+
def __init__(
24+
self,
25+
*,
26+
host_address: Optional[str] = None,
27+
metadata: Optional[list[tuple[str, str]]] = None,
28+
log_handler: Optional[logging.Handler] = None,
29+
log_formatter: Optional[logging.Formatter] = None,
30+
secure_channel: bool = False,
31+
interceptors: Optional[Sequence[ClientInterceptor]] = None,
32+
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
33+
):
3234
if interceptors is not None:
3335
interceptors = list(interceptors)
3436
if metadata is not None:
@@ -41,7 +43,8 @@ def __init__(self, *,
4143
channel = get_grpc_aio_channel(
4244
host_address=host_address,
4345
secure_channel=secure_channel,
44-
interceptors=interceptors
46+
interceptors=interceptors,
47+
options=channel_options,
4548
)
4649
self._channel = channel
4750
self._stub = stubs.TaskHubSidecarServiceStub(channel)

durabletask/aio/internal/shared.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,72 @@
11
# Copyright (c) The Dapr Authors.
22
# Licensed under the MIT License.
33

4-
from typing import Optional, Sequence, Union
4+
from typing import Any, Optional, Sequence, Union
55

66
import grpc
77
from grpc import aio as grpc_aio
88

99
from durabletask.internal.shared import (
10-
get_default_host_address,
11-
SECURE_PROTOCOLS,
1210
INSECURE_PROTOCOLS,
11+
SECURE_PROTOCOLS,
12+
get_default_host_address,
1313
)
1414

15-
1615
ClientInterceptor = Union[
1716
grpc_aio.UnaryUnaryClientInterceptor,
1817
grpc_aio.UnaryStreamClientInterceptor,
1918
grpc_aio.StreamUnaryClientInterceptor,
20-
grpc_aio.StreamStreamClientInterceptor
19+
grpc_aio.StreamStreamClientInterceptor,
2120
]
2221

2322

2423
def get_grpc_aio_channel(
25-
host_address: Optional[str],
26-
secure_channel: bool = False,
27-
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel:
24+
host_address: Optional[str],
25+
secure_channel: bool = False,
26+
interceptors: Optional[Sequence[ClientInterceptor]] = None,
27+
options: Optional[Sequence[tuple[str, Any]]] = None,
28+
) -> grpc_aio.Channel:
29+
"""create a grpc asyncio channel
2830
31+
Args:
32+
host_address: The host address of the gRPC server. If None, uses the default address.
33+
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
34+
interceptors: Optional sequence of client interceptors to apply to the channel.
35+
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
36+
"""
2937
if host_address is None:
3038
host_address = get_default_host_address()
3139

3240
for protocol in SECURE_PROTOCOLS:
3341
if host_address.lower().startswith(protocol):
3442
secure_channel = True
35-
host_address = host_address[len(protocol):]
43+
host_address = host_address[len(protocol) :]
3644
break
3745

3846
for protocol in INSECURE_PROTOCOLS:
3947
if host_address.lower().startswith(protocol):
4048
secure_channel = False
41-
host_address = host_address[len(protocol):]
49+
host_address = host_address[len(protocol) :]
4250
break
4351

52+
# Create the base channel
4453
if secure_channel:
45-
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors)
46-
else:
47-
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
54+
if options is not None:
55+
return grpc_aio.secure_channel(
56+
host_address,
57+
grpc.ssl_channel_credentials(),
58+
interceptors=interceptors,
59+
options=options,
60+
)
61+
return grpc_aio.secure_channel(
62+
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors
63+
)
4864

49-
return channel
65+
if options is not None:
66+
# validate all options keys prefix starts with `grpc.`
67+
if not all(key.startswith('grpc.') for key, _ in options):
68+
raise ValueError(
69+
f'All options keys must start with `grpc.`. Invalid options: {options}'
70+
)
71+
return grpc_aio.insecure_channel(host_address, interceptors=interceptors, options=options)
72+
return grpc_aio.insecure_channel(host_address, interceptors=interceptors)

durabletask/client.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,17 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
9191

9292

9393
class TaskHubGrpcClient:
94-
95-
def __init__(self, *,
96-
host_address: Optional[str] = None,
97-
metadata: Optional[list[tuple[str, str]]] = None,
98-
log_handler: Optional[logging.Handler] = None,
99-
log_formatter: Optional[logging.Formatter] = None,
100-
secure_channel: bool = False,
101-
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
102-
94+
def __init__(
95+
self,
96+
*,
97+
host_address: Optional[str] = None,
98+
metadata: Optional[list[tuple[str, str]]] = None,
99+
log_handler: Optional[logging.Handler] = None,
100+
log_formatter: Optional[logging.Formatter] = None,
101+
secure_channel: bool = False,
102+
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
103+
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
104+
):
103105
# If the caller provided metadata, we need to create a new interceptor for it and
104106
# add it to the list of interceptors.
105107
if interceptors is not None:
@@ -114,7 +116,8 @@ def __init__(self, *,
114116
channel = shared.get_grpc_channel(
115117
host_address=host_address,
116118
secure_channel=secure_channel,
117-
interceptors=interceptors
119+
interceptors=interceptors,
120+
options=channel_options,
118121
)
119122
self._stub = stubs.TaskHubSidecarServiceStub(channel)
120123
self._logger = shared.get_logger("client", log_handler, log_formatter)

durabletask/internal/shared.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,54 @@ def get_default_host_address() -> str:
5151

5252

5353
def get_grpc_channel(
54-
host_address: Optional[str],
55-
secure_channel: bool = False,
56-
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel:
54+
host_address: Optional[str],
55+
secure_channel: bool = False,
56+
interceptors: Optional[Sequence[ClientInterceptor]] = None,
57+
options: Optional[Sequence[tuple[str, Any]]] = None,
58+
) -> grpc.Channel:
59+
"""create a grpc channel
60+
61+
Args:
62+
host_address: The host address of the gRPC server. If None, uses the default address.
63+
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
64+
interceptors: Optional sequence of client interceptors to apply to the channel.
65+
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
66+
"""
5767
if host_address is None:
5868
host_address = get_default_host_address()
5969

6070
for protocol in SECURE_PROTOCOLS:
6171
if host_address.lower().startswith(protocol):
6272
secure_channel = True
6373
# remove the protocol from the host name
64-
host_address = host_address[len(protocol):]
74+
host_address = host_address[len(protocol) :]
6575
break
6676

6777
for protocol in INSECURE_PROTOCOLS:
6878
if host_address.lower().startswith(protocol):
6979
secure_channel = False
7080
# remove the protocol from the host name
71-
host_address = host_address[len(protocol):]
81+
host_address = host_address[len(protocol) :]
7282
break
7383

7484
# Create the base channel
75-
if secure_channel:
76-
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
85+
if options is not None:
86+
# validate all options keys prefix starts with `grpc.`
87+
if not all(key.startswith('grpc.') for key, _ in options):
88+
raise ValueError(
89+
f'All options keys must start with `grpc.`. Invalid options: {options}'
90+
)
91+
if secure_channel:
92+
channel = grpc.secure_channel(
93+
host_address, grpc.ssl_channel_credentials(), options=options
94+
)
95+
else:
96+
channel = grpc.insecure_channel(host_address, options=options)
7797
else:
78-
channel = grpc.insecure_channel(host_address)
98+
if secure_channel:
99+
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
100+
else:
101+
channel = grpc.insecure_channel(host_address)
79102

80103
# Apply interceptors ONLY if they exist
81104
if interceptors:
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import json
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from durabletask.aio.internal.shared import get_grpc_aio_channel
7+
8+
HOST_ADDRESS = 'localhost:50051'
9+
10+
11+
def _find_option(options, key):
12+
for k, v in options:
13+
if k == key:
14+
return v
15+
raise AssertionError(f'Option with key {key} not found in options: {options}')
16+
17+
18+
def test_aio_channel_passes_base_options_and_max_lengths():
19+
base_options = [
20+
('grpc.max_send_message_length', 4321),
21+
('grpc.max_receive_message_length', 8765),
22+
('grpc.primary_user_agent', 'durabletask-aio-tests'),
23+
]
24+
with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel:
25+
get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options)
26+
# Ensure called with options kwarg
27+
assert mock_channel.call_count == 1
28+
args, kwargs = mock_channel.call_args
29+
assert args[0] == HOST_ADDRESS
30+
assert 'options' in kwargs
31+
opts = kwargs['options']
32+
# Check our base options made it through
33+
assert ('grpc.max_send_message_length', 4321) in opts
34+
assert ('grpc.max_receive_message_length', 8765) in opts
35+
assert ('grpc.primary_user_agent', 'durabletask-aio-tests') in opts
36+
37+
38+
def test_aio_channel_merges_env_keepalive_and_retry(monkeypatch: pytest.MonkeyPatch):
39+
# retry grpc option
40+
# service_config ref => https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L44
41+
max_attempts = 4
42+
initial_backoff_ms = 250
43+
max_backoff_ms = 2000
44+
backoff_multiplier = 1.5
45+
codes = ['RESOURCE_EXHAUSTED']
46+
service_config = {
47+
'methodConfig': [
48+
{
49+
'name': [{'service': ''}], # match all services/methods
50+
'retryPolicy': {
51+
'maxAttempts': max_attempts,
52+
'initialBackoff': f'{initial_backoff_ms / 1000.0}s',
53+
'maxBackoff': f'{max_backoff_ms / 1000.0}s',
54+
'backoffMultiplier': backoff_multiplier,
55+
'retryableStatusCodes': codes,
56+
},
57+
}
58+
]
59+
}
60+
61+
base_options = [('grpc.service_config', json.dumps(service_config))]
62+
63+
with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel:
64+
get_grpc_aio_channel(HOST_ADDRESS, False, options=base_options)
65+
66+
args, kwargs = mock_channel.call_args
67+
assert args[0] == HOST_ADDRESS
68+
assert 'options' in kwargs
69+
opts = kwargs['options']
70+
71+
# Retry service config present and parses correctly
72+
svc_cfg_str = _find_option(opts, 'grpc.service_config')
73+
svc_cfg = json.loads(svc_cfg_str)
74+
assert 'methodConfig' in svc_cfg and isinstance(svc_cfg['methodConfig'], list)
75+
retry_policy = svc_cfg['methodConfig'][0]['retryPolicy']
76+
assert retry_policy['maxAttempts'] == 4
77+
assert retry_policy['initialBackoff'] == f'{250 / 1000.0}s'
78+
assert retry_policy['maxBackoff'] == f'{2000 / 1000.0}s'
79+
assert retry_policy['backoffMultiplier'] == 1.5
80+
# Codes are upper-cased list
81+
assert 'RESOURCE_EXHAUSTED' in retry_policy['retryableStatusCodes']
82+
83+
84+
def test_aio_secure_channel_receives_options_when_secure_true():
85+
base_options = [('grpc.max_receive_message_length', 999999)]
86+
with (
87+
patch('durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_channel,
88+
patch('grpc.ssl_channel_credentials') as mock_credentials,
89+
):
90+
get_grpc_aio_channel(HOST_ADDRESS, True, options=base_options)
91+
args, kwargs = mock_channel.call_args
92+
assert args[0] == HOST_ADDRESS
93+
assert args[1] == mock_credentials.return_value
94+
assert ('grpc.max_receive_message_length', 999999) in kwargs.get('options', [])

0 commit comments

Comments
 (0)