Skip to content

Commit b640243

Browse files
committed
Fixed errors with updated interfaces for federated and octa plugins.
1 parent d22c555 commit b640243

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from aws_advanced_python_wrapper.utils.properties import (Properties,
4545
WrapperProperties)
4646
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
47+
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
48+
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
4749

4850
logger = Logger(__name__)
4951

@@ -55,12 +57,13 @@ class FederatedAuthPlugin(Plugin):
5557
_rds_utils: RdsUtils = RdsUtils()
5658
_token_cache: Dict[str, TokenInfo] = {}
5759

58-
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
60+
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None):
5961
self._plugin_service = plugin_service
6062
self._credentials_provider_factory = credentials_provider_factory
6163
self._session = session
6264

6365
self._region_utils = RegionUtils()
66+
self._token_utils = token_utils
6467
telemetry_factory = self._plugin_service.get_telemetry_factory()
6568
self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count")
6669
self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache))
@@ -145,7 +148,7 @@ def _update_authentication_token(self,
145148
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)
146149

147150
self._fetch_token_counter.inc()
148-
token: str = IamAuthUtils.generate_authentication_token(
151+
token: str = self._token_utils.generate_authentication_token(
149152
self._plugin_service,
150153
user,
151154
host_info.host,
@@ -159,7 +162,7 @@ def _update_authentication_token(self,
159162

160163
class FederatedAuthPluginFactory(PluginFactory):
161164
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
162-
return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props))
165+
return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils())
163166

164167
def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory:
165168
idp_name = WrapperProperties.IDP_NAME.get(props)

aws_advanced_python_wrapper/okta_plugin.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from aws_advanced_python_wrapper.utils.properties import (Properties,
4242
WrapperProperties)
4343
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
44+
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
45+
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
4446

4547
logger = Logger(__name__)
4648

@@ -51,12 +53,13 @@ class OktaAuthPlugin(Plugin):
5153
_rds_utils: RdsUtils = RdsUtils()
5254
_token_cache: Dict[str, TokenInfo] = {}
5355

54-
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
56+
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None):
5557
self._plugin_service = plugin_service
5658
self._credentials_provider_factory = credentials_provider_factory
5759
self._session = session
5860

5961
self._region_utils = RegionUtils()
62+
self._token_utils = token_utils
6063
telemetry_factory = self._plugin_service.get_telemetry_factory()
6164
self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count")
6265
self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache))
@@ -140,7 +143,7 @@ def _update_authentication_token(self,
140143
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
141144
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)
142145

143-
token: str = IamAuthUtils.generate_authentication_token(
146+
token: str = self._token_utils.generate_authentication_token(
144147
self._plugin_service,
145148
user,
146149
host_info.host,
@@ -228,7 +231,7 @@ def get_saml_assertion(self, props: Properties):
228231

229232
class OktaAuthPluginFactory(PluginFactory):
230233
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
231-
return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props))
234+
return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils())
232235

233236
def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory:
234237
return OktaCredentialsProviderFactory(plugin_service, props)

0 commit comments

Comments
 (0)