diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 40d857f4..7896cb7e 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -18,19 +18,23 @@ import jwt from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError, DecodeError from jwt import InvalidAudienceError, InvalidIssuerError, InvalidSignatureError +import requests from firebase_admin import _utils +from firebase_admin import _http_client +from firebase_admin import exceptions _APP_CHECK_ATTRIBUTE = '_app_check' def _get_app_check_service(app) -> Any: return _utils.get_app_service(app, _APP_CHECK_ATTRIBUTE, _AppCheckService) -def verify_token(token: str, app=None) -> Dict[str, Any]: +def verify_token(token: str, app=None, consume: bool = False) -> Dict[str, Any]: """Verifies a Firebase App Check token. Args: token: A token from App Check. app: An App instance (optional). + consume: A boolean indicating whether to consume the token (optional). Returns: Dict[str, Any]: The token's decoded claims. @@ -40,16 +44,18 @@ def verify_token(token: str, app=None) -> Dict[str, Any]: or if the token's headers or payload are invalid. PyJWKClientError: If PyJWKClient fails to fetch a valid signing key. """ - return _get_app_check_service(app).verify_token(token) + return _get_app_check_service(app).verify_token(token, consume) class _AppCheckService: """Service class that implements Firebase App Check functionality.""" _APP_CHECK_ISSUER = 'https://firebaseappcheck.googleapis.com/' _JWKS_URL = 'https://firebaseappcheck.googleapis.com/v1/jwks' + _APP_CHECK_V1BETA_URL = 'https://firebaseappcheck.googleapis.com/v1beta' _project_id = None _scoped_project_id = None _jwks_client = None + _http_client = None _APP_CHECK_HEADERS = { 'x-goog-api-client': _utils.get_metrics_header(), @@ -68,9 +74,12 @@ def __init__(self, app): # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). self._jwks_client = PyJWKClient( self._JWKS_URL, lifespan=21600, headers=self._APP_CHECK_HEADERS) + self._http_client = _http_client.JsonHttpClient( + credential=app.credential, + base_url=self._APP_CHECK_V1BETA_URL) - def verify_token(self, token: str) -> Dict[str, Any]: + def verify_token(self, token: str, consume: bool = False) -> Dict[str, Any]: """Verifies a Firebase App Check token.""" _Validators.check_string("app check token", token) @@ -87,8 +96,29 @@ def verify_token(self, token: str) -> Dict[str, Any]: ) from exception verified_claims['app_id'] = verified_claims.get('sub') + if consume: + already_consumed = self._verify_replay_protection(token) + verified_claims['already_consumed'] = already_consumed return verified_claims + def _verify_replay_protection(self, token: str) -> bool: + """Verifies the token's consumption status.""" + path = f'{self._scoped_project_id}:verifyAppCheckToken' + body = {'app_check_token': token} + try: + response = self._http_client.body('post', path, json=body) + if not isinstance(response, dict): + raise exceptions.UnknownError( + 'Unexpected response from App Check service. ' + f'Expected a JSON object, but got {type(response).__name__}.') + return response.get('alreadyConsumed', False) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + except ValueError as error: + raise exceptions.UnknownError( + 'Unexpected response from App Check service. ' + f'Error: {error}') + def _has_valid_token_headers(self, headers: Any) -> None: """Checks whether the token has valid headers for App Check.""" # Ensure the token's header has type JWT diff --git a/tests/test_app_check.py b/tests/test_app_check.py index e55ae39d..76bd2661 100644 --- a/tests/test_app_check.py +++ b/tests/test_app_check.py @@ -18,8 +18,9 @@ from jwt import PyJWK, InvalidAudienceError, InvalidIssuerError from jwt import ExpiredSignatureError, InvalidSignatureError +import requests import firebase_admin -from firebase_admin import app_check +from firebase_admin import app_check, exceptions from tests import testutils NON_STRING_ARGS = [[], tuple(), {}, True, False, 1, 0] @@ -58,6 +59,21 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() +@pytest.fixture +def app_check_mock(mocker): + """Fixture to mock JWT functions and provide a fresh app.""" + mocker.patch("jwt.decode", return_value=JWT_PAYLOAD_SAMPLE) + mocker.patch("jwt.PyJWKClient.get_signing_key_from_jwt", return_value=PyJWK(signing_key)) + mocker.patch("jwt.get_unverified_header", return_value=JWT_PAYLOAD_SAMPLE.get("headers")) + mock_http_client = mocker.patch("firebase_admin._http_client.JsonHttpClient") + + cred = testutils.MockCredential() + app = firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}, name='test_consume_app') + + yield mock_http_client, app + + firebase_admin.delete_app(app) + class TestVerifyToken(TestBatch): def test_no_project_id(self): @@ -232,6 +248,54 @@ def test_verify_token(self, mocker): expected['app_id'] = APP_ID assert payload == expected + def test_verify_token_with_consume(self, app_check_mock): + """Test verify_token with consume=True.""" + mock_http_client, app = app_check_mock + mock_http_client.return_value.body.return_value = {'alreadyConsumed': True} + + payload = app_check.verify_token("encoded", app, consume=True) + expected = JWT_PAYLOAD_SAMPLE.copy() + expected['app_id'] = APP_ID + expected['already_consumed'] = True + assert payload == expected + mock_http_client.return_value.body.assert_called_once_with( + 'post', + f'{SCOPED_PROJECT_ID}:verifyAppCheckToken', + json={'app_check_token': 'encoded'}) + + def test_verify_token_with_consume_network_error(self, app_check_mock): + """Test verify_token with consume=True handles network errors.""" + mock_http_client, app = app_check_mock + mock_http_client.return_value.body.side_effect = requests.exceptions.RequestException( + "Network error") + + with pytest.raises(exceptions.UnknownError) as excinfo: + app_check.verify_token("encoded", app, consume=True) + assert str(excinfo.value) == ( + "Unknown error while making a remote service call: Network error") + + def test_verify_token_with_consume_non_dict_response(self, app_check_mock): + """Test verify_token with consume=True handles non-dict response.""" + mock_http_client, app = app_check_mock + mock_http_client.return_value.body.return_value = ["not", "a", "dict"] + + with pytest.raises(exceptions.UnknownError) as excinfo: + app_check.verify_token("encoded", app, consume=True) + assert str(excinfo.value) == ( + 'Unexpected response from App Check service. ' + 'Expected a JSON object, but got list.') + + def test_verify_token_with_consume_malformed_json(self, app_check_mock): + """Test verify_token with consume=True handles malformed JSON response.""" + mock_http_client, app = app_check_mock + mock_http_client.return_value.body.side_effect = ValueError("Malformed JSON") + + with pytest.raises(exceptions.UnknownError) as excinfo: + app_check.verify_token("encoded", app, consume=True) + assert str(excinfo.value) == ( + 'Unexpected response from App Check service. ' + 'Error: Malformed JSON') + def test_verify_token_with_non_list_audience_raises_error(self, mocker): jwt_with_non_list_audience = JWT_PAYLOAD_SAMPLE.copy() jwt_with_non_list_audience["aud"] = '1234'