Skip to content

Commit d22c555

Browse files
committed
Add support for DSQL iam authentication
1 parent ddbb65f commit d22c555

File tree

10 files changed

+266
-51
lines changed

10 files changed

+266
-51
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils
20+
from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin
21+
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
22+
from aws_advanced_python_wrapper.utils.properties import (Properties)
23+
24+
if TYPE_CHECKING:
25+
from aws_advanced_python_wrapper.plugin_service import PluginService
26+
27+
class DsqlIamAuthPluginFactory(PluginFactory):
28+
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
29+
return IamAuthPlugin(plugin_service, DSQLTokenUtils())

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
2020
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
21+
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
22+
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
2123

2224
if TYPE_CHECKING:
2325
from boto3 import Session
@@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin):
4850
_rds_utils: RdsUtils = RdsUtils()
4951
_token_cache: Dict[str, TokenInfo] = {}
5052

51-
def __init__(self, plugin_service: PluginService, session: Optional[Session] = None):
53+
def __init__(self, plugin_service: PluginService, token_utils: TokenUtils, session: Optional[Session] = None):
5254
self._plugin_service = plugin_service
5355
self._session = session
5456

5557
self._region_utils = RegionUtils()
58+
self._token_utils = token_utils
5659
telemetry_factory = self._plugin_service.get_telemetry_factory()
5760
self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count")
5861
self._cache_size_gauge = telemetry_factory.create_gauge(
@@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
102105
else:
103106
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
104107
self._fetch_token_counter.inc()
105-
token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
108+
token: str = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
106109
self._plugin_service.driver_dialect.set_password(props, token)
107110
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
108111

@@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
120123
# Try to generate a new token and try to connect again
121124
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
122125
self._fetch_token_counter.inc()
123-
token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
126+
token = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
124127
self._plugin_service.driver_dialect.set_password(props, token)
125128
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
126129

@@ -142,4 +145,4 @@ def force_connect(
142145

143146
class IamAuthPluginFactory(PluginFactory):
144147
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
145-
return IamAuthPlugin(plugin_service)
148+
return IamAuthPlugin(plugin_service, RDSTokenUtils())

aws_advanced_python_wrapper/plugin_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
HostMonitoringPluginFactory
7474
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
7575
from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory
76+
from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import DsqlIamAuthPluginFactory
7677
from aws_advanced_python_wrapper.plugin import CanReleaseResources
7778
from aws_advanced_python_wrapper.read_write_splitting_plugin import \
7879
ReadWriteSplittingPluginFactory
@@ -716,6 +717,7 @@ class PluginManager(CanReleaseResources):
716717

717718
PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = {
718719
"iam": IamAuthPluginFactory,
720+
"iam_dsql": DsqlIamAuthPluginFactory,
719721
"aws_secrets_manager": AwsSecretsManagerPluginFactory,
720722
"aurora_connection_tracker": AuroraConnectionTrackerPluginFactory,
721723
"host_monitoring": HostMonitoringPluginFactory,
@@ -748,6 +750,7 @@ class PluginManager(CanReleaseResources):
748750
HostMonitoringPluginFactory: 500,
749751
FastestResponseStrategyPluginFactory: 600,
750752
IamAuthPluginFactory: 700,
753+
DsqlIamAuthPluginFactory: 710,
751754
AwsSecretsManagerPluginFactory: 800,
752755
FederatedAuthPluginFactory: 900,
753756
LimitlessPluginFactory: 950,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC, abstractmethod
18+
from typing import TYPE_CHECKING, Dict, Optional
19+
from aws_advanced_python_wrapper.utils.log import Logger
20+
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
21+
TelemetryTraceLevel
22+
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
23+
24+
if TYPE_CHECKING:
25+
from aws_advanced_python_wrapper.plugin_service import PluginService
26+
from boto3 import Session
27+
28+
import boto3
29+
30+
logger = Logger(__name__)
31+
32+
class DSQLTokenUtils(TokenUtils):
33+
def generate_authentication_token(
34+
self,
35+
plugin_service: PluginService,
36+
user: Optional[str],
37+
host_name: Optional[str],
38+
port: Optional[int],
39+
region: Optional[str],
40+
credentials: Optional[Dict[str, str]] = None,
41+
client_session: Optional[Session] = None) -> str:
42+
telemetry_factory = plugin_service.get_telemetry_factory()
43+
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)
44+
45+
try:
46+
client = boto3.client("dsql", region_name=region)
47+
48+
if user == "admin":
49+
token = client.generate_db_connect_admin_auth_token(host_name, region)
50+
else:
51+
token = client.generate_db_connect_auth_token(host_name, region)
52+
53+
logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
54+
return token
55+
except Exception as ex:
56+
context.set_success(False)
57+
context.set_exception(ex)
58+
raise ex
59+
finally:
60+
context.close_context()

aws_advanced_python_wrapper/utils/iam_utils.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -70,53 +70,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int)
7070
def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
7171
return f"{region}:{hostname}:{port}:{user}"
7272

73-
@staticmethod
74-
def generate_authentication_token(
75-
plugin_service: PluginService,
76-
user: Optional[str],
77-
host_name: Optional[str],
78-
port: Optional[int],
79-
region: Optional[str],
80-
credentials: Optional[Dict[str, str]] = None,
81-
client_session: Optional[Session] = None) -> str:
82-
telemetry_factory = plugin_service.get_telemetry_factory()
83-
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)
84-
85-
try:
86-
session = client_session if client_session else boto3.Session()
87-
88-
if credentials is not None:
89-
client = session.client(
90-
'rds',
91-
region_name=region,
92-
aws_access_key_id=credentials.get('AccessKeyId'),
93-
aws_secret_access_key=credentials.get('SecretAccessKey'),
94-
aws_session_token=credentials.get('SessionToken')
95-
)
96-
else:
97-
client = session.client(
98-
'rds',
99-
region_name=region
100-
)
101-
102-
token = client.generate_db_auth_token(
103-
DBHostname=host_name,
104-
Port=port,
105-
DBUsername=user
106-
)
107-
108-
client.close()
109-
110-
logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
111-
return token
112-
except Exception as ex:
113-
context.set_success(False)
114-
context.set_exception(ex)
115-
raise ex
116-
finally:
117-
context.close_context()
118-
119-
12073
class TokenInfo:
12174
@property
12275
def token(self):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC, abstractmethod
18+
from typing import TYPE_CHECKING, Dict, Optional
19+
from aws_advanced_python_wrapper.utils.log import Logger
20+
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
21+
TelemetryTraceLevel
22+
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
23+
24+
if TYPE_CHECKING:
25+
from aws_advanced_python_wrapper.plugin_service import PluginService
26+
from boto3 import Session
27+
28+
import boto3
29+
30+
logger = Logger(__name__)
31+
32+
class RDSTokenUtils(TokenUtils):
33+
def generate_authentication_token(
34+
self,
35+
plugin_service: PluginService,
36+
user: Optional[str],
37+
host_name: Optional[str],
38+
port: Optional[int],
39+
region: Optional[str],
40+
credentials: Optional[Dict[str, str]] = None,
41+
client_session: Optional[Session] = None) -> str:
42+
43+
telemetry_factory = plugin_service.get_telemetry_factory()
44+
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)
45+
46+
try:
47+
session = client_session if client_session else boto3.Session()
48+
49+
if credentials is not None:
50+
client = session.client(
51+
'rds',
52+
region_name=region,
53+
aws_access_key_id=credentials.get('AccessKeyId'),
54+
aws_secret_access_key=credentials.get('SecretAccessKey'),
55+
aws_session_token=credentials.get('SessionToken')
56+
)
57+
else:
58+
client = session.client(
59+
'rds',
60+
region_name=region
61+
)
62+
63+
token = client.generate_db_auth_token(
64+
DBHostname=host_name,
65+
Port=port,
66+
DBUsername=user
67+
)
68+
69+
client.close()
70+
71+
logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
72+
return token
73+
except Exception as ex:
74+
context.set_success(False)
75+
context.set_exception(ex)
76+
raise ex
77+
finally:
78+
context.close_context()

