-
Notifications
You must be signed in to change notification settings - Fork 2
fix: async httpx TLS pinning via wrap_bio interception #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,5 +93,70 @@ def test_secure_http_client(): | |
| print(f" TLS pinned to: {ground_truth.public_key}") | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_secure_async_http_client(): | ||
| """ | ||
| Tests that the async pinned client can connect to the enclave. | ||
| Mirrors test_secure_http_client for the async path. | ||
| """ | ||
| try: | ||
| enclave = get_router_address() | ||
| except Exception as e: | ||
| pytest.skip(f"Could not fetch router address from ATC service: {e}") | ||
|
|
||
| client = SecureClient(enclave=enclave, repo=REPO) | ||
| http_client = client.make_secure_async_http_client() | ||
|
|
||
| ground_truth = client.ground_truth | ||
| assert ground_truth is not None | ||
|
|
||
| try: | ||
| response = await http_client.get(f"https://{enclave}/.well-known/tinfoil-attestation") | ||
| assert response.status_code == 200, f"Expected 200, got {response.status_code}" | ||
| finally: | ||
| await http_client.aclose() | ||
|
|
||
|
|
||
| def test_sync_pinned_client_rejects_wrong_host(): | ||
| """ | ||
| A client pinned to the enclave's cert must reject connections | ||
| to a different host (whose cert won't match the pinned fingerprint). | ||
| """ | ||
| try: | ||
| enclave = get_router_address() | ||
| except Exception as e: | ||
| pytest.skip(f"Could not fetch router address from ATC service: {e}") | ||
|
|
||
| client = SecureClient(enclave=enclave, repo=REPO) | ||
| http_client = client.make_secure_http_client() | ||
|
|
||
| try: | ||
| with pytest.raises(Exception): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P3: The async pinning rejection test is overly permissive; it will pass on any network failure. Assert the specific fingerprint-mismatch error so the test only passes when TLS pinning actually rejects the cert. Prompt for AI agents |
||
| http_client.get("https://google.com") | ||
| finally: | ||
| http_client.close() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_async_pinned_client_rejects_wrong_host(): | ||
| """ | ||
| The async pinned client must reject connections to a host | ||
| whose cert doesn't match the pinned fingerprint. | ||
| """ | ||
| try: | ||
| enclave = get_router_address() | ||
| except Exception as e: | ||
| pytest.skip(f"Could not fetch router address from ATC service: {e}") | ||
|
|
||
| client = SecureClient(enclave=enclave, repo=REPO) | ||
| http_client = client.make_secure_async_http_client() | ||
|
|
||
| try: | ||
| with pytest.raises(Exception): | ||
| await http_client.get("https://google.com") | ||
| finally: | ||
| await http_client.aclose() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v", "-s"]) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,8 @@ | |||||
| connections to proceed. | ||||||
| """ | ||||||
|
|
||||||
| import ssl | ||||||
|
|
||||||
| import pytest | ||||||
| from unittest.mock import patch, MagicMock | ||||||
|
|
||||||
|
|
@@ -294,5 +296,213 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch): | |||||
| client.verify() | ||||||
|
|
||||||
|
|
||||||
| class TestVerifyPeerFingerprint: | ||||||
| """Tests for SecureClient._verify_peer_fingerprint static method.""" | ||||||
|
|
||||||
| def test_raises_on_none_cert(self): | ||||||
| """Must raise ValueError when cert_binary is None.""" | ||||||
| with pytest.raises(ValueError, match="No certificate found"): | ||||||
| SecureClient._verify_peer_fingerprint(None, "abc123") | ||||||
|
|
||||||
| def test_raises_on_empty_cert(self): | ||||||
| """Must raise ValueError when cert_binary is empty bytes.""" | ||||||
| with pytest.raises(ValueError, match="No certificate found"): | ||||||
| SecureClient._verify_peer_fingerprint(b"", "abc123") | ||||||
|
|
||||||
| def test_raises_on_fingerprint_mismatch(self): | ||||||
| """Must raise ValueError when public key fingerprint doesn't match.""" | ||||||
| from cryptography import x509 | ||||||
| from cryptography.hazmat.primitives import hashes | ||||||
| from cryptography.hazmat.primitives.asymmetric import ec | ||||||
| from cryptography.x509.oid import NameOID | ||||||
| import datetime | ||||||
|
|
||||||
| # Generate a self-signed cert to get valid DER bytes | ||||||
| key = ec.generate_private_key(ec.SECP256R1()) | ||||||
| subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) | ||||||
| cert = ( | ||||||
| x509.CertificateBuilder() | ||||||
| .subject_name(subject) | ||||||
| .issuer_name(subject) | ||||||
| .public_key(key.public_key()) | ||||||
| .serial_number(x509.random_serial_number()) | ||||||
| .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) | ||||||
| .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) | ||||||
| .sign(key, hashes.SHA256()) | ||||||
| ) | ||||||
| from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding | ||||||
| cert_der = cert.public_bytes(CryptoEncoding.DER) | ||||||
|
|
||||||
| with pytest.raises(ValueError, match="Certificate fingerprint mismatch"): | ||||||
| SecureClient._verify_peer_fingerprint(cert_der, "wrong_fingerprint") | ||||||
|
|
||||||
| def test_passes_on_fingerprint_match(self): | ||||||
| """Must not raise when public key fingerprint matches.""" | ||||||
| from cryptography import x509 | ||||||
| from cryptography.hazmat.primitives import hashes | ||||||
| from cryptography.hazmat.primitives.asymmetric import ec | ||||||
| from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding, PublicFormat as CryptoPublicFormat | ||||||
| from cryptography.x509.oid import NameOID | ||||||
| import datetime | ||||||
| import hashlib | ||||||
|
|
||||||
| key = ec.generate_private_key(ec.SECP256R1()) | ||||||
| subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) | ||||||
| cert = ( | ||||||
| x509.CertificateBuilder() | ||||||
| .subject_name(subject) | ||||||
| .issuer_name(subject) | ||||||
| .public_key(key.public_key()) | ||||||
| .serial_number(x509.random_serial_number()) | ||||||
| .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) | ||||||
| .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) | ||||||
| .sign(key, hashes.SHA256()) | ||||||
| ) | ||||||
| cert_der = cert.public_bytes(CryptoEncoding.DER) | ||||||
|
|
||||||
| # Compute the expected fingerprint the same way the code does | ||||||
| pub_der = key.public_key().public_bytes( | ||||||
| CryptoEncoding.DER, CryptoPublicFormat.SubjectPublicKeyInfo | ||||||
| ) | ||||||
| expected_fp = hashlib.sha256(pub_der).hexdigest() | ||||||
|
|
||||||
| # Should not raise | ||||||
| SecureClient._verify_peer_fingerprint(cert_der, expected_fp) | ||||||
|
|
||||||
|
|
||||||
| class TestAsyncTLSPinning: | ||||||
| """Tests that the async httpx client correctly pins TLS certificates via wrap_bio.""" | ||||||
|
|
||||||
| FAKE_FP = "a" * 64 | ||||||
|
|
||||||
| def _make_client_with_fake_verify(self): | ||||||
| """Create a SecureClient that returns a fake fingerprint from verify().""" | ||||||
| client = SecureClient(enclave="test.enclave.sh", repo="test/repo") | ||||||
| ground_truth = MagicMock() | ||||||
| ground_truth.public_key = self.FAKE_FP | ||||||
| client.verify = MagicMock(return_value=ground_truth) | ||||||
| return client | ||||||
|
|
||||||
| def _get_ssl_context(self, async_http_client): | ||||||
| """Extract the SSL context from an httpx.AsyncClient.""" | ||||||
| return async_http_client._transport._pool._ssl_context | ||||||
|
|
||||||
| def _call_pinned_wrap_bio(self, ssl_ctx, fake_ssl_object): | ||||||
| """ | ||||||
| Call the patched wrap_bio on ssl_ctx, intercepting the real | ||||||
| original_wrap_bio so it returns our fake_ssl_object instead of | ||||||
| attempting a real SSL operation. | ||||||
|
|
||||||
| The pinned_wrap_bio closure holds a reference to original_wrap_bio | ||||||
| (the real ctx.wrap_bio at creation time). We patch the underlying | ||||||
| ctx._wrap_bio C method so that the call chain | ||||||
| pinned_wrap_bio -> original_wrap_bio -> SSLContext.wrap_bio -> _wrap_bio | ||||||
| returns our fake. | ||||||
| """ | ||||||
| with patch.object(ssl_ctx, '_wrap_bio', create=True) as mock_inner: | ||||||
| # ssl.SSLContext.wrap_bio calls sslobject_class._create which | ||||||
| # calls context._wrap_bio. We need to go one level deeper and | ||||||
| # patch SSLObject._create to just return our fake. | ||||||
| with patch('ssl.SSLObject._create', return_value=fake_ssl_object): | ||||||
| return ssl_ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO()) | ||||||
|
|
||||||
| def test_wrap_bio_is_monkey_patched(self): | ||||||
| """make_secure_async_http_client() must replace ctx.wrap_bio.""" | ||||||
| client = self._make_client_with_fake_verify() | ||||||
| async_http = client.make_secure_async_http_client() | ||||||
|
|
||||||
| # The underlying SSL context's wrap_bio should no longer be the | ||||||
| # original C-level method — it should be our pinned_wrap_bio closure. | ||||||
| ssl_ctx = self._get_ssl_context(async_http) | ||||||
| assert ssl_ctx.wrap_bio is not ssl.SSLContext.wrap_bio | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. P2: This assertion doesn’t prove wrap_bio was monkey-patched because bound method objects are never identical to the class function; the test will pass even when no patch occurred. Check for an instance-level override (e.g., presence in Prompt for AI agents
Suggested change
|
||||||
|
|
||||||
| def test_do_handshake_verifies_fingerprint_match(self): | ||||||
| """After handshake, matching fingerprint must not raise.""" | ||||||
| client = self._make_client_with_fake_verify() | ||||||
| async_http = client.make_secure_async_http_client() | ||||||
| ssl_ctx = self._get_ssl_context(async_http) | ||||||
|
|
||||||
| fake_ssl_object = MagicMock() | ||||||
| fake_ssl_object.do_handshake = MagicMock(return_value=None) | ||||||
|
|
||||||
| with patch.object(SecureClient, '_verify_peer_fingerprint') as mock_verify: | ||||||
| result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) | ||||||
|
|
||||||
| # do_handshake should have been replaced with the checked version | ||||||
| result.do_handshake() | ||||||
|
|
||||||
| mock_verify.assert_called_once_with( | ||||||
| fake_ssl_object.getpeercert(binary_form=True), | ||||||
| self.FAKE_FP, | ||||||
| ) | ||||||
|
|
||||||
| def test_do_handshake_rejects_fingerprint_mismatch(self): | ||||||
| """After handshake, mismatched fingerprint must raise ValueError.""" | ||||||
| client = self._make_client_with_fake_verify() | ||||||
| async_http = client.make_secure_async_http_client() | ||||||
| ssl_ctx = self._get_ssl_context(async_http) | ||||||
|
|
||||||
| fake_ssl_object = MagicMock() | ||||||
| fake_ssl_object.do_handshake = MagicMock(return_value=None) | ||||||
| fake_ssl_object.getpeercert.return_value = b"fake_cert_bytes" | ||||||
|
|
||||||
| with patch.object( | ||||||
| SecureClient, '_verify_peer_fingerprint', | ||||||
| side_effect=ValueError("Certificate fingerprint mismatch"), | ||||||
| ): | ||||||
| result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) | ||||||
|
|
||||||
| with pytest.raises(ValueError, match="Certificate fingerprint mismatch"): | ||||||
| result.do_handshake() | ||||||
|
|
||||||
| def test_do_handshake_rejects_missing_cert(self): | ||||||
| """After handshake, missing peer cert must raise ValueError.""" | ||||||
| client = self._make_client_with_fake_verify() | ||||||
| async_http = client.make_secure_async_http_client() | ||||||
| ssl_ctx = self._get_ssl_context(async_http) | ||||||
|
|
||||||
| fake_ssl_object = MagicMock() | ||||||
| fake_ssl_object.do_handshake = MagicMock(return_value=None) | ||||||
| fake_ssl_object.getpeercert.return_value = None | ||||||
|
|
||||||
| result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) | ||||||
|
|
||||||
| with pytest.raises(ValueError, match="No certificate found"): | ||||||
| result.do_handshake() | ||||||
|
|
||||||
| def test_ssl_want_read_propagates_without_cert_check(self): | ||||||
| """SSLWantReadError during handshake must propagate without checking cert.""" | ||||||
| client = self._make_client_with_fake_verify() | ||||||
| async_http = client.make_secure_async_http_client() | ||||||
| ssl_ctx = self._get_ssl_context(async_http) | ||||||
|
|
||||||
| fake_ssl_object = MagicMock() | ||||||
| fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantReadError()) | ||||||
|
|
||||||
| result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) | ||||||
|
|
||||||
| with pytest.raises(ssl.SSLWantReadError): | ||||||
| result.do_handshake() | ||||||
|
|
||||||
| # getpeercert should NOT have been called — handshake isn't done yet | ||||||
| fake_ssl_object.getpeercert.assert_not_called() | ||||||
|
|
||||||
| def test_ssl_want_write_propagates_without_cert_check(self): | ||||||
| """SSLWantWriteError during handshake must propagate without checking cert.""" | ||||||
| client = self._make_client_with_fake_verify() | ||||||
| async_http = client.make_secure_async_http_client() | ||||||
| ssl_ctx = self._get_ssl_context(async_http) | ||||||
|
|
||||||
| fake_ssl_object = MagicMock() | ||||||
| fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantWriteError()) | ||||||
|
|
||||||
| result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) | ||||||
|
|
||||||
| with pytest.raises(ssl.SSLWantWriteError): | ||||||
| result.do_handshake() | ||||||
|
|
||||||
| fake_ssl_object.getpeercert.assert_not_called() | ||||||
|
|
||||||
|
|
||||||
| if __name__ == "__main__": | ||||||
| pytest.main([__file__, "-v"]) | ||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P2:
TLSBoundHTTPSHandler._get_connection(lines 55-66) contains the same fingerprint verification logic that_verify_peer_fingerprintnow encapsulates. Consider updating_get_connectionto callSecureClient._verify_peer_fingerprint(cert_binary, self.expected_pubkey)to eliminate the remaining duplication and ensure all verification paths stay in sync.Prompt for AI agents