diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 6e4f73351f..e205d9be52 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -79,6 +79,7 @@ class OAuth2Auth(BaseModelWithConfig): auth_code: Optional[str] = None access_token: Optional[str] = None refresh_token: Optional[str] = None + id_token: Optional[str] = None expires_at: Optional[int] = None expires_in: Optional[int] = None audience: Optional[str] = None diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py index d2f40b339f..888c8d0d39 100644 --- a/src/google/adk/auth/oauth2_credential_util.py +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -107,11 +107,13 @@ def update_credential_with_tokens( auth_credential: The authentication credential to update. tokens: The OAuth2Token object containing new token information. """ - auth_credential.oauth2.access_token = tokens.get("access_token") - auth_credential.oauth2.refresh_token = tokens.get("refresh_token") - auth_credential.oauth2.expires_at = ( - int(tokens.get("expires_at")) if tokens.get("expires_at") else None - ) - auth_credential.oauth2.expires_in = ( - int(tokens.get("expires_in")) if tokens.get("expires_in") else None - ) + if auth_credential.oauth2: + auth_credential.oauth2.access_token = tokens.get("access_token") + auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.id_token = tokens.get("id_token") + auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py index 1e499ca741..88398c482a 100644 --- a/tests/unittests/auth/test_oauth2_credential_util.py +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -222,6 +222,7 @@ def test_update_credential_with_tokens(self): tokens = OAuth2Token({ "access_token": "new_access_token", "refresh_token": "new_refresh_token", + "id_token": "new_id_token", "expires_at": expected_expires_at, "expires_in": 3600, }) @@ -230,5 +231,16 @@ def test_update_credential_with_tokens(self): assert credential.oauth2.access_token == "new_access_token" assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.id_token == "new_id_token" assert credential.oauth2.expires_at == expected_expires_at assert credential.oauth2.expires_in == 3600 + + def test_update_credential_with_tokens_none(self) -> None: + credential = AuthCredential( + auth_type=AuthCredentialTypes.API_KEY, + ) + tokens = OAuth2Token({"access_token": "new_access_token"}) + + # Should not raise any exceptions when oauth2 is None + update_credential_with_tokens(credential, tokens) + assert credential.oauth2 is None