aws_advanced_python_wrapper/utils/rds_url_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ def __init__(self, is_rds: bool, is_rds_cluster: bool):
3434
RDS_PROXY = True, False,
3535
RDS_INSTANCE = True, False,
3636
RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = True, False,
37+
DSQL_CLUSTER = False, False,
3738
OTHER = False, False

aws_advanced_python_wrapper/utils/rdsutils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ class RdsUtils:
108108
r"(?P<dns>cluster-|cluster-ro-)+" \
109109
r"(?P<domain>[a-zA-Z0-9]+\.rds\.(?P<region>[a-zA-Z0-9\-]+)" \
110110
r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$"
111+
AURORA_DSQL_CLUSTER_PATTERN = r"^(?P<instance>[^.]+)\." \
112+
r"(?P<dns>dsql(?:-[^.]+)?)\." \
113+
r"(?P<domain>(?P<region>[a-zA-Z0-9\-]+)" \
114+
r"\.on\.aws\.?)$"
111115
ELB_PATTERN = r"^(?<instance>.+)\.elb\.((?<region>[a-zA-Z0-9\-]+)\.amazonaws\.com)$"
112116

113117
IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \
@@ -148,6 +152,14 @@ def is_rds_dns(self, host: str) -> bool:
148152

149153
def is_rds_instance(self, host: str) -> bool:
150154
return self._get_dns_group(host) is None and self.is_rds_dns(host)
155+
156+
def is_dsql_cluster(self, host: str) -> bool:
157+
if not host or not host.strip():
158+
return False
159+
160+
pattern = self._find(host, [RdsUtils.AURORA_DSQL_CLUSTER_PATTERN])
161+
162+
return pattern is not None
151163

152164
def is_rds_proxy_dns(self, host: str) -> bool:
153165
dns_group = self._get_dns_group(host)
@@ -257,6 +269,8 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType:
257269
return RdsUrlType.RDS_PROXY
258270
elif self.is_rds_instance(host):
259271
return RdsUrlType.RDS_INSTANCE
272+
elif self.is_dsql_cluster(host):
273+
return RdsUrlType.DSQL_CLUSTER
260274

261275
return RdsUrlType.OTHER
262276

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC, abstractmethod
18+
from typing import TYPE_CHECKING, Dict, Optional
19+
20+
if TYPE_CHECKING:
21+
from aws_advanced_python_wrapper.plugin_service import PluginService
22+
from boto3 import Session
23+
24+
class TokenUtils(ABC):
25+
@abstractmethod
26+
def generate_authentication_token(
27+
self,
28+
plugin_service: PluginService,
29+
user: Optional[str],
30+
host_name: Optional[str],
31+
port: Optional[int],
32+
region: Optional[str],
33+
credentials: Optional[Dict[str, str]] = None,
34+
client_session: Optional[Session] = None) -> str:
35+
pass

0 commit comments

Comments
 (0)