Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions src/tinfoil/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,18 @@ def ground_truth(self) -> Optional[GroundTruth]:
"""Returns the last verified enclave state"""
return self._ground_truth

def make_secure_http_client(self) -> httpx.Client:
"""
Build an httpx.Client that pins the enclave's TLS cert
"""
expected_fp = self.verify().public_key
wrap_socket = self._create_socket_wrapper(expected_fp)

ctx = ssl.create_default_context()
ctx.wrap_socket = wrap_socket
return httpx.Client(verify=ctx, follow_redirects=True)
@staticmethod
def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None:
Copy link

@cubic-dev-ai cubic-dev-ai bot Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: TLSBoundHTTPSHandler._get_connection (lines 55-66) contains the same fingerprint verification logic that _verify_peer_fingerprint now encapsulates. Consider updating _get_connection to call SecureClient._verify_peer_fingerprint(cert_binary, self.expected_pubkey) to eliminate the remaining duplication and ensure all verification paths stay in sync.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At src/tinfoil/client.py, line 100:

<comment>`TLSBoundHTTPSHandler._get_connection` (lines 55-66) contains the same fingerprint verification logic that `_verify_peer_fingerprint` now encapsulates. Consider updating `_get_connection` to call `SecureClient._verify_peer_fingerprint(cert_binary, self.expected_pubkey)` to eliminate the remaining duplication and ensure all verification paths stay in sync.</comment>

<file context>
@@ -96,16 +96,18 @@ def ground_truth(self) -> Optional[GroundTruth]:
-        ctx.wrap_socket = wrap_socket
-        return httpx.Client(verify=ctx, follow_redirects=True)
+    @staticmethod
+    def _verify_peer_fingerprint(cert_binary: Optional[bytes], expected_fp: str) -> None:
+        """Verify that a certificate's public key fingerprint matches the expected value."""
+        if not cert_binary:
</file context>
Fix with Cubic

"""Verify that a certificate's public key fingerprint matches the expected value."""
if not cert_binary:
raise ValueError("No certificate found")
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
pub_der = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo
)
pk_fp = hashlib.sha256(pub_der).hexdigest()
if pk_fp != expected_fp:
raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}")

