Skip to content

Commit da8f15c

Browse files
committed
fixed unit tests
1 parent 37bfb90 commit da8f15c

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

tests/unit/test_federated_auth_plugin.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from aws_advanced_python_wrapper.iam_plugin import TokenInfo
2626
from aws_advanced_python_wrapper.utils.properties import (Properties,
2727
WrapperProperties)
28+
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
2829

2930
_GENERATED_TOKEN = "generated_token"
3031
_TEST_TOKEN = "test_token"
@@ -101,6 +102,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi
101102
_token_cache[_PG_CACHE_KEY] = initial_token
102103

103104
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service,
105+
RDSTokenUtils(),
104106
mock_session)
105107
key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser"
106108
_token_cache[key] = initial_token
@@ -129,7 +131,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu
129131
initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5))
130132
_token_cache[_PG_CACHE_KEY] = initial_token
131133

132-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
134+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service,
135+
mock_credentials_provider_factory,
136+
RDSTokenUtils(),
137+
mock_session)
133138

134139
target_plugin.connect(
135140
target_driver_func=mocker.MagicMock(),
@@ -154,7 +159,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
154159
test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
155160
WrapperProperties.DB_USER.set(test_props, _DB_USER)
156161

157-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
162+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service,
163+
mock_credentials_provider_factory,
164+
RDSTokenUtils(),
165+
mock_session)
158166

159167
target_plugin.connect(
160168
target_driver_func=mocker.MagicMock(),
@@ -183,7 +191,9 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess
183191
exception_message = "generic exception"
184192
mock_func.side_effect = Exception(exception_message)
185193

186-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory,
194+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service,
195+
mock_credentials_provider_factory,
196+
RDSTokenUtils(),
187197
mock_session)
188198
with pytest.raises(Exception) as e_info:
189199
target_plugin.connect(
@@ -229,7 +239,10 @@ def test_connect_with_specified_iam_host_port_region(mocker,
229239

230240
mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}"
231241

232-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
242+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service,
243+
mock_credentials_provider_factory,
244+
RDSTokenUtils(),
245+
mock_session)
233246
target_plugin.connect(
234247
target_driver_func=mocker.MagicMock(),
235248
driver_dialect=mock_dialect,

tests/unit/test_iam_plugin.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo
2727
from aws_advanced_python_wrapper.utils.properties import (Properties,
2828
WrapperProperties)
29+
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
2930

3031
_GENERATED_TOKEN = "generated_token"
3132
_TEST_TOKEN = "test_token"
@@ -99,6 +100,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi
99100
_token_cache[_PG_CACHE_KEY] = initial_token
100101

101102
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
103+
RDSTokenUtils(),
102104
mock_session)
103105
target_plugin.connect(
104106
target_driver_func=mocker.MagicMock(),
@@ -127,6 +129,7 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port(
127129
assert test_props.get("password") is None
128130

129131
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
132+
RDSTokenUtils(),
130133
mock_session)
131134
target_plugin.connect(
132135
target_driver_func=mocker.MagicMock(),
@@ -163,6 +166,7 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port(
163166
assert test_props.get("password") is None
164167

165168
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
169+
RDSTokenUtils(),
166170
mock_session)
167171
target_plugin.connect(
168172
target_driver_func=mocker.MagicMock(),
@@ -195,7 +199,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio
195199
_token_cache[_PG_CACHE_KEY] = initial_token
196200

197201
mock_func.side_effect = Exception("generic exception")
198-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
202+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
203+
RDSTokenUtils(),
204+
mock_session)
199205
with pytest.raises(Exception):
200206
target_plugin.connect(
201207
target_driver_func=mocker.MagicMock(),
@@ -220,7 +226,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio
220226
@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache)
221227
def test_connect_empty_cache(mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect):
222228
test_props: Properties = Properties({"user": "postgresqlUser"})
223-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
229+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
230+
RDSTokenUtils(),
231+
mock_session)
224232
actual_connection = target_plugin.connect(
225233
target_driver_func=mocker.MagicMock(),
226234
driver_dialect=mock_dialect,
@@ -251,7 +259,9 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session,
251259
# Assert no password has been set
252260
assert test_props.get("password") is None
253261

254-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
262+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
263+
RDSTokenUtils(),
264+
mock_session)
255265
target_plugin.connect(
256266
target_driver_func=mocker.MagicMock(),
257267
driver_dialect=mock_dialect,
@@ -285,7 +295,9 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo
285295
# Assert no password has been set
286296
assert test_props.get("password") is None
287297

288-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
298+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
299+
RDSTokenUtils(),
300+
mock_session)
289301
target_plugin.connect(
290302
target_driver_func=mocker.MagicMock(),
291303
driver_dialect=mock_dialect,
@@ -323,7 +335,9 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session
323335
assert test_props.get("password") is None
324336

325337
mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}"
326-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
338+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
339+
RDSTokenUtils(),
340+
mock_session)
327341
target_plugin.connect(
328342
target_driver_func=mocker.MagicMock(),
329343
driver_dialect=mock_dialect,
@@ -369,7 +383,9 @@ def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service,
369383
assert test_props.get("password") is None
370384

371385
mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}"
372-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
386+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service,
387+
RDSTokenUtils(),
388+
mock_session)
373389
target_plugin.connect(
374390
target_driver_func=mocker.MagicMock(),
375391
driver_dialect=mock_dialect,
@@ -411,7 +427,7 @@ def test_aws_supported_regions_url_exists():
411427
def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect):
412428
test_props: Properties = Properties({"user": "postgresqlUser"})
413429
with pytest.raises(AwsWrapperError):
414-
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
430+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session)
415431
target_plugin.connect(
416432
target_driver_func=mocker.MagicMock(),
417433
driver_dialect=mock_dialect,

tests/unit/test_okta_plugin.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from aws_advanced_python_wrapper.okta_plugin import OktaAuthPlugin
2626
from aws_advanced_python_wrapper.utils.properties import (Properties,
2727
WrapperProperties)
28+
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
2829

2930
_GENERATED_TOKEN = "generated_token"
3031
_TEST_TOKEN = "test_token"
@@ -100,7 +101,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi
100101
initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5))
101102
_token_cache[_PG_CACHE_KEY] = initial_token
102103

