Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions src/tinfoil/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,18 @@ def ground_truth(self) -> Optional[GroundTruth]:
"""Returns the last verified enclave state"""
return self._ground_truth

def make_secure_http_client(self) -> httpx.Client:
"""
Build an httpx.Client that pins the enclave's TLS cert
"""
expected_fp = self.verify().public_key
wrap_socket = self._create_socket_wrapper(expected_fp)

ctx = ssl.create_default_context()
ctx.wrap_socket = wrap_socket
return httpx.Client(verify=ctx, follow_redirects=True)
@staticmethod
def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None:
Copy link

@cubic-dev-ai cubic-dev-ai bot Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: The TLSBoundHTTPSHandler._get_connection method still duplicates the fingerprint verification logic that _verify_peer_fingerprint was extracted to centralize. Consider refactoring TLSBoundHTTPSHandler to call SecureClient._verify_peer_fingerprint(cert_binary, self.expected_pubkey) (or promote the helper to a module-level function) so there's a single source of truth for cert pinning verification.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At src/tinfoil/client.py, line 100:

<comment>The `TLSBoundHTTPSHandler._get_connection` method still duplicates the fingerprint verification logic that `_verify_peer_fingerprint` was extracted to centralize. Consider refactoring `TLSBoundHTTPSHandler` to call `SecureClient._verify_peer_fingerprint(cert_binary, self.expected_pubkey)` (or promote the helper to a module-level function) so there's a single source of truth for cert pinning verification.</comment>

<file context>
@@ -96,16 +96,18 @@ def ground_truth(self) -> Optional[GroundTruth]:
-        ctx.wrap_socket = wrap_socket
-        return httpx.Client(verify=ctx, follow_redirects=True)
+    @staticmethod
+    def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None:
+        """Verify that a certificate's public key fingerprint matches the expected value."""
+        if not cert_binary:
</file context>
Fix with Cubic

"""Verify that a certificate's public key fingerprint matches the expected value."""
if not cert_binary:
raise ValueError("No certificate found")
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
pub_der = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo
)
pk_fp = hashlib.sha256(pub_der).hexdigest()
if pk_fp != expected_fp:
raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}")

def _create_socket_wrapper(self, expected_fp: str):
"""
Expand All @@ -114,28 +116,46 @@ def _create_socket_wrapper(self, expected_fp: str):
"""
def wrap_socket(*args, **kwargs) -> ssl.SSLSocket:
sock = ssl.create_default_context().wrap_socket(*args, **kwargs)
cert_binary = sock.getpeercert(binary_form=True)
if not cert_binary:
raise ValueError("No certificate found")
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
pub_der = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo
SecureClient._verify_peer_fingerprint(
sock.getpeercert(binary_form=True), expected_fp
)
pk_fp = hashlib.sha256(pub_der).hexdigest()
if pk_fp != expected_fp:
raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}")
return sock
return wrap_socket

def make_secure_async_http_client(self) -> httpx.AsyncClient:
def make_secure_http_client(self) -> httpx.Client:
"""
Build an httpx.AsyncClient that pins the enclave's TLS cert.
Build an httpx.Client that pins the enclave's TLS cert
"""
expected_fp = self.verify().public_key
wrap_socket = self._create_socket_wrapper(expected_fp)

ctx = ssl.create_default_context()
ctx.wrap_socket = wrap_socket
return httpx.Client(verify=ctx, follow_redirects=True)

def make_secure_async_http_client(self) -> httpx.AsyncClient:
"""
Build an httpx.AsyncClient that pins the enclave's TLS cert
"""
expected_fp = self.verify().public_key

ctx = ssl.create_default_context()
original_wrap_bio = ctx.wrap_bio

def pinned_wrap_bio(*args, **kwargs):
Copy link

@cubic-dev-ai cubic-dev-ai bot Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Custom agent: Check System Design and Architectural Patterns

The async TLS pinning logic is inlined with deeply nested closures, breaking the separation-of-concerns pattern established by the sync path's _create_socket_wrapper helper. Extract a _create_bio_wrapper(self, expected_fp) method (parallel to _create_socket_wrapper) so the async pinning logic is independently testable, reusable, and consistent with the sync architecture. This also keeps make_secure_async_http_client at the same abstraction level as make_secure_http_client.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At src/tinfoil/client.py, line 145:

<comment>The async TLS pinning logic is inlined with deeply nested closures, breaking the separation-of-concerns pattern established by the sync path's `_create_socket_wrapper` helper. Extract a `_create_bio_wrapper(self, expected_fp)` method (parallel to `_create_socket_wrapper`) so the async pinning logic is independently testable, reusable, and consistent with the sync architecture. This also keeps `make_secure_async_http_client` at the same abstraction level as `make_secure_http_client`.</comment>

<file context>
@@ -114,28 +116,46 @@ def _create_socket_wrapper(self, expected_fp: str):
+        ctx = ssl.create_default_context()
+        original_wrap_bio = ctx.wrap_bio
+
+        def pinned_wrap_bio(*args, **kwargs):
+            ssl_object = original_wrap_bio(*args, **kwargs)
+            original_do_handshake = ssl_object.do_handshake
</file context>
Fix with Cubic

ssl_object = original_wrap_bio(*args, **kwargs)
original_do_handshake = ssl_object.do_handshake

def checked_do_handshake():
result = original_do_handshake()
cert_binary = ssl_object.getpeercert(binary_form=True)
SecureClient._verify_peer_fingerprint(cert_binary, expected_fp)
return result

ssl_object.do_handshake = checked_do_handshake
return ssl_object

ctx.wrap_bio = pinned_wrap_bio
return httpx.AsyncClient(verify=ctx, follow_redirects=True)

def verify(self) -> GroundTruth:
Expand Down
65 changes: 65 additions & 0 deletions tests/test_attestation_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,70 @@ def test_secure_http_client():
print(f" TLS pinned to: {ground_truth.public_key}")


@pytest.mark.asyncio
async def test_secure_async_http_client():
"""
Tests that the async pinned client can connect to the enclave.
Mirrors test_secure_http_client for the async path.
"""
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")

client = SecureClient(enclave=enclave, repo=REPO)
http_client = client.make_secure_async_http_client()

ground_truth = client.ground_truth
assert ground_truth is not None

try:
response = await http_client.get(f"https://{enclave}/.well-known/tinfoil-attestation")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
finally:
await http_client.aclose()


def test_sync_pinned_client_rejects_wrong_host():
"""
A client pinned to the enclave's cert must reject connections
to a different host (whose cert won't match the pinned fingerprint).
"""
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")

client = SecureClient(enclave=enclave, repo=REPO)
http_client = client.make_secure_http_client()

try:
with pytest.raises(Exception):
http_client.get("https://google.com")
finally:
http_client.close()


@pytest.mark.asyncio
async def test_async_pinned_client_rejects_wrong_host():
"""
The async pinned client must reject connections to a host
whose cert doesn't match the pinned fingerprint.
"""
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")

client = SecureClient(enclave=enclave, repo=REPO)
http_client = client.make_secure_async_http_client()

try:
with pytest.raises(Exception):
await http_client.get("https://google.com")
finally:
await http_client.aclose()


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
210 changes: 210 additions & 0 deletions tests/test_verification_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
connections to proceed.
"""

import ssl

import pytest
from unittest.mock import patch, MagicMock

Expand Down Expand Up @@ -294,5 +296,213 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch):
client.verify()


class TestVerifyPeerFingerprint:
"""Tests for SecureClient._verify_peer_fingerprint static method."""

def test_raises_on_none_cert(self):
"""Must raise ValueError when cert_binary is None."""
with pytest.raises(ValueError, match="No certificate found"):
SecureClient._verify_peer_fingerprint(None, "abc123")

def test_raises_on_empty_cert(self):
"""Must raise ValueError when cert_binary is empty bytes."""
with pytest.raises(ValueError, match="No certificate found"):
SecureClient._verify_peer_fingerprint(b"", "abc123")

def test_raises_on_fingerprint_mismatch(self):
"""Must raise ValueError when public key fingerprint doesn't match."""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import NameOID
import datetime

# Generate a self-signed cert to get valid DER bytes
key = ec.generate_private_key(ec.SECP256R1())
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1))
.sign(key, hashes.SHA256())
)
from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding
cert_der = cert.public_bytes(CryptoEncoding.DER)

with pytest.raises(ValueError, match="Certificate fingerprint mismatch"):
SecureClient._verify_peer_fingerprint(cert_der, "wrong_fingerprint")

def test_passes_on_fingerprint_match(self):
"""Must not raise when public key fingerprint matches."""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding, PublicFormat as CryptoPublicFormat
from cryptography.x509.oid import NameOID
import datetime
import hashlib