def _create_socket_wrapper(self, expected_fp: str):
"""
Expand All @@ -114,28 +116,46 @@ def _create_socket_wrapper(self, expected_fp: str):
"""
def wrap_socket(*args, **kwargs) -> ssl.SSLSocket:
sock = ssl.create_default_context().wrap_socket(*args, **kwargs)
cert_binary = sock.getpeercert(binary_form=True)
if not cert_binary:
raise ValueError("No certificate found")
cert = cryptography.x509.load_der_x509_certificate(cert_binary)
pub_der = cert.public_key().public_bytes(
Encoding.DER, PublicFormat.SubjectPublicKeyInfo
SecureClient._verify_peer_fingerprint(
sock.getpeercert(binary_form=True), expected_fp
)
pk_fp = hashlib.sha256(pub_der).hexdigest()
if pk_fp != expected_fp:
raise ValueError(f"Certificate fingerprint mismatch: expected {expected_fp}, got {pk_fp}")
return sock
return wrap_socket

def make_secure_async_http_client(self) -> httpx.AsyncClient:
def make_secure_http_client(self) -> httpx.Client:
"""
Build an httpx.AsyncClient that pins the enclave's TLS cert.
Build an httpx.Client that pins the enclave's TLS cert
"""
expected_fp = self.verify().public_key
wrap_socket = self._create_socket_wrapper(expected_fp)

ctx = ssl.create_default_context()
ctx.wrap_socket = wrap_socket
return httpx.Client(verify=ctx, follow_redirects=True)

def make_secure_async_http_client(self) -> httpx.AsyncClient:
"""
Build an httpx.AsyncClient that pins the enclave's TLS cert
"""
expected_fp = self.verify().public_key

ctx = ssl.create_default_context()
original_wrap_bio = ctx.wrap_bio

def pinned_wrap_bio(*args, **kwargs):
ssl_object = original_wrap_bio(*args, **kwargs)
original_do_handshake = ssl_object.do_handshake

def checked_do_handshake():
result = original_do_handshake()
cert_binary = ssl_object.getpeercert(binary_form=True)
SecureClient._verify_peer_fingerprint(cert_binary, expected_fp)
return result

ssl_object.do_handshake = checked_do_handshake
return ssl_object

ctx.wrap_bio = pinned_wrap_bio
return httpx.AsyncClient(verify=ctx, follow_redirects=True)

def verify(self) -> GroundTruth:
Expand Down
65 changes: 65 additions & 0 deletions tests/test_attestation_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,70 @@ def test_secure_http_client():
print(f" TLS pinned to: {ground_truth.public_key}")


@pytest.mark.asyncio
async def test_secure_async_http_client():
"""
Tests that the async pinned client can connect to the enclave.
Mirrors test_secure_http_client for the async path.
"""
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")

client = SecureClient(enclave=enclave, repo=REPO)
http_client = client.make_secure_async_http_client()

ground_truth = client.ground_truth
assert ground_truth is not None

try:
response = await http_client.get(f"https://{enclave}/.well-known/tinfoil-attestation")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
finally:
await http_client.aclose()


def test_sync_pinned_client_rejects_wrong_host():
"""
A client pinned to the enclave's cert must reject connections
to a different host (whose cert won't match the pinned fingerprint).
"""
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")

client = SecureClient(enclave=enclave, repo=REPO)
http_client = client.make_secure_http_client()

try:
with pytest.raises(Exception):
Copy link

@cubic-dev-ai cubic-dev-ai bot Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P3: The async pinning rejection test is overly permissive; it will pass on any network failure. Assert the specific fingerprint-mismatch error so the test only passes when TLS pinning actually rejects the cert.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At tests/test_attestation_flow.py, line 134:

<comment>The async pinning rejection test is overly permissive; it will pass on any network failure. Assert the specific fingerprint-mismatch error so the test only passes when TLS pinning actually rejects the cert.</comment>

<file context>
@@ -93,5 +93,70 @@ def test_secure_http_client():
+    http_client = client.make_secure_http_client()
+
+    try:
+        with pytest.raises(Exception):
+            http_client.get("https://google.com")
+    finally:
</file context>
Fix with Cubic

http_client.get("https://google.com")
finally:
http_client.close()


@pytest.mark.asyncio
async def test_async_pinned_client_rejects_wrong_host():
"""
The async pinned client must reject connections to a host
whose cert doesn't match the pinned fingerprint.
"""
try:
enclave = get_router_address()
except Exception as e:
pytest.skip(f"Could not fetch router address from ATC service: {e}")

client = SecureClient(enclave=enclave, repo=REPO)
http_client = client.make_secure_async_http_client()

try:
with pytest.raises(Exception):
await http_client.get("https://google.com")
finally:
await http_client.aclose()


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
210 changes: 210 additions & 0 deletions tests/test_verification_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
connections to proceed.
"""

import ssl

import pytest
from unittest.mock import patch, MagicMock

Expand Down Expand Up @@ -294,5 +296,213 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch):
client.verify()


class TestVerifyPeerFingerprint:
"""Tests for SecureClient._verify_peer_fingerprint static method."""

def test_raises_on_none_cert(self):
"""Must raise ValueError when cert_binary is None."""
with pytest.raises(ValueError, match="No certificate found"):
SecureClient._verify_peer_fingerprint(None, "abc123")

def test_raises_on_empty_cert(self):
"""Must raise ValueError when cert_binary is empty bytes."""
with pytest.raises(ValueError, match="No certificate found"):
SecureClient._verify_peer_fingerprint(b"", "abc123")

def test_raises_on_fingerprint_mismatch(self):
"""Must raise ValueError when public key fingerprint doesn't match."""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.oid import NameOID
import datetime

# Generate a self-signed cert to get valid DER bytes
key = ec.generate_private_key(ec.SECP256R1())
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1))
.sign(key, hashes.SHA256())
)
from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding
cert_der = cert.public_bytes(CryptoEncoding.DER)

with pytest.raises(ValueError, match="Certificate fingerprint mismatch"):
SecureClient._verify_peer_fingerprint(cert_der, "wrong_fingerprint")

def test_passes_on_fingerprint_match(self):
"""Must not raise when public key fingerprint matches."""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import Encoding as CryptoEncoding, PublicFormat as CryptoPublicFormat
from cryptography.x509.oid import NameOID
import datetime
import hashlib

key = ec.generate_private_key(ec.SECP256R1())
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1))
.sign(key, hashes.SHA256())
)
cert_der = cert.public_bytes(CryptoEncoding.DER)

# Compute the expected fingerprint the same way the code does
pub_der = key.public_key().public_bytes(
CryptoEncoding.DER, CryptoPublicFormat.SubjectPublicKeyInfo
)
expected_fp = hashlib.sha256(pub_der).hexdigest()

# Should not raise
SecureClient._verify_peer_fingerprint(cert_der, expected_fp)


class TestAsyncTLSPinning:
"""Tests that the async httpx client correctly pins TLS certificates via wrap_bio."""

FAKE_FP = "a" * 64

def _make_client_with_fake_verify(self):
"""Create a SecureClient that returns a fake fingerprint from verify()."""
client = SecureClient(enclave="test.enclave.sh", repo="test/repo")
ground_truth = MagicMock()
ground_truth.public_key = self.FAKE_FP
client.verify = MagicMock(return_value=ground_truth)
return client

def _get_ssl_context(self, async_http_client):
"""Extract the SSL context from an httpx.AsyncClient."""
return async_http_client._transport._pool._ssl_context

