Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 38 additions & 44 deletions src/tinfoil/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def __init__(self, status: str, status_code: int, body: bytes):
self.body = body


def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None:
"""Verify that a certificate's public key fingerprint matches the expected value."""
if not cert_binary:
raise ValueError("No certificate found")
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
pub_der = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo
)
pk_fp = hashlib.sha256(pub_der).hexdigest()
if pk_fp != expected_fp:
raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}")


class TLSBoundHTTPSHandler(urllib.request.HTTPSHandler):
"""Custom HTTPS handler that verifies certificate public keys"""

Expand All @@ -52,20 +65,9 @@ def _get_connection(self, host, timeout=None):
if not conn.sock:
raise ValueError("No TLS connection")

cert_binary = conn.sock.getpeercert(binary_form=True)
if not cert_binary:
raise ValueError("No valid certificate")

# Parse the certificate using cryptography
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
public_key = cert.public_key()
# Get the public key in PKIX/DER format
public_key_der = public_key.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
# Hash the public key
cert_fp = hashlib.sha256(public_key_der).hexdigest()

if cert_fp != self.expected_pubkey:
raise ValueError(f"Certificate public key fingerprint mismatch: expected {self.expected_pubkey}, got {cert_fp}")
_verify_peer_fingerprint(
conn.sock.getpeercert(binary_form=True), self.expected_pubkey
)

return conn

Expand Down Expand Up @@ -96,32 +98,39 @@ def ground_truth(self) -> Optional[GroundTruth]:
"""Returns the last verified enclave state"""
return self._ground_truth

@staticmethod
def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None:
"""Verify that a certificate's public key fingerprint matches the expected value."""
if not cert_binary:
raise ValueError("No certificate found")
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
pub_der = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo
)
pk_fp = hashlib.sha256(pub_der).hexdigest()
if pk_fp != expected_fp:
raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}")

def _create_socket_wrapper(self, expected_fp: str):
"""
Creates a socket wrapper function that verifies the certificate's public key fingerprint
matches the expected fingerprint.
"""
def wrap_socket(*args, **kwargs) -> ssl.SSLSocket:
sock = ssl.create_default_context().wrap_socket(*args, **kwargs)
SecureClient._verify_peer_fingerprint(
_verify_peer_fingerprint(
sock.getpeercert(binary_form=True), expected_fp
)
return sock
return wrap_socket

def _create_bio_wrapper(self, original_wrap_bio, expected_fp: str):
"""
Creates a wrap_bio replacement that verifies the certificate's public key fingerprint
after the TLS handshake completes.
"""
def pinned_wrap_bio(*args, **kwargs):
ssl_object = original_wrap_bio(*args, **kwargs)
original_do_handshake = ssl_object.do_handshake

def checked_do_handshake():
result = original_do_handshake()
_verify_peer_fingerprint(
ssl_object.getpeercert(binary_form=True), expected_fp
)
return result

ssl_object.do_handshake = checked_do_handshake
return ssl_object
return pinned_wrap_bio

def make_secure_http_client(self) -> httpx.Client:
"""
Build an httpx.Client that pins the enclave's TLS cert
Expand All @@ -140,22 +149,7 @@ def make_secure_async_http_client(self) -> httpx.AsyncClient:
expected_fp = self.verify().public_key

ctx = ssl.create_default_context()
original_wrap_bio = ctx.wrap_bio

def pinned_wrap_bio(*args, **kwargs):
ssl_object = original_wrap_bio(*args, **kwargs)
original_do_handshake = ssl_object.do_handshake

def checked_do_handshake():
result = original_do_handshake()
cert_binary = ssl_object.getpeercert(binary_form=True)
SecureClient._verify_peer_fingerprint(cert_binary, expected_fp)
return result

ssl_object.do_handshake = checked_do_handshake
return ssl_object

ctx.wrap_bio = pinned_wrap_bio
ctx.wrap_bio = self._create_bio_wrapper(ctx.wrap_bio, expected_fp)
return httpx.AsyncClient(verify=ctx, follow_redirects=True)

def verify(self) -> GroundTruth:
Expand Down
18 changes: 9 additions & 9 deletions tests/test_verification_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
from unittest.mock import patch, MagicMock

from tinfoil.client import SecureClient
from tinfoil.client import SecureClient, _verify_peer_fingerprint
from tinfoil.attestation import (
Measurement,
PredicateType,
Expand Down Expand Up @@ -297,17 +297,17 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch):


class TestVerifyPeerFingerprint:
"""Tests for SecureClient._verify_peer_fingerprint static method."""
"""Tests for _verify_peer_fingerprint."""

def test_raises_on_none_cert(self):
"""Must raise ValueError when cert_binary is None."""
with pytest.raises(ValueError, match="No certificate found"):
SecureClient._verify_peer_fingerprint(None, "abc123")
_verify_peer_fingerprint(None, "abc123")

def test_raises_on_empty_cert(self):
"""Must raise ValueError when cert_binary is empty bytes."""
with pytest.raises(ValueError, match="No certificate found"):
SecureClient._verify_peer_fingerprint(b"", "abc123")
_verify_peer_fingerprint(b"", "abc123")

def test_raises_on_fingerprint_mismatch(self):
"""Must raise ValueError when public key fingerprint doesn't match."""
Expand All @@ -334,7 +334,7 @@ def test_raises_on_fingerprint_mismatch(self):
cert_der = cert.public_bytes(CryptoEncoding.DER)

with pytest.raises(ValueError, match="Certificate fingerprint mismatch"):
SecureClient._verify_peer_fingerprint(cert_der, "wrong_fingerprint")
_verify_peer_fingerprint(cert_der, "wrong_fingerprint")

def test_passes_on_fingerprint_match(self):
"""Must not raise when public key fingerprint matches."""
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_passes_on_fingerprint_match(self):
expected_fp = hashlib.sha256(pub_der).hexdigest()

# Should not raise
SecureClient._verify_peer_fingerprint(cert_der, expected_fp)
_verify_peer_fingerprint(cert_der, expected_fp)


class TestAsyncTLSPinning:
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_do_handshake_verifies_fingerprint_match(self):
fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)

with patch.object(SecureClient, '_verify_peer_fingerprint') as mock_verify:
with patch('tinfoil.client._verify_peer_fingerprint') as mock_verify:
result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

# do_handshake should have been replaced with the checked version
Expand All @@ -446,8 +446,8 @@ def test_do_handshake_rejects_fingerprint_mismatch(self):
fake_ssl_object.do_handshake = MagicMock(return_value=None)
fake_ssl_object.getpeercert.return_value = b"fake_cert_bytes"

with patch.object(
SecureClient, '_verify_peer_fingerprint',
with patch(
'tinfoil.client._verify_peer_fingerprint',
side_effect=ValueError("Certificate fingerprint mismatch"),
):
result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)
Expand Down