key = ec.generate_private_key(ec.SECP256R1())
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1))
.sign(key, hashes.SHA256())
)
cert_der = cert.public_bytes(CryptoEncoding.DER)

# Compute the expected fingerprint the same way the code does
pub_der = key.public_key().public_bytes(
CryptoEncoding.DER, CryptoPublicFormat.SubjectPublicKeyInfo
)
expected_fp = hashlib.sha256(pub_der).hexdigest()

# Should not raise
SecureClient._verify_peer_fingerprint(cert_der, expected_fp)


class TestAsyncTLSPinning:
"""Tests that the async httpx client correctly pins TLS certificates via wrap_bio."""

FAKE_FP = "a" * 64

def _make_client_with_fake_verify(self):
"""Create a SecureClient that returns a fake fingerprint from verify()."""
client = SecureClient(enclave="test.enclave.sh", repo="test/repo")
ground_truth = MagicMock()
ground_truth.public_key = self.FAKE_FP
client.verify = MagicMock(return_value=ground_truth)
return client

def _get_ssl_context(self, async_http_client):
"""Extract the SSL context from an httpx.AsyncClient."""
return async_http_client._transport._pool._ssl_context

def _call_pinned_wrap_bio(self, ssl_ctx, fake_ssl_object):
"""
Call the patched wrap_bio on ssl_ctx, intercepting the real
original_wrap_bio so it returns our fake_ssl_object instead of
attempting a real SSL operation.

The pinned_wrap_bio closure holds a reference to original_wrap_bio
(the real ctx.wrap_bio at creation time). We patch the underlying
ctx._wrap_bio C method so that the call chain
pinned_wrap_bio -> original_wrap_bio -> SSLContext.wrap_bio -> _wrap_bio
returns our fake.
"""
with patch.object(ssl_ctx, '_wrap_bio', create=True) as mock_inner:
# ssl.SSLContext.wrap_bio calls sslobject_class._create which
# calls context._wrap_bio. We need to go one level deeper and
# patch SSLObject._create to just return our fake.
with patch('ssl.SSLObject._create', return_value=fake_ssl_object):
return ssl_ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO())

def test_wrap_bio_is_monkey_patched(self):
"""make_secure_async_http_client() must replace ctx.wrap_bio."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()

# The underlying SSL context's wrap_bio should no longer be the
# original C-level method — it should be our pinned_wrap_bio closure.
ssl_ctx = self._get_ssl_context(async_http)
assert ssl_ctx.wrap_bio is not ssl.SSLContext.wrap_bio

def test_do_handshake_verifies_fingerprint_match(self):
"""After handshake, matching fingerprint must not raise."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)

with patch.object(SecureClient, '_verify_peer_fingerprint') as mock_verify:
result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

# do_handshake should have been replaced with the checked version
result.do_handshake()

mock_verify.assert_called_once_with(
fake_ssl_object.getpeercert(binary_form=True),
self.FAKE_FP,
)

def test_do_handshake_rejects_fingerprint_mismatch(self):
"""After handshake, mismatched fingerprint must raise ValueError."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)
fake_ssl_object.getpeercert.return_value = b"fake_cert_bytes"

with patch.object(
SecureClient, '_verify_peer_fingerprint',
side_effect=ValueError("Certificate fingerprint mismatch"),
):
result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ValueError, match="Certificate fingerprint mismatch"):
result.do_handshake()

def test_do_handshake_rejects_missing_cert(self):
"""After handshake, missing peer cert must raise ValueError."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)
fake_ssl_object.getpeercert.return_value = None

result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ValueError, match="No certificate found"):
result.do_handshake()

def test_ssl_want_read_propagates_without_cert_check(self):
"""SSLWantReadError during handshake must propagate without checking cert."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantReadError())

result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ssl.SSLWantReadError):
result.do_handshake()

# getpeercert should NOT have been called — handshake isn't done yet
fake_ssl_object.getpeercert.assert_not_called()

def test_ssl_want_write_propagates_without_cert_check(self):
"""SSLWantWriteError during handshake must propagate without checking cert."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantWriteError())

result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ssl.SSLWantWriteError):
result.do_handshake()

fake_ssl_object.getpeercert.assert_not_called()


if __name__ == "__main__":
pytest.main([__file__, "-v"])