Skip to content

Commit b4d02f4

Browse files
authored
[python] support Multi clouds for ARM SDK (#5925)
- Fix #5783 - The expected SDK API: https://github.com/Azure/azure-sdk-for-python/pull/38250/files
1 parent be2ca06 commit b4d02f4

8 files changed

Lines changed: 85 additions & 22 deletions

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
changeKind: feature
3+
packages:
4+
- "@typespec/http-client-python"
5+
---
6+
7+
Improve user experience in multi clouds scenario

packages/http-client-python/generator/pygen/codegen/models/client.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
OverloadedRequestBuilder,
1616
get_request_builder,
1717
)
18-
from .parameter import Parameter, ParameterMethodLocation
18+
from .parameter import Parameter, ParameterMethodLocation, ParameterLocation
1919
from .lro_operation import LROOperation
2020
from .lro_paging_operation import LROPagingOperation
2121
from ...utils import extract_original_name, NAME_LENGTH_LIMIT
@@ -54,7 +54,7 @@ def name(self) -> str:
5454
return self.yaml_data["name"]
5555

5656

57-
class Client(_ClientConfigBase[ClientGlobalParameterList]):
57+
class Client(_ClientConfigBase[ClientGlobalParameterList]): # pylint: disable=too-many-public-methods
5858
"""Model representing our service client"""
5959

6060
def __init__(
@@ -79,6 +79,27 @@ def __init__(
7979
self.request_id_header_name = self.yaml_data.get("requestIdHeaderName", None)
8080
self.has_etag: bool = yaml_data.get("hasEtag", False)
8181

82+
# update the host parameter value. In later logic, SDK will overwrite it
83+
# with value from cloud_setting if users don't provide it.
84+
if self.need_cloud_setting:
85+
for p in self.parameters.parameters:
86+
if p.location == ParameterLocation.ENDPOINT_PATH:
87+
p.client_default_value = None
88+
p.optional = True
89+
break
90+
91+
@property
92+
def need_cloud_setting(self) -> bool:
93+
return bool(
94+
self.code_model.options.get("azure_arm", False)
95+
and self.credential_scopes is not None
96+
and self.endpoint_parameter is not None
97+
)
98+
99+
@property
100+
def endpoint_parameter(self) -> Optional[Parameter]:
101+
return next((p for p in self.parameters.parameters if p.location == ParameterLocation.ENDPOINT_PATH), None)
102+
82103
def _build_request_builders(
83104
self,
84105
) -> List[Union[RequestBuilder, OverloadedRequestBuilder]]:
@@ -233,6 +254,10 @@ def _imports_shared(self, async_mode: bool, **kwargs) -> FileImport:
233254
"Self",
234255
ImportType.STDLIB,
235256
)
257+
if self.need_cloud_setting:
258+
file_import.add_submodule_import("typing", "cast", ImportType.STDLIB)
259+
file_import.add_submodule_import("azure.core.settings", "settings", ImportType.SDKCORE)
260+
file_import.add_submodule_import("azure.mgmt.core.tools", "get_arm_endpoints", ImportType.SDKCORE)
236261
return file_import
237262

238263
@property
@@ -332,6 +357,18 @@ def imports_for_multiapi(self, async_mode: bool, **kwargs) -> FileImport:
332357
)
333358
return file_import
334359

