1818
1919from aws_advanced_python_wrapper .utils .iam_utils import IamAuthUtils , TokenInfo
2020from aws_advanced_python_wrapper .utils .region_utils import RegionUtils
21+ from aws_advanced_python_wrapper .utils .token_utils import TokenUtils
22+ from aws_advanced_python_wrapper .utils .rds_token_utils import RDSTokenUtils
2123
2224if TYPE_CHECKING :
2325 from boto3 import Session
@@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin):
4850 _rds_utils : RdsUtils = RdsUtils ()
4951 _token_cache : Dict [str , TokenInfo ] = {}
5052
51- def __init__ (self , plugin_service : PluginService , session : Optional [Session ] = None ):
53+ def __init__ (self , plugin_service : PluginService , token_utils : TokenUtils , session : Optional [Session ] = None ):
5254 self ._plugin_service = plugin_service
5355 self ._session = session
5456
5557 self ._region_utils = RegionUtils ()
58+ self ._token_utils = token_utils
5659 telemetry_factory = self ._plugin_service .get_telemetry_factory ()
5760 self ._fetch_token_counter = telemetry_factory .create_counter ("iam.fetch_token.count" )
5861 self ._cache_size_gauge = telemetry_factory .create_gauge (
@@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
102105 else :
103106 token_expiry = datetime .now () + timedelta (seconds = token_expiration_sec )
104107 self ._fetch_token_counter .inc ()
105- token : str = IamAuthUtils .generate_authentication_token (self ._plugin_service , user , host , port , region , client_session = self ._session )
108+ token : str = self . _token_utils .generate_authentication_token (self ._plugin_service , user , host , port , region , client_session = self ._session )
106109 self ._plugin_service .driver_dialect .set_password (props , token )
107110 IamAuthPlugin ._token_cache [cache_key ] = TokenInfo (token , token_expiry )
108111
@@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
120123 # Try to generate a new token and try to connect again
121124 token_expiry = datetime .now () + timedelta (seconds = token_expiration_sec )
122125 self ._fetch_token_counter .inc ()
123- token = IamAuthUtils .generate_authentication_token (self ._plugin_service , user , host , port , region , client_session = self ._session )
126+ token = self . _token_utils .generate_authentication_token (self ._plugin_service , user , host , port , region , client_session = self ._session )
124127 self ._plugin_service .driver_dialect .set_password (props , token )
125128 IamAuthPlugin ._token_cache [cache_key ] = TokenInfo (token , token_expiry )
126129
@@ -142,4 +145,4 @@ def force_connect(
142145
143146class IamAuthPluginFactory (PluginFactory ):
144147 def get_instance (self , plugin_service : PluginService , props : Properties ) -> Plugin :
145- return IamAuthPlugin (plugin_service )
148+ return IamAuthPlugin (plugin_service , RDSTokenUtils () )
0 commit comments