|
1 | 1 | import base64 |
| 2 | +import importlib.metadata |
2 | 3 | import json |
3 | 4 | import platform |
4 | | -from unittest.mock import AsyncMock, patch |
| 5 | +from unittest.mock import AsyncMock, MagicMock, patch |
5 | 6 |
|
6 | 7 | import pytest |
7 | 8 |
|
@@ -63,7 +64,7 @@ def test_default_factory(self): |
63 | 64 |
|
64 | 65 | @patch( |
65 | 66 | "auth0_server_python.telemetry.importlib.metadata.version", |
66 | | - side_effect=Exception("not installed"), |
| 67 | + side_effect=importlib.metadata.PackageNotFoundError("not installed"), |
67 | 68 | ) |
68 | 69 | def test_default_factory_unknown_version_on_error(self, _mock): |
69 | 70 | telemetry = Telemetry.default() |
@@ -96,23 +97,40 @@ def test_server_client_telemetry_payload_structure(self): |
96 | 97 | assert "version" in decoded |
97 | 98 | assert "python" in decoded["env"] |
98 | 99 |
|
99 | | - def test_get_http_client_includes_telemetry_headers(self): |
| 100 | + @pytest.mark.asyncio |
| 101 | + async def test_get_http_client_includes_telemetry_headers(self): |
100 | 102 | client = self._make_client() |
101 | 103 | http_client = client._get_http_client() |
102 | 104 | for key, value in client._telemetry_headers.items(): |
103 | 105 | assert http_client.headers.get(key) == value |
| 106 | + await http_client.aclose() |
| 107 | + |
| 108 | + @pytest.mark.asyncio |
| 109 | + async def test_get_http_client_per_request_headers_do_not_override_telemetry(self): |
| 110 | + client = self._make_client() |
| 111 | + http_client = client._get_http_client(headers={"User-Agent": "custom", "X-Custom": "val"}) |
| 112 | + # Telemetry headers must win over caller-provided duplicates |
| 113 | + assert http_client.headers.get("User-Agent") == client._telemetry_headers["User-Agent"] |
| 114 | + assert http_client.headers.get("Auth0-Client") == client._telemetry_headers["Auth0-Client"] |
| 115 | + # Non-conflicting caller headers are preserved |
| 116 | + assert http_client.headers.get("X-Custom") == "val" |
| 117 | + await http_client.aclose() |
104 | 118 |
|
105 | 119 | def test_my_account_client_receives_telemetry_headers(self): |
106 | 120 | client = self._make_client() |
107 | 121 | assert client._my_account_client._headers == client._telemetry_headers |
108 | 122 |
|
| 123 | + def test_mfa_client_receives_telemetry_headers(self): |
| 124 | + client = self._make_client() |
| 125 | + assert client._mfa_client._headers == client._telemetry_headers |
| 126 | + |
109 | 127 | @pytest.mark.asyncio |
110 | 128 | async def test_fetch_oidc_metadata_sends_telemetry(self, mocker): |
111 | 129 | client = self._make_client() |
112 | | - mock_response = AsyncMock() |
| 130 | + mock_response = MagicMock() |
113 | 131 | mock_response.status_code = 200 |
114 | 132 | mock_response.json.return_value = {"issuer": "https://auth0.local/"} |
115 | | - mock_response.raise_for_status = AsyncMock() |
| 133 | + mock_response.raise_for_status = MagicMock() |
116 | 134 |
|
117 | 135 | mock_http_client = AsyncMock() |
118 | 136 | mock_http_client.get.return_value = mock_response |
|
0 commit comments