diff --git a/src/tinfoil/client.py b/src/tinfoil/client.py index 8462b82..09fbc7f 100644 --- a/src/tinfoil/client.py +++ b/src/tinfoil/client.py @@ -96,16 +96,18 @@ def ground_truth(self) -> Optional[GroundTruth]: """Returns the last verified enclave state""" return self._ground_truth - def make_secure_http_client(self) -> httpx.Client: - """ - Build an httpx.Client that pins the enclave's TLS cert - """ - expected_fp = self.verify().public_key - wrap_socket = self._create_socket_wrapper(expected_fp) - - ctx = ssl.create_default_context() - ctx.wrap_socket = wrap_socket - return httpx.Client(verify=ctx, follow_redirects=True) + @staticmethod + def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None: + """Verify that a certificate's public key fingerprint matches the expected value.""" + if not cert_binary: + raise ValueError("No certificate found") + cert = cryptography.x509.load_der_x509_certificate(cert_binary) + pub_der = cert.public_key().public_bytes( + Encoding.DER, PublicFormat.SubjectPublicKeyInfo + ) + pk_fp = hashlib.sha256(pub_der).hexdigest() + if pk_fp != expected_fp: + raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") def _create_socket_wrapper(self, expected_fp: str): """ @@ -114,28 +116,46 @@ def _create_socket_wrapper(self, expected_fp: str): """ def wrap_socket(*args, **kwargs) -> ssl.SSLSocket: sock = ssl.create_default_context().wrap_socket(*args, **kwargs) - cert_binary = sock.getpeercert(binary_form=True) - if not cert_binary: - raise ValueError("No certificate found") - cert = cryptography.x509.load_der_x509_certificate(cert_binary) - pub_der = cert.public_key().public_bytes( - Encoding.DER, PublicFormat.SubjectPublicKeyInfo + SecureClient._verify_peer_fingerprint( + sock.getpeercert(binary_form=True), expected_fp ) - pk_fp = hashlib.sha256(pub_der).hexdigest() - if pk_fp != expected_fp: - raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") return sock return wrap_socket - def make_secure_async_http_client(self) -> httpx.AsyncClient: + def make_secure_http_client(self) -> httpx.Client: """ - Build an httpx.AsyncClient that pins the enclave's TLS cert. + Build an httpx.Client that pins the enclave's TLS cert """ expected_fp = self.verify().public_key wrap_socket = self._create_socket_wrapper(expected_fp) ctx = ssl.create_default_context() ctx.wrap_socket = wrap_socket + return httpx.Client(verify=ctx, follow_redirects=True) + + def make_secure_async_http_client(self) -> httpx.AsyncClient: + """ + Build an httpx.AsyncClient that pins the enclave's TLS cert + """ + expected_fp = self.verify().public_key + + ctx = ssl.create_default_context() + original_wrap_bio = ctx.wrap_bio + + def pinned_wrap_bio(*args, **kwargs): + ssl_object = original_wrap_bio(*args, **kwargs) + original_do_handshake = ssl_object.do_handshake + + def checked_do_handshake(): + result = original_do_handshake() + cert_binary = ssl_object.getpeercert(binary_form=True) + SecureClient._verify_peer_fingerprint(cert_binary, expected_fp) + return result + + ssl_object.do_handshake = checked_do_handshake + return ssl_object + + ctx.wrap_bio = pinned_wrap_bio return httpx.AsyncClient(verify=ctx, follow_redirects=True) def verify(self) -> GroundTruth: diff --git a/tests/test_attestation_flow.py b/tests/test_attestation_flow.py index 750c911..dc3ec05 100644 --- a/tests/test_attestation_flow.py +++ b/tests/test_attestation_flow.py @@ -93,5 +93,70 @@ def test_secure_http_client(): print(f" TLS pinned to: {ground_truth.public_key}") +@pytest.mark.asyncio +async def test_secure_async_http_client(): + """ + Tests that the async pinned client can connect to the enclave. + Mirrors test_secure_http_client for the async path. + """ + try: + enclave = get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + client = SecureClient(enclave=enclave, repo=REPO) + http_client = client.make_secure_async_http_client() + + ground_truth = client.ground_truth + assert ground_truth is not None + + try: + response = await http_client.get(f"https://{enclave}/.well-known/tinfoil-attestation") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + finally: + await http_client.aclose() + + +def test_sync_pinned_client_rejects_wrong_host(): + """ + A client pinned to the enclave's cert must reject connections + to a different host (whose cert won't match the pinned fingerprint). + """ + try: + enclave = get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + client = SecureClient(enclave=enclave, repo=REPO) + http_client = client.make_secure_http_client() + + try: + with pytest.raises(Exception): + http_client.get("https://google.com") + finally: + http_client.close() + + +@pytest.mark.asyncio +async def test_async_pinned_client_rejects_wrong_host(): + """ + The async pinned client must reject connections to a host + whose cert doesn't match the pinned fingerprint. + """ + try: + enclave = get_router_address() + except Exception as e: + pytest.skip(f"Could not fetch router address from ATC service: {e}") + + client = SecureClient(enclave=enclave, repo=REPO) + http_client = client.make_secure_async_http_client() + + try: + with pytest.raises(Exception): + await http_client.get("https://google.com") + finally: + await http_client.aclose() + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_verification_failures.py b/tests/test_verification_failures.py index 0cc4f18..11db999 100644 --- a/tests/test_verification_failures.py +++ b/tests/test_verification_failures.py @@ -10,6 +10,8 @@ connections to proceed. """ +import ssl + import pytest from unittest.mock import patch, MagicMock @@ -294,5 +296,213 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch): client.verify() +class TestVerifyPeerFingerprint: + """Tests for SecureClient._verify_peer_fingerprint static method.""" + + def test_raises_on_none_cert(self): + """Must raise ValueError when cert_binary is None.""" + with pytest.raises(ValueError, match="No certificate found"): + SecureClient._verify_peer_fingerprint(None, "abc123") + + def test_raises_on_empty_cert(self): + """Must raise ValueError when cert_binary is empty bytes.""" + with pytest.raises(ValueError, match="No certificate found"): + SecureClient._verify_peer_fingerprint(b"", "abc123") + + def test_raises_on_fingerprint_mismatch(self): + """Must raise ValueError when public key fingerprint doesn't match.""" + from cryptography import x509 + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.x509.oid import NameOID + import datetime + + # Generate a self-signed cert to get valid DER bytes + key = ec.generate_private_key(ec.SECP256R1()) + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding + cert_der = cert.public_bytes(CryptoEncoding.DER) + + with pytest.raises(ValueError, match="Certificate fingerprint mismatch"): + SecureClient._verify_peer_fingerprint(cert_der, "wrong_fingerprint") + + def test_passes_on_fingerprint_match(self): + """Must not raise when public key fingerprint matches.""" + from cryptography import x509 + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding, PublicFormat as CryptoPublicFormat + from cryptography.x509.oid import NameOID + import datetime + import hashlib + + key = ec.generate_private_key(ec.SECP256R1()) + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + cert_der = cert.public_bytes(CryptoEncoding.DER) + + # Compute the expected fingerprint the same way the code does + pub_der = key.public_key().public_bytes( + CryptoEncoding.DER, CryptoPublicFormat.SubjectPublicKeyInfo + ) + expected_fp = hashlib.sha256(pub_der).hexdigest() + + # Should not raise + SecureClient._verify_peer_fingerprint(cert_der, expected_fp) + + +class TestAsyncTLSPinning: + """Tests that the async httpx client correctly pins TLS certificates via wrap_bio.""" + + FAKE_FP = "a" * 64 + + def _make_client_with_fake_verify(self): + """Create a SecureClient that returns a fake fingerprint from verify().""" + client = SecureClient(enclave="test.enclave.sh", repo="test/repo") + ground_truth = MagicMock() + ground_truth.public_key = self.FAKE_FP + client.verify = MagicMock(return_value=ground_truth) + return client + + def _get_ssl_context(self, async_http_client): + """Extract the SSL context from an httpx.AsyncClient.""" + return async_http_client._transport._pool._ssl_context + + def _call_pinned_wrap_bio(self, ssl_ctx, fake_ssl_object): + """ + Call the patched wrap_bio on ssl_ctx, intercepting the real + original_wrap_bio so it returns our fake_ssl_object instead of + attempting a real SSL operation. + + The pinned_wrap_bio closure holds a reference to original_wrap_bio + (the real ctx.wrap_bio at creation time). We patch the underlying + ctx._wrap_bio C method so that the call chain + pinned_wrap_bio -> original_wrap_bio -> SSLContext.wrap_bio -> _wrap_bio + returns our fake. + """ + with patch.object(ssl_ctx, '_wrap_bio', create=True) as mock_inner: + # ssl.SSLContext.wrap_bio calls sslobject_class._create which + # calls context._wrap_bio. We need to go one level deeper and + # patch SSLObject._create to just return our fake. + with patch('ssl.SSLObject._create', return_value=fake_ssl_object): + return ssl_ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO()) + + def test_wrap_bio_is_monkey_patched(self): + """make_secure_async_http_client() must replace ctx.wrap_bio.""" + client = self._make_client_with_fake_verify() + async_http = client.make_secure_async_http_client() + + # The underlying SSL context's wrap_bio should no longer be the + # original C-level method — it should be our pinned_wrap_bio closure. + ssl_ctx = self._get_ssl_context(async_http) + assert ssl_ctx.wrap_bio is not ssl.SSLContext.wrap_bio + + def test_do_handshake_verifies_fingerprint_match(self): + """After handshake, matching fingerprint must not raise.""" + client = self._make_client_with_fake_verify() + async_http = client.make_secure_async_http_client() + ssl_ctx = self._get_ssl_context(async_http) + + fake_ssl_object = MagicMock() + fake_ssl_object.do_handshake = MagicMock(return_value=None) + + with patch.object(SecureClient, '_verify_peer_fingerprint') as mock_verify: + result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) + + # do_handshake should have been replaced with the checked version + result.do_handshake() + + mock_verify.assert_called_once_with( + fake_ssl_object.getpeercert(binary_form=True), + self.FAKE_FP, + ) + + def test_do_handshake_rejects_fingerprint_mismatch(self): + """After handshake, mismatched fingerprint must raise ValueError.""" + client = self._make_client_with_fake_verify() + async_http = client.make_secure_async_http_client() + ssl_ctx = self._get_ssl_context(async_http) + + fake_ssl_object = MagicMock() + fake_ssl_object.do_handshake = MagicMock(return_value=None) + fake_ssl_object.getpeercert.return_value = b"fake_cert_bytes" + + with patch.object( + SecureClient, '_verify_peer_fingerprint', + side_effect=ValueError("Certificate fingerprint mismatch"), + ): + result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) + + with pytest.raises(ValueError, match="Certificate fingerprint mismatch"): + result.do_handshake() + + def test_do_handshake_rejects_missing_cert(self): + """After handshake, missing peer cert must raise ValueError.""" + client = self._make_client_with_fake_verify() + async_http = client.make_secure_async_http_client() + ssl_ctx = self._get_ssl_context(async_http) + + fake_ssl_object = MagicMock() + fake_ssl_object.do_handshake = MagicMock(return_value=None) + fake_ssl_object.getpeercert.return_value = None + + result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) + + with pytest.raises(ValueError, match="No certificate found"): + result.do_handshake() + + def test_ssl_want_read_propagates_without_cert_check(self): + """SSLWantReadError during handshake must propagate without checking cert.""" + client = self._make_client_with_fake_verify() + async_http = client.make_secure_async_http_client() + ssl_ctx = self._get_ssl_context(async_http) + + fake_ssl_object = MagicMock() + fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantReadError()) + + result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) + + with pytest.raises(ssl.SSLWantReadError): + result.do_handshake() + + # getpeercert should NOT have been called — handshake isn't done yet + fake_ssl_object.getpeercert.assert_not_called() + + def test_ssl_want_write_propagates_without_cert_check(self): + """SSLWantWriteError during handshake must propagate without checking cert.""" + client = self._make_client_with_fake_verify() + async_http = client.make_secure_async_http_client() + ssl_ctx = self._get_ssl_context(async_http) + + fake_ssl_object = MagicMock() + fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantWriteError()) + + result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) + + with pytest.raises(ssl.SSLWantWriteError): + result.do_handshake() + + fake_ssl_object.getpeercert.assert_not_called() + + if __name__ == "__main__": pytest.main([__file__, "-v"])