360+
@property
361+
def credential_scopes(self) -> Optional[List[str]]:
362+
"""Credential scopes for this client"""
363+
364+
if self.credential:
365+
if hasattr(getattr(self.credential.type, "policy", None), "credential_scopes"):
366+
return self.credential.type.policy.credential_scopes # type: ignore
367+
for t in getattr(self.credential.type, "types", []):
368+
if hasattr(getattr(t, "policy", None), "credential_scopes"):
369+
return t.policy.credential_scopes
370+
return None
371+
335372
@classmethod
336373
def from_yaml(
337374
cls,

packages/http-client-python/generator/pygen/codegen/serializers/client_serializer.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6-
from typing import List
6+
from typing import List, cast
77

88
from . import utils
9-
from ..models import Client, ParameterMethodLocation
9+
from ..models import Client, ParameterMethodLocation, Parameter, ParameterLocation
1010
from .parameter_serializer import ParameterSerializer, PopKwargType
1111
from ...utils import build_policies
1212

@@ -77,17 +77,40 @@ def property_descriptions(self, async_mode: bool) -> List[str]:
7777
retval.append('"""')
7878
return retval
7979

80-
def initialize_config(self) -> str:
80+
def initialize_config(self) -> List[str]:
81+
retval = []
82+
additional_signatures = []
83+
if self.client.need_cloud_setting:
84+
additional_signatures.append("credential_scopes=credential_scopes")
85+
endpoint_parameter = cast(Parameter, self.client.endpoint_parameter)
86+
retval.extend(
87+
[
88+
'_cloud = kwargs.pop("cloud_setting", None) or settings.current.azure_cloud # type: ignore',
89+
"_endpoints = get_arm_endpoints(_cloud)",
90+
f"if not {endpoint_parameter.client_name}:",
91+
f' {endpoint_parameter.client_name} = _endpoints["resource_manager"]',
92+
'credential_scopes = kwargs.pop("credential_scopes", _endpoints["credential_scopes"])',
93+
]
94+
)
8195
config_name = f"{self.client.name}Configuration"
8296
config_call = ", ".join(
8397
[
84-
f"{p.client_name}={p.client_name}"
98+
(
99+
f"{p.client_name}="
100+
+ (
101+
f"cast(str, {p.client_name})"
102+
if self.client.need_cloud_setting and p.location == ParameterLocation.ENDPOINT_PATH
103+
else p.client_name
104+
)
105+
)
85106
for p in self.client.config.parameters.method
86107
if p.method_location != ParameterMethodLocation.KWARG
87108
]
109+
+ additional_signatures
88110
+ ["**kwargs"]
89111
)
90-
return f"self._config = {config_name}({config_call})"
112+
retval.append(f"self._config = {config_name}({config_call})")
113+
return retval
91114

92115
@property
93116
def host_variable_name(self) -> str:
@@ -104,8 +127,11 @@ def initialize_pipeline_client(self, async_mode: bool) -> List[str]:
104127
result = []
105128
pipeline_client_name = self.client.pipeline_class(async_mode)
106129
endpoint_name = "base_url" if self.client.code_model.is_azure_flavor else "endpoint"
130+
host_variable_name = (
131+
f"cast(str, {self.host_variable_name})" if self.client.need_cloud_setting else self.host_variable_name
132+
)
107133
params = {
108-
endpoint_name: self.host_variable_name,
134+
endpoint_name: host_variable_name,
109135
"policies": "_policies",
110136
}
111137
if not self.client.code_model.is_legacy and self.client.request_id_header_name:

packages/http-client-python/generator/pygen/codegen/serializers/general_serializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
VERSION_MAP = {
2020
"msrest": "0.7.1",
2121
"isodate": "0.6.1",
22-
"azure-mgmt-core": "1.3.2",
22+
"azure-mgmt-core": "1.5.0",
2323
"azure-core": "1.30.0",
2424
"typing-extensions": "4.6.0",
2525
"corehttp": "1.0.0b6",

packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _imports(self) -> FileImportSerializer:
6262
ImportType.SDKCORE,
6363
)
6464
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
65-
if not param.client_default_value and not param.optional and param.wire_name in self.sample_params:
65+
if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
6666
imports.merge(param.type.imports_for_sample())
6767
return FileImportSerializer(imports, True)
6868

@@ -80,7 +80,7 @@ def _client_params(self) -> Dict[str, Any]:
8080
for p in (
8181
self.code_model.clients[0].parameters.positional + self.code_model.clients[0].parameters.keyword_only
8282
)
83-
if not (p.optional or p.client_default_value)
83+
if not p.optional and p.client_default_value is None
8484
]
8585
client_params = {
8686
p.client_name: special_param.get(

packages/http-client-python/generator/pygen/codegen/templates/client.py.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
{% if client.has_parameterized_host %}
1010
{{ serializer.host_variable_name }} = {{ keywords.escape_str(client.url) }}
1111
{% endif %}
12-
{{ serializer.initialize_config() }}
12+
{{ op_tools.serialize(serializer.initialize_config()) | indent(8) }}
1313
{{ op_tools.serialize(serializer.initialize_pipeline_client(async_mode)) | indent(8) }}
1414

1515
{{ op_tools.serialize(serializer.serializers_and_operation_groups_properties()) | indent(8) }}

packages/http-client-python/generator/pygen/codegen/templates/config.py.jinja2

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,8 @@ class {{ client.name }}Configuration: {{ client.config.pylint_disable() }}
2121
{% if serializer.set_constants() %}
2222
{{ op_tools.serialize(serializer.set_constants()) | indent(8) -}}
2323
{% endif %}
24-
{% if client.credential %}
25-
{% set cred_scopes = client.credential.type if client.credential.type.policy is defined and client.credential.type.policy.credential_scopes is defined %}
26-
{% if not cred_scopes %}
27-
{% set cred_scopes = client.credential.type.types | selectattr("policy.credential_scopes") | first if client.credential.type.types is defined %}
28-
{% endif %}
29-
{% if cred_scopes %}
30-
self.credential_scopes = kwargs.pop('credential_scopes', {{ cred_scopes.policy.credential_scopes }})
31-
{% endif %}
24+
{% if client.credential_scopes is not none %}
25+
self.credential_scopes = kwargs.pop('credential_scopes', {{ client.credential_scopes }})
3226
{% endif %}
3327
kwargs.setdefault('sdk_moniker', '{{ client.config.sdk_moniker }}/{}'.format(VERSION))
3428
self.polling_interval = kwargs.get("polling_interval", 30)

packages/http-client-python/generator/test/azure/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
-r ../dev_requirements.txt
22
-e ../../
3-
azure-core==1.30.0
4-
azure-mgmt-core==1.3.2
3+
azure-mgmt-core==1.5.0
54

65
# only for azure
76
-e ./generated/azure-client-generator-core-access

0 commit comments

Comments
 (0)