103-
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_session)
104+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session)
104105
key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser"
105106
_token_cache[key] = initial_token
106107

@@ -127,7 +128,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu
127128
initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5))
128129
_token_cache[_PG_CACHE_KEY] = initial_token
129130

130-
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
131+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service,
132+
mock_credentials_provider_factory,
133+
RDSTokenUtils(),
134+
mock_session)
131135

132136
target_plugin.connect(
133137
target_driver_func=mocker.MagicMock(),
@@ -151,7 +155,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
151155
test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
152156
WrapperProperties.DB_USER.set(test_props, _DB_USER)
153157

154-
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
158+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service,
159+
mock_credentials_provider_factory,
160+
RDSTokenUtils(),
161+
mock_session)
155162

156163
target_plugin.connect(
157164
target_driver_func=mocker.MagicMock(),
@@ -179,7 +186,10 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess
179186
exception_message = "generic exception"
180187
mock_func.side_effect = Exception(exception_message)
181188

182-
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
189+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service,
190+
mock_credentials_provider_factory,
191+
RDSTokenUtils(),
192+
mock_session)
183193

184194
with pytest.raises(Exception) as e_info:
185195
target_plugin.connect(
@@ -225,7 +235,10 @@ def test_connect_with_specified_iam_host_port_region(mocker,
225235

226236
mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}"
227237

228-
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
238+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service,
239+
mock_credentials_provider_factory,
240+
RDSTokenUtils(),
241+
mock_session)
229242
target_plugin.connect(
230243
target_driver_func=mocker.MagicMock(),
231244
driver_dialect=mock_dialect,

0 commit comments

Comments
 (0)