diff --git a/src/tinfoil/client.py b/src/tinfoil/client.py index 09fbc7f..5fa6a00 100644 --- a/src/tinfoil/client.py +++ b/src/tinfoil/client.py @@ -34,6 +34,19 @@ def __init__(self, status: str, status_code: int, body: bytes): self.body = body +def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None: + """Verify that a certificate's public key fingerprint matches the expected value.""" + if not cert_binary: + raise ValueError("No certificate found") + cert = cryptography.x509.load_der_x509_certificate(cert_binary) + pub_der = cert.public_key().public_bytes( + Encoding.DER, PublicFormat.SubjectPublicKeyInfo + ) + pk_fp = hashlib.sha256(pub_der).hexdigest() + if pk_fp != expected_fp: + raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") + + class TLSBoundHTTPSHandler(urllib.request.HTTPSHandler): """Custom HTTPS handler that verifies certificate public keys""" @@ -52,20 +65,9 @@ def _get_connection(self, host, timeout=None): if not conn.sock: raise ValueError("No TLS connection") - cert_binary = conn.sock.getpeercert(binary_form=True) - if not cert_binary: - raise ValueError("No valid certificate") - - # Parse the certificate using cryptography - cert = cryptography.x509.load_der_x509_certificate(cert_binary) - public_key = cert.public_key() - # Get the public key in PKIX/DER format - public_key_der = public_key.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) - # Hash the public key - cert_fp = hashlib.sha256(public_key_der).hexdigest() - - if cert_fp != self.expected_pubkey: - raise ValueError(f"Certificate public key fingerprint mismatch: expected {self.expected_pubkey}, got {cert_fp}") + _verify_peer_fingerprint( + conn.sock.getpeercert(binary_form=True), self.expected_pubkey + ) return conn @@ -96,19 +98,6 @@ def ground_truth(self) -> Optional[GroundTruth]: """Returns the last verified enclave state""" return self._ground_truth - @staticmethod - def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None: - """Verify that a certificate's public key fingerprint matches the expected value.""" - if not cert_binary: - raise ValueError("No certificate found") - cert = cryptography.x509.load_der_x509_certificate(cert_binary) - pub_der = cert.public_key().public_bytes( - Encoding.DER, PublicFormat.SubjectPublicKeyInfo - ) - pk_fp = hashlib.sha256(pub_der).hexdigest() - if pk_fp != expected_fp: - raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}") - def _create_socket_wrapper(self, expected_fp: str): """ Creates a socket wrapper function that verifies the certificate's public key fingerprint @@ -116,12 +105,32 @@ def _create_socket_wrapper(self, expected_fp: str): """ def wrap_socket(*args, **kwargs) -> ssl.SSLSocket: sock = ssl.create_default_context().wrap_socket(*args, **kwargs) - SecureClient._verify_peer_fingerprint( + _verify_peer_fingerprint( sock.getpeercert(binary_form=True), expected_fp ) return sock return wrap_socket + def _create_bio_wrapper(self, original_wrap_bio, expected_fp: str): + """ + Creates a wrap_bio replacement that verifies the certificate's public key fingerprint + after the TLS handshake completes. + """ + def pinned_wrap_bio(*args, **kwargs): + ssl_object = original_wrap_bio(*args, **kwargs) + original_do_handshake = ssl_object.do_handshake + + def checked_do_handshake(): + result = original_do_handshake() + _verify_peer_fingerprint( + ssl_object.getpeercert(binary_form=True), expected_fp + ) + return result + + ssl_object.do_handshake = checked_do_handshake + return ssl_object + return pinned_wrap_bio + def make_secure_http_client(self) -> httpx.Client: """ Build an httpx.Client that pins the enclave's TLS cert @@ -140,22 +149,7 @@ def make_secure_async_http_client(self) -> httpx.AsyncClient: expected_fp = self.verify().public_key ctx = ssl.create_default_context() - original_wrap_bio = ctx.wrap_bio - - def pinned_wrap_bio(*args, **kwargs): - ssl_object = original_wrap_bio(*args, **kwargs) - original_do_handshake = ssl_object.do_handshake - - def checked_do_handshake(): - result = original_do_handshake() - cert_binary = ssl_object.getpeercert(binary_form=True) - SecureClient._verify_peer_fingerprint(cert_binary, expected_fp) - return result - - ssl_object.do_handshake = checked_do_handshake - return ssl_object - - ctx.wrap_bio = pinned_wrap_bio + ctx.wrap_bio = self._create_bio_wrapper(ctx.wrap_bio, expected_fp) return httpx.AsyncClient(verify=ctx, follow_redirects=True) def verify(self) -> GroundTruth: diff --git a/tests/test_verification_failures.py b/tests/test_verification_failures.py index 11db999..5518e48 100644 --- a/tests/test_verification_failures.py +++ b/tests/test_verification_failures.py @@ -15,7 +15,7 @@ import pytest from unittest.mock import patch, MagicMock -from tinfoil.client import SecureClient +from tinfoil.client import SecureClient, _verify_peer_fingerprint from tinfoil.attestation import ( Measurement, PredicateType, @@ -297,17 +297,17 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch): class TestVerifyPeerFingerprint: - """Tests for SecureClient._verify_peer_fingerprint static method.""" + """Tests for _verify_peer_fingerprint.""" def test_raises_on_none_cert(self): """Must raise ValueError when cert_binary is None.""" with pytest.raises(ValueError, match="No certificate found"): - SecureClient._verify_peer_fingerprint(None, "abc123") + _verify_peer_fingerprint(None, "abc123") def test_raises_on_empty_cert(self): """Must raise ValueError when cert_binary is empty bytes.""" with pytest.raises(ValueError, match="No certificate found"): - SecureClient._verify_peer_fingerprint(b"", "abc123") + _verify_peer_fingerprint(b"", "abc123") def test_raises_on_fingerprint_mismatch(self): """Must raise ValueError when public key fingerprint doesn't match.""" @@ -334,7 +334,7 @@ def test_raises_on_fingerprint_mismatch(self): cert_der = cert.public_bytes(CryptoEncoding.DER) with pytest.raises(ValueError, match="Certificate fingerprint mismatch"): - SecureClient._verify_peer_fingerprint(cert_der, "wrong_fingerprint") + _verify_peer_fingerprint(cert_der, "wrong_fingerprint") def test_passes_on_fingerprint_match(self): """Must not raise when public key fingerprint matches.""" @@ -367,7 +367,7 @@ def test_passes_on_fingerprint_match(self): expected_fp = hashlib.sha256(pub_der).hexdigest() # Should not raise - SecureClient._verify_peer_fingerprint(cert_der, expected_fp) + _verify_peer_fingerprint(cert_der, expected_fp) class TestAsyncTLSPinning: @@ -425,7 +425,7 @@ def test_do_handshake_verifies_fingerprint_match(self): fake_ssl_object = MagicMock() fake_ssl_object.do_handshake = MagicMock(return_value=None) - with patch.object(SecureClient, '_verify_peer_fingerprint') as mock_verify: + with patch('tinfoil.client._verify_peer_fingerprint') as mock_verify: result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object) # do_handshake should have been replaced with the checked version @@ -446,8 +446,8 @@ def test_do_handshake_rejects_fingerprint_mismatch(self): fake_ssl_object.do_handshake = MagicMock(return_value=None) fake_ssl_object.getpeercert.return_value = b"fake_cert_bytes" - with patch.object( - SecureClient, '_verify_peer_fingerprint', + with patch( + 'tinfoil.client._verify_peer_fingerprint', side_effect=ValueError("Certificate fingerprint mismatch"), ): result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)