def _call_pinned_wrap_bio(self, ssl_ctx, fake_ssl_object):
"""
Call the patched wrap_bio on ssl_ctx, intercepting the real
original_wrap_bio so it returns our fake_ssl_object instead of
attempting a real SSL operation.

The pinned_wrap_bio closure holds a reference to original_wrap_bio
(the real ctx.wrap_bio at creation time). We patch the underlying
ctx._wrap_bio C method so that the call chain
pinned_wrap_bio -> original_wrap_bio -> SSLContext.wrap_bio -> _wrap_bio
returns our fake.
"""
with patch.object(ssl_ctx, '_wrap_bio', create=True) as mock_inner:
# ssl.SSLContext.wrap_bio calls sslobject_class._create which
# calls context._wrap_bio. We need to go one level deeper and
# patch SSLObject._create to just return our fake.
with patch('ssl.SSLObject._create', return_value=fake_ssl_object):
return ssl_ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO())

def test_wrap_bio_is_monkey_patched(self):
"""make_secure_async_http_client() must replace ctx.wrap_bio."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()

# The underlying SSL context's wrap_bio should no longer be the
# original C-level method — it should be our pinned_wrap_bio closure.
ssl_ctx = self._get_ssl_context(async_http)
assert ssl_ctx.wrap_bio is not ssl.SSLContext.wrap_bio
Copy link

@cubic-dev-ai cubic-dev-ai bot Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: This assertion doesn’t prove wrap_bio was monkey-patched because bound method objects are never identical to the class function; the test will pass even when no patch occurred. Check for an instance-level override (e.g., presence in __dict__) or compare the underlying function to make the test meaningful.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At tests/test_verification_failures.py, line 417:

<comment>This assertion doesn’t prove wrap_bio was monkey-patched because bound method objects are never identical to the class function; the test will pass even when no patch occurred. Check for an instance-level override (e.g., presence in `__dict__`) or compare the underlying function to make the test meaningful.</comment>

<file context>
@@ -294,5 +296,213 @@ def test_snp_measurement_mismatch_raises(self, mock_fetch):
+        # The underlying SSL context's wrap_bio should no longer be the
+        # original C-level method — it should be our pinned_wrap_bio closure.
+        ssl_ctx = self._get_ssl_context(async_http)
+        assert ssl_ctx.wrap_bio is not ssl.SSLContext.wrap_bio
+
+    def test_do_handshake_verifies_fingerprint_match(self):
</file context>
Suggested change
assert ssl_ctx.wrap_bio is not ssl.SSLContext.wrap_bio
assert "wrap_bio" in ssl_ctx.__dict__
Fix with Cubic


def test_do_handshake_verifies_fingerprint_match(self):
"""After handshake, matching fingerprint must not raise."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)

with patch.object(SecureClient, '_verify_peer_fingerprint') as mock_verify:
result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

# do_handshake should have been replaced with the checked version
result.do_handshake()

mock_verify.assert_called_once_with(
fake_ssl_object.getpeercert(binary_form=True),
self.FAKE_FP,
)

def test_do_handshake_rejects_fingerprint_mismatch(self):
"""After handshake, mismatched fingerprint must raise ValueError."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)
fake_ssl_object.getpeercert.return_value = b"fake_cert_bytes"

with patch.object(
SecureClient, '_verify_peer_fingerprint',
side_effect=ValueError("Certificate fingerprint mismatch"),
):
result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ValueError, match="Certificate fingerprint mismatch"):
result.do_handshake()

def test_do_handshake_rejects_missing_cert(self):
"""After handshake, missing peer cert must raise ValueError."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(return_value=None)
fake_ssl_object.getpeercert.return_value = None

result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ValueError, match="No certificate found"):
result.do_handshake()

def test_ssl_want_read_propagates_without_cert_check(self):
"""SSLWantReadError during handshake must propagate without checking cert."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantReadError())

result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ssl.SSLWantReadError):
result.do_handshake()

# getpeercert should NOT have been called — handshake isn't done yet
fake_ssl_object.getpeercert.assert_not_called()

def test_ssl_want_write_propagates_without_cert_check(self):
"""SSLWantWriteError during handshake must propagate without checking cert."""
client = self._make_client_with_fake_verify()
async_http = client.make_secure_async_http_client()
ssl_ctx = self._get_ssl_context(async_http)

fake_ssl_object = MagicMock()
fake_ssl_object.do_handshake = MagicMock(side_effect=ssl.SSLWantWriteError())

result = self._call_pinned_wrap_bio(ssl_ctx, fake_ssl_object)

with pytest.raises(ssl.SSLWantWriteError):
result.do_handshake()

fake_ssl_object.getpeercert.assert_not_called()


if __name__ == "__main__":
pytest.main([__file__, "-v"])